- SessionPool 新增 get_distinct_configs 方法,支持获取不同配置用于并发调用 - video_analysis 重构为缓存优先策略:数据库有 A3/Cost 数据时直接使用 - 并发 API 调用预分配不同 cookie,避免 session 冲突 - API 数据写回数据库,实现下次查询缓存命中 - 新增 heated_cost 字段追踪 - 测试全面重写,覆盖缓存/API/混合/降级场景
702 lines
22 KiB
Python
702 lines
22 KiB
Python
"""
|
|
Tests for SessionID Pool Service (T-021, T-022, T-027)
|
|
|
|
T-027 更新:
|
|
- 改为 CookieConfig 数据结构
|
|
- get_random_config() 随机选取配置
|
|
- remove_by_auth_token() 移除失效配置
|
|
"""
|
|
|
|
import pytest
|
|
from unittest.mock import AsyncMock, patch, MagicMock
|
|
import httpx
|
|
|
|
from app.services.session_pool import (
|
|
SessionPool,
|
|
CookieConfig,
|
|
session_pool,
|
|
get_session_with_retry,
|
|
get_random_config,
|
|
get_distinct_configs,
|
|
)
|
|
|
|
|
|
class TestSessionPool:
|
|
"""Tests for SessionPool class."""
|
|
|
|
async def test_refresh_success(self):
|
|
"""Test successful session pool refresh (T-027 format)."""
|
|
pool = SessionPool()
|
|
|
|
mock_response = MagicMock()
|
|
mock_response.status_code = 200
|
|
mock_response.json.return_value = {
|
|
"data": [
|
|
{
|
|
"brand_id": "533661",
|
|
"aadvid": "1648829117232140",
|
|
"auth_token": "sessionid=session_001",
|
|
"industry_id": 20,
|
|
"brand_name": "Brand1",
|
|
},
|
|
{
|
|
"brand_id": "10186612",
|
|
"aadvid": "9876543210",
|
|
"auth_token": "sessionid=session_002",
|
|
"industry_id": 30,
|
|
"brand_name": "Brand2",
|
|
},
|
|
]
|
|
}
|
|
|
|
mock_client = AsyncMock()
|
|
mock_client.get.return_value = mock_response
|
|
mock_client.__aenter__.return_value = mock_client
|
|
mock_client.__aexit__.return_value = None
|
|
|
|
with patch("httpx.AsyncClient", return_value=mock_client):
|
|
result = await pool.refresh()
|
|
|
|
assert result is True
|
|
assert pool.size == 2
|
|
assert not pool.is_empty
|
|
|
|
async def test_refresh_with_sessionid_cookie_field(self):
|
|
"""Test refresh using sessionid_cookie field (fallback)."""
|
|
pool = SessionPool()
|
|
|
|
mock_response = MagicMock()
|
|
mock_response.status_code = 200
|
|
mock_response.json.return_value = {
|
|
"data": [
|
|
{
|
|
"brand_id": "533661",
|
|
"aadvid": "1648829117232140",
|
|
"sessionid_cookie": "sessionid=session_001",
|
|
"industry_id": 20,
|
|
"brand_name": "Brand1",
|
|
},
|
|
]
|
|
}
|
|
|
|
mock_client = AsyncMock()
|
|
mock_client.get.return_value = mock_response
|
|
mock_client.__aenter__.return_value = mock_client
|
|
mock_client.__aexit__.return_value = None
|
|
|
|
with patch("httpx.AsyncClient", return_value=mock_client):
|
|
result = await pool.refresh()
|
|
|
|
assert result is True
|
|
assert pool.size == 1
|
|
|
|
async def test_refresh_empty_data(self):
|
|
"""Test refresh with empty data array."""
|
|
pool = SessionPool()
|
|
|
|
mock_response = MagicMock()
|
|
mock_response.status_code = 200
|
|
mock_response.json.return_value = {"data": []}
|
|
|
|
mock_client = AsyncMock()
|
|
mock_client.get.return_value = mock_response
|
|
mock_client.__aenter__.return_value = mock_client
|
|
mock_client.__aexit__.return_value = None
|
|
|
|
with patch("httpx.AsyncClient", return_value=mock_client):
|
|
result = await pool.refresh()
|
|
|
|
assert result is False
|
|
assert pool.size == 0
|
|
|
|
async def test_refresh_api_error(self):
|
|
"""Test refresh with API error."""
|
|
pool = SessionPool()
|
|
|
|
mock_response = MagicMock()
|
|
mock_response.status_code = 500
|
|
|
|
mock_client = AsyncMock()
|
|
mock_client.get.return_value = mock_response
|
|
mock_client.__aenter__.return_value = mock_client
|
|
mock_client.__aexit__.return_value = None
|
|
|
|
with patch("httpx.AsyncClient", return_value=mock_client):
|
|
result = await pool.refresh()
|
|
|
|
assert result is False
|
|
|
|
async def test_refresh_timeout(self):
|
|
"""Test refresh with timeout."""
|
|
pool = SessionPool()
|
|
|
|
mock_client = AsyncMock()
|
|
mock_client.get.side_effect = httpx.TimeoutException("Timeout")
|
|
mock_client.__aenter__.return_value = mock_client
|
|
mock_client.__aexit__.return_value = None
|
|
|
|
with patch("httpx.AsyncClient", return_value=mock_client):
|
|
result = await pool.refresh()
|
|
|
|
assert result is False
|
|
|
|
async def test_refresh_request_error(self):
|
|
"""Test refresh with request error."""
|
|
pool = SessionPool()
|
|
|
|
mock_client = AsyncMock()
|
|
mock_client.get.side_effect = httpx.RequestError("Connection failed")
|
|
mock_client.__aenter__.return_value = mock_client
|
|
mock_client.__aexit__.return_value = None
|
|
|
|
with patch("httpx.AsyncClient", return_value=mock_client):
|
|
result = await pool.refresh()
|
|
|
|
assert result is False
|
|
|
|
async def test_refresh_unexpected_error(self):
|
|
"""Test refresh with unexpected error."""
|
|
pool = SessionPool()
|
|
|
|
mock_client = AsyncMock()
|
|
mock_client.get.side_effect = ValueError("Unexpected")
|
|
mock_client.__aenter__.return_value = mock_client
|
|
mock_client.__aexit__.return_value = None
|
|
|
|
with patch("httpx.AsyncClient", return_value=mock_client):
|
|
result = await pool.refresh()
|
|
|
|
assert result is False
|
|
|
|
async def test_refresh_with_auth_header(self):
|
|
"""Test that refresh includes Authorization header."""
|
|
pool = SessionPool()
|
|
|
|
mock_response = MagicMock()
|
|
mock_response.status_code = 200
|
|
mock_response.json.return_value = {
|
|
"data": [
|
|
{
|
|
"brand_id": "123",
|
|
"aadvid": "456",
|
|
"auth_token": "sessionid=test",
|
|
"industry_id": 20,
|
|
"brand_name": "Test",
|
|
}
|
|
]
|
|
}
|
|
|
|
mock_client = AsyncMock()
|
|
mock_client.get.return_value = mock_response
|
|
mock_client.__aenter__.return_value = mock_client
|
|
mock_client.__aexit__.return_value = None
|
|
|
|
with patch("httpx.AsyncClient", return_value=mock_client):
|
|
with patch("app.services.session_pool.settings") as mock_settings:
|
|
mock_settings.YUNTU_API_TOKEN = "test_token"
|
|
mock_settings.YUNTU_API_TIMEOUT = 10.0
|
|
mock_settings.BRAND_API_BASE_URL = "https://api.test.com"
|
|
|
|
await pool.refresh()
|
|
|
|
mock_client.get.assert_called_once()
|
|
call_args = mock_client.get.call_args
|
|
assert "headers" in call_args.kwargs
|
|
assert call_args.kwargs["headers"]["Authorization"] == "Bearer test_token"
|
|
|
|
def test_get_random_config_from_pool(self):
|
|
"""Test getting random config from pool (T-027)."""
|
|
pool = SessionPool()
|
|
pool._configs = [
|
|
CookieConfig(
|
|
brand_id="533661",
|
|
aadvid="1648829117232140",
|
|
auth_token="sessionid=session_1",
|
|
industry_id=20,
|
|
brand_name="Brand1",
|
|
),
|
|
CookieConfig(
|
|
brand_id="10186612",
|
|
aadvid="9876543210",
|
|
auth_token="sessionid=session_2",
|
|
industry_id=30,
|
|
brand_name="Brand2",
|
|
),
|
|
]
|
|
|
|
config = pool.get_random_config()
|
|
|
|
assert config is not None
|
|
assert "aadvid" in config
|
|
assert "auth_token" in config
|
|
assert config["auth_token"] in ["sessionid=session_1", "sessionid=session_2"]
|
|
|
|
def test_get_random_config_from_empty_pool(self):
|
|
"""Test getting random config from empty pool."""
|
|
pool = SessionPool()
|
|
|
|
config = pool.get_random_config()
|
|
|
|
assert config is None
|
|
|
|
def test_get_random_from_pool_compat(self):
|
|
"""Test get_random compatibility method."""
|
|
pool = SessionPool()
|
|
pool._configs = [
|
|
CookieConfig(
|
|
brand_id="533661",
|
|
aadvid="1648829117232140",
|
|
auth_token="sessionid=session_1",
|
|
industry_id=20,
|
|
brand_name="Brand1",
|
|
),
|
|
]
|
|
|
|
session = pool.get_random()
|
|
|
|
assert session == "session_1"
|
|
|
|
def test_get_random_from_empty_pool_compat(self):
|
|
"""Test get_random from empty pool."""
|
|
pool = SessionPool()
|
|
|
|
session = pool.get_random()
|
|
|
|
assert session is None
|
|
|
|
def test_remove_by_auth_token(self):
|
|
"""Test removing config by auth_token (T-027)."""
|
|
pool = SessionPool()
|
|
pool._configs = [
|
|
CookieConfig(
|
|
brand_id="533661",
|
|
aadvid="1648829117232140",
|
|
auth_token="sessionid=session_1",
|
|
industry_id=20,
|
|
brand_name="Brand1",
|
|
),
|
|
CookieConfig(
|
|
brand_id="10186612",
|
|
aadvid="9876543210",
|
|
auth_token="sessionid=session_2",
|
|
industry_id=30,
|
|
brand_name="Brand2",
|
|
),
|
|
]
|
|
|
|
pool.remove_by_auth_token("sessionid=session_1")
|
|
|
|
assert pool.size == 1
|
|
config = pool.get_random_config()
|
|
assert config["auth_token"] == "sessionid=session_2"
|
|
|
|
def test_remove_session_compat(self):
|
|
"""Test remove compatibility method."""
|
|
pool = SessionPool()
|
|
pool._configs = [
|
|
CookieConfig(
|
|
brand_id="533661",
|
|
aadvid="1648829117232140",
|
|
auth_token="sessionid=session_1",
|
|
industry_id=20,
|
|
brand_name="Brand1",
|
|
),
|
|
CookieConfig(
|
|
brand_id="10186612",
|
|
aadvid="9876543210",
|
|
auth_token="sessionid=session_2",
|
|
industry_id=30,
|
|
brand_name="Brand2",
|
|
),
|
|
]
|
|
|
|
pool.remove("session_1")
|
|
|
|
assert pool.size == 1
|
|
|
|
def test_remove_nonexistent_session(self):
|
|
"""Test removing a session that doesn't exist."""
|
|
pool = SessionPool()
|
|
pool._configs = [
|
|
CookieConfig(
|
|
brand_id="533661",
|
|
aadvid="1648829117232140",
|
|
auth_token="sessionid=session_1",
|
|
industry_id=20,
|
|
brand_name="Brand1",
|
|
),
|
|
]
|
|
|
|
# Should not raise
|
|
pool.remove_by_auth_token("nonexistent")
|
|
|
|
assert pool.size == 1
|
|
|
|
def test_size_property(self):
|
|
"""Test size property."""
|
|
pool = SessionPool()
|
|
assert pool.size == 0
|
|
|
|
pool._configs = [
|
|
CookieConfig(
|
|
brand_id="123",
|
|
aadvid="456",
|
|
auth_token="sessionid=a",
|
|
industry_id=20,
|
|
brand_name="A",
|
|
),
|
|
CookieConfig(
|
|
brand_id="789",
|
|
aadvid="012",
|
|
auth_token="sessionid=b",
|
|
industry_id=30,
|
|
brand_name="B",
|
|
),
|
|
]
|
|
assert pool.size == 2
|
|
|
|
def test_is_empty_property(self):
|
|
"""Test is_empty property."""
|
|
pool = SessionPool()
|
|
assert pool.is_empty is True
|
|
|
|
pool._configs = [
|
|
CookieConfig(
|
|
brand_id="123",
|
|
aadvid="456",
|
|
auth_token="sessionid=a",
|
|
industry_id=20,
|
|
brand_name="A",
|
|
),
|
|
]
|
|
assert pool.is_empty is False
|
|
|
|
|
|
class TestGetRandomConfig:
|
|
"""Tests for get_random_config function (T-027)."""
|
|
|
|
async def test_get_config_success(self):
|
|
"""Test successful config retrieval."""
|
|
pool = SessionPool()
|
|
pool._configs = [
|
|
CookieConfig(
|
|
brand_id="533661",
|
|
aadvid="1648829117232140",
|
|
auth_token="sessionid=session_1",
|
|
industry_id=20,
|
|
brand_name="Brand1",
|
|
),
|
|
]
|
|
|
|
with patch("app.services.session_pool.session_pool", pool):
|
|
result = await get_random_config()
|
|
|
|
assert result is not None
|
|
assert result["aadvid"] == "1648829117232140"
|
|
assert result["auth_token"] == "sessionid=session_1"
|
|
|
|
async def test_get_config_refresh_on_empty(self):
|
|
"""Test that pool is refreshed when empty."""
|
|
pool = SessionPool()
|
|
|
|
with patch("app.services.session_pool.session_pool", pool):
|
|
with patch.object(pool, "refresh") as mock_refresh:
|
|
async def refresh_side_effect():
|
|
pool._configs = [
|
|
CookieConfig(
|
|
brand_id="123",
|
|
aadvid="456",
|
|
auth_token="sessionid=new_session",
|
|
industry_id=20,
|
|
brand_name="New",
|
|
),
|
|
]
|
|
return True
|
|
|
|
mock_refresh.side_effect = refresh_side_effect
|
|
|
|
result = await get_random_config()
|
|
|
|
assert mock_refresh.called
|
|
assert result["auth_token"] == "sessionid=new_session"
|
|
|
|
async def test_get_config_retry_on_refresh_failure(self):
|
|
"""Test retry behavior when refresh fails."""
|
|
pool = SessionPool()
|
|
|
|
with patch("app.services.session_pool.session_pool", pool):
|
|
with patch.object(pool, "refresh") as mock_refresh:
|
|
mock_refresh.return_value = False
|
|
|
|
result = await get_random_config(max_retries=3)
|
|
|
|
assert result is None
|
|
assert mock_refresh.call_count == 3
|
|
|
|
|
|
class TestGetSessionWithRetry:
|
|
"""Tests for get_session_with_retry function (T-022 compat)."""
|
|
|
|
async def test_get_session_success(self):
|
|
"""Test successful session retrieval."""
|
|
pool = SessionPool()
|
|
pool._configs = [
|
|
CookieConfig(
|
|
brand_id="533661",
|
|
aadvid="1648829117232140",
|
|
auth_token="sessionid=session_1",
|
|
industry_id=20,
|
|
brand_name="Brand1",
|
|
),
|
|
]
|
|
|
|
with patch("app.services.session_pool.session_pool", pool):
|
|
result = await get_session_with_retry()
|
|
|
|
assert result == "session_1"
|
|
|
|
async def test_get_session_refresh_on_empty(self):
|
|
"""Test that pool is refreshed when empty."""
|
|
pool = SessionPool()
|
|
|
|
with patch("app.services.session_pool.session_pool", pool):
|
|
with patch.object(pool, "refresh") as mock_refresh:
|
|
async def refresh_side_effect():
|
|
pool._configs = [
|
|
CookieConfig(
|
|
brand_id="123",
|
|
aadvid="456",
|
|
auth_token="sessionid=new_session",
|
|
industry_id=20,
|
|
brand_name="New",
|
|
),
|
|
]
|
|
return True
|
|
|
|
mock_refresh.side_effect = refresh_side_effect
|
|
|
|
result = await get_session_with_retry()
|
|
|
|
assert mock_refresh.called
|
|
assert result == "new_session"
|
|
|
|
async def test_get_session_retry_on_refresh_failure(self):
|
|
"""Test retry behavior when refresh fails."""
|
|
pool = SessionPool()
|
|
|
|
with patch("app.services.session_pool.session_pool", pool):
|
|
with patch.object(pool, "refresh") as mock_refresh:
|
|
mock_refresh.return_value = False
|
|
|
|
result = await get_session_with_retry(max_retries=3)
|
|
|
|
assert result is None
|
|
assert mock_refresh.call_count == 3
|
|
|
|
async def test_get_session_max_retries(self):
|
|
"""Test max retries limit."""
|
|
pool = SessionPool()
|
|
|
|
with patch("app.services.session_pool.session_pool", pool):
|
|
with patch.object(pool, "refresh") as mock_refresh:
|
|
mock_refresh.return_value = False
|
|
|
|
result = await get_session_with_retry(max_retries=5)
|
|
|
|
assert result is None
|
|
assert mock_refresh.call_count == 5
|
|
|
|
|
|
class TestSessionPoolIntegration:
|
|
"""Integration tests for session pool."""
|
|
|
|
async def test_refresh_filters_invalid_items(self):
|
|
"""Test that refresh filters out invalid items (T-027 format)."""
|
|
pool = SessionPool()
|
|
|
|
mock_response = MagicMock()
|
|
mock_response.status_code = 200
|
|
mock_response.json.return_value = {
|
|
"data": [
|
|
{
|
|
"brand_id": "533661",
|
|
"aadvid": "1648829117232140",
|
|
"auth_token": "sessionid=valid_session",
|
|
"industry_id": 20,
|
|
"brand_name": "Valid1",
|
|
},
|
|
{"no_auth_token": "missing"},
|
|
None,
|
|
{
|
|
"brand_id": "10186612",
|
|
"aadvid": "", # Empty aadvid should be filtered
|
|
"auth_token": "sessionid=xxx",
|
|
"industry_id": 30,
|
|
"brand_name": "Invalid",
|
|
},
|
|
{
|
|
"brand_id": "789012",
|
|
"aadvid": "9876543210",
|
|
"auth_token": "sessionid=another_valid",
|
|
"industry_id": 40,
|
|
"brand_name": "Valid2",
|
|
},
|
|
]
|
|
}
|
|
|
|
mock_client = AsyncMock()
|
|
mock_client.get.return_value = mock_response
|
|
mock_client.__aenter__.return_value = mock_client
|
|
mock_client.__aexit__.return_value = None
|
|
|
|
with patch("httpx.AsyncClient", return_value=mock_client):
|
|
result = await pool.refresh()
|
|
|
|
assert result is True
|
|
assert pool.size == 2
|
|
|
|
async def test_refresh_handles_non_dict_data(self):
|
|
"""Test refresh with non-dict response."""
|
|
pool = SessionPool()
|
|
|
|
mock_response = MagicMock()
|
|
mock_response.status_code = 200
|
|
mock_response.json.return_value = ["not", "a", "dict"]
|
|
|
|
mock_client = AsyncMock()
|
|
mock_client.get.return_value = mock_response
|
|
mock_client.__aenter__.return_value = mock_client
|
|
mock_client.__aexit__.return_value = None
|
|
|
|
with patch("httpx.AsyncClient", return_value=mock_client):
|
|
result = await pool.refresh()
|
|
|
|
assert result is False
|
|
|
|
|
|
def _make_configs(count: int) -> list[CookieConfig]:
|
|
"""创建 count 个不同的 CookieConfig 用于测试。"""
|
|
return [
|
|
CookieConfig(
|
|
brand_id=f"brand_{i}",
|
|
aadvid=f"aadvid_{i}",
|
|
auth_token=f"sessionid=session_{i}",
|
|
industry_id=20 + i,
|
|
brand_name=f"Brand{i}",
|
|
)
|
|
for i in range(count)
|
|
]
|
|
|
|
|
|
class TestGetDistinctConfigs:
|
|
"""Tests for SessionPool.get_distinct_configs and module-level get_distinct_configs."""
|
|
|
|
def test_enough_configs_returns_distinct(self):
|
|
"""池中配置 >= count → 返回不重复的"""
|
|
pool = SessionPool()
|
|
pool._configs = _make_configs(5)
|
|
|
|
result = pool.get_distinct_configs(3)
|
|
|
|
assert len(result) == 3
|
|
tokens = [r["auth_token"] for r in result]
|
|
assert len(set(tokens)) == 3
|
|
|
|
def test_exact_count(self):
|
|
"""池中配置 == count → 全部返回"""
|
|
pool = SessionPool()
|
|
pool._configs = _make_configs(3)
|
|
|
|
result = pool.get_distinct_configs(3)
|
|
|
|
assert len(result) == 3
|
|
tokens = {r["auth_token"] for r in result}
|
|
assert len(tokens) == 3
|
|
|
|
def test_fewer_configs_wraps_around(self):
|
|
"""池中配置 < count → 循环复用补足"""
|
|
pool = SessionPool()
|
|
pool._configs = _make_configs(2)
|
|
|
|
result = pool.get_distinct_configs(5)
|
|
|
|
assert len(result) == 5
|
|
# 前 2 个一定不重复
|
|
first_two_tokens = {result[0]["auth_token"], result[1]["auth_token"]}
|
|
assert len(first_two_tokens) == 2
|
|
|
|
def test_empty_pool_returns_empty(self):
|
|
"""空池 → 返回空列表"""
|
|
pool = SessionPool()
|
|
|
|
result = pool.get_distinct_configs(3)
|
|
|
|
assert result == []
|
|
|
|
def test_zero_count_returns_empty(self):
|
|
"""count=0 → 返回空列表"""
|
|
pool = SessionPool()
|
|
pool._configs = _make_configs(3)
|
|
|
|
result = pool.get_distinct_configs(0)
|
|
|
|
assert result == []
|
|
|
|
def test_result_contains_all_fields(self):
|
|
"""验证返回的 dict 包含所有必要字段"""
|
|
pool = SessionPool()
|
|
pool._configs = _make_configs(1)
|
|
|
|
result = pool.get_distinct_configs(1)
|
|
|
|
assert len(result) == 1
|
|
item = result[0]
|
|
assert "brand_id" in item
|
|
assert "aadvid" in item
|
|
assert "auth_token" in item
|
|
assert "industry_id" in item
|
|
assert "brand_name" in item
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_module_level_get_distinct_configs(self):
|
|
"""测试模块级 get_distinct_configs 异步函数"""
|
|
pool = SessionPool()
|
|
pool._configs = _make_configs(3)
|
|
|
|
with patch("app.services.session_pool.session_pool", pool):
|
|
result = await get_distinct_configs(2)
|
|
|
|
assert len(result) == 2
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_module_level_refreshes_on_empty(self):
|
|
"""池为空时自动刷新"""
|
|
pool = SessionPool()
|
|
|
|
with patch("app.services.session_pool.session_pool", pool):
|
|
with patch.object(pool, "refresh") as mock_refresh:
|
|
async def refresh_side_effect():
|
|
pool._configs = _make_configs(3)
|
|
return True
|
|
|
|
mock_refresh.side_effect = refresh_side_effect
|
|
|
|
result = await get_distinct_configs(2)
|
|
|
|
assert mock_refresh.called
|
|
assert len(result) == 2
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_module_level_returns_empty_on_refresh_failure(self):
|
|
"""刷新失败 → 返回空列表"""
|
|
pool = SessionPool()
|
|
|
|
with patch("app.services.session_pool.session_pool", pool):
|
|
with patch.object(pool, "refresh") as mock_refresh:
|
|
mock_refresh.return_value = False
|
|
|
|
result = await get_distinct_configs(2, max_retries=2)
|
|
|
|
assert result == []
|
|
assert mock_refresh.call_count == 2
|