Your Name fed361b9b3 feat: 平台规则从硬编码改为品牌方上传文档 + AI 解析
- 新增 PlatformRule 模型 (draft/active/inactive 状态流转)
- 新增文档解析服务 (PDF/Word/Excel → 纯文本)
- 新增 4 个 API: 解析/确认/查询/删除平台规则
- 脚本审核优先从 DB 读取 active 规则,硬编码兜底
- 视频审核合并平台规则违禁词到检测列表
- Alembic 迁移 006: platform_rules 表

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-10 13:23:11 +08:00

861 lines
24 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
规则管理 API
违禁词库、白名单、竞品库、平台规则
"""
import json
import logging
import uuid
from fastapi import APIRouter, Depends, Header, HTTPException, Query, status
from pydantic import BaseModel, Field
from typing import Optional
from sqlalchemy import select, and_
from sqlalchemy.ext.asyncio import AsyncSession
from app.database import get_db
from app.models.tenant import Tenant
from app.models.rule import ForbiddenWord, WhitelistItem, Competitor, PlatformRule, RuleStatus
from app.schemas.rules import (
PlatformRuleParseRequest,
PlatformRuleParseResponse,
PlatformRuleConfirmRequest,
PlatformRuleResponse as PlatformRuleDBResponse,
PlatformRuleListResponse as PlatformRuleDBListResponse,
ParsedRulesData,
)
from app.services.document_parser import DocumentParser
from app.services.ai_service import AIServiceFactory
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/rules", tags=["rules"])
# ==================== 请求/响应模型 ====================
class ForbiddenWordCreate(BaseModel):
word: str
category: str
severity: str
class ForbiddenWordResponse(BaseModel):
id: str
word: str
category: str
severity: str
class ForbiddenWordListResponse(BaseModel):
items: list[ForbiddenWordResponse]
total: int
class WhitelistCreate(BaseModel):
term: str
reason: str
brand_id: str
class WhitelistResponse(BaseModel):
id: str
term: str
reason: str
brand_id: str
class WhitelistListResponse(BaseModel):
items: list[WhitelistResponse]
total: int
class CompetitorCreate(BaseModel):
name: str
brand_id: str
logo_url: Optional[str] = None
keywords: list[str] = Field(default_factory=list)
class CompetitorResponse(BaseModel):
id: str
name: str
brand_id: str
logo_url: Optional[str] = None
keywords: list[str] = Field(default_factory=list)
class CompetitorListResponse(BaseModel):
items: list[CompetitorResponse]
total: int
class PlatformRuleResponse(BaseModel):
platform: str
rules: list[dict]
version: str
updated_at: str
class PlatformListResponse(BaseModel):
items: list[PlatformRuleResponse]
total: int
class RuleValidateRequest(BaseModel):
brand_id: str
platform: str
brief_rules: dict
class RuleConflict(BaseModel):
brief_rule: str
platform_rule: str
suggestion: str
class RuleValidateResponse(BaseModel):
conflicts: list[RuleConflict]
# ==================== 预置平台规则 ====================
_platform_rules = {
"douyin": {
"platform": "douyin",
"rules": [
{"type": "forbidden_word", "words": ["最好", "第一", "最佳", "绝对", "100%"]},
{"type": "duration", "min_seconds": 7},
],
"version": "2024.01",
"updated_at": "2024-01-15T00:00:00Z",
},
"xiaohongshu": {
"platform": "xiaohongshu",
"rules": [
{"type": "forbidden_word", "words": ["最好", "绝对", "100%"]},
],
"version": "2024.01",
"updated_at": "2024-01-10T00:00:00Z",
},
"bilibili": {
"platform": "bilibili",
"rules": [
{"type": "forbidden_word", "words": ["最好", "第一"]},
],
"version": "2024.01",
"updated_at": "2024-01-12T00:00:00Z",
},
}
# ==================== 辅助函数 ====================
async def _ensure_tenant_exists(tenant_id: str, db: AsyncSession) -> Tenant:
"""确保租户存在,不存在则自动创建"""
result = await db.execute(
select(Tenant).where(Tenant.id == tenant_id)
)
tenant = result.scalar_one_or_none()
if not tenant:
tenant = Tenant(id=tenant_id, name=f"租户-{tenant_id}")
db.add(tenant)
await db.flush()
return tenant
# ==================== 违禁词库 ====================
@router.get("/forbidden-words", response_model=ForbiddenWordListResponse)
async def list_forbidden_words(
category: str = None,
x_tenant_id: str = Header(..., alias="X-Tenant-ID"),
db: AsyncSession = Depends(get_db),
) -> ForbiddenWordListResponse:
"""查询违禁词列表"""
query = select(ForbiddenWord).where(ForbiddenWord.tenant_id == x_tenant_id)
if category:
query = query.where(ForbiddenWord.category == category)
result = await db.execute(query)
words = result.scalars().all()
return ForbiddenWordListResponse(
items=[
ForbiddenWordResponse(
id=w.id,
word=w.word,
category=w.category,
severity=w.severity,
)
for w in words
],
total=len(words),
)
@router.post(
"/forbidden-words",
response_model=ForbiddenWordResponse,
status_code=status.HTTP_201_CREATED,
)
async def add_forbidden_word(
request: ForbiddenWordCreate,
x_tenant_id: str = Header(..., alias="X-Tenant-ID"),
db: AsyncSession = Depends(get_db),
) -> ForbiddenWordResponse:
"""添加违禁词"""
# 确保租户存在
await _ensure_tenant_exists(x_tenant_id, db)
# 检查重复
result = await db.execute(
select(ForbiddenWord).where(
and_(
ForbiddenWord.tenant_id == x_tenant_id,
ForbiddenWord.word == request.word,
)
)
)
existing = result.scalar_one_or_none()
if existing:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail=f"违禁词已存在: {request.word}",
)
word_id = f"fw-{uuid.uuid4().hex[:8]}"
word = ForbiddenWord(
id=word_id,
tenant_id=x_tenant_id,
word=request.word,
category=request.category,
severity=request.severity,
)
db.add(word)
await db.flush()
return ForbiddenWordResponse(
id=word.id,
word=word.word,
category=word.category,
severity=word.severity,
)
@router.delete("/forbidden-words/{word_id}", status_code=status.HTTP_204_NO_CONTENT)
async def delete_forbidden_word(
word_id: str,
x_tenant_id: str = Header(..., alias="X-Tenant-ID"),
db: AsyncSession = Depends(get_db),
):
"""删除违禁词"""
result = await db.execute(
select(ForbiddenWord).where(
and_(
ForbiddenWord.id == word_id,
ForbiddenWord.tenant_id == x_tenant_id,
)
)
)
word = result.scalar_one_or_none()
if not word:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"违禁词不存在: {word_id}",
)
await db.delete(word)
await db.flush()
# ==================== 白名单 ====================
@router.get("/whitelist", response_model=WhitelistListResponse)
async def list_whitelist(
brand_id: str = None,
x_tenant_id: str = Header(..., alias="X-Tenant-ID"),
db: AsyncSession = Depends(get_db),
) -> WhitelistListResponse:
"""查询白名单"""
query = select(WhitelistItem).where(WhitelistItem.tenant_id == x_tenant_id)
if brand_id:
query = query.where(WhitelistItem.brand_id == brand_id)
result = await db.execute(query)
items = result.scalars().all()
return WhitelistListResponse(
items=[
WhitelistResponse(
id=item.id,
term=item.term,
reason=item.reason,
brand_id=item.brand_id,
)
for item in items
],
total=len(items),
)
@router.post(
"/whitelist",
response_model=WhitelistResponse,
status_code=status.HTTP_201_CREATED,
)
async def add_to_whitelist(
request: WhitelistCreate,
x_tenant_id: str = Header(..., alias="X-Tenant-ID"),
db: AsyncSession = Depends(get_db),
) -> WhitelistResponse:
"""添加白名单"""
# 确保租户存在
await _ensure_tenant_exists(x_tenant_id, db)
item_id = f"wl-{uuid.uuid4().hex[:8]}"
item = WhitelistItem(
id=item_id,
tenant_id=x_tenant_id,
brand_id=request.brand_id,
term=request.term,
reason=request.reason,
)
db.add(item)
await db.flush()
return WhitelistResponse(
id=item.id,
term=item.term,
reason=item.reason,
brand_id=item.brand_id,
)
# ==================== 竞品库 ====================
@router.get("/competitors", response_model=CompetitorListResponse)
async def list_competitors(
brand_id: str = None,
x_tenant_id: str = Header(..., alias="X-Tenant-ID"),
db: AsyncSession = Depends(get_db),
) -> CompetitorListResponse:
"""查询竞品列表"""
query = select(Competitor).where(Competitor.tenant_id == x_tenant_id)
if brand_id:
query = query.where(Competitor.brand_id == brand_id)
result = await db.execute(query)
competitors = result.scalars().all()
return CompetitorListResponse(
items=[
CompetitorResponse(
id=c.id,
name=c.name,
brand_id=c.brand_id,
logo_url=c.logo_url,
keywords=c.keywords or [],
)
for c in competitors
],
total=len(competitors),
)
@router.post(
"/competitors",
response_model=CompetitorResponse,
status_code=status.HTTP_201_CREATED,
)
async def add_competitor(
request: CompetitorCreate,
x_tenant_id: str = Header(..., alias="X-Tenant-ID"),
db: AsyncSession = Depends(get_db),
) -> CompetitorResponse:
"""添加竞品"""
# 确保租户存在
await _ensure_tenant_exists(x_tenant_id, db)
comp_id = f"comp-{uuid.uuid4().hex[:8]}"
competitor = Competitor(
id=comp_id,
tenant_id=x_tenant_id,
brand_id=request.brand_id,
name=request.name,
logo_url=request.logo_url,
keywords=request.keywords,
)
db.add(competitor)
await db.flush()
return CompetitorResponse(
id=competitor.id,
name=competitor.name,
brand_id=competitor.brand_id,
logo_url=competitor.logo_url,
keywords=competitor.keywords or [],
)
@router.delete("/competitors/{competitor_id}", status_code=status.HTTP_204_NO_CONTENT)
async def delete_competitor(
competitor_id: str,
x_tenant_id: str = Header(..., alias="X-Tenant-ID"),
db: AsyncSession = Depends(get_db),
):
"""删除竞品"""
result = await db.execute(
select(Competitor).where(
and_(
Competitor.id == competitor_id,
Competitor.tenant_id == x_tenant_id,
)
)
)
competitor = result.scalar_one_or_none()
if not competitor:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"竞品不存在: {competitor_id}",
)
await db.delete(competitor)
await db.flush()
# ==================== 平台规则 ====================
@router.get("/platforms", response_model=PlatformListResponse)
async def list_platform_rules() -> PlatformListResponse:
"""查询所有平台规则"""
return PlatformListResponse(
items=[PlatformRuleResponse(**r) for r in _platform_rules.values()],
total=len(_platform_rules),
)
@router.get("/platforms/{platform}", response_model=PlatformRuleResponse)
async def get_platform_rules(platform: str) -> PlatformRuleResponse:
"""查询指定平台规则"""
if platform not in _platform_rules:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"平台不存在: {platform}",
)
return PlatformRuleResponse(**_platform_rules[platform])
# ==================== 规则冲突检测 ====================
@router.post("/validate", response_model=RuleValidateResponse)
async def validate_rules(request: RuleValidateRequest) -> RuleValidateResponse:
"""检测 Brief 与平台规则冲突"""
conflicts = []
platform_rule = _platform_rules.get(request.platform)
if not platform_rule:
return RuleValidateResponse(conflicts=[])
# 检查 required_phrases 是否包含违禁词
required_phrases = request.brief_rules.get("required_phrases", [])
platform_forbidden = []
for rule in platform_rule.get("rules", []):
if rule.get("type") == "forbidden_word":
platform_forbidden.extend(rule.get("words", []))
for phrase in required_phrases:
for word in platform_forbidden:
if word in phrase:
conflicts.append(RuleConflict(
brief_rule=f"要求使用:{phrase}",
platform_rule=f"平台禁止:{word}",
suggestion=f"Brief 要求的 '{phrase}' 包含平台违禁词 '{word}',建议修改",
))
return RuleValidateResponse(conflicts=conflicts)
# ==================== 品牌方平台规则(文档上传 + AI 解析) ====================
def _format_platform_rule(rule: PlatformRule) -> PlatformRuleDBResponse:
"""将 ORM 对象转为响应 Schema"""
return PlatformRuleDBResponse(
id=rule.id,
platform=rule.platform,
brand_id=rule.brand_id,
document_url=rule.document_url,
document_name=rule.document_name,
parsed_rules=ParsedRulesData(**(rule.parsed_rules or {})),
status=rule.status,
created_at=rule.created_at.isoformat() if rule.created_at else "",
updated_at=rule.updated_at.isoformat() if rule.updated_at else "",
)
@router.post(
"/platform-rules/parse",
response_model=PlatformRuleParseResponse,
status_code=status.HTTP_201_CREATED,
)
async def parse_platform_rule_document(
request: PlatformRuleParseRequest,
x_tenant_id: str = Header(..., alias="X-Tenant-ID"),
db: AsyncSession = Depends(get_db),
) -> PlatformRuleParseResponse:
"""
上传文档并通过 AI 解析平台规则
流程:
1. 下载文档
2. 提取纯文本
3. AI 解析出结构化规则
4. 存入 DB (status=draft)
5. 返回解析结果供品牌方确认
"""
await _ensure_tenant_exists(x_tenant_id, db)
# 1. 下载并解析文档
try:
document_text = await DocumentParser.download_and_parse(
request.document_url, request.document_name,
)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
logger.error(f"文档解析失败: {e}")
raise HTTPException(status_code=400, detail=f"文档下载或解析失败: {e}")
if not document_text.strip():
raise HTTPException(status_code=400, detail="文档内容为空,无法解析")
# 2. AI 解析
parsed_rules = await _ai_parse_platform_rules(x_tenant_id, request.platform, document_text, db)
# 3. 存入 DB (draft)
rule_id = f"pr-{uuid.uuid4().hex[:8]}"
rule = PlatformRule(
id=rule_id,
tenant_id=x_tenant_id,
brand_id=request.brand_id,
platform=request.platform,
document_url=request.document_url,
document_name=request.document_name,
parsed_rules=parsed_rules,
status=RuleStatus.DRAFT.value,
)
db.add(rule)
await db.flush()
return PlatformRuleParseResponse(
id=rule.id,
platform=rule.platform,
brand_id=rule.brand_id,
document_url=rule.document_url,
document_name=rule.document_name,
parsed_rules=ParsedRulesData(**parsed_rules),
status=rule.status,
)
@router.put(
"/platform-rules/{rule_id}/confirm",
response_model=PlatformRuleDBResponse,
)
async def confirm_platform_rule(
rule_id: str,
request: PlatformRuleConfirmRequest,
x_tenant_id: str = Header(..., alias="X-Tenant-ID"),
db: AsyncSession = Depends(get_db),
) -> PlatformRuleDBResponse:
"""
确认/编辑平台规则解析结果
将 draft 状态的规则设为 active同时将同 (tenant_id, brand_id, platform) 下
已有的 active 规则设为 inactive。
"""
result = await db.execute(
select(PlatformRule).where(
and_(
PlatformRule.id == rule_id,
PlatformRule.tenant_id == x_tenant_id,
)
)
)
rule = result.scalar_one_or_none()
if not rule:
raise HTTPException(status_code=404, detail=f"规则不存在: {rule_id}")
# 将同 (tenant_id, brand_id, platform) 下已有的 active 规则设为 inactive
existing_active = await db.execute(
select(PlatformRule).where(
and_(
PlatformRule.tenant_id == x_tenant_id,
PlatformRule.brand_id == rule.brand_id,
PlatformRule.platform == rule.platform,
PlatformRule.status == RuleStatus.ACTIVE.value,
PlatformRule.id != rule_id,
)
)
)
for old_rule in existing_active.scalars().all():
old_rule.status = RuleStatus.INACTIVE.value
# 更新当前规则
rule.parsed_rules = request.parsed_rules.model_dump()
rule.status = RuleStatus.ACTIVE.value
await db.flush()
return _format_platform_rule(rule)
@router.get(
"/platform-rules",
response_model=PlatformRuleDBListResponse,
)
async def list_brand_platform_rules(
brand_id: Optional[str] = Query(None),
platform: Optional[str] = Query(None),
rule_status: Optional[str] = Query(None, alias="status"),
x_tenant_id: str = Header(..., alias="X-Tenant-ID"),
db: AsyncSession = Depends(get_db),
) -> PlatformRuleDBListResponse:
"""查询品牌方的平台规则列表"""
query = select(PlatformRule).where(PlatformRule.tenant_id == x_tenant_id)
if brand_id:
query = query.where(PlatformRule.brand_id == brand_id)
if platform:
query = query.where(PlatformRule.platform == platform)
if rule_status:
query = query.where(PlatformRule.status == rule_status)
result = await db.execute(query.order_by(PlatformRule.created_at.desc()))
rules = result.scalars().all()
return PlatformRuleDBListResponse(
items=[_format_platform_rule(r) for r in rules],
total=len(rules),
)
@router.delete(
"/platform-rules/{rule_id}",
status_code=status.HTTP_204_NO_CONTENT,
)
async def delete_platform_rule(
rule_id: str,
x_tenant_id: str = Header(..., alias="X-Tenant-ID"),
db: AsyncSession = Depends(get_db),
):
"""删除平台规则"""
result = await db.execute(
select(PlatformRule).where(
and_(
PlatformRule.id == rule_id,
PlatformRule.tenant_id == x_tenant_id,
)
)
)
rule = result.scalar_one_or_none()
if not rule:
raise HTTPException(status_code=404, detail=f"规则不存在: {rule_id}")
await db.delete(rule)
await db.flush()
async def _ai_parse_platform_rules(
tenant_id: str,
platform: str,
document_text: str,
db: AsyncSession,
) -> dict:
"""
使用 AI 将文档文本解析为结构化平台规则
AI 失败时返回空规则结构(降级为手动编辑)
"""
try:
ai_client = await AIServiceFactory.get_client(tenant_id, db)
if not ai_client:
logger.warning(f"租户 {tenant_id} 未配置 AI 服务,返回空规则")
return _empty_parsed_rules()
config = await AIServiceFactory.get_config(tenant_id, db)
if not config:
return _empty_parsed_rules()
text_model = config.models.get("text", "gpt-4o")
# 截断过长文本(避免超出 token 限制)
max_chars = 15000
if len(document_text) > max_chars:
document_text = document_text[:max_chars] + "\n...(文档内容已截断)"
prompt = f"""你是平台广告合规规则分析专家。请从以下 {platform} 平台规则文档中提取结构化规则。
文档内容:
{document_text}
请以 JSON 格式返回,不要包含其他内容:
{{
"forbidden_words": ["违禁词1", "违禁词2"],
"restricted_words": [{{"word": "xx", "condition": "使用条件", "suggestion": "替换建议"}}],
"duration": {{"min_seconds": 7, "max_seconds": null}},
"content_requirements": ["必须展示产品正面", "需要口播品牌名"],
"other_rules": [{{"rule": "规则名称", "description": "详细说明"}}]
}}
注意:
- forbidden_words: 明确禁止使用的词语
- restricted_words: 有条件限制的词语
- duration: 视频时长要求,如果文档未提及则为 null
- content_requirements: 内容上的硬性要求
- other_rules: 不属于以上分类的其他规则
- 如果某项没有提取到内容,使用空数组或 null"""
response = await ai_client.chat_completion(
messages=[{"role": "user", "content": prompt}],
model=text_model,
temperature=0.2,
max_tokens=2000,
)
# 解析 AI 响应
content = response.content.strip()
if content.startswith("```"):
content = content.split("\n", 1)[1]
if content.endswith("```"):
content = content.rsplit("\n", 1)[0]
parsed = json.loads(content)
# 校验并补全字段
return {
"forbidden_words": parsed.get("forbidden_words", []),
"restricted_words": parsed.get("restricted_words", []),
"duration": parsed.get("duration"),
"content_requirements": parsed.get("content_requirements", []),
"other_rules": parsed.get("other_rules", []),
}
except json.JSONDecodeError:
logger.warning("AI 返回内容非 JSON降级为空规则")
return _empty_parsed_rules()
except Exception as e:
logger.error(f"AI 解析平台规则失败: {e}")
return _empty_parsed_rules()
def _empty_parsed_rules() -> dict:
"""返回空的解析规则结构"""
return {
"forbidden_words": [],
"restricted_words": [],
"duration": None,
"content_requirements": [],
"other_rules": [],
}
# ==================== 辅助函数(供其他模块调用) ====================
async def get_whitelist_for_brand(
tenant_id: str,
brand_id: str,
db: AsyncSession,
) -> list[str]:
"""获取品牌白名单词汇"""
result = await db.execute(
select(WhitelistItem).where(
and_(
WhitelistItem.tenant_id == tenant_id,
WhitelistItem.brand_id == brand_id,
)
)
)
items = result.scalars().all()
return [item.term for item in items]
async def get_other_brands_whitelist_terms(
tenant_id: str,
brand_id: str,
db: AsyncSession,
) -> list[tuple[str, str]]:
"""
获取其他品牌的白名单词汇(用于品牌安全检测)
Returns:
list of (term, owner_brand_id)
"""
result = await db.execute(
select(WhitelistItem).where(
and_(
WhitelistItem.tenant_id == tenant_id,
WhitelistItem.brand_id != brand_id,
)
)
)
items = result.scalars().all()
return [(item.term, item.brand_id) for item in items]
async def get_forbidden_words_for_tenant(
tenant_id: str,
db: AsyncSession,
category: str = None,
) -> list[dict]:
"""获取租户的违禁词列表"""
query = select(ForbiddenWord).where(ForbiddenWord.tenant_id == tenant_id)
if category:
query = query.where(ForbiddenWord.category == category)
result = await db.execute(query)
words = result.scalars().all()
return [
{
"id": w.id,
"word": w.word,
"category": w.category,
"severity": w.severity,
}
for w in words
]
async def get_active_platform_rules(
tenant_id: str,
brand_id: str,
platform: str,
db: AsyncSession,
) -> Optional[dict]:
"""
获取品牌方在该平台的生效规则 (active)
Returns:
parsed_rules dict 或 None没有上传规则时
"""
result = await db.execute(
select(PlatformRule).where(
and_(
PlatformRule.tenant_id == tenant_id,
PlatformRule.brand_id == brand_id,
PlatformRule.platform == platform,
PlatformRule.status == RuleStatus.ACTIVE.value,
)
)
)
rule = result.scalar_one_or_none()
if not rule:
return None
return rule.parsed_rules