""" ASR 服务单元测试 TDD 测试用例 - 基于 DevelopmentPlan.md 的验收标准 验收标准: - 字错率 (WER) ≤ 10% - 时间戳精度 ≤ 100ms """ import pytest from typing import Any from app.services.ai.asr import ( ASRService, ASRResult, ASRSegment, calculate_word_error_rate, load_asr_labeled_dataset, load_asr_test_set_by_type, load_timestamp_labeled_dataset, ) class TestASRService: """ASR 服务测试""" @pytest.mark.ai @pytest.mark.unit def test_asr_service_initialization(self) -> None: """测试 ASR 服务初始化""" service = ASRService() assert service.is_ready() assert service.model_name is not None @pytest.mark.ai @pytest.mark.unit def test_asr_transcribe_audio_file(self) -> None: """测试音频文件转写""" service = ASRService() result = service.transcribe("tests/fixtures/audio/sample.wav") assert result.status == "success" assert result.text is not None assert len(result.text) > 0 @pytest.mark.ai @pytest.mark.unit def test_asr_output_format(self) -> None: """测试 ASR 输出格式""" service = ASRService() result = service.transcribe("tests/fixtures/audio/sample.wav") # 验证输出结构 assert hasattr(result, "text") assert hasattr(result, "segments") assert hasattr(result, "language") assert hasattr(result, "duration_ms") # 验证 segment 结构 for segment in result.segments: assert hasattr(segment, "text") assert hasattr(segment, "start_ms") assert hasattr(segment, "end_ms") assert hasattr(segment, "confidence") assert segment.end_ms >= segment.start_ms class TestASRAccuracy: """ASR 准确率测试""" @pytest.mark.ai @pytest.mark.unit def test_word_error_rate_threshold(self) -> None: """ 测试字错率阈值 验收标准:WER ≤ 10% """ service = ASRService() # 完全匹配测试 wer = service.calculate_wer("测试内容", "测试内容") assert wer == 0.0 # 部分匹配测试 wer = service.calculate_wer("测试内文", "测试内容") assert wer <= 0.5 # 1/4 字符错误 @pytest.mark.ai @pytest.mark.unit @pytest.mark.parametrize("audio_type,expected_wer_threshold", [ ("clean_speech", 0.05), ("background_music", 0.10), ("multiple_speakers", 0.15), ("noisy_environment", 0.20), ]) def test_wer_by_audio_type( self, audio_type: str, expected_wer_threshold: float, ) -> None: """测试不同音频类型的 WER""" service = ASRService() test_cases = load_asr_test_set_by_type(audio_type) # 模拟测试 - 实际需要真实音频 assert len(test_cases) > 0 for case in test_cases: result = service.transcribe(case["audio_path"]) assert result.status == "success" class TestASRTimestamp: """ASR 时间戳测试""" @pytest.mark.ai @pytest.mark.unit def test_timestamp_monotonic_increase(self) -> None: """测试时间戳单调递增""" service = ASRService() result = service.transcribe("tests/fixtures/audio/sample.wav") prev_end = 0 for segment in result.segments: assert segment.start_ms >= prev_end, \ f"时间戳不是单调递增: {segment.start_ms} < {prev_end}" prev_end = segment.end_ms @pytest.mark.ai @pytest.mark.unit def test_timestamp_precision(self) -> None: """ 测试时间戳精度 验收标准:精度 ≤ 100ms """ service = ASRService() result = service.transcribe("tests/fixtures/audio/sample.wav") # 验证时间戳存在且有效 for segment in result.segments: assert segment.start_ms >= 0 assert segment.end_ms > segment.start_ms @pytest.mark.ai @pytest.mark.unit def test_timestamp_within_audio_duration(self) -> None: """测试时间戳在音频时长范围内""" service = ASRService() result = service.transcribe("tests/fixtures/audio/sample.wav") for segment in result.segments: assert segment.start_ms >= 0 assert segment.end_ms <= result.duration_ms class TestASRLanguage: """ASR 语言处理测试""" @pytest.mark.ai @pytest.mark.unit def test_chinese_mandarin_recognition(self) -> None: """测试普通话识别""" service = ASRService() result = service.transcribe("tests/fixtures/audio/mandarin.wav") assert result.language == "zh-CN" assert len(result.text) > 0 @pytest.mark.ai @pytest.mark.unit def test_mixed_language_handling(self) -> None: """测试中英混合语音处理""" service = ASRService() result = service.transcribe("tests/fixtures/audio/mixed_cn_en.wav") assert result.status == "success" @pytest.mark.ai @pytest.mark.unit def test_dialect_handling(self) -> None: """测试方言处理""" service = ASRService() result = service.transcribe("tests/fixtures/audio/cantonese.wav") if result.status == "success": assert result.language in ["zh-CN", "zh-HK", "yue"] else: assert result.warning == "dialect_detected" class TestASRSpecialCases: """ASR 特殊情况测试""" @pytest.mark.ai @pytest.mark.unit def test_silent_audio(self) -> None: """测试静音音频""" service = ASRService() result = service.transcribe("tests/fixtures/audio/silent.wav") assert result.status == "success" assert result.text == "" or result.segments == [] @pytest.mark.ai @pytest.mark.unit def test_very_short_audio(self) -> None: """测试极短音频 (< 1秒)""" service = ASRService() result = service.transcribe("tests/fixtures/audio/short_500ms.wav") assert result.status == "success" @pytest.mark.ai @pytest.mark.unit def test_long_audio(self) -> None: """测试长音频 (> 5分钟)""" service = ASRService() result = service.transcribe("tests/fixtures/audio/long_10min.wav") assert result.status == "success" assert result.duration_ms >= 600000 # 10分钟 @pytest.mark.ai @pytest.mark.unit def test_corrupted_audio_handling(self) -> None: """测试损坏音频处理""" service = ASRService() result = service.transcribe("tests/fixtures/audio/corrupted.wav") assert result.status == "error" assert "corrupted" in result.error_message.lower() or \ "invalid" in result.error_message.lower() class TestASRPerformance: """ASR 性能测试""" @pytest.mark.ai @pytest.mark.performance def test_transcription_speed(self) -> None: """ 测试转写速度 验收标准:实时率 ≤ 0.5 (转写时间 / 音频时长) """ import time service = ASRService() start_time = time.time() result = service.transcribe("tests/fixtures/audio/sample.wav") processing_time = time.time() - start_time # 模拟测试应该非常快 assert processing_time < 1.0 assert result.status == "success" @pytest.mark.ai @pytest.mark.performance @pytest.mark.asyncio async def test_concurrent_transcription(self) -> None: """测试并发转写""" import asyncio service = ASRService() async def transcribe_one(audio_path: str): return await service.transcribe_async(audio_path) # 并发处理 5 个音频 tasks = [ transcribe_one(f"tests/fixtures/audio/sample_{i}.wav") for i in range(5) ] results = await asyncio.gather(*tasks) assert all(r.status == "success" for r in results)