videos1.0/backend/tests/unit/test_timestamp_alignment.py
Your Name e77af7f8f0 feat: 实现 TDD 绿色阶段核心模块
实现以下模块并通过全部测试 (150 passed, 92.65% coverage):

- validators.py: 数据验证器 (Brief/视频/审核决策/申诉/时间戳/UUID)
- timestamp_align.py: 多模态时间戳对齐 (ASR/OCR/CV 融合)
- rule_engine.py: 规则引擎 (违禁词检测/语境感知/规则版本管理)
- brief_parser.py: Brief 解析 (卖点/禁忌词/时序要求/品牌调性提取)
- video_auditor.py: 视频审核 (文件验证/ASR/OCR/Logo检测/合规检查)

验收标准达成:
- 违禁词召回率 ≥ 95%
- 误报率 ≤ 5%
- 时长统计误差 ≤ 0.5秒
- 语境感知检测 ("最开心的一天" 不误判)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-02 17:41:37 +08:00

344 lines
10 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
多模态时间戳对齐模块单元测试
TDD 测试用例 - 基于 DevelopmentPlan.md (F-14, F-45) 的验收标准
验收标准:
- 时长统计误差 ≤ 0.5秒
- 频次统计准确率 ≥ 95%
- 时间轴归一化精度 ≤ 0.1秒
- 模糊匹配容差窗口 ±0.5秒
"""
import pytest
from typing import Any
from app.utils.timestamp_align import (
TimestampAligner,
MultiModalEvent,
AlignmentResult,
FrequencyCounter,
)
class TestTimestampAligner:
"""
时间戳对齐器测试
验收标准:
- 时间轴归一化精度 ≤ 0.1秒
- 模糊匹配容差窗口 ±0.5秒
"""
@pytest.mark.unit
@pytest.mark.parametrize("asr_ts,ocr_ts,cv_ts,tolerance,expected_merged,expected_ts", [
# 完全对齐
(1000, 1000, 1000, 500, True, 1000),
# 容差范围内 - 应合并
(1000, 1200, 1100, 500, True, 1100), # 中位数
(1000, 1400, 1200, 500, True, 1200), # 中位数
# 超出容差 - 不应合并
(1000, 2000, 3000, 500, False, None),
(1000, 1600, 1000, 500, False, None), # OCR 超出容差
])
def test_multimodal_event_alignment(
self,
asr_ts: int,
ocr_ts: int,
cv_ts: int,
tolerance: int,
expected_merged: bool,
expected_ts: int | None,
) -> None:
"""测试多模态事件对齐"""
events = [
{"source": "asr", "timestamp_ms": asr_ts, "content": "测试文本"},
{"source": "ocr", "timestamp_ms": ocr_ts, "content": "字幕内容"},
{"source": "cv", "timestamp_ms": cv_ts, "content": "product_detected"},
]
aligner = TimestampAligner(tolerance_ms=tolerance)
result = aligner.align_events(events)
if expected_merged:
assert len(result.merged_events) == 1
assert abs(result.merged_events[0].timestamp_ms - expected_ts) <= 100
else:
# 未合并时,每个事件独立
assert len(result.merged_events) == 3
@pytest.mark.unit
def test_timestamp_normalization_precision(self) -> None:
"""
测试时间戳归一化精度
验收标准:精度 ≤ 0.1秒 (100ms)
"""
# 不同来源的时间戳格式
asr_event = {"source": "asr", "timestamp_ms": 1500} # 毫秒
cv_event = {"source": "cv", "frame": 45, "fps": 30} # 帧号 (45/30 = 1.5秒)
ocr_event = {"source": "ocr", "timestamp_seconds": 1.5} # 秒
aligner = TimestampAligner()
normalized = aligner.normalize_timestamps([asr_event, cv_event, ocr_event])
# 所有归一化后的时间戳应在 100ms 误差范围内
timestamps = [e.timestamp_ms for e in normalized]
assert max(timestamps) - min(timestamps) <= 100
@pytest.mark.unit
def test_fuzzy_matching_window(self) -> None:
"""
测试模糊匹配容差窗口
验收标准:容差 ±0.5秒
"""
aligner = TimestampAligner(tolerance_ms=500)
# 1000ms 和 1499ms 应该匹配(差值 < 500ms
assert aligner.is_within_tolerance(1000, 1499)
# 1000ms 和 1501ms 不应匹配(差值 > 500ms
assert not aligner.is_within_tolerance(1000, 1501)
class TestDurationCalculation:
"""
时长统计测试
验收标准 (FeatureSummary.md F-45):
- 时长统计误差 ≤ 0.5秒
"""
@pytest.mark.unit
@pytest.mark.parametrize("start_ms,end_ms,expected_duration_ms,tolerance_ms", [
(0, 5000, 5000, 500),
(1000, 6500, 5500, 500),
(0, 10000, 10000, 500),
(500, 3200, 2700, 500),
])
def test_duration_calculation_accuracy(
self,
start_ms: int,
end_ms: int,
expected_duration_ms: int,
tolerance_ms: int,
) -> None:
"""测试时长计算准确性 - 误差 ≤ 0.5秒"""
events = [
{"timestamp_ms": start_ms, "type": "object_appear"},
{"timestamp_ms": end_ms, "type": "object_disappear"},
]
aligner = TimestampAligner()
duration = aligner.calculate_duration(events)
assert abs(duration - expected_duration_ms) <= tolerance_ms
@pytest.mark.unit
def test_product_visible_duration(
self,
sample_cv_result: dict[str, Any],
) -> None:
"""测试产品可见时长统计"""
# sample_cv_result 包含 start_frame=30, end_frame=180, fps=30
# 预期时长: (180-30)/30 = 5 秒
aligner = TimestampAligner()
duration = aligner.calculate_object_duration(
sample_cv_result["detections"],
object_type="product"
)
expected_duration_ms = 5000
assert abs(duration - expected_duration_ms) <= 500
@pytest.mark.unit
def test_multiple_segments_duration(self) -> None:
"""测试多段时长累加"""
# 产品在视频中多次出现
segments = [
{"start_ms": 0, "end_ms": 3000}, # 3秒
{"start_ms": 10000, "end_ms": 12000}, # 2秒
{"start_ms": 25000, "end_ms": 30000}, # 5秒
]
# 总时长应为 10秒
aligner = TimestampAligner()
total_duration = aligner.calculate_total_duration(segments)
assert abs(total_duration - 10000) <= 500
class TestFrequencyCount:
"""
频次统计测试
验收标准 (FeatureSummary.md F-45):
- 频次统计准确率 ≥ 95%
"""
@pytest.mark.unit
def test_brand_mention_frequency(
self,
sample_asr_result: dict[str, Any],
) -> None:
"""测试品牌名提及频次统计"""
counter = FrequencyCounter()
count = counter.count_mentions(
sample_asr_result["segments"],
keyword="品牌"
)
# 验证统计准确性
assert count >= 0
@pytest.mark.unit
@pytest.mark.parametrize("text_segments,keyword,expected_count", [
# 简单情况
(
[{"text": "这个品牌真不错"}, {"text": "品牌介绍"}, {"text": "品牌故事"}],
"品牌",
3
),
# 无匹配
(
[{"text": "产品介绍"}, {"text": "使用方法"}],
"品牌",
0
),
# 同一句多次出现
(
[{"text": "品牌品牌品牌"}],
"品牌",
3
),
])
def test_keyword_frequency_accuracy(
self,
text_segments: list[dict[str, str]],
keyword: str,
expected_count: int,
) -> None:
"""测试关键词频次准确性"""
counter = FrequencyCounter()
count = counter.count_keyword(text_segments, keyword)
assert count == expected_count
@pytest.mark.unit
def test_frequency_count_accuracy_rate(self) -> None:
"""
测试频次统计准确率
验收标准:准确率 ≥ 95%
"""
# 简化测试:直接验证几个用例
test_cases = [
{"segments": [{"text": "测试品牌提及"}], "keyword": "品牌", "expected_count": 1},
{"segments": [{"text": "品牌品牌"}], "keyword": "品牌", "expected_count": 2},
{"segments": [{"text": "无关内容"}], "keyword": "品牌", "expected_count": 0},
]
counter = FrequencyCounter()
correct = 0
for case in test_cases:
count = counter.count_keyword(case["segments"], case["keyword"])
if count == case["expected_count"]:
correct += 1
accuracy = correct / len(test_cases)
assert accuracy >= 0.95
class TestMultiModalFusion:
"""
多模态融合测试
"""
@pytest.mark.unit
def test_asr_ocr_cv_fusion(
self,
sample_asr_result: dict[str, Any],
sample_ocr_result: dict[str, Any],
sample_cv_result: dict[str, Any],
) -> None:
"""测试 ASR + OCR + CV 三模态融合"""
aligner = TimestampAligner()
fused = aligner.fuse_multimodal(
asr_result=sample_asr_result,
ocr_result=sample_ocr_result,
cv_result=sample_cv_result,
)
# 验证融合结果包含所有模态
assert fused.has_asr
assert fused.has_ocr
assert fused.has_cv
@pytest.mark.unit
def test_cross_modality_consistency(self) -> None:
"""测试跨模态一致性检测"""
# ASR 说"产品名"OCR 显示"产品名"CV 检测到产品
# 三者应该在时间上一致
asr_event = {"source": "asr", "timestamp_ms": 5000, "content": "产品名"}
ocr_event = {"source": "ocr", "timestamp_ms": 5100, "content": "产品名"}
cv_event = {"source": "cv", "timestamp_ms": 5050, "content": "product"}
aligner = TimestampAligner(tolerance_ms=500)
consistency = aligner.check_consistency([asr_event, ocr_event, cv_event])
assert consistency.is_consistent
assert consistency.cross_modality_score >= 0.9
@pytest.mark.unit
def test_handle_missing_modality(self) -> None:
"""测试缺失模态处理"""
# 视频无字幕时OCR 结果为空
asr_events = [{"source": "asr", "timestamp_ms": 1000, "content": "测试"}]
ocr_events: list[dict] = [] # 无 OCR 结果
cv_events = [{"source": "cv", "timestamp_ms": 1000, "content": "product"}]
aligner = TimestampAligner()
result = aligner.align_events(asr_events + ocr_events + cv_events)
# 应正常处理,不报错
assert result.status == "success"
assert "ocr" in result.missing_modalities
class TestTimestampOutput:
"""
时间戳输出格式测试
"""
@pytest.mark.unit
def test_unified_timeline_format(self) -> None:
"""测试统一时间轴输出格式"""
events = [
{"source": "asr", "timestamp_ms": 1000, "content": "测试"},
]
aligner = TimestampAligner()
result = aligner.align_events(events)
# 验证输出格式
for entry in result.merged_events:
assert hasattr(entry, "timestamp_ms")
assert hasattr(entry, "source")
assert hasattr(entry, "content")
@pytest.mark.unit
def test_violation_with_timestamp(self) -> None:
"""测试违规项时间戳标注"""
violation = {
"type": "forbidden_word",
"content": "最好的",
"timestamp_start": 5.0,
"timestamp_end": 5.5,
}
assert violation["timestamp_end"] > violation["timestamp_start"]