import pytest from httpx import AsyncClient, ASGITransport from unittest.mock import patch, AsyncMock from app.main import app from app.models import KolVideo from app.database import get_db class TestQueryAPI: """Tests for Query API.""" @pytest.fixture async def client(self, override_get_db): """Create test client with dependency override.""" transport = ASGITransport(app=app) async with AsyncClient(transport=transport, base_url="http://test") as ac: yield ac @pytest.fixture async def seed_data(self, test_session, sample_video_data): """Seed test data.""" videos = [] for i in range(3): data = sample_video_data.copy() data["item_id"] = f"item_{i:03d}" data["star_id"] = f"star_{i:03d}" data["star_unique_id"] = f"unique_{i:03d}" data["star_nickname"] = f"测试达人{i}" videos.append(KolVideo(**data)) test_session.add_all(videos) await test_session.commit() return videos @patch("app.api.v1.query.get_brand_names", new_callable=AsyncMock) async def test_query_by_star_id_success( self, mock_brand, client, test_session, seed_data ): """Test querying by star_id returns correct results.""" mock_brand.return_value = {} response = await client.post( "/api/v1/query", json={"type": "star_id", "values": ["star_000", "star_001"]}, ) assert response.status_code == 200 data = response.json() assert data["success"] is True assert data["total"] == 2 @patch("app.api.v1.query.get_brand_names", new_callable=AsyncMock) async def test_query_by_unique_id_success( self, mock_brand, client, test_session, seed_data ): """Test querying by unique_id returns correct results.""" mock_brand.return_value = {} response = await client.post( "/api/v1/query", json={"type": "unique_id", "values": ["unique_000"]}, ) assert response.status_code == 200 data = response.json() assert data["success"] is True assert data["total"] == 1 @patch("app.api.v1.query.get_brand_names", new_callable=AsyncMock) async def test_query_by_nickname_like( self, mock_brand, client, test_session, seed_data ): """Test querying by nickname using fuzzy match.""" mock_brand.return_value = {} response = await client.post( "/api/v1/query", json={"type": "nickname", "values": ["测试达人"]}, ) assert response.status_code == 200 data = response.json() assert data["success"] is True assert data["total"] == 3 # 所有包含 "测试达人" 的记录 async def test_query_empty_values(self, client): """Test querying with empty values returns error.""" response = await client.post( "/api/v1/query", json={"type": "star_id", "values": []}, ) assert response.status_code == 422 # Validation error async def test_query_invalid_type(self, client): """Test querying with invalid type returns error.""" response = await client.post( "/api/v1/query", json={"type": "invalid_type", "values": ["test"]}, ) assert response.status_code == 422 @patch("app.api.v1.query.get_brand_names", new_callable=AsyncMock) async def test_query_no_results(self, mock_brand, client, test_session, seed_data): """Test querying with no matching results.""" mock_brand.return_value = {} response = await client.post( "/api/v1/query", json={"type": "star_id", "values": ["nonexistent_id"]}, ) assert response.status_code == 200 data = response.json() assert data["success"] is True assert data["total"] == 0 assert data["data"] == [] @patch("app.api.v1.query.get_brand_names", new_callable=AsyncMock) async def test_query_limit_enforcement(self, mock_brand, client, test_session): """Test that query limit is enforced.""" mock_brand.return_value = {} # 创建超过 1000 条记录的情况在测试中略过 # 这里只测试 API 能正常工作 response = await client.post( "/api/v1/query", json={"type": "star_id", "values": ["star_000"]}, ) assert response.status_code == 200 @patch("app.api.v1.query.get_brand_names", new_callable=AsyncMock) async def test_query_returns_calculated_fields( self, mock_brand, client, test_session, seed_data ): """Test that calculated fields are returned.""" mock_brand.return_value = {} response = await client.post( "/api/v1/query", json={"type": "star_id", "values": ["star_000"]}, ) assert response.status_code == 200 data = response.json() if data["total"] > 0: video = data["data"][0] # 检查计算字段存在 assert "estimated_natural_cpm" in video assert "estimated_natural_search_uv" in video assert "estimated_natural_search_cost" in video