diff --git a/backend/alembic/versions/006_add_platform_rules.py b/backend/alembic/versions/006_add_platform_rules.py new file mode 100644 index 0000000..38aa563 --- /dev/null +++ b/backend/alembic/versions/006_add_platform_rules.py @@ -0,0 +1,38 @@ +"""添加平台规则表 + +Revision ID: 006 +Revises: 005 +Create Date: 2026-02-10 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision: str = '006' +down_revision: Union[str, None] = '005' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.create_table( + 'platform_rules', + sa.Column('id', sa.String(64), primary_key=True), + sa.Column('tenant_id', sa.String(64), sa.ForeignKey('tenants.id', ondelete='CASCADE'), nullable=False, index=True), + sa.Column('brand_id', sa.String(64), nullable=False, index=True), + sa.Column('platform', sa.String(50), nullable=False, index=True), + sa.Column('document_url', sa.String(2048), nullable=False), + sa.Column('document_name', sa.String(512), nullable=False), + sa.Column('parsed_rules', sa.JSON().with_variant(postgresql.JSONB, 'postgresql'), nullable=True), + sa.Column('status', sa.String(20), nullable=False, default='draft', index=True), + sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False), + sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False), + ) + + +def downgrade() -> None: + op.drop_table('platform_rules') diff --git a/backend/app/api/rules.py b/backend/app/api/rules.py index 12f2d5e..745ac9d 100644 --- a/backend/app/api/rules.py +++ b/backend/app/api/rules.py @@ -2,8 +2,10 @@ 规则管理 API 违禁词库、白名单、竞品库、平台规则 """ +import json +import logging import uuid -from fastapi import APIRouter, Depends, Header, HTTPException, status +from fastapi import APIRouter, Depends, Header, HTTPException, Query, status from pydantic import BaseModel, Field from typing import Optional from sqlalchemy import select, and_ @@ -11,7 +13,19 @@ from sqlalchemy.ext.asyncio import AsyncSession from app.database import get_db from app.models.tenant import Tenant -from app.models.rule import ForbiddenWord, WhitelistItem, Competitor +from app.models.rule import ForbiddenWord, WhitelistItem, Competitor, PlatformRule, RuleStatus +from app.schemas.rules import ( + PlatformRuleParseRequest, + PlatformRuleParseResponse, + PlatformRuleConfirmRequest, + PlatformRuleResponse as PlatformRuleDBResponse, + PlatformRuleListResponse as PlatformRuleDBListResponse, + ParsedRulesData, +) +from app.services.document_parser import DocumentParser +from app.services.ai_service import AIServiceFactory + +logger = logging.getLogger(__name__) router = APIRouter(prefix="/rules", tags=["rules"]) @@ -468,6 +482,289 @@ async def validate_rules(request: RuleValidateRequest) -> RuleValidateResponse: return RuleValidateResponse(conflicts=conflicts) +# ==================== 品牌方平台规则(文档上传 + AI 解析) ==================== + +def _format_platform_rule(rule: PlatformRule) -> PlatformRuleDBResponse: + """将 ORM 对象转为响应 Schema""" + return PlatformRuleDBResponse( + id=rule.id, + platform=rule.platform, + brand_id=rule.brand_id, + document_url=rule.document_url, + document_name=rule.document_name, + parsed_rules=ParsedRulesData(**(rule.parsed_rules or {})), + status=rule.status, + created_at=rule.created_at.isoformat() if rule.created_at else "", + updated_at=rule.updated_at.isoformat() if rule.updated_at else "", + ) + + +@router.post( + "/platform-rules/parse", + response_model=PlatformRuleParseResponse, + status_code=status.HTTP_201_CREATED, +) +async def parse_platform_rule_document( + request: PlatformRuleParseRequest, + x_tenant_id: str = Header(..., alias="X-Tenant-ID"), + db: AsyncSession = Depends(get_db), +) -> PlatformRuleParseResponse: + """ + 上传文档并通过 AI 解析平台规则 + + 流程: + 1. 下载文档 + 2. 提取纯文本 + 3. AI 解析出结构化规则 + 4. 存入 DB (status=draft) + 5. 返回解析结果供品牌方确认 + """ + await _ensure_tenant_exists(x_tenant_id, db) + + # 1. 下载并解析文档 + try: + document_text = await DocumentParser.download_and_parse( + request.document_url, request.document_name, + ) + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + except Exception as e: + logger.error(f"文档解析失败: {e}") + raise HTTPException(status_code=400, detail=f"文档下载或解析失败: {e}") + + if not document_text.strip(): + raise HTTPException(status_code=400, detail="文档内容为空,无法解析") + + # 2. AI 解析 + parsed_rules = await _ai_parse_platform_rules(x_tenant_id, request.platform, document_text, db) + + # 3. 存入 DB (draft) + rule_id = f"pr-{uuid.uuid4().hex[:8]}" + rule = PlatformRule( + id=rule_id, + tenant_id=x_tenant_id, + brand_id=request.brand_id, + platform=request.platform, + document_url=request.document_url, + document_name=request.document_name, + parsed_rules=parsed_rules, + status=RuleStatus.DRAFT.value, + ) + db.add(rule) + await db.flush() + + return PlatformRuleParseResponse( + id=rule.id, + platform=rule.platform, + brand_id=rule.brand_id, + document_url=rule.document_url, + document_name=rule.document_name, + parsed_rules=ParsedRulesData(**parsed_rules), + status=rule.status, + ) + + +@router.put( + "/platform-rules/{rule_id}/confirm", + response_model=PlatformRuleDBResponse, +) +async def confirm_platform_rule( + rule_id: str, + request: PlatformRuleConfirmRequest, + x_tenant_id: str = Header(..., alias="X-Tenant-ID"), + db: AsyncSession = Depends(get_db), +) -> PlatformRuleDBResponse: + """ + 确认/编辑平台规则解析结果 + + 将 draft 状态的规则设为 active,同时将同 (tenant_id, brand_id, platform) 下 + 已有的 active 规则设为 inactive。 + """ + result = await db.execute( + select(PlatformRule).where( + and_( + PlatformRule.id == rule_id, + PlatformRule.tenant_id == x_tenant_id, + ) + ) + ) + rule = result.scalar_one_or_none() + if not rule: + raise HTTPException(status_code=404, detail=f"规则不存在: {rule_id}") + + # 将同 (tenant_id, brand_id, platform) 下已有的 active 规则设为 inactive + existing_active = await db.execute( + select(PlatformRule).where( + and_( + PlatformRule.tenant_id == x_tenant_id, + PlatformRule.brand_id == rule.brand_id, + PlatformRule.platform == rule.platform, + PlatformRule.status == RuleStatus.ACTIVE.value, + PlatformRule.id != rule_id, + ) + ) + ) + for old_rule in existing_active.scalars().all(): + old_rule.status = RuleStatus.INACTIVE.value + + # 更新当前规则 + rule.parsed_rules = request.parsed_rules.model_dump() + rule.status = RuleStatus.ACTIVE.value + await db.flush() + + return _format_platform_rule(rule) + + +@router.get( + "/platform-rules", + response_model=PlatformRuleDBListResponse, +) +async def list_brand_platform_rules( + brand_id: Optional[str] = Query(None), + platform: Optional[str] = Query(None), + rule_status: Optional[str] = Query(None, alias="status"), + x_tenant_id: str = Header(..., alias="X-Tenant-ID"), + db: AsyncSession = Depends(get_db), +) -> PlatformRuleDBListResponse: + """查询品牌方的平台规则列表""" + query = select(PlatformRule).where(PlatformRule.tenant_id == x_tenant_id) + + if brand_id: + query = query.where(PlatformRule.brand_id == brand_id) + if platform: + query = query.where(PlatformRule.platform == platform) + if rule_status: + query = query.where(PlatformRule.status == rule_status) + + result = await db.execute(query.order_by(PlatformRule.created_at.desc())) + rules = result.scalars().all() + + return PlatformRuleDBListResponse( + items=[_format_platform_rule(r) for r in rules], + total=len(rules), + ) + + +@router.delete( + "/platform-rules/{rule_id}", + status_code=status.HTTP_204_NO_CONTENT, +) +async def delete_platform_rule( + rule_id: str, + x_tenant_id: str = Header(..., alias="X-Tenant-ID"), + db: AsyncSession = Depends(get_db), +): + """删除平台规则""" + result = await db.execute( + select(PlatformRule).where( + and_( + PlatformRule.id == rule_id, + PlatformRule.tenant_id == x_tenant_id, + ) + ) + ) + rule = result.scalar_one_or_none() + if not rule: + raise HTTPException(status_code=404, detail=f"规则不存在: {rule_id}") + + await db.delete(rule) + await db.flush() + + +async def _ai_parse_platform_rules( + tenant_id: str, + platform: str, + document_text: str, + db: AsyncSession, +) -> dict: + """ + 使用 AI 将文档文本解析为结构化平台规则 + + AI 失败时返回空规则结构(降级为手动编辑) + """ + try: + ai_client = await AIServiceFactory.get_client(tenant_id, db) + if not ai_client: + logger.warning(f"租户 {tenant_id} 未配置 AI 服务,返回空规则") + return _empty_parsed_rules() + + config = await AIServiceFactory.get_config(tenant_id, db) + if not config: + return _empty_parsed_rules() + + text_model = config.models.get("text", "gpt-4o") + + # 截断过长文本(避免超出 token 限制) + max_chars = 15000 + if len(document_text) > max_chars: + document_text = document_text[:max_chars] + "\n...(文档内容已截断)" + + prompt = f"""你是平台广告合规规则分析专家。请从以下 {platform} 平台规则文档中提取结构化规则。 + +文档内容: +{document_text} + +请以 JSON 格式返回,不要包含其他内容: +{{ + "forbidden_words": ["违禁词1", "违禁词2"], + "restricted_words": [{{"word": "xx", "condition": "使用条件", "suggestion": "替换建议"}}], + "duration": {{"min_seconds": 7, "max_seconds": null}}, + "content_requirements": ["必须展示产品正面", "需要口播品牌名"], + "other_rules": [{{"rule": "规则名称", "description": "详细说明"}}] +}} + +注意: +- forbidden_words: 明确禁止使用的词语 +- restricted_words: 有条件限制的词语 +- duration: 视频时长要求,如果文档未提及则为 null +- content_requirements: 内容上的硬性要求 +- other_rules: 不属于以上分类的其他规则 +- 如果某项没有提取到内容,使用空数组或 null""" + + response = await ai_client.chat_completion( + messages=[{"role": "user", "content": prompt}], + model=text_model, + temperature=0.2, + max_tokens=2000, + ) + + # 解析 AI 响应 + content = response.content.strip() + if content.startswith("```"): + content = content.split("\n", 1)[1] + if content.endswith("```"): + content = content.rsplit("\n", 1)[0] + + parsed = json.loads(content) + + # 校验并补全字段 + return { + "forbidden_words": parsed.get("forbidden_words", []), + "restricted_words": parsed.get("restricted_words", []), + "duration": parsed.get("duration"), + "content_requirements": parsed.get("content_requirements", []), + "other_rules": parsed.get("other_rules", []), + } + + except json.JSONDecodeError: + logger.warning("AI 返回内容非 JSON,降级为空规则") + return _empty_parsed_rules() + except Exception as e: + logger.error(f"AI 解析平台规则失败: {e}") + return _empty_parsed_rules() + + +def _empty_parsed_rules() -> dict: + """返回空的解析规则结构""" + return { + "forbidden_words": [], + "restricted_words": [], + "duration": None, + "content_requirements": [], + "other_rules": [], + } + + # ==================== 辅助函数(供其他模块调用) ==================== async def get_whitelist_for_brand( @@ -533,3 +830,31 @@ async def get_forbidden_words_for_tenant( } for w in words ] + + +async def get_active_platform_rules( + tenant_id: str, + brand_id: str, + platform: str, + db: AsyncSession, +) -> Optional[dict]: + """ + 获取品牌方在该平台的生效规则 (active) + + Returns: + parsed_rules dict 或 None(没有上传规则时) + """ + result = await db.execute( + select(PlatformRule).where( + and_( + PlatformRule.tenant_id == tenant_id, + PlatformRule.brand_id == brand_id, + PlatformRule.platform == platform, + PlatformRule.status == RuleStatus.ACTIVE.value, + ) + ) + ) + rule = result.scalar_one_or_none() + if not rule: + return None + return rule.parsed_rules diff --git a/backend/app/api/scripts.py b/backend/app/api/scripts.py index 3ab09e4..c70e562 100644 --- a/backend/app/api/scripts.py +++ b/backend/app/api/scripts.py @@ -15,11 +15,14 @@ from app.schemas.review import ( RiskLevel, Position, SoftRiskWarning, + SoftRiskAction, ) from app.api.rules import ( get_whitelist_for_brand, get_other_brands_whitelist_terms, get_forbidden_words_for_tenant, + get_active_platform_rules, + _platform_rules, ) from app.services.soft_risk import evaluate_soft_risk from app.services.ai_service import AIServiceFactory @@ -175,19 +178,88 @@ async def review_script( position=Position(start=content.find(term), end=content.find(term) + len(term)), )) + # 3A. 平台规则违禁词(优先从 DB 读取,硬编码兜底) + already_checked = set(ABSOLUTE_WORDS + [w["word"] for w in tenant_forbidden_words]) + platform_forbidden_words: list[str] = [] + + # 优先从 DB 获取品牌方上传的 active 平台规则 + db_platform_rules = await get_active_platform_rules( + x_tenant_id, request.brand_id, request.platform.value, db, + ) + if db_platform_rules: + platform_forbidden_words = db_platform_rules.get("forbidden_words", []) + else: + # 兜底:从硬编码 _platform_rules 读取 + platform_rule = _platform_rules.get(request.platform.value) + if platform_rule: + for rule in platform_rule.get("rules", []): + if rule.get("type") == "forbidden_word": + platform_forbidden_words.extend(rule.get("words", [])) + + for word in platform_forbidden_words: + if word in already_checked or word in whitelist: + continue + start = 0 + while True: + pos = content.find(word, start) + if pos == -1: + break + if not _is_ad_context(content, word): + start = pos + 1 + continue + violations.append(Violation( + type=ViolationType.FORBIDDEN_WORD, + content=word, + severity=RiskLevel.MEDIUM, + suggestion=f"违反{request.platform.value}平台规则,建议删除:{word}", + position=Position(start=pos, end=pos + len(word)), + )) + start = pos + 1 + + # 3B. Brief 黑名单词 + if request.blacklist_words: + for item in request.blacklist_words: + word = item.get("word", "") + reason = item.get("reason", "") + if not word or word in whitelist: + continue + start_pos = 0 + while True: + pos = content.find(word, start_pos) + if pos == -1: + break + suggestion = f"Brief 黑名单词:{word}" + if reason: + suggestion += f"({reason})" + violations.append(Violation( + type=ViolationType.FORBIDDEN_WORD, + content=word, + severity=RiskLevel.HIGH, + suggestion=suggestion, + position=Position(start=pos, end=pos + len(word)), + )) + start_pos = pos + 1 + # 4. 检查遗漏卖点 missing_points: list[str] | None = None if request.required_points: missing = _check_selling_point_coverage(content, request.required_points) missing_points = missing if missing else [] - # 5. 可选:AI 深度分析 - ai_violations = await _ai_deep_analysis(x_tenant_id, content, db) + # 5. 可选:AI 深度分析(返回 violations + warnings) + ai_violations, ai_warnings = await _ai_deep_analysis(x_tenant_id, content, db) if ai_violations: violations.extend(ai_violations) - # 6. 计算分数 - score = 100 - len(violations) * 25 + # 6. 计算分数(按严重程度加权) + score = 100 + for v in violations: + if v.severity == RiskLevel.HIGH: + score -= 25 + elif v.severity == RiskLevel.MEDIUM: + score -= 15 + else: + score -= 5 if missing_points: score -= len(missing_points) * 5 score = max(0, score) @@ -209,6 +281,19 @@ async def review_script( if request.soft_risk_context: soft_warnings = evaluate_soft_risk(request.soft_risk_context) + # 合并 AI 产出的 soft_warnings + if ai_warnings: + soft_warnings.extend(ai_warnings) + + # 遗漏卖点也加入 soft_warnings + if missing_points: + soft_warnings.append(SoftRiskWarning( + code="missing_selling_points", + message=f"遗漏 {len(missing_points)} 个卖点:{', '.join(missing_points)}", + action_required=SoftRiskAction.NOTE, + blocking=False, + )) + return ScriptReviewResponse( score=score, summary=summary, @@ -222,26 +307,27 @@ async def _ai_deep_analysis( tenant_id: str, content: str, db: AsyncSession, -) -> list[Violation]: +) -> tuple[list[Violation], list[SoftRiskWarning]]: """ 使用 AI 进行深度分析 + 返回 (violations, soft_warnings) AI 分析失败时返回空列表,降级到规则检测 """ try: # 获取 AI 客户端 ai_client = await AIServiceFactory.get_client(tenant_id, db) if not ai_client: - return [] + return [], [] # 获取模型配置 config = await AIServiceFactory.get_config(tenant_id, db) if not config: - return [] + return [], [] text_model = config.models.get("text", "gpt-4o") - # 构建分析提示 + # 构建分析提示(两类输出) analysis_prompt = f"""作为广告合规审核专家,请分析以下广告脚本内容,检测潜在的合规风险: 脚本内容: @@ -253,12 +339,17 @@ async def _ai_deep_analysis( 3. 是否存在夸大描述 4. 是否存在可能违反广告法的其他内容 -如果发现问题,请以 JSON 数组格式返回,每项包含: +请以 JSON 数组返回,每项包含: +- category: "violation"(硬性违规,明确违法/违规)或 "warning"(软性提醒,需人工判断) - type: 违规类型 (forbidden_word/efficacy_claim/brand_safety) -- content: 违规内容 +- content: 问题内容 - severity: 严重程度 (high/medium/low) - suggestion: 修改建议 +分类标准: +- violation: 违禁词、功效宣称、品牌安全等明确违规 +- warning: 夸大描述、易误解表述、潜在风险 + 如果未发现问题,返回空数组 [] 请只返回 JSON 数组,不要包含其他内容。""" @@ -283,7 +374,10 @@ async def _ai_deep_analysis( ai_results = json.loads(response_content) violations = [] + warnings = [] for item in ai_results: + category = item.get("category", "violation") # 默认当硬性违规(安全兜底) + violation_type = item.get("type", "forbidden_word") if violation_type == "forbidden_word": vtype = ViolationType.FORBIDDEN_WORD @@ -300,19 +394,28 @@ async def _ai_deep_analysis( else: slevel = RiskLevel.MEDIUM - violations.append(Violation( - type=vtype, - content=item.get("content", ""), - severity=slevel, - suggestion=item.get("suggestion", "建议修改"), - )) + if category == "warning": + # 软性提醒 → SoftRiskWarning + warnings.append(SoftRiskWarning( + code="ai_warning", + message=f"{item.get('content', '')}: {item.get('suggestion', '建议修改')}", + action_required=SoftRiskAction.NOTE, + blocking=False, + context={"type": violation_type, "severity": severity}, + )) + else: + # 硬性违规 → Violation + violations.append(Violation( + type=vtype, + content=item.get("content", ""), + severity=slevel, + suggestion=item.get("suggestion", "建议修改"), + )) - return violations + return violations, warnings except json.JSONDecodeError: - # JSON 解析失败,返回空列表 - return [] + return [], [] except Exception: - # AI 调用失败,降级到规则检测 - return [] + return [], [] diff --git a/backend/app/models/__init__.py b/backend/app/models/__init__.py index 840ac88..6eee7df 100644 --- a/backend/app/models/__init__.py +++ b/backend/app/models/__init__.py @@ -10,7 +10,7 @@ from app.models.task import Task, TaskStage, TaskStatus from app.models.brief import Brief from app.models.ai_config import AIConfig from app.models.review import ReviewTask, Platform -from app.models.rule import ForbiddenWord, WhitelistItem, Competitor +from app.models.rule import ForbiddenWord, WhitelistItem, Competitor, PlatformRule, RuleStatus from app.models.audit_log import AuditLog from app.models.message import Message # 保留 Tenant 兼容旧代码,但新代码应使用 Brand @@ -44,6 +44,8 @@ __all__ = [ "ForbiddenWord", "WhitelistItem", "Competitor", + "PlatformRule", + "RuleStatus", # 审计日志 "AuditLog", # 消息 diff --git a/backend/app/models/rule.py b/backend/app/models/rule.py index 11ecc47..165620b 100644 --- a/backend/app/models/rule.py +++ b/backend/app/models/rule.py @@ -1,7 +1,8 @@ """ 规则模型 -违禁词、白名单、竞品 +违禁词、白名单、竞品、平台规则 """ +import enum from typing import TYPE_CHECKING, Optional from sqlalchemy import String, Text, ForeignKey from app.models.types import JSONType @@ -13,6 +14,13 @@ if TYPE_CHECKING: from app.models.tenant import Tenant +class RuleStatus(str, enum.Enum): + """平台规则状态""" + DRAFT = "draft" # AI 解析完成,待确认 + ACTIVE = "active" # 品牌方已确认,生效中 + INACTIVE = "inactive" # 已停用 + + class ForbiddenWord(Base, TimestampMixin): """违禁词表""" __tablename__ = "forbidden_words" @@ -83,3 +91,36 @@ class Competitor(Base, TimestampMixin): def __repr__(self) -> str: return f"" + + +class PlatformRule(Base, TimestampMixin): + """平台规则表 — 品牌方上传文档 + AI 解析""" + __tablename__ = "platform_rules" + + id: Mapped[str] = mapped_column(String(64), primary_key=True) + tenant_id: Mapped[str] = mapped_column( + String(64), + ForeignKey("tenants.id", ondelete="CASCADE"), + nullable=False, + index=True, + ) + brand_id: Mapped[str] = mapped_column(String(64), nullable=False, index=True) + platform: Mapped[str] = mapped_column(String(50), nullable=False, index=True) + + # 文档信息 + document_url: Mapped[str] = mapped_column(String(2048), nullable=False) + document_name: Mapped[str] = mapped_column(String(512), nullable=False) + + # AI 解析结果(JSON) + parsed_rules: Mapped[Optional[dict]] = mapped_column(JSONType, nullable=True) + + # 状态 + status: Mapped[str] = mapped_column( + String(20), nullable=False, default=RuleStatus.DRAFT.value, index=True, + ) + + # 关联 + tenant: Mapped["Tenant"] = relationship("Tenant", back_populates="platform_rules") + + def __repr__(self) -> str: + return f"" diff --git a/backend/app/models/tenant.py b/backend/app/models/tenant.py index a6a911f..97cfbc4 100644 --- a/backend/app/models/tenant.py +++ b/backend/app/models/tenant.py @@ -10,7 +10,7 @@ from app.models.base import Base, TimestampMixin if TYPE_CHECKING: from app.models.ai_config import AIConfig from app.models.review import ReviewTask - from app.models.rule import ForbiddenWord, WhitelistItem, Competitor + from app.models.rule import ForbiddenWord, WhitelistItem, Competitor, PlatformRule class Tenant(Base, TimestampMixin): @@ -48,5 +48,11 @@ class Tenant(Base, TimestampMixin): back_populates="tenant", lazy="selectin", ) + platform_rules: Mapped[list["PlatformRule"]] = relationship( + "PlatformRule", + back_populates="tenant", + lazy="selectin", + ) + def __repr__(self) -> str: return f"" diff --git a/backend/app/schemas/rules.py b/backend/app/schemas/rules.py new file mode 100644 index 0000000..19ec693 --- /dev/null +++ b/backend/app/schemas/rules.py @@ -0,0 +1,69 @@ +""" +平台规则相关 Schema +""" +from typing import Optional +from pydantic import BaseModel, Field + + +class PlatformRuleParseRequest(BaseModel): + """上传文档并解析""" + document_url: str = Field(..., description="TOS 上传后的文件 URL") + document_name: str = Field(..., description="原始文件名(用于判断格式)") + platform: str = Field(..., description="目标平台 (douyin/xiaohongshu/bilibili/kuaishou)") + brand_id: str = Field(..., description="品牌 ID") + + +class ParsedRulesData(BaseModel): + """AI 解析出的结构化规则""" + forbidden_words: list[str] = Field(default_factory=list, description="违禁词列表") + restricted_words: list[dict] = Field( + default_factory=list, + description="限制词 [{word, condition, suggestion}]", + ) + duration: Optional[dict] = Field( + None, + description="时长要求 {min_seconds, max_seconds}", + ) + content_requirements: list[str] = Field( + default_factory=list, + description="内容要求(如'必须展示产品')", + ) + other_rules: list[dict] = Field( + default_factory=list, + description="其他规则 [{rule, description}]", + ) + + +class PlatformRuleParseResponse(BaseModel): + """解析响应(draft 状态)""" + id: str + platform: str + brand_id: str + document_url: str + document_name: str + parsed_rules: ParsedRulesData + status: str + + +class PlatformRuleConfirmRequest(BaseModel): + """确认/编辑解析结果""" + parsed_rules: ParsedRulesData = Field(..., description="品牌方可能修改过的规则") + + +class PlatformRuleResponse(BaseModel): + """完整响应""" + id: str + platform: str + brand_id: str + document_url: str + document_name: str + parsed_rules: ParsedRulesData + status: str + created_at: str + updated_at: str + + +class PlatformRuleListResponse(BaseModel): + """列表响应""" + items: list[PlatformRuleResponse] + total: int diff --git a/backend/app/services/document_parser.py b/backend/app/services/document_parser.py new file mode 100644 index 0000000..7444f06 --- /dev/null +++ b/backend/app/services/document_parser.py @@ -0,0 +1,119 @@ +""" +文档解析服务 +从 PDF/Word/Excel 文档中提取纯文本 +""" +import os +import tempfile +from typing import Optional + +import httpx + + +class DocumentParser: + """从文档中提取纯文本""" + + @staticmethod + async def download_and_parse(document_url: str, document_name: str) -> str: + """ + 下载文档并解析为纯文本 + + Args: + document_url: 文档 URL (TOS) + document_name: 原始文件名(用于判断格式) + + Returns: + 提取的纯文本 + """ + # 下载到临时文件 + tmp_path: Optional[str] = None + try: + async with httpx.AsyncClient(timeout=60.0) as client: + resp = await client.get(document_url) + resp.raise_for_status() + + ext = document_name.rsplit(".", 1)[-1].lower() if "." in document_name else "" + with tempfile.NamedTemporaryFile(delete=False, suffix=f".{ext}") as tmp: + tmp.write(resp.content) + tmp_path = tmp.name + + return DocumentParser.parse_file(tmp_path, document_name) + finally: + if tmp_path and os.path.exists(tmp_path): + os.unlink(tmp_path) + + @staticmethod + def parse_file(file_path: str, file_name: str) -> str: + """ + 根据扩展名选择解析器,返回纯文本 + + Args: + file_path: 本地文件路径 + file_name: 原始文件名 + + Returns: + 提取的纯文本 + """ + ext = file_name.rsplit(".", 1)[-1].lower() if "." in file_name else "" + + if ext == "pdf": + return DocumentParser._parse_pdf(file_path) + elif ext in ("doc", "docx"): + return DocumentParser._parse_docx(file_path) + elif ext in ("xls", "xlsx"): + return DocumentParser._parse_xlsx(file_path) + elif ext == "txt": + return DocumentParser._parse_txt(file_path) + else: + raise ValueError(f"不支持的文件格式: {ext}") + + @staticmethod + def _parse_pdf(path: str) -> str: + """pdfplumber 提取 PDF 文本""" + import pdfplumber + + texts = [] + with pdfplumber.open(path) as pdf: + for page in pdf.pages: + text = page.extract_text() + if text: + texts.append(text) + return "\n".join(texts) + + @staticmethod + def _parse_docx(path: str) -> str: + """python-docx 提取 Word 文本""" + from docx import Document + + doc = Document(path) + texts = [] + for para in doc.paragraphs: + if para.text.strip(): + texts.append(para.text) + # 也提取表格内容 + for table in doc.tables: + for row in table.rows: + row_text = "\t".join(cell.text.strip() for cell in row.cells if cell.text.strip()) + if row_text: + texts.append(row_text) + return "\n".join(texts) + + @staticmethod + def _parse_xlsx(path: str) -> str: + """openpyxl 提取 Excel 文本(所有 sheet 拼接)""" + from openpyxl import load_workbook + + wb = load_workbook(path, read_only=True, data_only=True) + texts = [] + for sheet in wb.worksheets: + for row in sheet.iter_rows(values_only=True): + row_text = "\t".join(str(cell) for cell in row if cell is not None) + if row_text.strip(): + texts.append(row_text) + wb.close() + return "\n".join(texts) + + @staticmethod + def _parse_txt(path: str) -> str: + """纯文本文件""" + with open(path, "r", encoding="utf-8") as f: + return f.read() diff --git a/backend/app/tasks/review.py b/backend/app/tasks/review.py index ee02654..37c6ac0 100644 --- a/backend/app/tasks/review.py +++ b/backend/app/tasks/review.py @@ -14,7 +14,7 @@ from sqlalchemy.orm import sessionmaker from app.config import settings from app.models.review import ReviewTask, TaskStatus as DBTaskStatus -from app.models.rule import ForbiddenWord, Competitor +from app.models.rule import ForbiddenWord, Competitor, PlatformRule, RuleStatus from app.models.ai_config import AIConfig from app.services.video_download import VideoDownloadService, DownloadResult from app.services.keyframe import KeyFrameExtractor, ExtractionResult @@ -81,6 +81,7 @@ async def complete_review( summary: str, violations: list[dict], status: DBTaskStatus = DBTaskStatus.COMPLETED, + soft_warnings: Optional[list[dict]] = None, ): """完成审核""" result = await db.execute( @@ -94,6 +95,8 @@ async def complete_review( task.score = score task.summary = summary task.violations = violations + if soft_warnings is not None: + task.soft_warnings = soft_warnings task.completed_at = datetime.now(timezone.utc) await db.commit() @@ -153,6 +156,24 @@ async def get_competitors(db: AsyncSession, tenant_id: str, brand_id: str) -> li return [row[0] for row in result.fetchall()] +async def get_platform_forbidden_words( + db: AsyncSession, tenant_id: str, brand_id: str, platform: str, +) -> list[str]: + """从 DB 获取品牌方在该平台的 active 规则中的违禁词""" + result = await db.execute( + select(PlatformRule).where( + PlatformRule.tenant_id == tenant_id, + PlatformRule.brand_id == brand_id, + PlatformRule.platform == platform, + PlatformRule.status == RuleStatus.ACTIVE.value, + ) + ) + rule = result.scalar_one_or_none() + if not rule or not rule.parsed_rules: + return [] + return rule.parsed_rules.get("forbidden_words", []) + + async def process_video_review( review_id: str, tenant_id: str, @@ -199,6 +220,13 @@ async def process_video_review( # 获取规则 forbidden_words = await get_forbidden_words(db, tenant_id) + # 合并平台规则中的违禁词 + platform_fw = await get_platform_forbidden_words(db, tenant_id, brand_id, platform) + existing_set = set(forbidden_words) + for w in platform_fw: + if w not in existing_set: + forbidden_words.append(w) + existing_set.add(w) competitors = await get_competitors(db, tenant_id, brand_id) # 初始化 AI 服务 @@ -281,16 +309,37 @@ async def process_video_review( ) all_violations.extend(subtitle_violations) - # 6. 计算分数和生成报告 + # 6. 分流 violations / soft_warnings await update_review_progress(db, review_id, 90, "生成报告") - score = review_service.calculate_score(all_violations) - if not all_violations: + hard_violations = [] + soft_warnings_data = [] + + for v in all_violations: + v_type = v.get("type", "") + if v_type in ("forbidden_word", "efficacy_claim", "competitor_logo", "brand_safety"): + hard_violations.append(v) + elif v_type in ("duration_short", "mention_missing"): + soft_warnings_data.append({ + "code": f"video_{v_type}", + "message": v.get("content", ""), + "action_required": "note", + "blocking": False, + "context": {"suggestion": v.get("suggestion", "")}, + }) + else: + hard_violations.append(v) # 默认当硬性违规 + + # 计算分数(仅硬性违规影响分数) + score = review_service.calculate_score(hard_violations) + + if not hard_violations: summary = "视频内容合规,未发现违规项" + if soft_warnings_data: + summary += f"({len(soft_warnings_data)} 条提醒)" else: - high_count = sum(1 for v in all_violations if v.get("risk_level") == "high") - medium_count = sum(1 for v in all_violations if v.get("risk_level") == "medium") - summary = f"发现 {len(all_violations)} 处违规" + high_count = sum(1 for v in hard_violations if v.get("risk_level") == "high") + summary = f"发现 {len(hard_violations)} 处违规" if high_count > 0: summary += f"({high_count} 处高风险)" @@ -300,7 +349,8 @@ async def process_video_review( review_id, score=score, summary=summary, - violations=all_violations, + violations=hard_violations, + soft_warnings=soft_warnings_data if soft_warnings_data else None, ) except Exception as e: diff --git a/backend/pyproject.toml b/backend/pyproject.toml index 543b4e0..335e974 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -21,6 +21,9 @@ dependencies = [ "openai>=1.12.0", "cachetools>=5.3.0", "sse-starlette>=2.0.0", + "pdfplumber>=0.10.0", + "python-docx>=1.1.0", + "openpyxl>=3.1.0", ] [project.optional-dependencies]