- Brief 支持代理商附件上传 (迁移 007) - 项目新增 platform 字段 (迁移 008),前端创建/展示平台信息 - 修复 AI 规则解析:处理中文引号导致 JSON 解析失败的问题 - 修复消息中心崩溃:补全后端消息类型映射 + fallback 保护 - 项目创建时自动发送消息通知 - .gitignore 排除 backend/data/ 数据库文件 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1041 lines
31 KiB
Python
1041 lines
31 KiB
Python
"""
|
||
规则管理 API
|
||
违禁词库、白名单、竞品库、平台规则
|
||
"""
|
||
import json
|
||
import logging
|
||
import uuid
|
||
from fastapi import APIRouter, Depends, Header, HTTPException, Query, status
|
||
from pydantic import BaseModel, Field
|
||
from typing import Optional
|
||
from sqlalchemy import select, and_
|
||
from sqlalchemy.ext.asyncio import AsyncSession
|
||
|
||
from app.database import get_db
|
||
from app.models.tenant import Tenant
|
||
from app.models.rule import ForbiddenWord, WhitelistItem, Competitor, PlatformRule, RuleStatus
|
||
from app.schemas.rules import (
|
||
PlatformRuleParseRequest,
|
||
PlatformRuleParseResponse,
|
||
PlatformRuleConfirmRequest,
|
||
PlatformRuleResponse as PlatformRuleDBResponse,
|
||
PlatformRuleListResponse as PlatformRuleDBListResponse,
|
||
ParsedRulesData,
|
||
)
|
||
from app.services.document_parser import DocumentParser
|
||
from app.services.ai_service import AIServiceFactory
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
router = APIRouter(prefix="/rules", tags=["rules"])
|
||
|
||
|
||
# ==================== 请求/响应模型 ====================
|
||
|
||
class ForbiddenWordCreate(BaseModel):
|
||
word: str
|
||
category: str
|
||
severity: str
|
||
|
||
|
||
class ForbiddenWordResponse(BaseModel):
|
||
id: str
|
||
word: str
|
||
category: str
|
||
severity: str
|
||
|
||
|
||
class ForbiddenWordListResponse(BaseModel):
|
||
items: list[ForbiddenWordResponse]
|
||
total: int
|
||
|
||
|
||
class WhitelistCreate(BaseModel):
|
||
term: str
|
||
reason: str
|
||
brand_id: str
|
||
|
||
|
||
class WhitelistResponse(BaseModel):
|
||
id: str
|
||
term: str
|
||
reason: str
|
||
brand_id: str
|
||
|
||
|
||
class WhitelistListResponse(BaseModel):
|
||
items: list[WhitelistResponse]
|
||
total: int
|
||
|
||
|
||
class CompetitorCreate(BaseModel):
|
||
name: str
|
||
brand_id: str
|
||
logo_url: Optional[str] = None
|
||
keywords: list[str] = Field(default_factory=list)
|
||
|
||
|
||
class CompetitorResponse(BaseModel):
|
||
id: str
|
||
name: str
|
||
brand_id: str
|
||
logo_url: Optional[str] = None
|
||
keywords: list[str] = Field(default_factory=list)
|
||
|
||
|
||
class CompetitorListResponse(BaseModel):
|
||
items: list[CompetitorResponse]
|
||
total: int
|
||
|
||
|
||
class PlatformRuleResponse(BaseModel):
|
||
platform: str
|
||
rules: list[dict]
|
||
version: str
|
||
updated_at: str
|
||
|
||
|
||
class PlatformListResponse(BaseModel):
|
||
items: list[PlatformRuleResponse]
|
||
total: int
|
||
|
||
|
||
class RuleValidateRequest(BaseModel):
|
||
brand_id: str
|
||
platform: str
|
||
brief_rules: dict
|
||
|
||
|
||
class RuleConflict(BaseModel):
|
||
brief_rule: str
|
||
platform_rule: str
|
||
suggestion: str
|
||
|
||
|
||
class RuleValidateResponse(BaseModel):
|
||
conflicts: list[RuleConflict]
|
||
|
||
|
||
# ==================== 预置平台规则 ====================
|
||
|
||
_platform_rules = {
|
||
"douyin": {
|
||
"platform": "douyin",
|
||
"rules": [
|
||
{"type": "forbidden_word", "words": ["最好", "第一", "最佳", "绝对", "100%"]},
|
||
{"type": "duration", "min_seconds": 7},
|
||
],
|
||
"version": "2024.01",
|
||
"updated_at": "2024-01-15T00:00:00Z",
|
||
},
|
||
"xiaohongshu": {
|
||
"platform": "xiaohongshu",
|
||
"rules": [
|
||
{"type": "forbidden_word", "words": ["最好", "绝对", "100%"]},
|
||
],
|
||
"version": "2024.01",
|
||
"updated_at": "2024-01-10T00:00:00Z",
|
||
},
|
||
"bilibili": {
|
||
"platform": "bilibili",
|
||
"rules": [
|
||
{"type": "forbidden_word", "words": ["最好", "第一"]},
|
||
],
|
||
"version": "2024.01",
|
||
"updated_at": "2024-01-12T00:00:00Z",
|
||
},
|
||
}
|
||
|
||
|
||
# ==================== 辅助函数 ====================
|
||
|
||
async def _ensure_tenant_exists(tenant_id: str, db: AsyncSession) -> Tenant:
|
||
"""确保租户存在,不存在则自动创建"""
|
||
result = await db.execute(
|
||
select(Tenant).where(Tenant.id == tenant_id)
|
||
)
|
||
tenant = result.scalar_one_or_none()
|
||
|
||
if not tenant:
|
||
tenant = Tenant(id=tenant_id, name=f"租户-{tenant_id}")
|
||
db.add(tenant)
|
||
await db.flush()
|
||
|
||
return tenant
|
||
|
||
|
||
# ==================== 违禁词库 ====================
|
||
|
||
@router.get("/forbidden-words", response_model=ForbiddenWordListResponse)
|
||
async def list_forbidden_words(
|
||
category: str = None,
|
||
x_tenant_id: str = Header(..., alias="X-Tenant-ID"),
|
||
db: AsyncSession = Depends(get_db),
|
||
) -> ForbiddenWordListResponse:
|
||
"""查询违禁词列表"""
|
||
query = select(ForbiddenWord).where(ForbiddenWord.tenant_id == x_tenant_id)
|
||
|
||
if category:
|
||
query = query.where(ForbiddenWord.category == category)
|
||
|
||
result = await db.execute(query)
|
||
words = result.scalars().all()
|
||
|
||
return ForbiddenWordListResponse(
|
||
items=[
|
||
ForbiddenWordResponse(
|
||
id=w.id,
|
||
word=w.word,
|
||
category=w.category,
|
||
severity=w.severity,
|
||
)
|
||
for w in words
|
||
],
|
||
total=len(words),
|
||
)
|
||
|
||
|
||
@router.post(
|
||
"/forbidden-words",
|
||
response_model=ForbiddenWordResponse,
|
||
status_code=status.HTTP_201_CREATED,
|
||
)
|
||
async def add_forbidden_word(
|
||
request: ForbiddenWordCreate,
|
||
x_tenant_id: str = Header(..., alias="X-Tenant-ID"),
|
||
db: AsyncSession = Depends(get_db),
|
||
) -> ForbiddenWordResponse:
|
||
"""添加违禁词"""
|
||
# 确保租户存在
|
||
await _ensure_tenant_exists(x_tenant_id, db)
|
||
|
||
# 检查重复
|
||
result = await db.execute(
|
||
select(ForbiddenWord).where(
|
||
and_(
|
||
ForbiddenWord.tenant_id == x_tenant_id,
|
||
ForbiddenWord.word == request.word,
|
||
)
|
||
)
|
||
)
|
||
existing = result.scalar_one_or_none()
|
||
|
||
if existing:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_409_CONFLICT,
|
||
detail=f"违禁词已存在: {request.word}",
|
||
)
|
||
|
||
word_id = f"fw-{uuid.uuid4().hex[:8]}"
|
||
word = ForbiddenWord(
|
||
id=word_id,
|
||
tenant_id=x_tenant_id,
|
||
word=request.word,
|
||
category=request.category,
|
||
severity=request.severity,
|
||
)
|
||
db.add(word)
|
||
await db.flush()
|
||
|
||
return ForbiddenWordResponse(
|
||
id=word.id,
|
||
word=word.word,
|
||
category=word.category,
|
||
severity=word.severity,
|
||
)
|
||
|
||
|
||
@router.delete("/forbidden-words/{word_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||
async def delete_forbidden_word(
|
||
word_id: str,
|
||
x_tenant_id: str = Header(..., alias="X-Tenant-ID"),
|
||
db: AsyncSession = Depends(get_db),
|
||
):
|
||
"""删除违禁词"""
|
||
result = await db.execute(
|
||
select(ForbiddenWord).where(
|
||
and_(
|
||
ForbiddenWord.id == word_id,
|
||
ForbiddenWord.tenant_id == x_tenant_id,
|
||
)
|
||
)
|
||
)
|
||
word = result.scalar_one_or_none()
|
||
|
||
if not word:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_404_NOT_FOUND,
|
||
detail=f"违禁词不存在: {word_id}",
|
||
)
|
||
|
||
await db.delete(word)
|
||
await db.flush()
|
||
|
||
|
||
# ==================== 白名单 ====================
|
||
|
||
@router.get("/whitelist", response_model=WhitelistListResponse)
|
||
async def list_whitelist(
|
||
brand_id: str = None,
|
||
x_tenant_id: str = Header(..., alias="X-Tenant-ID"),
|
||
db: AsyncSession = Depends(get_db),
|
||
) -> WhitelistListResponse:
|
||
"""查询白名单"""
|
||
query = select(WhitelistItem).where(WhitelistItem.tenant_id == x_tenant_id)
|
||
|
||
if brand_id:
|
||
query = query.where(WhitelistItem.brand_id == brand_id)
|
||
|
||
result = await db.execute(query)
|
||
items = result.scalars().all()
|
||
|
||
return WhitelistListResponse(
|
||
items=[
|
||
WhitelistResponse(
|
||
id=item.id,
|
||
term=item.term,
|
||
reason=item.reason,
|
||
brand_id=item.brand_id,
|
||
)
|
||
for item in items
|
||
],
|
||
total=len(items),
|
||
)
|
||
|
||
|
||
@router.post(
|
||
"/whitelist",
|
||
response_model=WhitelistResponse,
|
||
status_code=status.HTTP_201_CREATED,
|
||
)
|
||
async def add_to_whitelist(
|
||
request: WhitelistCreate,
|
||
x_tenant_id: str = Header(..., alias="X-Tenant-ID"),
|
||
db: AsyncSession = Depends(get_db),
|
||
) -> WhitelistResponse:
|
||
"""添加白名单"""
|
||
# 确保租户存在
|
||
await _ensure_tenant_exists(x_tenant_id, db)
|
||
|
||
item_id = f"wl-{uuid.uuid4().hex[:8]}"
|
||
item = WhitelistItem(
|
||
id=item_id,
|
||
tenant_id=x_tenant_id,
|
||
brand_id=request.brand_id,
|
||
term=request.term,
|
||
reason=request.reason,
|
||
)
|
||
db.add(item)
|
||
await db.flush()
|
||
|
||
return WhitelistResponse(
|
||
id=item.id,
|
||
term=item.term,
|
||
reason=item.reason,
|
||
brand_id=item.brand_id,
|
||
)
|
||
|
||
|
||
# ==================== 竞品库 ====================
|
||
|
||
@router.get("/competitors", response_model=CompetitorListResponse)
|
||
async def list_competitors(
|
||
brand_id: str = None,
|
||
x_tenant_id: str = Header(..., alias="X-Tenant-ID"),
|
||
db: AsyncSession = Depends(get_db),
|
||
) -> CompetitorListResponse:
|
||
"""查询竞品列表"""
|
||
query = select(Competitor).where(Competitor.tenant_id == x_tenant_id)
|
||
|
||
if brand_id:
|
||
query = query.where(Competitor.brand_id == brand_id)
|
||
|
||
result = await db.execute(query)
|
||
competitors = result.scalars().all()
|
||
|
||
return CompetitorListResponse(
|
||
items=[
|
||
CompetitorResponse(
|
||
id=c.id,
|
||
name=c.name,
|
||
brand_id=c.brand_id,
|
||
logo_url=c.logo_url,
|
||
keywords=c.keywords or [],
|
||
)
|
||
for c in competitors
|
||
],
|
||
total=len(competitors),
|
||
)
|
||
|
||
|
||
@router.post(
|
||
"/competitors",
|
||
response_model=CompetitorResponse,
|
||
status_code=status.HTTP_201_CREATED,
|
||
)
|
||
async def add_competitor(
|
||
request: CompetitorCreate,
|
||
x_tenant_id: str = Header(..., alias="X-Tenant-ID"),
|
||
db: AsyncSession = Depends(get_db),
|
||
) -> CompetitorResponse:
|
||
"""添加竞品"""
|
||
# 确保租户存在
|
||
await _ensure_tenant_exists(x_tenant_id, db)
|
||
|
||
comp_id = f"comp-{uuid.uuid4().hex[:8]}"
|
||
competitor = Competitor(
|
||
id=comp_id,
|
||
tenant_id=x_tenant_id,
|
||
brand_id=request.brand_id,
|
||
name=request.name,
|
||
logo_url=request.logo_url,
|
||
keywords=request.keywords,
|
||
)
|
||
db.add(competitor)
|
||
await db.flush()
|
||
|
||
return CompetitorResponse(
|
||
id=competitor.id,
|
||
name=competitor.name,
|
||
brand_id=competitor.brand_id,
|
||
logo_url=competitor.logo_url,
|
||
keywords=competitor.keywords or [],
|
||
)
|
||
|
||
|
||
@router.delete("/competitors/{competitor_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||
async def delete_competitor(
|
||
competitor_id: str,
|
||
x_tenant_id: str = Header(..., alias="X-Tenant-ID"),
|
||
db: AsyncSession = Depends(get_db),
|
||
):
|
||
"""删除竞品"""
|
||
result = await db.execute(
|
||
select(Competitor).where(
|
||
and_(
|
||
Competitor.id == competitor_id,
|
||
Competitor.tenant_id == x_tenant_id,
|
||
)
|
||
)
|
||
)
|
||
competitor = result.scalar_one_or_none()
|
||
|
||
if not competitor:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_404_NOT_FOUND,
|
||
detail=f"竞品不存在: {competitor_id}",
|
||
)
|
||
|
||
await db.delete(competitor)
|
||
await db.flush()
|
||
|
||
|
||
# ==================== 平台规则 ====================
|
||
|
||
@router.get("/platforms", response_model=PlatformListResponse)
|
||
async def list_platform_rules() -> PlatformListResponse:
|
||
"""查询所有平台规则"""
|
||
return PlatformListResponse(
|
||
items=[PlatformRuleResponse(**r) for r in _platform_rules.values()],
|
||
total=len(_platform_rules),
|
||
)
|
||
|
||
|
||
@router.get("/platforms/{platform}", response_model=PlatformRuleResponse)
|
||
async def get_platform_rules(platform: str) -> PlatformRuleResponse:
|
||
"""查询指定平台规则"""
|
||
if platform not in _platform_rules:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_404_NOT_FOUND,
|
||
detail=f"平台不存在: {platform}",
|
||
)
|
||
return PlatformRuleResponse(**_platform_rules[platform])
|
||
|
||
|
||
# ==================== 规则冲突检测 ====================
|
||
|
||
@router.post("/validate", response_model=RuleValidateResponse)
|
||
async def validate_rules(
|
||
request: RuleValidateRequest,
|
||
x_tenant_id: str = Header(..., alias="X-Tenant-ID"),
|
||
db: AsyncSession = Depends(get_db),
|
||
) -> RuleValidateResponse:
|
||
"""检测 Brief 与平台规则冲突(合并 DB 规则 + 硬编码兜底)"""
|
||
conflicts = []
|
||
|
||
# 1. 收集违禁词:DB active 规则优先,硬编码兜底
|
||
db_rules = await get_active_platform_rules(
|
||
x_tenant_id, request.brand_id, request.platform, db
|
||
)
|
||
forbidden_words: set[str] = set()
|
||
min_seconds: Optional[int] = None
|
||
max_seconds: Optional[int] = None
|
||
|
||
if db_rules:
|
||
forbidden_words.update(db_rules.get("forbidden_words", []))
|
||
duration = db_rules.get("duration") or {}
|
||
min_seconds = duration.get("min_seconds")
|
||
max_seconds = duration.get("max_seconds")
|
||
|
||
# 硬编码兜底
|
||
hardcoded = _platform_rules.get(request.platform, {})
|
||
for rule in hardcoded.get("rules", []):
|
||
if rule.get("type") == "forbidden_word":
|
||
forbidden_words.update(rule.get("words", []))
|
||
elif rule.get("type") == "duration" and min_seconds is None:
|
||
if rule.get("min_seconds") is not None:
|
||
min_seconds = rule["min_seconds"]
|
||
if rule.get("max_seconds") is not None and max_seconds is None:
|
||
max_seconds = rule["max_seconds"]
|
||
|
||
# 2. 检查卖点/必选短语与违禁词冲突
|
||
phrases = list(request.brief_rules.get("required_phrases", []))
|
||
phrases += list(request.brief_rules.get("selling_points", []))
|
||
for phrase in phrases:
|
||
for word in forbidden_words:
|
||
if word in str(phrase):
|
||
conflicts.append(RuleConflict(
|
||
brief_rule=f"卖点包含:{phrase}",
|
||
platform_rule=f"{request.platform} 禁止使用:{word}",
|
||
suggestion=f"卖点 '{phrase}' 包含违禁词 '{word}',建议修改表述",
|
||
))
|
||
|
||
# 3. 检查时长冲突
|
||
brief_min = request.brief_rules.get("min_duration")
|
||
brief_max = request.brief_rules.get("max_duration")
|
||
if min_seconds and brief_max and brief_max < min_seconds:
|
||
conflicts.append(RuleConflict(
|
||
brief_rule=f"Brief 最长时长:{brief_max}秒",
|
||
platform_rule=f"{request.platform} 最短要求:{min_seconds}秒",
|
||
suggestion=f"Brief 最长 {brief_max}s 低于平台最短要求 {min_seconds}s,视频可能不达标",
|
||
))
|
||
if max_seconds and brief_min and brief_min > max_seconds:
|
||
conflicts.append(RuleConflict(
|
||
brief_rule=f"Brief 最短时长:{brief_min}秒",
|
||
platform_rule=f"{request.platform} 最长限制:{max_seconds}秒",
|
||
suggestion=f"Brief 最短 {brief_min}s 超过平台最长限制 {max_seconds}s,建议调整",
|
||
))
|
||
|
||
return RuleValidateResponse(conflicts=conflicts)
|
||
|
||
|
||
# ==================== 品牌方平台规则(文档上传 + AI 解析) ====================
|
||
|
||
def _format_platform_rule(rule: PlatformRule) -> PlatformRuleDBResponse:
|
||
"""将 ORM 对象转为响应 Schema"""
|
||
return PlatformRuleDBResponse(
|
||
id=rule.id,
|
||
platform=rule.platform,
|
||
brand_id=rule.brand_id,
|
||
document_url=rule.document_url,
|
||
document_name=rule.document_name,
|
||
parsed_rules=ParsedRulesData(**(rule.parsed_rules or {})),
|
||
status=rule.status,
|
||
created_at=rule.created_at.isoformat() if rule.created_at else "",
|
||
updated_at=rule.updated_at.isoformat() if rule.updated_at else "",
|
||
)
|
||
|
||
|
||
@router.post(
|
||
"/platform-rules/parse",
|
||
response_model=PlatformRuleParseResponse,
|
||
status_code=status.HTTP_201_CREATED,
|
||
)
|
||
async def parse_platform_rule_document(
|
||
request: PlatformRuleParseRequest,
|
||
x_tenant_id: str = Header(..., alias="X-Tenant-ID"),
|
||
db: AsyncSession = Depends(get_db),
|
||
) -> PlatformRuleParseResponse:
|
||
"""
|
||
上传文档并通过 AI 解析平台规则
|
||
|
||
流程:
|
||
1. 下载文档
|
||
2. 提取纯文本
|
||
3. AI 解析出结构化规则
|
||
4. 存入 DB (status=draft)
|
||
5. 返回解析结果供品牌方确认
|
||
"""
|
||
await _ensure_tenant_exists(x_tenant_id, db)
|
||
|
||
# 1. 尝试提取文本;对图片型 PDF 走视觉解析
|
||
document_text = ""
|
||
image_b64_list: list[str] = []
|
||
|
||
try:
|
||
# 先检查是否为图片型 PDF
|
||
image_b64_list = await DocumentParser.download_and_get_images(
|
||
request.document_url, request.document_name,
|
||
) or []
|
||
except Exception as e:
|
||
logger.warning(f"图片 PDF 检测失败,回退文本模式: {e}")
|
||
|
||
if not image_b64_list:
|
||
# 非图片 PDF 或检测失败,走文本提取
|
||
try:
|
||
document_text = await DocumentParser.download_and_parse(
|
||
request.document_url, request.document_name,
|
||
)
|
||
except ValueError as e:
|
||
raise HTTPException(status_code=400, detail=str(e))
|
||
except Exception as e:
|
||
logger.error(f"文档解析失败: {e}")
|
||
raise HTTPException(status_code=400, detail=f"文档下载或解析失败: {e}")
|
||
|
||
if not document_text.strip():
|
||
raise HTTPException(status_code=400, detail="文档内容为空,无法解析")
|
||
|
||
# 2. AI 解析(图片模式 or 文本模式)
|
||
if image_b64_list:
|
||
parsed_rules = await _ai_parse_platform_rules_vision(
|
||
x_tenant_id, request.platform, image_b64_list, db,
|
||
)
|
||
else:
|
||
parsed_rules = await _ai_parse_platform_rules(x_tenant_id, request.platform, document_text, db)
|
||
|
||
# 3. 存入 DB (draft)
|
||
rule_id = f"pr-{uuid.uuid4().hex[:8]}"
|
||
rule = PlatformRule(
|
||
id=rule_id,
|
||
tenant_id=x_tenant_id,
|
||
brand_id=request.brand_id,
|
||
platform=request.platform,
|
||
document_url=request.document_url,
|
||
document_name=request.document_name,
|
||
parsed_rules=parsed_rules,
|
||
status=RuleStatus.DRAFT.value,
|
||
)
|
||
db.add(rule)
|
||
await db.flush()
|
||
|
||
return PlatformRuleParseResponse(
|
||
id=rule.id,
|
||
platform=rule.platform,
|
||
brand_id=rule.brand_id,
|
||
document_url=rule.document_url,
|
||
document_name=rule.document_name,
|
||
parsed_rules=ParsedRulesData(**parsed_rules),
|
||
status=rule.status,
|
||
)
|
||
|
||
|
||
@router.put(
|
||
"/platform-rules/{rule_id}/confirm",
|
||
response_model=PlatformRuleDBResponse,
|
||
)
|
||
async def confirm_platform_rule(
|
||
rule_id: str,
|
||
request: PlatformRuleConfirmRequest,
|
||
x_tenant_id: str = Header(..., alias="X-Tenant-ID"),
|
||
db: AsyncSession = Depends(get_db),
|
||
) -> PlatformRuleDBResponse:
|
||
"""
|
||
确认/编辑平台规则解析结果
|
||
|
||
将 draft 状态的规则设为 active,同时将同 (tenant_id, brand_id, platform) 下
|
||
已有的 active 规则设为 inactive。
|
||
"""
|
||
result = await db.execute(
|
||
select(PlatformRule).where(
|
||
and_(
|
||
PlatformRule.id == rule_id,
|
||
PlatformRule.tenant_id == x_tenant_id,
|
||
)
|
||
)
|
||
)
|
||
rule = result.scalar_one_or_none()
|
||
if not rule:
|
||
raise HTTPException(status_code=404, detail=f"规则不存在: {rule_id}")
|
||
|
||
# 将同 (tenant_id, brand_id, platform) 下已有的 active 规则设为 inactive
|
||
existing_active = await db.execute(
|
||
select(PlatformRule).where(
|
||
and_(
|
||
PlatformRule.tenant_id == x_tenant_id,
|
||
PlatformRule.brand_id == rule.brand_id,
|
||
PlatformRule.platform == rule.platform,
|
||
PlatformRule.status == RuleStatus.ACTIVE.value,
|
||
PlatformRule.id != rule_id,
|
||
)
|
||
)
|
||
)
|
||
for old_rule in existing_active.scalars().all():
|
||
old_rule.status = RuleStatus.INACTIVE.value
|
||
|
||
# 更新当前规则
|
||
rule.parsed_rules = request.parsed_rules.model_dump()
|
||
rule.status = RuleStatus.ACTIVE.value
|
||
await db.flush()
|
||
await db.refresh(rule)
|
||
|
||
return _format_platform_rule(rule)
|
||
|
||
|
||
@router.get(
|
||
"/platform-rules",
|
||
response_model=PlatformRuleDBListResponse,
|
||
)
|
||
async def list_brand_platform_rules(
|
||
brand_id: Optional[str] = Query(None),
|
||
platform: Optional[str] = Query(None),
|
||
rule_status: Optional[str] = Query(None, alias="status"),
|
||
x_tenant_id: str = Header(..., alias="X-Tenant-ID"),
|
||
db: AsyncSession = Depends(get_db),
|
||
) -> PlatformRuleDBListResponse:
|
||
"""查询品牌方的平台规则列表"""
|
||
query = select(PlatformRule).where(PlatformRule.tenant_id == x_tenant_id)
|
||
|
||
if brand_id:
|
||
query = query.where(PlatformRule.brand_id == brand_id)
|
||
if platform:
|
||
query = query.where(PlatformRule.platform == platform)
|
||
if rule_status:
|
||
query = query.where(PlatformRule.status == rule_status)
|
||
|
||
result = await db.execute(query.order_by(PlatformRule.created_at.desc()))
|
||
rules = result.scalars().all()
|
||
|
||
return PlatformRuleDBListResponse(
|
||
items=[_format_platform_rule(r) for r in rules],
|
||
total=len(rules),
|
||
)
|
||
|
||
|
||
@router.delete(
|
||
"/platform-rules/{rule_id}",
|
||
status_code=status.HTTP_204_NO_CONTENT,
|
||
)
|
||
async def delete_platform_rule(
|
||
rule_id: str,
|
||
x_tenant_id: str = Header(..., alias="X-Tenant-ID"),
|
||
db: AsyncSession = Depends(get_db),
|
||
):
|
||
"""删除平台规则"""
|
||
result = await db.execute(
|
||
select(PlatformRule).where(
|
||
and_(
|
||
PlatformRule.id == rule_id,
|
||
PlatformRule.tenant_id == x_tenant_id,
|
||
)
|
||
)
|
||
)
|
||
rule = result.scalar_one_or_none()
|
||
if not rule:
|
||
raise HTTPException(status_code=404, detail=f"规则不存在: {rule_id}")
|
||
|
||
await db.delete(rule)
|
||
await db.flush()
|
||
|
||
|
||
async def _ai_parse_platform_rules(
|
||
tenant_id: str,
|
||
platform: str,
|
||
document_text: str,
|
||
db: AsyncSession,
|
||
) -> dict:
|
||
"""
|
||
使用 AI 将文档文本解析为结构化平台规则
|
||
|
||
AI 失败时返回空规则结构(降级为手动编辑)
|
||
"""
|
||
try:
|
||
ai_client = await AIServiceFactory.get_client(tenant_id, db)
|
||
if not ai_client:
|
||
logger.warning(f"租户 {tenant_id} 未配置 AI 服务,返回空规则")
|
||
return _empty_parsed_rules()
|
||
|
||
config = await AIServiceFactory.get_config(tenant_id, db)
|
||
if not config:
|
||
return _empty_parsed_rules()
|
||
|
||
text_model = config.models.get("text", "gpt-4o")
|
||
|
||
# 截断过长文本(避免超出 token 限制)
|
||
max_chars = 15000
|
||
if len(document_text) > max_chars:
|
||
document_text = document_text[:max_chars] + "\n...(文档内容已截断)"
|
||
|
||
prompt = f"""你是平台广告合规规则分析专家。请从以下 {platform} 平台规则文档中提取结构化规则。
|
||
|
||
文档内容:
|
||
{document_text}
|
||
|
||
请以 JSON 格式返回,不要包含其他内容:
|
||
{{
|
||
"forbidden_words": ["违禁词1", "违禁词2"],
|
||
"restricted_words": [{{"word": "xx", "condition": "使用条件", "suggestion": "替换建议"}}],
|
||
"duration": {{"min_seconds": 7, "max_seconds": null}},
|
||
"content_requirements": ["必须展示产品正面", "需要口播品牌名"],
|
||
"other_rules": [{{"rule": "规则名称", "description": "详细说明"}}]
|
||
}}
|
||
|
||
注意:
|
||
- forbidden_words: 明确禁止使用的词语
|
||
- restricted_words: 有条件限制的词语
|
||
- duration: 视频时长要求,如果文档未提及则为 null
|
||
- content_requirements: 内容上的硬性要求
|
||
- other_rules: 不属于以上分类的其他规则
|
||
- 如果某项没有提取到内容,使用空数组或 null
|
||
- 重要:JSON 字符串值中不要使用中文引号(""),使用单引号或直接省略"""
|
||
|
||
response = await ai_client.chat_completion(
|
||
messages=[{"role": "user", "content": prompt}],
|
||
model=text_model,
|
||
temperature=0.2,
|
||
max_tokens=2000,
|
||
)
|
||
|
||
# 解析 AI 响应
|
||
content = _extract_json_from_ai_response(response.content)
|
||
parsed = json.loads(content)
|
||
|
||
# 校验并补全字段
|
||
return {
|
||
"forbidden_words": parsed.get("forbidden_words", []),
|
||
"restricted_words": parsed.get("restricted_words", []),
|
||
"duration": parsed.get("duration"),
|
||
"content_requirements": parsed.get("content_requirements", []),
|
||
"other_rules": parsed.get("other_rules", []),
|
||
}
|
||
|
||
except json.JSONDecodeError as e:
|
||
logger.warning(f"AI 返回内容非 JSON,降级为空规则: {e}")
|
||
return _empty_parsed_rules()
|
||
except Exception as e:
|
||
logger.error(f"AI 解析平台规则失败: {e}")
|
||
return _empty_parsed_rules()
|
||
|
||
|
||
async def _ai_parse_platform_rules_vision(
|
||
tenant_id: str,
|
||
platform: str,
|
||
image_b64_list: list[str],
|
||
db: AsyncSession,
|
||
) -> dict:
|
||
"""
|
||
使用 AI 视觉模型从 PDF 页面图片中提取结构化平台规则。
|
||
用于扫描件/截图型 PDF。
|
||
"""
|
||
try:
|
||
ai_client = await AIServiceFactory.get_client(tenant_id, db)
|
||
if not ai_client:
|
||
logger.warning(f"租户 {tenant_id} 未配置 AI 服务,返回空规则")
|
||
return _empty_parsed_rules()
|
||
|
||
config = await AIServiceFactory.get_config(tenant_id, db)
|
||
if not config:
|
||
return _empty_parsed_rules()
|
||
|
||
vision_model = config.models.get("vision", config.models.get("text", "gpt-4o"))
|
||
|
||
# 构建多模态消息
|
||
content: list[dict] = [
|
||
{
|
||
"type": "text",
|
||
"text": f"""你是平台广告合规规则分析专家。以下是 {platform} 平台规则文档的页面截图。
|
||
请仔细阅读所有页面,从中提取结构化规则。
|
||
|
||
请以 JSON 格式返回,不要包含其他内容:
|
||
{{
|
||
"forbidden_words": ["违禁词1", "违禁词2"],
|
||
"restricted_words": [{{"word": "xx", "condition": "使用条件", "suggestion": "替换建议"}}],
|
||
"duration": {{"min_seconds": 7, "max_seconds": null}},
|
||
"content_requirements": ["必须展示产品正面", "需要口播品牌名"],
|
||
"other_rules": [{{"rule": "规则名称", "description": "详细说明"}}]
|
||
}}
|
||
|
||
注意:
|
||
- forbidden_words: 明确禁止使用的词语
|
||
- restricted_words: 有条件限制的词语
|
||
- duration: 视频时长要求,如果文档未提及则为 null
|
||
- content_requirements: 内容上的硬性要求
|
||
- other_rules: 不属于以上分类的其他规则
|
||
- 如果某项没有提取到内容,使用空数组或 null
|
||
- 重要:JSON 字符串值中不要使用中文引号(\u201c\u201d),使用单引号或直接省略""",
|
||
}
|
||
]
|
||
for b64 in image_b64_list:
|
||
content.append({
|
||
"type": "image_url",
|
||
"image_url": {"url": f"data:image/png;base64,{b64}"},
|
||
})
|
||
|
||
response = await ai_client.chat_completion(
|
||
messages=[{"role": "user", "content": content}],
|
||
model=vision_model,
|
||
temperature=0.2,
|
||
max_tokens=3000,
|
||
)
|
||
|
||
# 解析 AI 响应
|
||
resp_content = _extract_json_from_ai_response(response.content)
|
||
parsed = json.loads(resp_content)
|
||
return {
|
||
"forbidden_words": parsed.get("forbidden_words", []),
|
||
"restricted_words": parsed.get("restricted_words", []),
|
||
"duration": parsed.get("duration"),
|
||
"content_requirements": parsed.get("content_requirements", []),
|
||
"other_rules": parsed.get("other_rules", []),
|
||
}
|
||
|
||
except json.JSONDecodeError as e:
|
||
logger.warning(f"AI 视觉解析返回内容非 JSON,降级为空规则: {e}")
|
||
return _empty_parsed_rules()
|
||
except Exception as e:
|
||
logger.error(f"AI 视觉解析平台规则失败: {e}")
|
||
return _empty_parsed_rules()
|
||
|
||
|
||
def _extract_json_from_ai_response(raw: str) -> str:
|
||
"""
|
||
从 AI 响应中提取并清理 JSON 文本。
|
||
处理:markdown 代码块包裹、中文引号等。
|
||
"""
|
||
import re
|
||
text = raw.strip()
|
||
# 去掉 markdown ```json ... ``` 包裹
|
||
m = re.search(r'```(?:json)?\s*\n(.*?)```', text, re.DOTALL)
|
||
if m:
|
||
text = m.group(1).strip()
|
||
return _sanitize_json_string(text)
|
||
|
||
|
||
def _sanitize_json_string(text: str) -> str:
|
||
"""
|
||
清理 AI 返回的 JSON 文本中的中文引号等特殊字符。
|
||
中文引号 "" 在 JSON 字符串值内会破坏解析。
|
||
"""
|
||
import re
|
||
result = []
|
||
in_string = False
|
||
i = 0
|
||
while i < len(text):
|
||
ch = text[i]
|
||
if ch == '\\' and in_string and i + 1 < len(text):
|
||
result.append(ch)
|
||
result.append(text[i + 1])
|
||
i += 2
|
||
continue
|
||
if ch == '"' and not in_string:
|
||
in_string = True
|
||
result.append(ch)
|
||
elif ch == '"' and in_string:
|
||
in_string = False
|
||
result.append(ch)
|
||
elif in_string and ch in '\u201c\u201d\u300c\u300d':
|
||
# 中文引号 "" 和「」 → 单引号
|
||
result.append("'")
|
||
elif not in_string and ch in '\u201c\u201d':
|
||
# JSON 结构层的中文引号 → 英文双引号
|
||
result.append('"')
|
||
else:
|
||
result.append(ch)
|
||
i += 1
|
||
return ''.join(result)
|
||
|
||
|
||
def _empty_parsed_rules() -> dict:
|
||
"""返回空的解析规则结构"""
|
||
return {
|
||
"forbidden_words": [],
|
||
"restricted_words": [],
|
||
"duration": None,
|
||
"content_requirements": [],
|
||
"other_rules": [],
|
||
}
|
||
|
||
|
||
# ==================== 辅助函数(供其他模块调用) ====================
|
||
|
||
async def get_whitelist_for_brand(
|
||
tenant_id: str,
|
||
brand_id: str,
|
||
db: AsyncSession,
|
||
) -> list[str]:
|
||
"""获取品牌白名单词汇"""
|
||
result = await db.execute(
|
||
select(WhitelistItem).where(
|
||
and_(
|
||
WhitelistItem.tenant_id == tenant_id,
|
||
WhitelistItem.brand_id == brand_id,
|
||
)
|
||
)
|
||
)
|
||
items = result.scalars().all()
|
||
return [item.term for item in items]
|
||
|
||
|
||
async def get_other_brands_whitelist_terms(
|
||
tenant_id: str,
|
||
brand_id: str,
|
||
db: AsyncSession,
|
||
) -> list[tuple[str, str]]:
|
||
"""
|
||
获取其他品牌的白名单词汇(用于品牌安全检测)
|
||
|
||
Returns:
|
||
list of (term, owner_brand_id)
|
||
"""
|
||
result = await db.execute(
|
||
select(WhitelistItem).where(
|
||
and_(
|
||
WhitelistItem.tenant_id == tenant_id,
|
||
WhitelistItem.brand_id != brand_id,
|
||
)
|
||
)
|
||
)
|
||
items = result.scalars().all()
|
||
return [(item.term, item.brand_id) for item in items]
|
||
|
||
|
||
async def get_forbidden_words_for_tenant(
|
||
tenant_id: str,
|
||
db: AsyncSession,
|
||
category: str = None,
|
||
) -> list[dict]:
|
||
"""获取租户的违禁词列表"""
|
||
query = select(ForbiddenWord).where(ForbiddenWord.tenant_id == tenant_id)
|
||
if category:
|
||
query = query.where(ForbiddenWord.category == category)
|
||
|
||
result = await db.execute(query)
|
||
words = result.scalars().all()
|
||
|
||
return [
|
||
{
|
||
"id": w.id,
|
||
"word": w.word,
|
||
"category": w.category,
|
||
"severity": w.severity,
|
||
}
|
||
for w in words
|
||
]
|
||
|
||
|
||
async def get_active_platform_rules(
|
||
tenant_id: str,
|
||
brand_id: str,
|
||
platform: str,
|
||
db: AsyncSession,
|
||
) -> Optional[dict]:
|
||
"""
|
||
获取品牌方在该平台的生效规则 (active)
|
||
|
||
Returns:
|
||
parsed_rules dict 或 None(没有上传规则时)
|
||
"""
|
||
result = await db.execute(
|
||
select(PlatformRule).where(
|
||
and_(
|
||
PlatformRule.tenant_id == tenant_id,
|
||
PlatformRule.brand_id == brand_id,
|
||
PlatformRule.platform == platform,
|
||
PlatformRule.status == RuleStatus.ACTIVE.value,
|
||
)
|
||
)
|
||
)
|
||
rule = result.scalar_one_or_none()
|
||
if not rule:
|
||
return None
|
||
return rule.parsed_rules
|