实现以下模块并通过全部测试 (150 passed, 92.65% coverage):
- validators.py: 数据验证器 (Brief/视频/审核决策/申诉/时间戳/UUID)
- timestamp_align.py: 多模态时间戳对齐 (ASR/OCR/CV 融合)
- rule_engine.py: 规则引擎 (违禁词检测/语境感知/规则版本管理)
- brief_parser.py: Brief 解析 (卖点/禁忌词/时序要求/品牌调性提取)
- video_auditor.py: 视频审核 (文件验证/ASR/OCR/Logo检测/合规检查)
验收标准达成:
- 违禁词召回率 ≥ 95%
- 误报率 ≤ 5%
- 时长统计误差 ≤ 0.5秒
- 语境感知检测 ("最开心的一天" 不误判)
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
344 lines
10 KiB
Python
344 lines
10 KiB
Python
"""
|
||
多模态时间戳对齐模块单元测试
|
||
|
||
TDD 测试用例 - 基于 DevelopmentPlan.md (F-14, F-45) 的验收标准
|
||
|
||
验收标准:
|
||
- 时长统计误差 ≤ 0.5秒
|
||
- 频次统计准确率 ≥ 95%
|
||
- 时间轴归一化精度 ≤ 0.1秒
|
||
- 模糊匹配容差窗口 ±0.5秒
|
||
"""
|
||
|
||
import pytest
|
||
from typing import Any
|
||
|
||
from app.utils.timestamp_align import (
|
||
TimestampAligner,
|
||
MultiModalEvent,
|
||
AlignmentResult,
|
||
FrequencyCounter,
|
||
)
|
||
|
||
|
||
class TestTimestampAligner:
|
||
"""
|
||
时间戳对齐器测试
|
||
|
||
验收标准:
|
||
- 时间轴归一化精度 ≤ 0.1秒
|
||
- 模糊匹配容差窗口 ±0.5秒
|
||
"""
|
||
|
||
@pytest.mark.unit
|
||
@pytest.mark.parametrize("asr_ts,ocr_ts,cv_ts,tolerance,expected_merged,expected_ts", [
|
||
# 完全对齐
|
||
(1000, 1000, 1000, 500, True, 1000),
|
||
# 容差范围内 - 应合并
|
||
(1000, 1200, 1100, 500, True, 1100), # 中位数
|
||
(1000, 1400, 1200, 500, True, 1200), # 中位数
|
||
# 超出容差 - 不应合并
|
||
(1000, 2000, 3000, 500, False, None),
|
||
(1000, 1600, 1000, 500, False, None), # OCR 超出容差
|
||
])
|
||
def test_multimodal_event_alignment(
|
||
self,
|
||
asr_ts: int,
|
||
ocr_ts: int,
|
||
cv_ts: int,
|
||
tolerance: int,
|
||
expected_merged: bool,
|
||
expected_ts: int | None,
|
||
) -> None:
|
||
"""测试多模态事件对齐"""
|
||
events = [
|
||
{"source": "asr", "timestamp_ms": asr_ts, "content": "测试文本"},
|
||
{"source": "ocr", "timestamp_ms": ocr_ts, "content": "字幕内容"},
|
||
{"source": "cv", "timestamp_ms": cv_ts, "content": "product_detected"},
|
||
]
|
||
|
||
aligner = TimestampAligner(tolerance_ms=tolerance)
|
||
result = aligner.align_events(events)
|
||
|
||
if expected_merged:
|
||
assert len(result.merged_events) == 1
|
||
assert abs(result.merged_events[0].timestamp_ms - expected_ts) <= 100
|
||
else:
|
||
# 未合并时,每个事件独立
|
||
assert len(result.merged_events) == 3
|
||
|
||
@pytest.mark.unit
|
||
def test_timestamp_normalization_precision(self) -> None:
|
||
"""
|
||
测试时间戳归一化精度
|
||
|
||
验收标准:精度 ≤ 0.1秒 (100ms)
|
||
"""
|
||
# 不同来源的时间戳格式
|
||
asr_event = {"source": "asr", "timestamp_ms": 1500} # 毫秒
|
||
cv_event = {"source": "cv", "frame": 45, "fps": 30} # 帧号 (45/30 = 1.5秒)
|
||
ocr_event = {"source": "ocr", "timestamp_seconds": 1.5} # 秒
|
||
|
||
aligner = TimestampAligner()
|
||
normalized = aligner.normalize_timestamps([asr_event, cv_event, ocr_event])
|
||
|
||
# 所有归一化后的时间戳应在 100ms 误差范围内
|
||
timestamps = [e.timestamp_ms for e in normalized]
|
||
assert max(timestamps) - min(timestamps) <= 100
|
||
|
||
@pytest.mark.unit
|
||
def test_fuzzy_matching_window(self) -> None:
|
||
"""
|
||
测试模糊匹配容差窗口
|
||
|
||
验收标准:容差 ±0.5秒
|
||
"""
|
||
aligner = TimestampAligner(tolerance_ms=500)
|
||
|
||
# 1000ms 和 1499ms 应该匹配(差值 < 500ms)
|
||
assert aligner.is_within_tolerance(1000, 1499)
|
||
|
||
# 1000ms 和 1501ms 不应匹配(差值 > 500ms)
|
||
assert not aligner.is_within_tolerance(1000, 1501)
|
||
|
||
|
||
class TestDurationCalculation:
|
||
"""
|
||
时长统计测试
|
||
|
||
验收标准 (FeatureSummary.md F-45):
|
||
- 时长统计误差 ≤ 0.5秒
|
||
"""
|
||
|
||
@pytest.mark.unit
|
||
@pytest.mark.parametrize("start_ms,end_ms,expected_duration_ms,tolerance_ms", [
|
||
(0, 5000, 5000, 500),
|
||
(1000, 6500, 5500, 500),
|
||
(0, 10000, 10000, 500),
|
||
(500, 3200, 2700, 500),
|
||
])
|
||
def test_duration_calculation_accuracy(
|
||
self,
|
||
start_ms: int,
|
||
end_ms: int,
|
||
expected_duration_ms: int,
|
||
tolerance_ms: int,
|
||
) -> None:
|
||
"""测试时长计算准确性 - 误差 ≤ 0.5秒"""
|
||
events = [
|
||
{"timestamp_ms": start_ms, "type": "object_appear"},
|
||
{"timestamp_ms": end_ms, "type": "object_disappear"},
|
||
]
|
||
|
||
aligner = TimestampAligner()
|
||
duration = aligner.calculate_duration(events)
|
||
|
||
assert abs(duration - expected_duration_ms) <= tolerance_ms
|
||
|
||
@pytest.mark.unit
|
||
def test_product_visible_duration(
|
||
self,
|
||
sample_cv_result: dict[str, Any],
|
||
) -> None:
|
||
"""测试产品可见时长统计"""
|
||
# sample_cv_result 包含 start_frame=30, end_frame=180, fps=30
|
||
# 预期时长: (180-30)/30 = 5 秒
|
||
|
||
aligner = TimestampAligner()
|
||
duration = aligner.calculate_object_duration(
|
||
sample_cv_result["detections"],
|
||
object_type="product"
|
||
)
|
||
|
||
expected_duration_ms = 5000
|
||
assert abs(duration - expected_duration_ms) <= 500
|
||
|
||
@pytest.mark.unit
|
||
def test_multiple_segments_duration(self) -> None:
|
||
"""测试多段时长累加"""
|
||
# 产品在视频中多次出现
|
||
segments = [
|
||
{"start_ms": 0, "end_ms": 3000}, # 3秒
|
||
{"start_ms": 10000, "end_ms": 12000}, # 2秒
|
||
{"start_ms": 25000, "end_ms": 30000}, # 5秒
|
||
]
|
||
# 总时长应为 10秒
|
||
|
||
aligner = TimestampAligner()
|
||
total_duration = aligner.calculate_total_duration(segments)
|
||
|
||
assert abs(total_duration - 10000) <= 500
|
||
|
||
|
||
class TestFrequencyCount:
|
||
"""
|
||
频次统计测试
|
||
|
||
验收标准 (FeatureSummary.md F-45):
|
||
- 频次统计准确率 ≥ 95%
|
||
"""
|
||
|
||
@pytest.mark.unit
|
||
def test_brand_mention_frequency(
|
||
self,
|
||
sample_asr_result: dict[str, Any],
|
||
) -> None:
|
||
"""测试品牌名提及频次统计"""
|
||
counter = FrequencyCounter()
|
||
count = counter.count_mentions(
|
||
sample_asr_result["segments"],
|
||
keyword="品牌"
|
||
)
|
||
|
||
# 验证统计准确性
|
||
assert count >= 0
|
||
|
||
@pytest.mark.unit
|
||
@pytest.mark.parametrize("text_segments,keyword,expected_count", [
|
||
# 简单情况
|
||
(
|
||
[{"text": "这个品牌真不错"}, {"text": "品牌介绍"}, {"text": "品牌故事"}],
|
||
"品牌",
|
||
3
|
||
),
|
||
# 无匹配
|
||
(
|
||
[{"text": "产品介绍"}, {"text": "使用方法"}],
|
||
"品牌",
|
||
0
|
||
),
|
||
# 同一句多次出现
|
||
(
|
||
[{"text": "品牌品牌品牌"}],
|
||
"品牌",
|
||
3
|
||
),
|
||
])
|
||
def test_keyword_frequency_accuracy(
|
||
self,
|
||
text_segments: list[dict[str, str]],
|
||
keyword: str,
|
||
expected_count: int,
|
||
) -> None:
|
||
"""测试关键词频次准确性"""
|
||
counter = FrequencyCounter()
|
||
count = counter.count_keyword(text_segments, keyword)
|
||
|
||
assert count == expected_count
|
||
|
||
@pytest.mark.unit
|
||
def test_frequency_count_accuracy_rate(self) -> None:
|
||
"""
|
||
测试频次统计准确率
|
||
|
||
验收标准:准确率 ≥ 95%
|
||
"""
|
||
# 简化测试:直接验证几个用例
|
||
test_cases = [
|
||
{"segments": [{"text": "测试品牌提及"}], "keyword": "品牌", "expected_count": 1},
|
||
{"segments": [{"text": "品牌品牌"}], "keyword": "品牌", "expected_count": 2},
|
||
{"segments": [{"text": "无关内容"}], "keyword": "品牌", "expected_count": 0},
|
||
]
|
||
|
||
counter = FrequencyCounter()
|
||
correct = 0
|
||
|
||
for case in test_cases:
|
||
count = counter.count_keyword(case["segments"], case["keyword"])
|
||
if count == case["expected_count"]:
|
||
correct += 1
|
||
|
||
accuracy = correct / len(test_cases)
|
||
assert accuracy >= 0.95
|
||
|
||
|
||
class TestMultiModalFusion:
|
||
"""
|
||
多模态融合测试
|
||
"""
|
||
|
||
@pytest.mark.unit
|
||
def test_asr_ocr_cv_fusion(
|
||
self,
|
||
sample_asr_result: dict[str, Any],
|
||
sample_ocr_result: dict[str, Any],
|
||
sample_cv_result: dict[str, Any],
|
||
) -> None:
|
||
"""测试 ASR + OCR + CV 三模态融合"""
|
||
aligner = TimestampAligner()
|
||
fused = aligner.fuse_multimodal(
|
||
asr_result=sample_asr_result,
|
||
ocr_result=sample_ocr_result,
|
||
cv_result=sample_cv_result,
|
||
)
|
||
|
||
# 验证融合结果包含所有模态
|
||
assert fused.has_asr
|
||
assert fused.has_ocr
|
||
assert fused.has_cv
|
||
|
||
@pytest.mark.unit
|
||
def test_cross_modality_consistency(self) -> None:
|
||
"""测试跨模态一致性检测"""
|
||
# ASR 说"产品名",OCR 显示"产品名",CV 检测到产品
|
||
# 三者应该在时间上一致
|
||
|
||
asr_event = {"source": "asr", "timestamp_ms": 5000, "content": "产品名"}
|
||
ocr_event = {"source": "ocr", "timestamp_ms": 5100, "content": "产品名"}
|
||
cv_event = {"source": "cv", "timestamp_ms": 5050, "content": "product"}
|
||
|
||
aligner = TimestampAligner(tolerance_ms=500)
|
||
consistency = aligner.check_consistency([asr_event, ocr_event, cv_event])
|
||
|
||
assert consistency.is_consistent
|
||
assert consistency.cross_modality_score >= 0.9
|
||
|
||
@pytest.mark.unit
|
||
def test_handle_missing_modality(self) -> None:
|
||
"""测试缺失模态处理"""
|
||
# 视频无字幕时,OCR 结果为空
|
||
asr_events = [{"source": "asr", "timestamp_ms": 1000, "content": "测试"}]
|
||
ocr_events: list[dict] = [] # 无 OCR 结果
|
||
cv_events = [{"source": "cv", "timestamp_ms": 1000, "content": "product"}]
|
||
|
||
aligner = TimestampAligner()
|
||
result = aligner.align_events(asr_events + ocr_events + cv_events)
|
||
|
||
# 应正常处理,不报错
|
||
assert result.status == "success"
|
||
assert "ocr" in result.missing_modalities
|
||
|
||
|
||
class TestTimestampOutput:
|
||
"""
|
||
时间戳输出格式测试
|
||
"""
|
||
|
||
@pytest.mark.unit
|
||
def test_unified_timeline_format(self) -> None:
|
||
"""测试统一时间轴输出格式"""
|
||
events = [
|
||
{"source": "asr", "timestamp_ms": 1000, "content": "测试"},
|
||
]
|
||
|
||
aligner = TimestampAligner()
|
||
result = aligner.align_events(events)
|
||
|
||
# 验证输出格式
|
||
for entry in result.merged_events:
|
||
assert hasattr(entry, "timestamp_ms")
|
||
assert hasattr(entry, "source")
|
||
assert hasattr(entry, "content")
|
||
|
||
@pytest.mark.unit
|
||
def test_violation_with_timestamp(self) -> None:
|
||
"""测试违规项时间戳标注"""
|
||
violation = {
|
||
"type": "forbidden_word",
|
||
"content": "最好的",
|
||
"timestamp_start": 5.0,
|
||
"timestamp_end": 5.5,
|
||
}
|
||
|
||
assert violation["timestamp_end"] > violation["timestamp_start"]
|