""" 视频审核后台任务 完整的视频审核流程:下载 → 提取帧 → ASR → 视觉分析 → 生成报告 """ import asyncio import os from datetime import datetime, timezone from typing import Optional from celery import shared_task from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine from sqlalchemy.orm import sessionmaker from app.config import settings from app.models.review import ReviewTask, TaskStatus as DBTaskStatus from app.models.rule import ForbiddenWord, Competitor from app.models.ai_config import AIConfig from app.services.video_download import VideoDownloadService, DownloadResult from app.services.keyframe import KeyFrameExtractor, ExtractionResult from app.services.asr import VideoASRService, TranscriptionResult from app.services.vision import CompetitorLogoDetector, VideoOCRService from app.services.video_review import VideoReviewService from app.utils.crypto import decrypt_api_key # 异步数据库引擎 _async_engine = None _async_session_factory = None def get_async_engine(): """获取异步数据库引擎""" global _async_engine if _async_engine is None: _async_engine = create_async_engine( settings.DATABASE_URL, echo=False, pool_size=5, max_overflow=10, ) return _async_engine def get_async_session() -> sessionmaker: """获取异步会话工厂""" global _async_session_factory if _async_session_factory is None: _async_session_factory = sessionmaker( get_async_engine(), class_=AsyncSession, expire_on_commit=False, ) return _async_session_factory async def update_review_progress( db: AsyncSession, review_id: str, progress: int, current_step: str, status: Optional[DBTaskStatus] = None, ): """更新审核进度""" result = await db.execute( select(ReviewTask).where(ReviewTask.id == review_id) ) task = result.scalar_one_or_none() if task: task.progress = progress task.current_step = current_step if status: task.status = status await db.commit() async def complete_review( db: AsyncSession, review_id: str, score: int, summary: str, violations: list[dict], status: DBTaskStatus = DBTaskStatus.COMPLETED, ): """完成审核""" result = await db.execute( select(ReviewTask).where(ReviewTask.id == review_id) ) task = result.scalar_one_or_none() if task: task.status = status task.progress = 100 task.current_step = "完成" task.score = score task.summary = summary task.violations = violations task.completed_at = datetime.now(timezone.utc) await db.commit() async def fail_review( db: AsyncSession, review_id: str, error: str, ): """审核失败""" result = await db.execute( select(ReviewTask).where(ReviewTask.id == review_id) ) task = result.scalar_one_or_none() if task: task.status = DBTaskStatus.FAILED task.current_step = "失败" task.summary = f"审核失败: {error}" await db.commit() async def get_ai_config(db: AsyncSession, tenant_id: str) -> Optional[dict]: """获取租户 AI 配置""" result = await db.execute( select(AIConfig).where( AIConfig.tenant_id == tenant_id, AIConfig.is_configured == True, ) ) config = result.scalar_one_or_none() if not config: return None return { "api_key": decrypt_api_key(config.api_key_encrypted), "base_url": config.base_url, "models": config.models, } async def get_forbidden_words(db: AsyncSession, tenant_id: str) -> list[str]: """获取违禁词列表""" result = await db.execute( select(ForbiddenWord.word).where(ForbiddenWord.tenant_id == tenant_id) ) return [row[0] for row in result.fetchall()] async def get_competitors(db: AsyncSession, tenant_id: str, brand_id: str) -> list[str]: """获取竞品列表""" result = await db.execute( select(Competitor.name).where( Competitor.tenant_id == tenant_id, Competitor.brand_id == brand_id, ) ) return [row[0] for row in result.fetchall()] async def process_video_review( review_id: str, tenant_id: str, video_url: str, brand_id: str, platform: str, ): """ 处理视频审核(异步核心逻辑) 流程: 1. 下载视频 2. 提取关键帧 3. ASR 语音转写 4. 视觉分析(竞品 Logo 检测) 5. OCR 字幕提取 6. 违规检测 7. 生成报告 """ session_factory = get_async_session() download_service = VideoDownloadService() keyframe_extractor = KeyFrameExtractor() review_service = VideoReviewService() video_path = None frames_dir = None logo_detector = None ocr_service = None asr_service = None async with session_factory() as db: try: # 更新状态:处理中 await update_review_progress( db, review_id, 5, "开始处理", status=DBTaskStatus.PROCESSING, ) # 获取 AI 配置 ai_config = await get_ai_config(db, tenant_id) if not ai_config: await fail_review(db, review_id, "AI 服务未配置") return # 获取规则 forbidden_words = await get_forbidden_words(db, tenant_id) competitors = await get_competitors(db, tenant_id, brand_id) # 初始化 AI 服务 api_key = ai_config["api_key"] base_url = ai_config["base_url"] models = ai_config["models"] asr_service = VideoASRService( api_key=api_key, base_url=base_url, model=models.get("audio", "whisper-1"), ) logo_detector = CompetitorLogoDetector( api_key=api_key, base_url=base_url, model=models.get("vision", "gpt-4o"), ) ocr_service = VideoOCRService( api_key=api_key, base_url=base_url, model=models.get("vision", "gpt-4o"), ) # 1. 下载视频 await update_review_progress(db, review_id, 10, "下载视频") download_result: DownloadResult = await download_service.download(video_url) if not download_result.success: await fail_review(db, review_id, f"视频下载失败: {download_result.error}") return video_path = download_result.file_path # 2. 提取关键帧 await update_review_progress(db, review_id, 25, "提取关键帧") extraction_result: ExtractionResult = await keyframe_extractor.extract_at_intervals( video_path, interval_seconds=2.0, max_frames=30, ) if not extraction_result.success: await fail_review(db, review_id, f"关键帧提取失败: {extraction_result.error}") return frames_dir = extraction_result.output_dir frames = extraction_result.frames all_violations = [] # 3. ASR 语音转写 await update_review_progress(db, review_id, 40, "语音转写") transcript_result: TranscriptionResult = await asr_service.transcribe_video(video_path) transcript = [] if transcript_result.success: transcript = [ {"text": seg.text, "start": seg.start, "end": seg.end} for seg in transcript_result.segments ] # 检测口播违禁词 speech_violations = await review_service.detect_forbidden_words_in_speech( transcript, forbidden_words, context_aware=True, ) all_violations.extend(speech_violations) # 4. 视觉分析 - 竞品 Logo 检测 await update_review_progress(db, review_id, 60, "检测竞品 Logo") if competitors and frames: logo_violations = await logo_detector.detect(frames, competitors) all_violations.extend(logo_violations) # 5. OCR 字幕提取 await update_review_progress(db, review_id, 75, "提取字幕") if frames: subtitles = await ocr_service.extract_subtitles(frames) # 检测字幕违禁词 subtitle_violations = await review_service.detect_forbidden_words_in_subtitle( subtitles, forbidden_words, ) all_violations.extend(subtitle_violations) # 6. 计算分数和生成报告 await update_review_progress(db, review_id, 90, "生成报告") score = review_service.calculate_score(all_violations) if not all_violations: summary = "视频内容合规,未发现违规项" else: high_count = sum(1 for v in all_violations if v.get("risk_level") == "high") medium_count = sum(1 for v in all_violations if v.get("risk_level") == "medium") summary = f"发现 {len(all_violations)} 处违规" if high_count > 0: summary += f"({high_count} 处高风险)" # 7. 完成审核 await complete_review( db, review_id, score=score, summary=summary, violations=all_violations, ) except Exception as e: await fail_review(db, review_id, str(e)) finally: # 清理资源 if video_path: download_service.cleanup(video_path) if frames_dir: keyframe_extractor.cleanup(frames_dir) if logo_detector: await logo_detector.close() if ocr_service: await ocr_service.close() @shared_task( bind=True, name="app.tasks.review.process_video_review_task", max_retries=3, default_retry_delay=60, ) def process_video_review_task( self, review_id: str, tenant_id: str, video_url: str, brand_id: str, platform: str, ): """ 视频审核 Celery 任务 Args: review_id: 审核任务 ID tenant_id: 租户 ID video_url: 视频 URL brand_id: 品牌 ID platform: 平台 """ try: # 运行异步任务 asyncio.run(process_video_review( review_id=review_id, tenant_id=tenant_id, video_url=video_url, brand_id=brand_id, platform=platform, )) except Exception as e: # 重试 raise self.retry(exc=e) @shared_task(name="app.tasks.review.cleanup_old_files_task") def cleanup_old_files_task(): """清理过期的临时文件""" from app.services.video_download import get_download_service service = get_download_service() deleted = service.cleanup_old_files(max_age_seconds=3600) return {"deleted_files": deleted}