""" Tests for SessionID Pool Service (T-021, T-022, T-027) T-027 更新: - 改为 CookieConfig 数据结构 - get_random_config() 随机选取配置 - remove_by_auth_token() 移除失效配置 """ import pytest from unittest.mock import AsyncMock, patch, MagicMock import httpx from app.services.session_pool import ( SessionPool, CookieConfig, session_pool, get_session_with_retry, get_random_config, get_distinct_configs, ) class TestSessionPool: """Tests for SessionPool class.""" async def test_refresh_success(self): """Test successful session pool refresh (T-027 format).""" pool = SessionPool() mock_response = MagicMock() mock_response.status_code = 200 mock_response.json.return_value = { "data": [ { "brand_id": "533661", "aadvid": "1648829117232140", "auth_token": "sessionid=session_001", "industry_id": 20, "brand_name": "Brand1", }, { "brand_id": "10186612", "aadvid": "9876543210", "auth_token": "sessionid=session_002", "industry_id": 30, "brand_name": "Brand2", }, ] } mock_client = AsyncMock() mock_client.get.return_value = mock_response mock_client.__aenter__.return_value = mock_client mock_client.__aexit__.return_value = None with patch("httpx.AsyncClient", return_value=mock_client): result = await pool.refresh() assert result is True assert pool.size == 2 assert not pool.is_empty async def test_refresh_with_sessionid_cookie_field(self): """Test refresh using sessionid_cookie field (fallback).""" pool = SessionPool() mock_response = MagicMock() mock_response.status_code = 200 mock_response.json.return_value = { "data": [ { "brand_id": "533661", "aadvid": "1648829117232140", "sessionid_cookie": "sessionid=session_001", "industry_id": 20, "brand_name": "Brand1", }, ] } mock_client = AsyncMock() mock_client.get.return_value = mock_response mock_client.__aenter__.return_value = mock_client mock_client.__aexit__.return_value = None with patch("httpx.AsyncClient", return_value=mock_client): result = await pool.refresh() assert result is True assert pool.size == 1 async def test_refresh_empty_data(self): """Test refresh with empty data array.""" pool = SessionPool() mock_response = MagicMock() mock_response.status_code = 200 mock_response.json.return_value = {"data": []} mock_client = AsyncMock() mock_client.get.return_value = mock_response mock_client.__aenter__.return_value = mock_client mock_client.__aexit__.return_value = None with patch("httpx.AsyncClient", return_value=mock_client): result = await pool.refresh() assert result is False assert pool.size == 0 async def test_refresh_api_error(self): """Test refresh with API error.""" pool = SessionPool() mock_response = MagicMock() mock_response.status_code = 500 mock_client = AsyncMock() mock_client.get.return_value = mock_response mock_client.__aenter__.return_value = mock_client mock_client.__aexit__.return_value = None with patch("httpx.AsyncClient", return_value=mock_client): result = await pool.refresh() assert result is False async def test_refresh_timeout(self): """Test refresh with timeout.""" pool = SessionPool() mock_client = AsyncMock() mock_client.get.side_effect = httpx.TimeoutException("Timeout") mock_client.__aenter__.return_value = mock_client mock_client.__aexit__.return_value = None with patch("httpx.AsyncClient", return_value=mock_client): result = await pool.refresh() assert result is False async def test_refresh_request_error(self): """Test refresh with request error.""" pool = SessionPool() mock_client = AsyncMock() mock_client.get.side_effect = httpx.RequestError("Connection failed") mock_client.__aenter__.return_value = mock_client mock_client.__aexit__.return_value = None with patch("httpx.AsyncClient", return_value=mock_client): result = await pool.refresh() assert result is False async def test_refresh_unexpected_error(self): """Test refresh with unexpected error.""" pool = SessionPool() mock_client = AsyncMock() mock_client.get.side_effect = ValueError("Unexpected") mock_client.__aenter__.return_value = mock_client mock_client.__aexit__.return_value = None with patch("httpx.AsyncClient", return_value=mock_client): result = await pool.refresh() assert result is False async def test_refresh_with_auth_header(self): """Test that refresh includes Authorization header.""" pool = SessionPool() mock_response = MagicMock() mock_response.status_code = 200 mock_response.json.return_value = { "data": [ { "brand_id": "123", "aadvid": "456", "auth_token": "sessionid=test", "industry_id": 20, "brand_name": "Test", } ] } mock_client = AsyncMock() mock_client.get.return_value = mock_response mock_client.__aenter__.return_value = mock_client mock_client.__aexit__.return_value = None with patch("httpx.AsyncClient", return_value=mock_client): with patch("app.services.session_pool.settings") as mock_settings: mock_settings.YUNTU_API_TOKEN = "test_token" mock_settings.YUNTU_API_TIMEOUT = 10.0 mock_settings.BRAND_API_BASE_URL = "https://api.test.com" await pool.refresh() mock_client.get.assert_called_once() call_args = mock_client.get.call_args assert "headers" in call_args.kwargs assert call_args.kwargs["headers"]["Authorization"] == "Bearer test_token" def test_get_random_config_from_pool(self): """Test getting random config from pool (T-027).""" pool = SessionPool() pool._configs = [ CookieConfig( brand_id="533661", aadvid="1648829117232140", auth_token="sessionid=session_1", industry_id=20, brand_name="Brand1", ), CookieConfig( brand_id="10186612", aadvid="9876543210", auth_token="sessionid=session_2", industry_id=30, brand_name="Brand2", ), ] config = pool.get_random_config() assert config is not None assert "aadvid" in config assert "auth_token" in config assert config["auth_token"] in ["sessionid=session_1", "sessionid=session_2"] def test_get_random_config_from_empty_pool(self): """Test getting random config from empty pool.""" pool = SessionPool() config = pool.get_random_config() assert config is None def test_get_random_from_pool_compat(self): """Test get_random compatibility method.""" pool = SessionPool() pool._configs = [ CookieConfig( brand_id="533661", aadvid="1648829117232140", auth_token="sessionid=session_1", industry_id=20, brand_name="Brand1", ), ] session = pool.get_random() assert session == "session_1" def test_get_random_from_empty_pool_compat(self): """Test get_random from empty pool.""" pool = SessionPool() session = pool.get_random() assert session is None def test_remove_by_auth_token(self): """Test removing config by auth_token (T-027).""" pool = SessionPool() pool._configs = [ CookieConfig( brand_id="533661", aadvid="1648829117232140", auth_token="sessionid=session_1", industry_id=20, brand_name="Brand1", ), CookieConfig( brand_id="10186612", aadvid="9876543210", auth_token="sessionid=session_2", industry_id=30, brand_name="Brand2", ), ] pool.remove_by_auth_token("sessionid=session_1") assert pool.size == 1 config = pool.get_random_config() assert config["auth_token"] == "sessionid=session_2" def test_remove_session_compat(self): """Test remove compatibility method.""" pool = SessionPool() pool._configs = [ CookieConfig( brand_id="533661", aadvid="1648829117232140", auth_token="sessionid=session_1", industry_id=20, brand_name="Brand1", ), CookieConfig( brand_id="10186612", aadvid="9876543210", auth_token="sessionid=session_2", industry_id=30, brand_name="Brand2", ), ] pool.remove("session_1") assert pool.size == 1 def test_remove_nonexistent_session(self): """Test removing a session that doesn't exist.""" pool = SessionPool() pool._configs = [ CookieConfig( brand_id="533661", aadvid="1648829117232140", auth_token="sessionid=session_1", industry_id=20, brand_name="Brand1", ), ] # Should not raise pool.remove_by_auth_token("nonexistent") assert pool.size == 1 def test_size_property(self): """Test size property.""" pool = SessionPool() assert pool.size == 0 pool._configs = [ CookieConfig( brand_id="123", aadvid="456", auth_token="sessionid=a", industry_id=20, brand_name="A", ), CookieConfig( brand_id="789", aadvid="012", auth_token="sessionid=b", industry_id=30, brand_name="B", ), ] assert pool.size == 2 def test_is_empty_property(self): """Test is_empty property.""" pool = SessionPool() assert pool.is_empty is True pool._configs = [ CookieConfig( brand_id="123", aadvid="456", auth_token="sessionid=a", industry_id=20, brand_name="A", ), ] assert pool.is_empty is False class TestGetRandomConfig: """Tests for get_random_config function (T-027).""" async def test_get_config_success(self): """Test successful config retrieval.""" pool = SessionPool() pool._configs = [ CookieConfig( brand_id="533661", aadvid="1648829117232140", auth_token="sessionid=session_1", industry_id=20, brand_name="Brand1", ), ] with patch("app.services.session_pool.session_pool", pool): result = await get_random_config() assert result is not None assert result["aadvid"] == "1648829117232140" assert result["auth_token"] == "sessionid=session_1" async def test_get_config_refresh_on_empty(self): """Test that pool is refreshed when empty.""" pool = SessionPool() with patch("app.services.session_pool.session_pool", pool): with patch.object(pool, "refresh") as mock_refresh: async def refresh_side_effect(): pool._configs = [ CookieConfig( brand_id="123", aadvid="456", auth_token="sessionid=new_session", industry_id=20, brand_name="New", ), ] return True mock_refresh.side_effect = refresh_side_effect result = await get_random_config() assert mock_refresh.called assert result["auth_token"] == "sessionid=new_session" async def test_get_config_retry_on_refresh_failure(self): """Test retry behavior when refresh fails.""" pool = SessionPool() with patch("app.services.session_pool.session_pool", pool): with patch.object(pool, "refresh") as mock_refresh: mock_refresh.return_value = False result = await get_random_config(max_retries=3) assert result is None assert mock_refresh.call_count == 3 class TestGetSessionWithRetry: """Tests for get_session_with_retry function (T-022 compat).""" async def test_get_session_success(self): """Test successful session retrieval.""" pool = SessionPool() pool._configs = [ CookieConfig( brand_id="533661", aadvid="1648829117232140", auth_token="sessionid=session_1", industry_id=20, brand_name="Brand1", ), ] with patch("app.services.session_pool.session_pool", pool): result = await get_session_with_retry() assert result == "session_1" async def test_get_session_refresh_on_empty(self): """Test that pool is refreshed when empty.""" pool = SessionPool() with patch("app.services.session_pool.session_pool", pool): with patch.object(pool, "refresh") as mock_refresh: async def refresh_side_effect(): pool._configs = [ CookieConfig( brand_id="123", aadvid="456", auth_token="sessionid=new_session", industry_id=20, brand_name="New", ), ] return True mock_refresh.side_effect = refresh_side_effect result = await get_session_with_retry() assert mock_refresh.called assert result == "new_session" async def test_get_session_retry_on_refresh_failure(self): """Test retry behavior when refresh fails.""" pool = SessionPool() with patch("app.services.session_pool.session_pool", pool): with patch.object(pool, "refresh") as mock_refresh: mock_refresh.return_value = False result = await get_session_with_retry(max_retries=3) assert result is None assert mock_refresh.call_count == 3 async def test_get_session_max_retries(self): """Test max retries limit.""" pool = SessionPool() with patch("app.services.session_pool.session_pool", pool): with patch.object(pool, "refresh") as mock_refresh: mock_refresh.return_value = False result = await get_session_with_retry(max_retries=5) assert result is None assert mock_refresh.call_count == 5 class TestSessionPoolIntegration: """Integration tests for session pool.""" async def test_refresh_filters_invalid_items(self): """Test that refresh filters out invalid items (T-027 format).""" pool = SessionPool() mock_response = MagicMock() mock_response.status_code = 200 mock_response.json.return_value = { "data": [ { "brand_id": "533661", "aadvid": "1648829117232140", "auth_token": "sessionid=valid_session", "industry_id": 20, "brand_name": "Valid1", }, {"no_auth_token": "missing"}, None, { "brand_id": "10186612", "aadvid": "", # Empty aadvid should be filtered "auth_token": "sessionid=xxx", "industry_id": 30, "brand_name": "Invalid", }, { "brand_id": "789012", "aadvid": "9876543210", "auth_token": "sessionid=another_valid", "industry_id": 40, "brand_name": "Valid2", }, ] } mock_client = AsyncMock() mock_client.get.return_value = mock_response mock_client.__aenter__.return_value = mock_client mock_client.__aexit__.return_value = None with patch("httpx.AsyncClient", return_value=mock_client): result = await pool.refresh() assert result is True assert pool.size == 2 async def test_refresh_handles_non_dict_data(self): """Test refresh with non-dict response.""" pool = SessionPool() mock_response = MagicMock() mock_response.status_code = 200 mock_response.json.return_value = ["not", "a", "dict"] mock_client = AsyncMock() mock_client.get.return_value = mock_response mock_client.__aenter__.return_value = mock_client mock_client.__aexit__.return_value = None with patch("httpx.AsyncClient", return_value=mock_client): result = await pool.refresh() assert result is False def _make_configs(count: int) -> list[CookieConfig]: """创建 count 个不同的 CookieConfig 用于测试。""" return [ CookieConfig( brand_id=f"brand_{i}", aadvid=f"aadvid_{i}", auth_token=f"sessionid=session_{i}", industry_id=20 + i, brand_name=f"Brand{i}", ) for i in range(count) ] class TestGetDistinctConfigs: """Tests for SessionPool.get_distinct_configs and module-level get_distinct_configs.""" def test_enough_configs_returns_distinct(self): """池中配置 >= count → 返回不重复的""" pool = SessionPool() pool._configs = _make_configs(5) result = pool.get_distinct_configs(3) assert len(result) == 3 tokens = [r["auth_token"] for r in result] assert len(set(tokens)) == 3 def test_exact_count(self): """池中配置 == count → 全部返回""" pool = SessionPool() pool._configs = _make_configs(3) result = pool.get_distinct_configs(3) assert len(result) == 3 tokens = {r["auth_token"] for r in result} assert len(tokens) == 3 def test_fewer_configs_wraps_around(self): """池中配置 < count → 循环复用补足""" pool = SessionPool() pool._configs = _make_configs(2) result = pool.get_distinct_configs(5) assert len(result) == 5 # 前 2 个一定不重复 first_two_tokens = {result[0]["auth_token"], result[1]["auth_token"]} assert len(first_two_tokens) == 2 def test_empty_pool_returns_empty(self): """空池 → 返回空列表""" pool = SessionPool() result = pool.get_distinct_configs(3) assert result == [] def test_zero_count_returns_empty(self): """count=0 → 返回空列表""" pool = SessionPool() pool._configs = _make_configs(3) result = pool.get_distinct_configs(0) assert result == [] def test_result_contains_all_fields(self): """验证返回的 dict 包含所有必要字段""" pool = SessionPool() pool._configs = _make_configs(1) result = pool.get_distinct_configs(1) assert len(result) == 1 item = result[0] assert "brand_id" in item assert "aadvid" in item assert "auth_token" in item assert "industry_id" in item assert "brand_name" in item @pytest.mark.asyncio async def test_module_level_get_distinct_configs(self): """测试模块级 get_distinct_configs 异步函数""" pool = SessionPool() pool._configs = _make_configs(3) with patch("app.services.session_pool.session_pool", pool): result = await get_distinct_configs(2) assert len(result) == 2 @pytest.mark.asyncio async def test_module_level_refreshes_on_empty(self): """池为空时自动刷新""" pool = SessionPool() with patch("app.services.session_pool.session_pool", pool): with patch.object(pool, "refresh") as mock_refresh: async def refresh_side_effect(): pool._configs = _make_configs(3) return True mock_refresh.side_effect = refresh_side_effect result = await get_distinct_configs(2) assert mock_refresh.called assert len(result) == 2 @pytest.mark.asyncio async def test_module_level_returns_empty_on_refresh_failure(self): """刷新失败 → 返回空列表""" pool = SessionPool() with patch("app.services.session_pool.session_pool", pool): with patch.object(pool, "refresh") as mock_refresh: mock_refresh.return_value = False result = await get_distinct_configs(2, max_retries=2) assert result == [] assert mock_refresh.call_count == 2