""" Brief API 项目 Brief 文档的 CRUD + AI 解析 """ import json import logging from fastapi import APIRouter, Depends, HTTPException, status from pydantic import BaseModel from typing import Optional from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import select from sqlalchemy.orm import selectinload from app.database import get_db from app.models.user import User, UserRole from app.models.project import Project from app.models.brief import Brief from app.models.organization import Brand, Agency from app.api.deps import get_current_user from app.schemas.brief import ( BriefCreateRequest, BriefUpdateRequest, AgencyBriefUpdateRequest, BriefResponse, ) from app.services.auth import generate_id logger = logging.getLogger(__name__) router = APIRouter(prefix="/projects/{project_id}/brief", tags=["Brief"]) async def _get_project_with_permission( project_id: str, current_user: User, db: AsyncSession, require_write: bool = False, ) -> Project: """获取项目并检查权限""" result = await db.execute( select(Project) .options(selectinload(Project.brand), selectinload(Project.agencies)) .where(Project.id == project_id) ) project = result.scalar_one_or_none() if not project: raise HTTPException(status_code=404, detail="项目不存在") if current_user.role == UserRole.BRAND: brand_result = await db.execute( select(Brand).where(Brand.user_id == current_user.id) ) brand = brand_result.scalar_one_or_none() if not brand or project.brand_id != brand.id: raise HTTPException(status_code=403, detail="无权访问此项目") elif current_user.role == UserRole.AGENCY: if require_write: raise HTTPException(status_code=403, detail="代理商无权修改 Brief") agency_result = await db.execute( select(Agency).where(Agency.user_id == current_user.id) ) agency = agency_result.scalar_one_or_none() if not agency or agency not in project.agencies: raise HTTPException(status_code=403, detail="无权访问此项目") elif current_user.role == UserRole.CREATOR: # 达人可以查看 Brief(只读) if require_write: raise HTTPException(status_code=403, detail="达人无权修改 Brief") else: raise HTTPException(status_code=403, detail="无权访问") return project def _brief_to_response(brief: Brief) -> BriefResponse: """转换 Brief 为响应""" return BriefResponse( id=brief.id, project_id=brief.project_id, project_name=brief.project.name if brief.project else None, file_url=brief.file_url, file_name=brief.file_name, selling_points=brief.selling_points, min_selling_points=brief.min_selling_points, blacklist_words=brief.blacklist_words, competitors=brief.competitors, brand_tone=brief.brand_tone, min_duration=brief.min_duration, max_duration=brief.max_duration, other_requirements=brief.other_requirements, attachments=brief.attachments, agency_attachments=brief.agency_attachments, created_at=brief.created_at, updated_at=brief.updated_at, ) @router.get("", response_model=BriefResponse) async def get_brief( project_id: str, current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db), ): """获取项目 Brief""" await _get_project_with_permission(project_id, current_user, db) result = await db.execute( select(Brief) .options(selectinload(Brief.project)) .where(Brief.project_id == project_id) ) brief = result.scalar_one_or_none() if not brief: raise HTTPException(status_code=404, detail="Brief 不存在") return _brief_to_response(brief) @router.post("", response_model=BriefResponse, status_code=status.HTTP_201_CREATED) async def create_brief( project_id: str, request: BriefCreateRequest, current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db), ): """创建项目 Brief(品牌方操作)""" await _get_project_with_permission(project_id, current_user, db, require_write=True) # 检查是否已存在 existing = await db.execute( select(Brief).where(Brief.project_id == project_id) ) if existing.scalar_one_or_none(): raise HTTPException(status_code=400, detail="该项目已有 Brief,请使用更新接口") brief = Brief( id=generate_id("BF"), project_id=project_id, file_url=request.file_url, file_name=request.file_name, selling_points=request.selling_points, blacklist_words=request.blacklist_words, competitors=request.competitors, brand_tone=request.brand_tone, min_duration=request.min_duration, max_duration=request.max_duration, other_requirements=request.other_requirements, attachments=request.attachments, agency_attachments=request.agency_attachments, ) db.add(brief) await db.flush() # 重新加载 result = await db.execute( select(Brief) .options(selectinload(Brief.project)) .where(Brief.id == brief.id) ) brief = result.scalar_one() return _brief_to_response(brief) @router.put("", response_model=BriefResponse) async def update_brief( project_id: str, request: BriefUpdateRequest, current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db), ): """更新项目 Brief(品牌方操作)""" await _get_project_with_permission(project_id, current_user, db, require_write=True) result = await db.execute( select(Brief) .options(selectinload(Brief.project)) .where(Brief.project_id == project_id) ) brief = result.scalar_one_or_none() if not brief: raise HTTPException(status_code=404, detail="Brief 不存在") # 更新字段 update_fields = request.model_dump(exclude_unset=True) for field, value in update_fields.items(): setattr(brief, field, value) await db.flush() await db.refresh(brief) return _brief_to_response(brief) @router.patch("/agency-attachments", response_model=BriefResponse) async def update_brief_agency_attachments( project_id: str, request: AgencyBriefUpdateRequest, current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db), ): """更新 Brief 代理商配置(代理商操作) 代理商可更新:agency_attachments、selling_points、blacklist_words。 不能修改品牌方设置的核心 Brief 内容(文件、时长、竞品等)。 """ # 权限检查:代理商必须属于该项目 result = await db.execute( select(Project) .options(selectinload(Project.brand), selectinload(Project.agencies)) .where(Project.id == project_id) ) project = result.scalar_one_or_none() if not project: raise HTTPException(status_code=404, detail="项目不存在") if current_user.role == UserRole.AGENCY: agency_result = await db.execute( select(Agency).where(Agency.user_id == current_user.id) ) agency = agency_result.scalar_one_or_none() if not agency or agency not in project.agencies: raise HTTPException(status_code=403, detail="无权访问此项目") elif current_user.role == UserRole.BRAND: # 品牌方也可以更新代理商附件 brand_result = await db.execute( select(Brand).where(Brand.user_id == current_user.id) ) brand = brand_result.scalar_one_or_none() if not brand or project.brand_id != brand.id: raise HTTPException(status_code=403, detail="无权访问此项目") else: raise HTTPException(status_code=403, detail="无权修改代理商附件") # 获取 Brief brief_result = await db.execute( select(Brief) .options(selectinload(Brief.project)) .where(Brief.project_id == project_id) ) brief = brief_result.scalar_one_or_none() if not brief: raise HTTPException(status_code=404, detail="Brief 不存在") # 更新代理商可编辑的字段 update_fields = request.model_dump(exclude_unset=True) for field, value in update_fields.items(): setattr(brief, field, value) await db.flush() await db.refresh(brief) return _brief_to_response(brief) # ==================== AI 解析 ==================== class BriefParseResponse(BaseModel): """Brief AI 解析响应""" product_name: str = "" target_audience: str = "" content_requirements: str = "" selling_points: list[dict] = [] blacklist_words: list[dict] = [] @router.post("/parse", response_model=BriefParseResponse) async def parse_brief_with_ai( project_id: str, current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db), ): """ AI 解析 Brief 文档 从品牌方上传的 Brief 文件中提取结构化信息: - 产品名称 - 目标人群 - 内容要求 - 卖点建议 - 违禁词建议 """ # 权限检查(代理商需要属于该项目) project = await _get_project_with_permission(project_id, current_user, db) # 获取 Brief result = await db.execute( select(Brief) .options(selectinload(Brief.project)) .where(Brief.project_id == project_id) ) brief = result.scalar_one_or_none() if not brief: raise HTTPException(status_code=404, detail="Brief 不存在,请先让品牌方创建 Brief") # 收集所有可解析的文档 URL documents: list[dict] = [] # [{"url": ..., "name": ...}] if brief.file_url and brief.file_name: documents.append({"url": brief.file_url, "name": brief.file_name}) if brief.attachments: for att in brief.attachments: if att.get("url") and att.get("name"): documents.append({"url": att["url"], "name": att["name"]}) if not documents: raise HTTPException(status_code=400, detail="Brief 没有可解析的文件") # 提取文本(每个文档限时 60 秒) import asyncio from app.services.document_parser import DocumentParser all_texts = [] for doc in documents: try: text = await asyncio.wait_for( DocumentParser.download_and_parse(doc["url"], doc["name"]), timeout=60.0, ) if text and text.strip(): all_texts.append(f"=== {doc['name']} ===\n{text}") logger.info(f"成功解析文档 {doc['name']},提取 {len(text)} 字符") except asyncio.TimeoutError: logger.warning(f"解析文档 {doc['name']} 超时(60s),已跳过") except Exception as e: logger.warning(f"解析文档 {doc['name']} 失败: {e}") if not all_texts: raise HTTPException(status_code=400, detail="所有文档均解析失败,无法提取文本内容") combined_text = "\n\n".join(all_texts) # 截断过长文本 max_chars = 15000 if len(combined_text) > max_chars: combined_text = combined_text[:max_chars] + "\n...(内容已截断)" # 获取 AI 客户端 from app.services.ai_service import AIServiceFactory tenant_id = project.brand_id or "default" ai_client = await AIServiceFactory.get_client(tenant_id, db) if not ai_client: raise HTTPException( status_code=400, detail="AI 服务未配置,请在品牌方设置中配置 AI 服务", ) config = await AIServiceFactory.get_config(tenant_id, db) text_model = "gpt-4o" if config and config.models: text_model = config.models.get("text", "gpt-4o") # AI 解析 prompt = f"""你是营销内容合规审核专家。请从以下品牌方 Brief 文档中提取结构化信息。 文档内容: {combined_text} 请以 JSON 格式返回,不要包含其他内容: {{ "product_name": "产品名称", "target_audience": "目标人群描述", "content_requirements": "内容创作要求的简要总结", "selling_points": [ {{"content": "卖点1", "priority": "core"}}, {{"content": "卖点2", "priority": "recommended"}}, {{"content": "卖点3", "priority": "reference"}} ], "blacklist_words": [ {{"word": "违禁词1", "reason": "原因"}}, {{"word": "违禁词2", "reason": "原因"}} ] }} 说明: - product_name: 从文档中识别的产品/品牌名称 - target_audience: 目标消费人群 - content_requirements: 对达人创作内容的要求(时长、风格、场景等) - selling_points: 产品卖点,priority 说明: - "core": 核心卖点,品牌方重点关注,建议优先传达 - "recommended": 推荐卖点,建议提及 - "reference": 参考信息,不要求出现在脚本中 - blacklist_words: 从文档中识别的需要避免的词语(绝对化用语、竞品名、敏感词等)""" last_error = None for attempt in range(2): try: response = await ai_client.chat_completion( messages=[{"role": "user", "content": prompt}], model=text_model, temperature=0.2 if attempt == 0 else 0.1, max_tokens=2000, ) # 提取 JSON logger.info(f"AI 原始响应 (attempt={attempt}): {response.content[:500]}") content = _extract_json_from_response(response.content) logger.info(f"提取的 JSON: {content[:500]}") parsed = json.loads(content) return BriefParseResponse( product_name=parsed.get("product_name", ""), target_audience=parsed.get("target_audience", ""), content_requirements=parsed.get("content_requirements", ""), selling_points=parsed.get("selling_points", []), blacklist_words=parsed.get("blacklist_words", []), ) except json.JSONDecodeError as e: last_error = e logger.warning(f"AI 返回内容非 JSON (attempt={attempt}): {e}, raw={response.content[:300]}") continue except Exception as e: logger.error(f"AI 解析 Brief 失败: {e}") raise HTTPException(status_code=500, detail=f"AI 解析失败: {str(e)[:200]}") # 两次都失败 logger.error(f"AI 解析 Brief JSON 格式错误,两次重试均失败: {last_error}") raise HTTPException(status_code=500, detail="AI 解析结果格式错误,请重试") def _extract_json_from_response(raw: str) -> str: """从 AI 响应中提取 JSON 内容(处理 markdown 代码块、中文引号等)""" import re text = raw.strip() # 移除 markdown ```json ... ``` 代码块包裹 m = re.search(r'```(?:json)?\s*\n(.*?)```', text, re.DOTALL) if m: text = m.group(1).strip() # 尝试找到第一个 { 和最后一个 } first_brace = text.find("{") last_brace = text.rfind("}") if first_brace != -1 and last_brace != -1 and last_brace > first_brace: text = text[first_brace:last_brace + 1] # 清理中文引号等特殊字符 text = _sanitize_json_string(text) return text def _sanitize_json_string(text: str) -> str: """ 清理 AI 返回的 JSON 文本中的中文引号等特殊字符。 中文引号 "" 在 JSON 字符串值内会破坏解析。 """ result = [] in_string = False i = 0 while i < len(text): ch = text[i] if ch == '\\' and in_string and i + 1 < len(text): result.append(ch) result.append(text[i + 1]) i += 2 continue if ch == '"' and not in_string: in_string = True result.append(ch) elif ch == '"' and in_string: in_string = False result.append(ch) elif in_string and ch in '\u201c\u201d\u300c\u300d': # 中文引号 "" 和「」 → 单引号 result.append("'") elif not in_string and ch in '\u201c\u201d': # JSON 结构层的中文引号 → 英文双引号 result.append('"') else: result.append(ch) i += 1 return ''.join(result)