主要更新: - 更新代理商端文档,明确项目由品牌方分配流程 - 新增Brief配置详情页(已配置)设计稿 - 完善工作台紧急待办中品牌新任务功能 - 整理Pencil设计文件中代理商端页面顺序 - 新增后端FastAPI框架及核心API - 新增前端Next.js页面和组件库 - 添加.gitignore排除构建和缓存文件 Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
367 lines
11 KiB
Python
367 lines
11 KiB
Python
"""
|
||
视频审核后台任务
|
||
完整的视频审核流程:下载 → 提取帧 → ASR → 视觉分析 → 生成报告
|
||
"""
|
||
import asyncio
|
||
import os
|
||
from datetime import datetime, timezone
|
||
from typing import Optional
|
||
|
||
from celery import shared_task
|
||
from sqlalchemy import select
|
||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
||
from sqlalchemy.orm import sessionmaker
|
||
|
||
from app.config import settings
|
||
from app.models.review import ReviewTask, TaskStatus as DBTaskStatus
|
||
from app.models.rule import ForbiddenWord, Competitor
|
||
from app.models.ai_config import AIConfig
|
||
from app.services.video_download import VideoDownloadService, DownloadResult
|
||
from app.services.keyframe import KeyFrameExtractor, ExtractionResult
|
||
from app.services.asr import VideoASRService, TranscriptionResult
|
||
from app.services.vision import CompetitorLogoDetector, VideoOCRService
|
||
from app.services.video_review import VideoReviewService
|
||
from app.utils.crypto import decrypt_api_key
|
||
|
||
|
||
# 异步数据库引擎
|
||
_async_engine = None
|
||
_async_session_factory = None
|
||
|
||
|
||
def get_async_engine():
|
||
"""获取异步数据库引擎"""
|
||
global _async_engine
|
||
if _async_engine is None:
|
||
_async_engine = create_async_engine(
|
||
settings.DATABASE_URL,
|
||
echo=False,
|
||
pool_size=5,
|
||
max_overflow=10,
|
||
)
|
||
return _async_engine
|
||
|
||
|
||
def get_async_session() -> sessionmaker:
|
||
"""获取异步会话工厂"""
|
||
global _async_session_factory
|
||
if _async_session_factory is None:
|
||
_async_session_factory = sessionmaker(
|
||
get_async_engine(),
|
||
class_=AsyncSession,
|
||
expire_on_commit=False,
|
||
)
|
||
return _async_session_factory
|
||
|
||
|
||
async def update_review_progress(
|
||
db: AsyncSession,
|
||
review_id: str,
|
||
progress: int,
|
||
current_step: str,
|
||
status: Optional[DBTaskStatus] = None,
|
||
):
|
||
"""更新审核进度"""
|
||
result = await db.execute(
|
||
select(ReviewTask).where(ReviewTask.id == review_id)
|
||
)
|
||
task = result.scalar_one_or_none()
|
||
if task:
|
||
task.progress = progress
|
||
task.current_step = current_step
|
||
if status:
|
||
task.status = status
|
||
await db.commit()
|
||
|
||
|
||
async def complete_review(
|
||
db: AsyncSession,
|
||
review_id: str,
|
||
score: int,
|
||
summary: str,
|
||
violations: list[dict],
|
||
status: DBTaskStatus = DBTaskStatus.COMPLETED,
|
||
):
|
||
"""完成审核"""
|
||
result = await db.execute(
|
||
select(ReviewTask).where(ReviewTask.id == review_id)
|
||
)
|
||
task = result.scalar_one_or_none()
|
||
if task:
|
||
task.status = status
|
||
task.progress = 100
|
||
task.current_step = "完成"
|
||
task.score = score
|
||
task.summary = summary
|
||
task.violations = violations
|
||
task.completed_at = datetime.now(timezone.utc)
|
||
await db.commit()
|
||
|
||
|
||
async def fail_review(
|
||
db: AsyncSession,
|
||
review_id: str,
|
||
error: str,
|
||
):
|
||
"""审核失败"""
|
||
result = await db.execute(
|
||
select(ReviewTask).where(ReviewTask.id == review_id)
|
||
)
|
||
task = result.scalar_one_or_none()
|
||
if task:
|
||
task.status = DBTaskStatus.FAILED
|
||
task.current_step = "失败"
|
||
task.summary = f"审核失败: {error}"
|
||
await db.commit()
|
||
|
||
|
||
async def get_ai_config(db: AsyncSession, tenant_id: str) -> Optional[dict]:
|
||
"""获取租户 AI 配置"""
|
||
result = await db.execute(
|
||
select(AIConfig).where(
|
||
AIConfig.tenant_id == tenant_id,
|
||
AIConfig.is_configured == True,
|
||
)
|
||
)
|
||
config = result.scalar_one_or_none()
|
||
if not config:
|
||
return None
|
||
|
||
return {
|
||
"api_key": decrypt_api_key(config.api_key_encrypted),
|
||
"base_url": config.base_url,
|
||
"models": config.models,
|
||
}
|
||
|
||
|
||
async def get_forbidden_words(db: AsyncSession, tenant_id: str) -> list[str]:
|
||
"""获取违禁词列表"""
|
||
result = await db.execute(
|
||
select(ForbiddenWord.word).where(ForbiddenWord.tenant_id == tenant_id)
|
||
)
|
||
return [row[0] for row in result.fetchall()]
|
||
|
||
|
||
async def get_competitors(db: AsyncSession, tenant_id: str, brand_id: str) -> list[str]:
|
||
"""获取竞品列表"""
|
||
result = await db.execute(
|
||
select(Competitor.name).where(
|
||
Competitor.tenant_id == tenant_id,
|
||
Competitor.brand_id == brand_id,
|
||
)
|
||
)
|
||
return [row[0] for row in result.fetchall()]
|
||
|
||
|
||
async def process_video_review(
|
||
review_id: str,
|
||
tenant_id: str,
|
||
video_url: str,
|
||
brand_id: str,
|
||
platform: str,
|
||
):
|
||
"""
|
||
处理视频审核(异步核心逻辑)
|
||
|
||
流程:
|
||
1. 下载视频
|
||
2. 提取关键帧
|
||
3. ASR 语音转写
|
||
4. 视觉分析(竞品 Logo 检测)
|
||
5. OCR 字幕提取
|
||
6. 违规检测
|
||
7. 生成报告
|
||
"""
|
||
session_factory = get_async_session()
|
||
download_service = VideoDownloadService()
|
||
keyframe_extractor = KeyFrameExtractor()
|
||
review_service = VideoReviewService()
|
||
|
||
video_path = None
|
||
frames_dir = None
|
||
logo_detector = None
|
||
ocr_service = None
|
||
asr_service = None
|
||
|
||
async with session_factory() as db:
|
||
try:
|
||
# 更新状态:处理中
|
||
await update_review_progress(
|
||
db, review_id, 5, "开始处理",
|
||
status=DBTaskStatus.PROCESSING,
|
||
)
|
||
|
||
# 获取 AI 配置
|
||
ai_config = await get_ai_config(db, tenant_id)
|
||
if not ai_config:
|
||
await fail_review(db, review_id, "AI 服务未配置")
|
||
return
|
||
|
||
# 获取规则
|
||
forbidden_words = await get_forbidden_words(db, tenant_id)
|
||
competitors = await get_competitors(db, tenant_id, brand_id)
|
||
|
||
# 初始化 AI 服务
|
||
api_key = ai_config["api_key"]
|
||
base_url = ai_config["base_url"]
|
||
models = ai_config["models"]
|
||
|
||
asr_service = VideoASRService(
|
||
api_key=api_key,
|
||
base_url=base_url,
|
||
model=models.get("audio", "whisper-1"),
|
||
)
|
||
logo_detector = CompetitorLogoDetector(
|
||
api_key=api_key,
|
||
base_url=base_url,
|
||
model=models.get("vision", "gpt-4o"),
|
||
)
|
||
ocr_service = VideoOCRService(
|
||
api_key=api_key,
|
||
base_url=base_url,
|
||
model=models.get("vision", "gpt-4o"),
|
||
)
|
||
|
||
# 1. 下载视频
|
||
await update_review_progress(db, review_id, 10, "下载视频")
|
||
download_result: DownloadResult = await download_service.download(video_url)
|
||
if not download_result.success:
|
||
await fail_review(db, review_id, f"视频下载失败: {download_result.error}")
|
||
return
|
||
video_path = download_result.file_path
|
||
|
||
# 2. 提取关键帧
|
||
await update_review_progress(db, review_id, 25, "提取关键帧")
|
||
extraction_result: ExtractionResult = await keyframe_extractor.extract_at_intervals(
|
||
video_path,
|
||
interval_seconds=2.0,
|
||
max_frames=30,
|
||
)
|
||
if not extraction_result.success:
|
||
await fail_review(db, review_id, f"关键帧提取失败: {extraction_result.error}")
|
||
return
|
||
frames_dir = extraction_result.output_dir
|
||
frames = extraction_result.frames
|
||
|
||
all_violations = []
|
||
|
||
# 3. ASR 语音转写
|
||
await update_review_progress(db, review_id, 40, "语音转写")
|
||
transcript_result: TranscriptionResult = await asr_service.transcribe_video(video_path)
|
||
transcript = []
|
||
if transcript_result.success:
|
||
transcript = [
|
||
{"text": seg.text, "start": seg.start, "end": seg.end}
|
||
for seg in transcript_result.segments
|
||
]
|
||
|
||
# 检测口播违禁词
|
||
speech_violations = await review_service.detect_forbidden_words_in_speech(
|
||
transcript,
|
||
forbidden_words,
|
||
context_aware=True,
|
||
)
|
||
all_violations.extend(speech_violations)
|
||
|
||
# 4. 视觉分析 - 竞品 Logo 检测
|
||
await update_review_progress(db, review_id, 60, "检测竞品 Logo")
|
||
if competitors and frames:
|
||
logo_violations = await logo_detector.detect(frames, competitors)
|
||
all_violations.extend(logo_violations)
|
||
|
||
# 5. OCR 字幕提取
|
||
await update_review_progress(db, review_id, 75, "提取字幕")
|
||
if frames:
|
||
subtitles = await ocr_service.extract_subtitles(frames)
|
||
|
||
# 检测字幕违禁词
|
||
subtitle_violations = await review_service.detect_forbidden_words_in_subtitle(
|
||
subtitles,
|
||
forbidden_words,
|
||
)
|
||
all_violations.extend(subtitle_violations)
|
||
|
||
# 6. 计算分数和生成报告
|
||
await update_review_progress(db, review_id, 90, "生成报告")
|
||
score = review_service.calculate_score(all_violations)
|
||
|
||
if not all_violations:
|
||
summary = "视频内容合规,未发现违规项"
|
||
else:
|
||
high_count = sum(1 for v in all_violations if v.get("risk_level") == "high")
|
||
medium_count = sum(1 for v in all_violations if v.get("risk_level") == "medium")
|
||
summary = f"发现 {len(all_violations)} 处违规"
|
||
if high_count > 0:
|
||
summary += f"({high_count} 处高风险)"
|
||
|
||
# 7. 完成审核
|
||
await complete_review(
|
||
db,
|
||
review_id,
|
||
score=score,
|
||
summary=summary,
|
||
violations=all_violations,
|
||
)
|
||
|
||
except Exception as e:
|
||
await fail_review(db, review_id, str(e))
|
||
|
||
finally:
|
||
# 清理资源
|
||
if video_path:
|
||
download_service.cleanup(video_path)
|
||
if frames_dir:
|
||
keyframe_extractor.cleanup(frames_dir)
|
||
if logo_detector:
|
||
await logo_detector.close()
|
||
if ocr_service:
|
||
await ocr_service.close()
|
||
|
||
|
||
@shared_task(
|
||
bind=True,
|
||
name="app.tasks.review.process_video_review_task",
|
||
max_retries=3,
|
||
default_retry_delay=60,
|
||
)
|
||
def process_video_review_task(
|
||
self,
|
||
review_id: str,
|
||
tenant_id: str,
|
||
video_url: str,
|
||
brand_id: str,
|
||
platform: str,
|
||
):
|
||
"""
|
||
视频审核 Celery 任务
|
||
|
||
Args:
|
||
review_id: 审核任务 ID
|
||
tenant_id: 租户 ID
|
||
video_url: 视频 URL
|
||
brand_id: 品牌 ID
|
||
platform: 平台
|
||
"""
|
||
try:
|
||
# 运行异步任务
|
||
asyncio.run(process_video_review(
|
||
review_id=review_id,
|
||
tenant_id=tenant_id,
|
||
video_url=video_url,
|
||
brand_id=brand_id,
|
||
platform=platform,
|
||
))
|
||
except Exception as e:
|
||
# 重试
|
||
raise self.retry(exc=e)
|
||
|
||
|
||
@shared_task(name="app.tasks.review.cleanup_old_files_task")
|
||
def cleanup_old_files_task():
|
||
"""清理过期的临时文件"""
|
||
from app.services.video_download import get_download_service
|
||
|
||
service = get_download_service()
|
||
deleted = service.cleanup_old_files(max_age_seconds=3600)
|
||
return {"deleted_files": deleted}
|