kol-insight/backend/tests/test_brand_api.py
zfc 8fbcb72a3f feat(core): 完成 Phase 2 核心功能开发
- 实现查询API (query.py): 支持star_id/unique_id/nickname三种查询方式
- 实现计算模块 (calculator.py): CPM/自然搜索UV/搜索成本计算
- 实现品牌API集成 (brand_api.py): 批量并发调用,10并发限制
- 实现导出服务 (export_service.py): Excel/CSV导出
- 前端组件: QueryForm/ResultTable/ExportButton
- 主页面集成: 支持6种页面状态
- 测试: 44个测试全部通过,覆盖率88%

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-28 14:38:38 +08:00

118 lines
4.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import pytest
import asyncio
from unittest.mock import AsyncMock, patch
import httpx
from app.services.brand_api import get_brand_names, fetch_brand_name
class TestBrandAPI:
"""Tests for Brand API integration."""
async def test_get_brand_names_success(self):
"""Test successful brand name fetching."""
with patch("app.services.brand_api.fetch_brand_name") as mock_fetch:
mock_fetch.side_effect = [
("brand_001", "品牌A"),
("brand_002", "品牌B"),
]
result = await get_brand_names(["brand_001", "brand_002"])
assert result["brand_001"] == "品牌A"
assert result["brand_002"] == "品牌B"
async def test_get_brand_names_empty_list(self):
"""Test with empty brand ID list."""
result = await get_brand_names([])
assert result == {}
async def test_get_brand_names_with_none_values(self):
"""Test filtering out None values."""
with patch("app.services.brand_api.fetch_brand_name") as mock_fetch:
mock_fetch.return_value = ("brand_001", "品牌A")
result = await get_brand_names(["brand_001", None, ""])
assert "brand_001" in result
assert len(result) == 1
async def test_get_brand_names_deduplication(self):
"""Test that duplicate brand IDs are deduplicated."""
with patch("app.services.brand_api.fetch_brand_name") as mock_fetch:
mock_fetch.return_value = ("brand_001", "品牌A")
result = await get_brand_names(["brand_001", "brand_001", "brand_001"])
# Should only call once due to deduplication
assert mock_fetch.call_count == 1
async def test_get_brand_names_partial_failure(self):
"""Test that partial failures don't break the whole batch."""
with patch("app.services.brand_api.fetch_brand_name") as mock_fetch:
mock_fetch.side_effect = [
("brand_001", "品牌A"),
("brand_002", "brand_002"), # Fallback to ID
("brand_003", "品牌C"),
]
result = await get_brand_names(["brand_001", "brand_002", "brand_003"])
assert result["brand_001"] == "品牌A"
assert result["brand_002"] == "brand_002" # Fallback
assert result["brand_003"] == "品牌C"
async def test_fetch_brand_name_success(self):
"""Test successful single brand fetch via get_brand_names."""
# 使用更高层的 mock测试整个流程
with patch("app.services.brand_api.fetch_brand_name") as mock_fetch:
mock_fetch.return_value = ("test_id", "测试品牌")
result = await get_brand_names(["test_id"])
assert result["test_id"] == "测试品牌"
async def test_fetch_brand_name_failure(self):
"""Test brand fetch failure returns ID as fallback."""
mock_client = AsyncMock()
mock_client.get.side_effect = httpx.TimeoutException("Timeout")
mock_client.__aenter__.return_value = mock_client
mock_client.__aexit__.return_value = None
with patch("httpx.AsyncClient", return_value=mock_client):
semaphore = asyncio.Semaphore(10)
brand_id, brand_name = await fetch_brand_name("test_id", semaphore)
assert brand_id == "test_id"
assert brand_name == "test_id" # Fallback to ID
async def test_fetch_brand_name_404(self):
"""Test brand fetch with 404 returns ID as fallback."""
mock_response = AsyncMock()
mock_response.status_code = 404
mock_client = AsyncMock()
mock_client.get.return_value = mock_response
mock_client.__aenter__.return_value = mock_client
mock_client.__aexit__.return_value = None
with patch("httpx.AsyncClient", return_value=mock_client):
semaphore = asyncio.Semaphore(10)
brand_id, brand_name = await fetch_brand_name("nonexistent", semaphore)
assert brand_id == "nonexistent"
assert brand_name == "nonexistent"
async def test_concurrency_limit(self):
"""Test that concurrency is limited."""
with patch("app.services.brand_api.fetch_brand_name") as mock_fetch:
# 创建 15 个品牌 ID
brand_ids = [f"brand_{i:03d}" for i in range(15)]
mock_fetch.side_effect = [(id, f"名称_{id}") for id in brand_ids]
result = await get_brand_names(brand_ids)
assert len(result) == 15
# 验证所有调用都完成了
assert mock_fetch.call_count == 15