""" 多模态时间戳对齐模块 提供 ASR/OCR/CV 多模态事件的时间戳对齐和融合功能 验收标准: - 时长统计误差 ≤ 0.5秒 - 频次统计准确率 ≥ 95% - 时间轴归一化精度 ≤ 0.1秒 - 模糊匹配容差窗口 ±0.5秒 """ from dataclasses import dataclass, field from typing import Any from statistics import median @dataclass class MultiModalEvent: """多模态事件""" source: str # "asr", "ocr", "cv" timestamp_ms: int content: str confidence: float = 1.0 metadata: dict[str, Any] = field(default_factory=dict) @dataclass class AlignmentResult: """对齐结果""" merged_events: list[MultiModalEvent] status: str = "success" missing_modalities: list[str] = field(default_factory=list) @dataclass class ConsistencyResult: """一致性检查结果""" is_consistent: bool cross_modality_score: float class TimestampAligner: """时间戳对齐器""" def __init__(self, tolerance_ms: int = 500): """ 初始化对齐器 Args: tolerance_ms: 模糊匹配容差窗口(毫秒),默认 500ms (±0.5秒) """ self.tolerance_ms = tolerance_ms def is_within_tolerance(self, ts1: int, ts2: int) -> bool: """判断两个时间戳是否在容差范围内""" return abs(ts1 - ts2) <= self.tolerance_ms def normalize_timestamps(self, events: list[dict[str, Any]]) -> list[MultiModalEvent]: """ 归一化不同格式的时间戳到毫秒 支持的格式: - timestamp_ms: 毫秒 - timestamp_seconds: 秒 - frame + fps: 帧号 """ normalized = [] for event in events: source = event.get("source", "unknown") content = event.get("content", "") # 确定时间戳(毫秒) if "timestamp_ms" in event: ts_ms = event["timestamp_ms"] elif "timestamp_seconds" in event: ts_ms = int(event["timestamp_seconds"] * 1000) elif "frame" in event and "fps" in event: ts_ms = int(event["frame"] / event["fps"] * 1000) else: ts_ms = 0 normalized.append(MultiModalEvent( source=source, timestamp_ms=ts_ms, content=content, confidence=event.get("confidence", 1.0), )) return normalized def align_events(self, events: list[dict[str, Any]]) -> AlignmentResult: """ 对齐多模态事件 将时间戳相近的事件合并 """ if not events: return AlignmentResult(merged_events=[], status="success") # 按来源分组 by_source: dict[str, list[dict]] = {} for event in events: source = event.get("source", "unknown") if source not in by_source: by_source[source] = [] by_source[source].append(event) # 检查缺失的模态 expected_modalities = {"asr", "ocr", "cv"} present_modalities = set(by_source.keys()) missing = list(expected_modalities - present_modalities) # 获取所有时间戳 timestamps = [e.get("timestamp_ms", 0) for e in events] # 检查是否所有时间戳都在容差范围内 if len(timestamps) >= 2: min_ts = min(timestamps) max_ts = max(timestamps) if max_ts - min_ts <= self.tolerance_ms: # 可以合并 - 使用中位数作为合并时间戳 merged_ts = int(median(timestamps)) merged_event = MultiModalEvent( source="merged", timestamp_ms=merged_ts, content="; ".join(e.get("content", "") for e in events), ) return AlignmentResult( merged_events=[merged_event], status="success", missing_modalities=missing, ) # 无法合并 - 返回各自独立的事件 normalized = self.normalize_timestamps(events) return AlignmentResult( merged_events=normalized, status="success", missing_modalities=missing, ) def calculate_duration(self, events: list[dict[str, Any]]) -> int: """ 计算事件时长(毫秒) 从 object_appear 到 object_disappear """ appear_ts = None disappear_ts = None for event in events: event_type = event.get("type", "") ts = event.get("timestamp_ms", 0) if event_type == "object_appear": appear_ts = ts elif event_type == "object_disappear": disappear_ts = ts if appear_ts is not None and disappear_ts is not None: return disappear_ts - appear_ts return 0 def calculate_object_duration( self, detections: list[dict[str, Any]], object_type: str ) -> int: """ 计算特定物体的可见时长(毫秒) Args: detections: 检测结果列表 object_type: 物体类型(如 "product") """ total_duration = 0 for detection in detections: if detection.get("object_type") == object_type: start = detection.get("start_ms", 0) end = detection.get("end_ms", 0) total_duration += end - start return total_duration def calculate_total_duration(self, segments: list[dict[str, Any]]) -> int: """ 计算多段时长累加(毫秒) """ total = 0 for segment in segments: start = segment.get("start_ms", 0) end = segment.get("end_ms", 0) total += end - start return total def fuse_multimodal( self, asr_result: dict[str, Any], ocr_result: dict[str, Any], cv_result: dict[str, Any], ) -> "FusedResult": """融合多模态结果""" return FusedResult( has_asr=bool(asr_result), has_ocr=bool(ocr_result), has_cv=bool(cv_result), timeline=[], ) def check_consistency( self, events: list[dict[str, Any]] ) -> ConsistencyResult: """检查跨模态一致性""" if len(events) < 2: return ConsistencyResult(is_consistent=True, cross_modality_score=1.0) timestamps = [e.get("timestamp_ms", 0) for e in events] max_diff = max(timestamps) - min(timestamps) is_consistent = max_diff <= self.tolerance_ms score = 1.0 - (max_diff / (self.tolerance_ms * 2)) if max_diff <= self.tolerance_ms * 2 else 0.0 return ConsistencyResult( is_consistent=is_consistent, cross_modality_score=max(0.0, min(1.0, score)), ) @dataclass class FusedResult: """融合结果""" has_asr: bool has_ocr: bool has_cv: bool timeline: list[dict[str, Any]] class FrequencyCounter: """频次统计器""" def count_mentions( self, segments: list[dict[str, Any]], keyword: str ) -> int: """ 统计关键词在所有片段中出现的次数 """ total = 0 for segment in segments: text = segment.get("text", "") total += text.count(keyword) return total def count_keyword( self, segments: list[dict[str, str]], keyword: str ) -> int: """ 统计关键词频次 """ return self.count_mentions(segments, keyword)