""" 视频审核模块 提供视频上传验证、ASR/OCR/Logo检测、审核报告生成等功能 验收标准: - 100MB 视频审核 ≤ 5 分钟 - 竞品 Logo F1 ≥ 0.85 - ASR 字错率 ≤ 10% - OCR 准确率 ≥ 95% """ from dataclasses import dataclass, field from typing import Any from datetime import datetime from enum import Enum class ProcessingStatus(str, Enum): """处理状态""" PENDING = "pending" PROCESSING = "processing" COMPLETED = "completed" FAILED = "failed" @dataclass class ValidationResult: """验证结果""" is_valid: bool error_message: str = "" @dataclass class ASRSegment: """ASR 分段结果""" word: str start_ms: int end_ms: int confidence: float @dataclass class ASRResult: """ASR 识别结果""" text: str segments: list[ASRSegment] @dataclass class OCRFrame: """OCR 帧结果""" timestamp_ms: int text: str confidence: float bbox: list[int] @dataclass class OCRResult: """OCR 识别结果""" frames: list[OCRFrame] @dataclass class LogoDetection: """Logo 检测结果""" logo_id: str brand: str confidence: float bbox: list[int] @dataclass class CVResult: """CV 检测结果""" detections: list[dict[str, Any]] @dataclass class ViolationEvidence: """违规证据""" url: str timestamp_start: float timestamp_end: float screenshot_url: str = "" @dataclass class Violation: """违规项""" violation_id: str type: str description: str severity: str evidence: ViolationEvidence @dataclass class BriefComplianceResult: """Brief 合规检查结果""" selling_point_coverage: dict[str, Any] duration_check: dict[str, Any] frequency_check: dict[str, Any] @dataclass class AuditReport: """审核报告""" report_id: str video_id: str processing_status: ProcessingStatus asr_results: dict[str, Any] ocr_results: dict[str, Any] cv_results: dict[str, Any] violations: list[Violation] brief_compliance: BriefComplianceResult | None created_at: datetime = field(default_factory=datetime.now) class VideoFileValidator: """视频文件验证器""" MAX_SIZE_BYTES = 100 * 1024 * 1024 # 100MB SUPPORTED_FORMATS = { "mp4": "video/mp4", "mov": "video/quicktime", } def validate_size(self, file_size_bytes: int) -> ValidationResult: """验证文件大小""" if file_size_bytes <= self.MAX_SIZE_BYTES: return ValidationResult(is_valid=True) return ValidationResult( is_valid=False, error_message=f"文件大小超过限制,最大支持 100MB,当前 {file_size_bytes / (1024*1024):.1f}MB" ) def validate_format(self, file_format: str, mime_type: str) -> ValidationResult: """验证文件格式""" format_lower = file_format.lower() if format_lower in self.SUPPORTED_FORMATS: expected_mime = self.SUPPORTED_FORMATS[format_lower] if mime_type == expected_mime: return ValidationResult(is_valid=True) return ValidationResult( is_valid=False, error_message=f"MIME 类型不匹配,期望 {expected_mime},实际 {mime_type}" ) return ValidationResult( is_valid=False, error_message=f"不支持的文件格式 {file_format},仅支持 MP4/MOV" ) class ASRService: """ASR 语音识别服务""" def transcribe(self, audio_path: str) -> dict[str, Any]: """ 语音转文字 Returns: 包含 text 和 segments 的字典 """ # 实际实现需要调用 ASR API(如阿里云、讯飞等) return { "text": "示例转写文本", "segments": [ { "word": "示例", "start_ms": 0, "end_ms": 500, "confidence": 0.98, }, { "word": "转写", "start_ms": 500, "end_ms": 1000, "confidence": 0.97, }, { "word": "文本", "start_ms": 1000, "end_ms": 1500, "confidence": 0.96, }, ], } def calculate_wer(self, hypothesis: str, reference: str) -> float: """ 计算字错率 (Word Error Rate) Args: hypothesis: 识别结果 reference: 参考文本 Returns: WER 值 (0-1) """ # 简化实现:字符级别计算 if not reference: return 0.0 if not hypothesis else 1.0 h_chars = list(hypothesis) r_chars = list(reference) # 使用编辑距离 m, n = len(r_chars), len(h_chars) dp = [[0] * (n + 1) for _ in range(m + 1)] for i in range(m + 1): dp[i][0] = i for j in range(n + 1): dp[0][j] = j for i in range(1, m + 1): for j in range(1, n + 1): if r_chars[i-1] == h_chars[j-1]: dp[i][j] = dp[i-1][j-1] else: dp[i][j] = min( dp[i-1][j] + 1, # 删除 dp[i][j-1] + 1, # 插入 dp[i-1][j-1] + 1, # 替换 ) return dp[m][n] / m if m > 0 else 0.0 class OCRService: """OCR 字幕识别服务""" def extract_text(self, image_path: str) -> dict[str, Any]: """ 从图片中提取文字 Returns: 包含 frames 的字典 """ # 实际实现需要调用 OCR API(如百度、阿里等) return { "frames": [ { "timestamp_ms": 0, "text": "示例字幕", "confidence": 0.98, "bbox": [100, 450, 300, 480], }, ], } def extract_from_video(self, video_path: str, sample_rate_ms: int = 1000) -> dict[str, Any]: """从视频中提取字幕""" # 实际实现需要视频帧采样 + OCR return { "frames": [], } class LogoDetector: """Logo 检测器""" def __init__(self): self.known_logos: dict[str, dict[str, Any]] = {} def detect(self, image_path: str) -> dict[str, Any]: """ 检测图片中的 Logo Returns: 包含 detections 的字典 """ # 实际实现需要调用 CV 模型 return { "detections": [], } def add_logo(self, logo_path: str, brand: str) -> None: """添加新 Logo 到检测库""" logo_id = f"logo_{len(self.known_logos) + 1}" self.known_logos[logo_id] = { "brand": brand, "path": logo_path, "added_at": datetime.now(), } def detect_in_video(self, video_path: str) -> dict[str, Any]: """在视频中检测 Logo""" # 实际实现需要视频帧采样 + Logo 检测 return { "detections": [], } class BriefComplianceChecker: """Brief 合规检查器""" def check_selling_points( self, video_content: dict[str, Any], selling_points: list[dict[str, Any]] ) -> dict[str, Any]: """检查卖点覆盖""" detected = [] asr_text = video_content.get("asr_text", "") ocr_text = video_content.get("ocr_text", "") combined_text = asr_text + " " + ocr_text for sp in selling_points: sp_text = sp.get("text", "") if sp_text and sp_text in combined_text: detected.append(sp_text) coverage_rate = len(detected) / len(selling_points) if selling_points else 0 return { "coverage_rate": coverage_rate, "detected": detected, "missing": [sp.get("text") for sp in selling_points if sp.get("text") not in detected], } def check_duration( self, cv_detections: list[dict[str, Any]], timing_requirements: list[dict[str, Any]] ) -> dict[str, Any]: """检查时长要求""" results = {} for req in timing_requirements: req_type = req.get("type", "") min_duration = req.get("min_duration_seconds", 0) if req_type == "product_visible": # 计算产品可见总时长 total_duration_ms = 0 for det in cv_detections: if det.get("object_type") == "product": start = det.get("start_ms", 0) end = det.get("end_ms", 0) total_duration_ms += end - start detected_seconds = total_duration_ms / 1000 results["product_visible"] = { "status": "passed" if detected_seconds >= min_duration else "failed", "detected_seconds": detected_seconds, "required_seconds": min_duration, } return results def check_frequency( self, asr_segments: list[dict[str, Any]], timing_requirements: list[dict[str, Any]], brand_keyword: str ) -> dict[str, Any]: """检查频次要求""" results = {} # 统计品牌名出现次数 count = 0 for seg in asr_segments: text = seg.get("text", "") count += text.count(brand_keyword) for req in timing_requirements: req_type = req.get("type", "") min_frequency = req.get("min_frequency", 0) if req_type == "brand_mention": results["brand_mention"] = { "status": "passed" if count >= min_frequency else "failed", "detected_count": count, "required_count": min_frequency, } return results class VideoAuditor: """视频审核器""" def __init__(self): self.asr_service = ASRService() self.ocr_service = OCRService() self.logo_detector = LogoDetector() self.compliance_checker = BriefComplianceChecker() def audit( self, video_path: str, brief_rules: dict[str, Any] | None = None ) -> dict[str, Any]: """ 执行视频审核 Args: video_path: 视频文件路径 brief_rules: Brief 规则(可选) Returns: 审核报告 """ import uuid report_id = f"report_{uuid.uuid4().hex[:8]}" video_id = f"video_{uuid.uuid4().hex[:8]}" # 执行各项检测 asr_results = self.asr_service.transcribe(video_path) ocr_results = self.ocr_service.extract_from_video(video_path) cv_results = self.logo_detector.detect_in_video(video_path) # 收集违规项 violations = [] # Brief 合规检查 brief_compliance = None if brief_rules: video_content = { "asr_text": asr_results.get("text", ""), "ocr_text": " ".join(f.get("text", "") for f in ocr_results.get("frames", [])), } sp_check = self.compliance_checker.check_selling_points( video_content, brief_rules.get("selling_points", []) ) duration_check = self.compliance_checker.check_duration( cv_results.get("detections", []), brief_rules.get("timing_requirements", []) ) frequency_check = self.compliance_checker.check_frequency( asr_results.get("segments", []), brief_rules.get("timing_requirements", []), brief_rules.get("brand_keyword", "品牌") ) brief_compliance = { "selling_point_coverage": sp_check, "duration_check": duration_check, "frequency_check": frequency_check, } return { "report_id": report_id, "video_id": video_id, "processing_status": ProcessingStatus.COMPLETED.value, "asr_results": asr_results, "ocr_results": ocr_results, "cv_results": cv_results, "violations": [ { "violation_id": v.violation_id, "type": v.type, "description": v.description, "severity": v.severity, "evidence": { "url": v.evidence.url, "timestamp_start": v.evidence.timestamp_start, "timestamp_end": v.evidence.timestamp_end, }, } for v in violations ], "brief_compliance": brief_compliance, }