""" 竞品 Logo 检测服务单元测试 TDD 测试用例 - 基于 FeatureSummary.md F-12 的验收标准 验收标准: - F1 ≥ 0.85(含遮挡 30% 场景) - 新 Logo 上传即刻生效 """ import pytest from typing import Any from app.services.ai.logo_detector import ( LogoDetector, LogoDetection, LogoDetectionResult, load_logo_labeled_dataset, calculate_f1_score, calculate_precision_recall, ) class TestLogoDetector: """Logo 检测器测试""" @pytest.mark.ai @pytest.mark.unit def test_logo_detector_initialization(self) -> None: """测试 Logo 检测器初始化""" detector = LogoDetector() assert detector.is_ready() assert detector.logo_count > 0 @pytest.mark.ai @pytest.mark.unit def test_detect_logo_in_image(self) -> None: """测试图片中的 Logo 检测""" detector = LogoDetector() result = detector.detect("tests/fixtures/images/with_competitor_logo.jpg") assert result.status == "success" assert len(result.detections) > 0 @pytest.mark.ai @pytest.mark.unit def test_logo_detection_output_format(self) -> None: """测试 Logo 检测输出格式""" detector = LogoDetector() result = detector.detect("tests/fixtures/images/with_competitor_logo.jpg") # 验证输出结构 assert hasattr(result, "detections") for detection in result.detections: assert hasattr(detection, "logo_id") assert hasattr(detection, "brand_name") assert hasattr(detection, "confidence") assert hasattr(detection, "bbox") assert 0 <= detection.confidence <= 1 assert len(detection.bbox) == 4 class TestLogoDetectionAccuracy: """Logo 检测准确率测试""" @pytest.mark.ai @pytest.mark.unit def test_f1_score_threshold(self) -> None: """ 测试 Logo 检测 F1 值 验收标准:F1 ≥ 0.85 """ detector = LogoDetector() test_set = load_logo_labeled_dataset() predictions = [] ground_truths = [] for sample in test_set: result = detector.detect(sample["image_path"]) predictions.append(result.detections) ground_truths.append(sample["ground_truth_logos"]) f1 = calculate_f1_score(predictions, ground_truths) assert f1 >= 0.85, f"F1 {f1:.2f} 低于阈值 0.85" @pytest.mark.ai @pytest.mark.unit def test_precision_recall(self) -> None: """测试查准率和查全率""" detector = LogoDetector() test_set = load_logo_labeled_dataset() precision, recall = calculate_precision_recall(detector, test_set) assert precision >= 0.80 assert recall >= 0.80 class TestLogoOcclusion: """Logo 遮挡检测测试""" @pytest.mark.ai @pytest.mark.unit @pytest.mark.parametrize("occlusion_percent,should_detect", [ (0, True), (10, True), (20, True), (30, True), (40, False), (50, False), ]) def test_logo_detection_with_occlusion( self, occlusion_percent: int, should_detect: bool, ) -> None: """ 测试遮挡场景下的 Logo 检测 验收标准:30% 遮挡仍可检测 """ detector = LogoDetector() image_path = f"tests/fixtures/images/logo_occluded_{occlusion_percent}pct.jpg" result = detector.detect(image_path) if should_detect: assert len(result.detections) > 0, \ f"{occlusion_percent}% 遮挡应能检测到 Logo" assert result.detections[0].confidence >= 0.5 @pytest.mark.ai @pytest.mark.unit def test_partial_logo_detection(self) -> None: """测试部分可见 Logo 检测""" detector = LogoDetector() result = detector.detect("tests/fixtures/images/logo_partial.jpg") if len(result.detections) > 0: assert result.detections[0].is_partial class TestLogoDynamicUpdate: """Logo 动态更新测试""" @pytest.mark.ai @pytest.mark.unit def test_add_new_logo_instant_effect(self) -> None: """ 测试新 Logo 上传即刻生效 验收标准:新增竞品 Logo 应立即可检测 """ detector = LogoDetector() # 检测前应无法识别 result_before = detector.detect("tests/fixtures/images/with_new_logo.jpg") assert not any(d.brand_name == "NewBrand" for d in result_before.detections) # 添加新 Logo detector.add_logo( logo_image="tests/fixtures/logos/new_brand_logo.png", brand_name="NewBrand" ) # 检测后应能识别 result_after = detector.detect("tests/fixtures/images/with_new_logo.jpg") assert any(d.brand_name == "NewBrand" for d in result_after.detections) @pytest.mark.ai @pytest.mark.unit def test_remove_logo(self) -> None: """测试移除 Logo""" detector = LogoDetector() # 移除前可检测 result_before = detector.detect("tests/fixtures/images/with_existing_logo.jpg") assert any(d.brand_name == "ExistingBrand" for d in result_before.detections) # 移除 Logo detector.remove_logo(brand_name="ExistingBrand") # 移除后不再检测 result_after = detector.detect("tests/fixtures/images/with_existing_logo.jpg") assert not any(d.brand_name == "ExistingBrand" for d in result_after.detections) @pytest.mark.ai @pytest.mark.unit def test_update_logo_variants(self) -> None: """测试更新 Logo 变体""" detector = LogoDetector() # 添加多个变体 detector.add_logo_variant( brand_name="Brand", variant_image="tests/fixtures/logos/brand_variant_dark.png", variant_type="dark_mode" ) # 应能检测新变体 result = detector.detect("tests/fixtures/images/with_dark_logo.jpg") assert len(result.detections) > 0 class TestLogoVideoProcessing: """视频 Logo 检测测试""" @pytest.mark.ai @pytest.mark.unit def test_detect_logo_in_video_frames(self) -> None: """测试视频帧中的 Logo 检测""" detector = LogoDetector() frame_paths = [ f"tests/fixtures/images/video_frame_{i}.jpg" for i in range(30) ] results = detector.batch_detect(frame_paths) assert len(results) == 30 @pytest.mark.ai @pytest.mark.unit def test_logo_tracking_across_frames(self) -> None: """测试跨帧 Logo 跟踪""" detector = LogoDetector() frame_results = [] for i in range(10): result = detector.detect(f"tests/fixtures/images/tracking_frame_{i}.jpg") frame_results.append(result) # 跟踪应返回相同的 track_id track_ids = [ r.detections[0].track_id for r in frame_results if len(r.detections) > 0 ] assert len(set(track_ids)) == 1 # 同一个 Logo class TestLogoSpecialCases: """Logo 检测特殊情况测试""" @pytest.mark.ai @pytest.mark.unit def test_no_logo_image(self) -> None: """测试无 Logo 图片""" detector = LogoDetector() result = detector.detect("tests/fixtures/images/no_logo.jpg") assert result.status == "success" assert len(result.detections) == 0 @pytest.mark.ai @pytest.mark.unit def test_multiple_logos_detection(self) -> None: """测试多 Logo 检测""" detector = LogoDetector() result = detector.detect("tests/fixtures/images/multiple_logos.jpg") assert len(result.detections) >= 2 # 每个检测应有唯一 ID logo_ids = [d.logo_id for d in result.detections] assert len(logo_ids) == len(set(logo_ids)) @pytest.mark.ai @pytest.mark.unit def test_similar_logo_distinction(self) -> None: """测试相似 Logo 区分""" detector = LogoDetector() result = detector.detect("tests/fixtures/images/similar_logos.jpg") brand_names = [d.brand_name for d in result.detections] assert "BrandA" in brand_names assert "BrandB" in brand_names @pytest.mark.ai @pytest.mark.unit def test_distorted_logo_detection(self) -> None: """测试变形 Logo 检测""" detector = LogoDetector() test_cases = [ "logo_stretched.jpg", "logo_rotated.jpg", "logo_skewed.jpg", ] for image_name in test_cases: result = detector.detect(f"tests/fixtures/images/{image_name}") assert len(result.detections) > 0, f"变形 Logo {image_name} 应被检测" class TestLogoPerformance: """Logo 检测性能测试""" @pytest.mark.ai @pytest.mark.performance def test_detection_speed(self) -> None: """测试检测速度""" import time detector = LogoDetector() start_time = time.time() result = detector.detect("tests/fixtures/images/1080p_sample.jpg") processing_time = time.time() - start_time # 模拟测试应该非常快 assert processing_time < 0.2 assert result.status == "success" @pytest.mark.ai @pytest.mark.performance def test_batch_detection_speed(self) -> None: """测试批量检测速度""" import time detector = LogoDetector() frame_paths = [ f"tests/fixtures/images/frame_{i}.jpg" for i in range(30) ] start_time = time.time() results = detector.batch_detect(frame_paths) processing_time = time.time() - start_time assert processing_time < 2.0 assert len(results) == 30