""" 规则管理 API 违禁词库、白名单、竞品库、平台规则 """ import uuid from fastapi import APIRouter, Depends, Header, HTTPException, status from pydantic import BaseModel, Field from typing import Optional from sqlalchemy import select, and_ from sqlalchemy.ext.asyncio import AsyncSession from app.database import get_db from app.models.tenant import Tenant from app.models.rule import ForbiddenWord, WhitelistItem, Competitor router = APIRouter(prefix="/rules", tags=["rules"]) # ==================== 请求/响应模型 ==================== class ForbiddenWordCreate(BaseModel): word: str category: str severity: str class ForbiddenWordResponse(BaseModel): id: str word: str category: str severity: str class ForbiddenWordListResponse(BaseModel): items: list[ForbiddenWordResponse] total: int class WhitelistCreate(BaseModel): term: str reason: str brand_id: str class WhitelistResponse(BaseModel): id: str term: str reason: str brand_id: str class WhitelistListResponse(BaseModel): items: list[WhitelistResponse] total: int class CompetitorCreate(BaseModel): name: str brand_id: str logo_url: Optional[str] = None keywords: list[str] = Field(default_factory=list) class CompetitorResponse(BaseModel): id: str name: str brand_id: str logo_url: Optional[str] = None keywords: list[str] = Field(default_factory=list) class CompetitorListResponse(BaseModel): items: list[CompetitorResponse] total: int class PlatformRuleResponse(BaseModel): platform: str rules: list[dict] version: str updated_at: str class PlatformListResponse(BaseModel): items: list[PlatformRuleResponse] total: int class RuleValidateRequest(BaseModel): brand_id: str platform: str brief_rules: dict class RuleConflict(BaseModel): brief_rule: str platform_rule: str suggestion: str class RuleValidateResponse(BaseModel): conflicts: list[RuleConflict] # ==================== 预置平台规则 ==================== _platform_rules = { "douyin": { "platform": "douyin", "rules": [ {"type": "forbidden_word", "words": ["最好", "第一", "最佳", "绝对", "100%"]}, {"type": "duration", "min_seconds": 7}, ], "version": "2024.01", "updated_at": "2024-01-15T00:00:00Z", }, "xiaohongshu": { "platform": "xiaohongshu", "rules": [ {"type": "forbidden_word", "words": ["最好", "绝对", "100%"]}, ], "version": "2024.01", "updated_at": "2024-01-10T00:00:00Z", }, "bilibili": { "platform": "bilibili", "rules": [ {"type": "forbidden_word", "words": ["最好", "第一"]}, ], "version": "2024.01", "updated_at": "2024-01-12T00:00:00Z", }, } # ==================== 辅助函数 ==================== async def _ensure_tenant_exists(tenant_id: str, db: AsyncSession) -> Tenant: """确保租户存在,不存在则自动创建""" result = await db.execute( select(Tenant).where(Tenant.id == tenant_id) ) tenant = result.scalar_one_or_none() if not tenant: tenant = Tenant(id=tenant_id, name=f"租户-{tenant_id}") db.add(tenant) await db.flush() return tenant # ==================== 违禁词库 ==================== @router.get("/forbidden-words", response_model=ForbiddenWordListResponse) async def list_forbidden_words( category: str = None, x_tenant_id: str = Header(..., alias="X-Tenant-ID"), db: AsyncSession = Depends(get_db), ) -> ForbiddenWordListResponse: """查询违禁词列表""" query = select(ForbiddenWord).where(ForbiddenWord.tenant_id == x_tenant_id) if category: query = query.where(ForbiddenWord.category == category) result = await db.execute(query) words = result.scalars().all() return ForbiddenWordListResponse( items=[ ForbiddenWordResponse( id=w.id, word=w.word, category=w.category, severity=w.severity, ) for w in words ], total=len(words), ) @router.post( "/forbidden-words", response_model=ForbiddenWordResponse, status_code=status.HTTP_201_CREATED, ) async def add_forbidden_word( request: ForbiddenWordCreate, x_tenant_id: str = Header(..., alias="X-Tenant-ID"), db: AsyncSession = Depends(get_db), ) -> ForbiddenWordResponse: """添加违禁词""" # 确保租户存在 await _ensure_tenant_exists(x_tenant_id, db) # 检查重复 result = await db.execute( select(ForbiddenWord).where( and_( ForbiddenWord.tenant_id == x_tenant_id, ForbiddenWord.word == request.word, ) ) ) existing = result.scalar_one_or_none() if existing: raise HTTPException( status_code=status.HTTP_409_CONFLICT, detail=f"违禁词已存在: {request.word}", ) word_id = f"fw-{uuid.uuid4().hex[:8]}" word = ForbiddenWord( id=word_id, tenant_id=x_tenant_id, word=request.word, category=request.category, severity=request.severity, ) db.add(word) await db.flush() return ForbiddenWordResponse( id=word.id, word=word.word, category=word.category, severity=word.severity, ) @router.delete("/forbidden-words/{word_id}", status_code=status.HTTP_204_NO_CONTENT) async def delete_forbidden_word( word_id: str, x_tenant_id: str = Header(..., alias="X-Tenant-ID"), db: AsyncSession = Depends(get_db), ): """删除违禁词""" result = await db.execute( select(ForbiddenWord).where( and_( ForbiddenWord.id == word_id, ForbiddenWord.tenant_id == x_tenant_id, ) ) ) word = result.scalar_one_or_none() if not word: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=f"违禁词不存在: {word_id}", ) await db.delete(word) await db.flush() # ==================== 白名单 ==================== @router.get("/whitelist", response_model=WhitelistListResponse) async def list_whitelist( brand_id: str = None, x_tenant_id: str = Header(..., alias="X-Tenant-ID"), db: AsyncSession = Depends(get_db), ) -> WhitelistListResponse: """查询白名单""" query = select(WhitelistItem).where(WhitelistItem.tenant_id == x_tenant_id) if brand_id: query = query.where(WhitelistItem.brand_id == brand_id) result = await db.execute(query) items = result.scalars().all() return WhitelistListResponse( items=[ WhitelistResponse( id=item.id, term=item.term, reason=item.reason, brand_id=item.brand_id, ) for item in items ], total=len(items), ) @router.post( "/whitelist", response_model=WhitelistResponse, status_code=status.HTTP_201_CREATED, ) async def add_to_whitelist( request: WhitelistCreate, x_tenant_id: str = Header(..., alias="X-Tenant-ID"), db: AsyncSession = Depends(get_db), ) -> WhitelistResponse: """添加白名单""" # 确保租户存在 await _ensure_tenant_exists(x_tenant_id, db) item_id = f"wl-{uuid.uuid4().hex[:8]}" item = WhitelistItem( id=item_id, tenant_id=x_tenant_id, brand_id=request.brand_id, term=request.term, reason=request.reason, ) db.add(item) await db.flush() return WhitelistResponse( id=item.id, term=item.term, reason=item.reason, brand_id=item.brand_id, ) # ==================== 竞品库 ==================== @router.get("/competitors", response_model=CompetitorListResponse) async def list_competitors( brand_id: str = None, x_tenant_id: str = Header(..., alias="X-Tenant-ID"), db: AsyncSession = Depends(get_db), ) -> CompetitorListResponse: """查询竞品列表""" query = select(Competitor).where(Competitor.tenant_id == x_tenant_id) if brand_id: query = query.where(Competitor.brand_id == brand_id) result = await db.execute(query) competitors = result.scalars().all() return CompetitorListResponse( items=[ CompetitorResponse( id=c.id, name=c.name, brand_id=c.brand_id, logo_url=c.logo_url, keywords=c.keywords or [], ) for c in competitors ], total=len(competitors), ) @router.post( "/competitors", response_model=CompetitorResponse, status_code=status.HTTP_201_CREATED, ) async def add_competitor( request: CompetitorCreate, x_tenant_id: str = Header(..., alias="X-Tenant-ID"), db: AsyncSession = Depends(get_db), ) -> CompetitorResponse: """添加竞品""" # 确保租户存在 await _ensure_tenant_exists(x_tenant_id, db) comp_id = f"comp-{uuid.uuid4().hex[:8]}" competitor = Competitor( id=comp_id, tenant_id=x_tenant_id, brand_id=request.brand_id, name=request.name, logo_url=request.logo_url, keywords=request.keywords, ) db.add(competitor) await db.flush() return CompetitorResponse( id=competitor.id, name=competitor.name, brand_id=competitor.brand_id, logo_url=competitor.logo_url, keywords=competitor.keywords or [], ) @router.delete("/competitors/{competitor_id}", status_code=status.HTTP_204_NO_CONTENT) async def delete_competitor( competitor_id: str, x_tenant_id: str = Header(..., alias="X-Tenant-ID"), db: AsyncSession = Depends(get_db), ): """删除竞品""" result = await db.execute( select(Competitor).where( and_( Competitor.id == competitor_id, Competitor.tenant_id == x_tenant_id, ) ) ) competitor = result.scalar_one_or_none() if not competitor: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=f"竞品不存在: {competitor_id}", ) await db.delete(competitor) await db.flush() # ==================== 平台规则 ==================== @router.get("/platforms", response_model=PlatformListResponse) async def list_platform_rules() -> PlatformListResponse: """查询所有平台规则""" return PlatformListResponse( items=[PlatformRuleResponse(**r) for r in _platform_rules.values()], total=len(_platform_rules), ) @router.get("/platforms/{platform}", response_model=PlatformRuleResponse) async def get_platform_rules(platform: str) -> PlatformRuleResponse: """查询指定平台规则""" if platform not in _platform_rules: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=f"平台不存在: {platform}", ) return PlatformRuleResponse(**_platform_rules[platform]) # ==================== 规则冲突检测 ==================== @router.post("/validate", response_model=RuleValidateResponse) async def validate_rules(request: RuleValidateRequest) -> RuleValidateResponse: """检测 Brief 与平台规则冲突""" conflicts = [] platform_rule = _platform_rules.get(request.platform) if not platform_rule: return RuleValidateResponse(conflicts=[]) # 检查 required_phrases 是否包含违禁词 required_phrases = request.brief_rules.get("required_phrases", []) platform_forbidden = [] for rule in platform_rule.get("rules", []): if rule.get("type") == "forbidden_word": platform_forbidden.extend(rule.get("words", [])) for phrase in required_phrases: for word in platform_forbidden: if word in phrase: conflicts.append(RuleConflict( brief_rule=f"要求使用:{phrase}", platform_rule=f"平台禁止:{word}", suggestion=f"Brief 要求的 '{phrase}' 包含平台违禁词 '{word}',建议修改", )) return RuleValidateResponse(conflicts=conflicts) # ==================== 辅助函数(供其他模块调用) ==================== async def get_whitelist_for_brand( tenant_id: str, brand_id: str, db: AsyncSession, ) -> list[str]: """获取品牌白名单词汇""" result = await db.execute( select(WhitelistItem).where( and_( WhitelistItem.tenant_id == tenant_id, WhitelistItem.brand_id == brand_id, ) ) ) items = result.scalars().all() return [item.term for item in items] async def get_other_brands_whitelist_terms( tenant_id: str, brand_id: str, db: AsyncSession, ) -> list[tuple[str, str]]: """ 获取其他品牌的白名单词汇(用于品牌安全检测) Returns: list of (term, owner_brand_id) """ result = await db.execute( select(WhitelistItem).where( and_( WhitelistItem.tenant_id == tenant_id, WhitelistItem.brand_id != brand_id, ) ) ) items = result.scalars().all() return [(item.term, item.brand_id) for item in items] async def get_forbidden_words_for_tenant( tenant_id: str, db: AsyncSession, category: str = None, ) -> list[dict]: """获取租户的违禁词列表""" query = select(ForbiddenWord).where(ForbiddenWord.tenant_id == tenant_id) if category: query = query.where(ForbiddenWord.category == category) result = await db.execute(query) words = result.scalars().all() return [ { "id": w.id, "word": w.word, "category": w.category, "severity": w.severity, } for w in words ]