feat(backend): 视频分析模块增加缓存优先策略和并发API调用

- SessionPool 新增 get_distinct_configs 方法,支持获取不同配置用于并发调用
- video_analysis 重构为缓存优先策略:数据库有 A3/Cost 数据时直接使用
- 并发 API 调用预分配不同 cookie,避免 session 冲突
- API 数据写回数据库,实现下次查询缓存命中
- 新增 heated_cost 字段追踪
- 测试全面重写,覆盖缓存/API/混合/降级场景
This commit is contained in:
zfc 2026-01-29 18:21:50 +08:00
parent c53b5008df
commit 376f0be6b4
4 changed files with 983 additions and 326 deletions

View File

@ -11,8 +11,8 @@ T-027 修复:
import asyncio
import logging
import random
from typing import Dict, Optional, Any, List
from dataclasses import dataclass
from typing import Any, Dict, List, Optional
import httpx
@ -150,6 +150,46 @@ class SessionPool:
"""检查池是否为空"""
return len(self._configs) == 0
def get_distinct_configs(self, count: int) -> List[Dict[str, Any]]:
"""
获取 count 个不同的配置用于并发调用
- 池中配置 >= count随机抽样 count 个不重复的
- 池中配置 < count全部取出循环复用补足
- 池为空返回空列表
Args:
count: 需要的配置数量
Returns:
List[Dict]: 配置字典列表
"""
if not self._configs or count <= 0:
return []
def _to_dict(config: CookieConfig) -> Dict[str, Any]:
return {
"brand_id": config.brand_id,
"aadvid": config.aadvid,
"auth_token": config.auth_token,
"industry_id": config.industry_id,
"brand_name": config.brand_name,
}
if len(self._configs) >= count:
sampled = random.sample(self._configs, count)
return [_to_dict(c) for c in sampled]
# 池中配置不足,全部取出后循环复用
result = [_to_dict(c) for c in self._configs]
shuffled = list(self._configs)
random.shuffle(shuffled)
idx = 0
while len(result) < count:
result.append(_to_dict(shuffled[idx % len(shuffled)]))
idx += 1
return result
# 兼容旧接口
def get_random(self) -> Optional[str]:
"""兼容旧接口:随机获取一个 SessionID"""
@ -218,6 +258,32 @@ async def get_session_with_retry(max_retries: int = 3) -> Optional[str]:
return None
async def get_distinct_configs(count: int, max_retries: int = 3) -> List[Dict[str, Any]]:
"""
获取 count 个不同的配置必要时刷新池
Args:
count: 需要的配置数量
max_retries: 最大重试次数
Returns:
List[Dict]: 配置字典列表
"""
for attempt in range(max_retries):
if session_pool.is_empty:
success = await session_pool.refresh()
if not success:
logger.warning(f"Session pool refresh failed, attempt {attempt + 1}")
continue
configs = session_pool.get_distinct_configs(count)
if configs:
return configs
logger.error("Failed to get distinct configs after all retries")
return []
async def get_config_for_brand(brand_id: str, max_retries: int = 3) -> Optional[Any]:
"""
兼容旧接口获取品牌对应的配置

View File

@ -4,25 +4,43 @@
实现视频分析数据获取和成本指标计算
"""
import asyncio
import logging
from datetime import datetime
from typing import Dict, Optional, Any
from typing import Any, Dict, Optional
from sqlalchemy import select, update
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from sqlalchemy import update
from app.models.kol_video import KolVideo
from app.services.session_pool import (
get_distinct_configs,
get_random_config,
session_pool,
)
from app.services.yuntu_api import (
SessionInvalidError,
call_yuntu_api,
parse_analysis_response,
)
from app.services.yuntu_api import (
get_video_analysis as fetch_yuntu_analysis,
parse_analysis_response,
YuntuAPIError,
)
logger = logging.getLogger(__name__)
def _needs_api_call(video: KolVideo) -> bool:
"""
判断是否需要调用 Yuntu API 获取 A3/Cost 数据
如果数据库中已有 A3 Cost 数据直接使用数据库数据不调 API
"""
has_a3 = (video.total_new_a3_cnt or 0) > 0
has_cost = (video.total_cost or 0) > 0
return not (has_a3 or has_cost)
def calculate_cost_metrics(
cost: float,
natural_play_cnt: int,
@ -151,33 +169,58 @@ async def get_video_analysis_data(
brand_map = await get_brand_names([video.brand_id])
brand_name = brand_map.get(video.brand_id, video.brand_id)
# 3. 调用巨量云图API获取实时 A3 数据和 cost
# 3. 获取 A3 数据和 cost(缓存优先策略)
a3_increase_cnt = 0
ad_a3_increase_cnt = 0
natural_a3_increase_cnt = 0
api_cost = 0.0
ad_cost = 0.0
try:
publish_time = video.publish_time or datetime.now()
industry_id = video.industry_id or ""
api_response = await fetch_yuntu_analysis(
item_id=item_id,
publish_time=publish_time,
industry_id=industry_id,
)
analysis_data = parse_analysis_response(api_response)
a3_increase_cnt = analysis_data.get("a3_increase_cnt", 0)
ad_a3_increase_cnt = analysis_data.get("ad_a3_increase_cnt", 0)
natural_a3_increase_cnt = analysis_data.get("natural_a3_increase_cnt", 0)
api_cost = analysis_data.get("cost", 0)
except Exception as e:
logger.warning(f"API failed for {item_id}: {e}, using DB data")
if not _needs_api_call(video):
# 数据库已有数据,直接使用
logger.info(f"Using DB data for {item_id} (A3/Cost already cached)")
a3_increase_cnt = video.total_new_a3_cnt or 0
ad_a3_increase_cnt = video.heated_new_a3_cnt or 0
natural_a3_increase_cnt = video.natural_new_a3_cnt or 0
api_cost = video.total_cost or 0.0
ad_cost = video.heated_cost or 0.0
else:
# 需要调用 API 获取数据
try:
publish_time = video.publish_time or datetime.now()
industry_id = video.industry_id or ""
api_response = await fetch_yuntu_analysis(
item_id=item_id,
publish_time=publish_time,
industry_id=industry_id,
)
analysis_data = parse_analysis_response(api_response)
a3_increase_cnt = analysis_data.get("a3_increase_cnt", 0)
ad_a3_increase_cnt = analysis_data.get("ad_a3_increase_cnt", 0)
natural_a3_increase_cnt = analysis_data.get("natural_a3_increase_cnt", 0)
api_cost = analysis_data.get("cost", 0)
ad_cost = analysis_data.get("ad_cost", 0)
# 写回数据库
await update_video_a3_metrics(
session=session,
item_id=item_id,
total_new_a3_cnt=int(a3_increase_cnt),
heated_new_a3_cnt=int(ad_a3_increase_cnt),
natural_new_a3_cnt=int(natural_a3_increase_cnt),
total_cost=float(api_cost),
heated_cost=float(ad_cost),
)
logger.info(f"API data fetched and saved to DB for {item_id}")
except Exception as e:
logger.warning(f"API failed for {item_id}: {e}, using DB data")
a3_increase_cnt = video.total_new_a3_cnt or 0
ad_a3_increase_cnt = video.heated_new_a3_cnt or 0
natural_a3_increase_cnt = video.natural_new_a3_cnt or 0
api_cost = video.total_cost or 0.0
ad_cost = video.heated_cost or 0.0
# 4. 数据库字段
estimated_video_cost = video.estimated_video_cost or 0.0
@ -187,8 +230,7 @@ async def get_video_analysis_data(
after_view_search_uv = video.after_view_search_uv or 0
# 5. 计算成本指标
# 预估加热费用 = max(total_cost - estimated_video_cost, 0)
heated_cost = max(api_cost - estimated_video_cost, 0) if api_cost > estimated_video_cost else 0
heated_cost = ad_cost
# 预估自然看后搜人数
estimated_natural_search_uv = None
@ -271,9 +313,10 @@ async def update_video_a3_metrics(
heated_new_a3_cnt: int,
natural_new_a3_cnt: int,
total_cost: float,
heated_cost: float = 0.0,
) -> bool:
"""
更新数据库中的A3指标 (T-025)
更新数据库中的A3指标和费用数据 (T-025)
Args:
session: 数据库会话
@ -281,7 +324,8 @@ async def update_video_a3_metrics(
total_new_a3_cnt: 总新增A3
heated_new_a3_cnt: 加热新增A3
natural_new_a3_cnt: 自然新增A3
total_cost: 总花费
total_cost: 预估总费用
heated_cost: 预估加热费用
Returns:
bool: 更新是否成功
@ -295,6 +339,7 @@ async def update_video_a3_metrics(
heated_new_a3_cnt=heated_new_a3_cnt,
natural_new_a3_cnt=natural_new_a3_cnt,
total_cost=total_cost,
heated_cost=heated_cost,
)
)
result = await session.execute(stmt)
@ -313,39 +358,6 @@ async def update_video_a3_metrics(
return False
async def get_and_update_video_analysis(
session: AsyncSession, item_id: str
) -> Dict[str, Any]:
"""
获取视频分析数据并更新数据库中的A3指标 (T-024 + T-025 组合)
Args:
session: 数据库会话
item_id: 视频ID
Returns:
Dict: 完整的视频分析数据
"""
# 获取分析数据
result = await get_video_analysis_data(session, item_id)
# 提取A3指标
a3_metrics = result.get("a3_metrics", {})
cost_raw = result.get("cost_metrics_raw", {})
# 更新数据库
await update_video_a3_metrics(
session=session,
item_id=item_id,
total_new_a3_cnt=a3_metrics.get("a3_increase_cnt", 0),
heated_new_a3_cnt=a3_metrics.get("ad_a3_increase_cnt", 0),
natural_new_a3_cnt=a3_metrics.get("natural_a3_increase_cnt", 0),
total_cost=cost_raw.get("cost", 0),
)
return result
async def search_videos_by_star_id(
session: AsyncSession, star_id: str
) -> list[KolVideo]:
@ -373,11 +385,60 @@ async def search_videos_by_nickname(
return list(result.scalars().all())
def _build_video_list_item(
video: KolVideo,
a3_increase_cnt: int,
ad_a3_increase_cnt: int,
natural_a3_increase_cnt: int,
api_cost: float,
brand_name: str,
) -> Dict[str, Any]:
"""构建视频列表项的结果字典。"""
estimated_video_cost = video.estimated_video_cost or 0.0
natural_play_cnt = video.natural_play_cnt or 0
total_play_cnt = video.total_play_cnt or 0
after_view_search_uv = video.after_view_search_uv or 0
estimated_natural_search_uv = None
if total_play_cnt > 0 and after_view_search_uv > 0:
estimated_natural_search_uv = (natural_play_cnt / total_play_cnt) * after_view_search_uv
estimated_natural_cpm = round((estimated_video_cost / natural_play_cnt) * 1000, 2) if natural_play_cnt > 0 else None
estimated_cp_a3 = round(api_cost / a3_increase_cnt, 2) if a3_increase_cnt > 0 else None
estimated_natural_cp_a3 = round(estimated_video_cost / natural_a3_increase_cnt, 2) if natural_a3_increase_cnt > 0 else None
estimated_cp_search = round(api_cost / after_view_search_uv, 2) if after_view_search_uv > 0 else None
estimated_natural_cp_search = round(estimated_video_cost / estimated_natural_search_uv, 2) if estimated_natural_search_uv and estimated_natural_search_uv > 0 else None
return {
"item_id": video.item_id,
"star_nickname": video.star_nickname or "",
"title": video.title or "",
"video_url": video.video_url or "",
"create_date": video.publish_time.isoformat() if video.publish_time else None,
"hot_type": video.viral_type or "",
"industry_id": video.industry_id or "",
"brand_id": video.brand_id or "",
"brand_name": brand_name,
"total_new_a3_cnt": a3_increase_cnt,
"heated_new_a3_cnt": ad_a3_increase_cnt,
"natural_new_a3_cnt": natural_a3_increase_cnt,
"estimated_natural_cpm": estimated_natural_cpm,
"estimated_cp_a3": estimated_cp_a3,
"estimated_natural_cp_a3": estimated_natural_cp_a3,
"estimated_cp_search": estimated_cp_search,
"estimated_natural_cp_search": estimated_natural_cp_search,
}
async def get_video_list_with_a3(
session: AsyncSession, videos: list[KolVideo]
) -> list[Dict[str, Any]]:
"""
获取视频列表的摘要数据实时调用云图API获取A3数据
获取视频列表的摘要数据
缓存优先策略
- 数据库有 A3/Cost 数据 直接使用
- 数据库无数据 并发调用云图 API预分配不同 cookie 写回数据库
"""
from app.services.brand_api import get_brand_names
@ -385,73 +446,149 @@ async def get_video_list_with_a3(
brand_ids = [video.brand_id for video in videos if video.brand_id]
brand_map = await get_brand_names(brand_ids) if brand_ids else {}
result = []
for video in videos:
# 实时调用云图 API 获取 A3 数据和 cost
a3_increase_cnt = 0
ad_a3_increase_cnt = 0
natural_a3_increase_cnt = 0
api_cost = 0.0
# 分组:已有数据 vs 需要 API 调用
cached_videos: list[tuple[int, KolVideo]] = [] # (原始索引, video)
api_videos: list[tuple[int, KolVideo]] = [] # (原始索引, video)
try:
publish_time = video.publish_time or datetime.now()
industry_id = video.industry_id or ""
for idx, video in enumerate(videos):
if _needs_api_call(video):
api_videos.append((idx, video))
else:
cached_videos.append((idx, video))
api_response = await fetch_yuntu_analysis(
item_id=video.item_id,
publish_time=publish_time,
industry_id=industry_id,
logger.info(
f"Video list: {len(cached_videos)} cached, {len(api_videos)} need API"
)
# 结果数组(按原始索引填充)
results: list[Optional[Dict[str, Any]]] = [None] * len(videos)
# 组 A直接用数据库数据
for idx, video in cached_videos:
brand_name = brand_map.get(video.brand_id, video.brand_id or "") if video.brand_id else ""
results[idx] = _build_video_list_item(
video=video,
a3_increase_cnt=video.total_new_a3_cnt or 0,
ad_a3_increase_cnt=video.heated_new_a3_cnt or 0,
natural_a3_increase_cnt=video.natural_new_a3_cnt or 0,
api_cost=video.total_cost or 0.0,
brand_name=brand_name,
)
# 组 B并发调用 API预分配不同 cookie
if api_videos:
configs = await get_distinct_configs(len(api_videos))
semaphore = asyncio.Semaphore(5)
# 收集需要写回 DB 的数据(避免并发 session 操作)
pending_updates: list[Dict[str, Any]] = []
async def _fetch_single(
idx: int, video: KolVideo, config: Dict[str, Any]
) -> None:
a3_increase_cnt = 0
ad_a3_increase_cnt = 0
natural_a3_increase_cnt = 0
api_cost = 0.0
ad_cost_val = 0.0
api_success = False
async with semaphore:
try:
publish_time = video.publish_time or datetime.now()
industry_id = video.industry_id or ""
api_response = await call_yuntu_api(
item_id=video.item_id,
publish_time=publish_time,
industry_id=industry_id,
aadvid=config["aadvid"],
auth_token=config["auth_token"],
)
api_data = parse_analysis_response(api_response)
a3_increase_cnt = api_data.get("a3_increase_cnt", 0)
ad_a3_increase_cnt = api_data.get("ad_a3_increase_cnt", 0)
natural_a3_increase_cnt = api_data.get("natural_a3_increase_cnt", 0)
api_cost = api_data.get("cost", 0)
ad_cost_val = api_data.get("ad_cost", 0)
api_success = True
except SessionInvalidError:
# Session 失效,从池中移除,重新获取随机 config 重试
session_pool.remove_by_auth_token(config["auth_token"])
logger.warning(f"Session invalid for {video.item_id}, retrying")
retry_config = await get_random_config()
if retry_config:
try:
publish_time = video.publish_time or datetime.now()
industry_id = video.industry_id or ""
api_response = await call_yuntu_api(
item_id=video.item_id,
publish_time=publish_time,
industry_id=industry_id,
aadvid=retry_config["aadvid"],
auth_token=retry_config["auth_token"],
)
api_data = parse_analysis_response(api_response)
a3_increase_cnt = api_data.get("a3_increase_cnt", 0)
ad_a3_increase_cnt = api_data.get("ad_a3_increase_cnt", 0)
natural_a3_increase_cnt = api_data.get("natural_a3_increase_cnt", 0)
api_cost = api_data.get("cost", 0)
ad_cost_val = api_data.get("ad_cost", 0)
api_success = True
except Exception as e2:
logger.warning(f"Retry failed for {video.item_id}: {e2}")
a3_increase_cnt = video.total_new_a3_cnt or 0
ad_a3_increase_cnt = video.heated_new_a3_cnt or 0
natural_a3_increase_cnt = video.natural_new_a3_cnt or 0
api_cost = video.total_cost or 0.0
except Exception as e:
logger.warning(f"API failed for {video.item_id}: {e}")
a3_increase_cnt = video.total_new_a3_cnt or 0
ad_a3_increase_cnt = video.heated_new_a3_cnt or 0
natural_a3_increase_cnt = video.natural_new_a3_cnt or 0
api_cost = video.total_cost or 0.0
# 收集待写回 DB 的数据(不在并发中操作 session
if api_success:
pending_updates.append({
"item_id": video.item_id,
"total_new_a3_cnt": int(a3_increase_cnt),
"heated_new_a3_cnt": int(ad_a3_increase_cnt),
"natural_new_a3_cnt": int(natural_a3_increase_cnt),
"total_cost": float(api_cost),
"heated_cost": float(ad_cost_val),
})
brand_name = brand_map.get(video.brand_id, video.brand_id or "") if video.brand_id else ""
results[idx] = _build_video_list_item(
video=video,
a3_increase_cnt=a3_increase_cnt,
ad_a3_increase_cnt=ad_a3_increase_cnt,
natural_a3_increase_cnt=natural_a3_increase_cnt,
api_cost=api_cost,
brand_name=brand_name,
)
api_data = parse_analysis_response(api_response)
a3_increase_cnt = api_data.get("a3_increase_cnt", 0)
ad_a3_increase_cnt = api_data.get("ad_a3_increase_cnt", 0)
natural_a3_increase_cnt = api_data.get("natural_a3_increase_cnt", 0)
api_cost = api_data.get("cost", 0)
except Exception as e:
logger.warning(f"API failed for {video.item_id}: {e}")
a3_increase_cnt = video.total_new_a3_cnt or 0
ad_a3_increase_cnt = video.heated_new_a3_cnt or 0
natural_a3_increase_cnt = video.natural_new_a3_cnt or 0
api_cost = video.total_cost or 0.0
# 为每个视频分配一个独立的 config并发执行
tasks = []
for i, (idx, video) in enumerate(api_videos):
config = configs[i] if i < len(configs) else configs[i % len(configs)] if configs else {}
tasks.append(_fetch_single(idx, video, config))
# 数据库字段
estimated_video_cost = video.estimated_video_cost or 0.0
natural_play_cnt = video.natural_play_cnt or 0
total_play_cnt = video.total_play_cnt or 0
after_view_search_uv = video.after_view_search_uv or 0
await asyncio.gather(*tasks)
# 计算成本指标
estimated_natural_search_uv = None
if total_play_cnt > 0 and after_view_search_uv > 0:
estimated_natural_search_uv = (natural_play_cnt / total_play_cnt) * after_view_search_uv
# 顺序写回 DB避免并发 session 操作导致状态损坏)
for upd in pending_updates:
await update_video_a3_metrics(
session=session,
item_id=upd["item_id"],
total_new_a3_cnt=upd["total_new_a3_cnt"],
heated_new_a3_cnt=upd["heated_new_a3_cnt"],
natural_new_a3_cnt=upd["natural_new_a3_cnt"],
total_cost=upd["total_cost"],
heated_cost=upd["heated_cost"],
)
estimated_natural_cpm = round((estimated_video_cost / natural_play_cnt) * 1000, 2) if natural_play_cnt > 0 else None
estimated_cp_a3 = round(api_cost / a3_increase_cnt, 2) if a3_increase_cnt > 0 else None
estimated_natural_cp_a3 = round(estimated_video_cost / natural_a3_increase_cnt, 2) if natural_a3_increase_cnt > 0 else None
estimated_cp_search = round(api_cost / after_view_search_uv, 2) if after_view_search_uv > 0 else None
estimated_natural_cp_search = round(estimated_video_cost / estimated_natural_search_uv, 2) if estimated_natural_search_uv and estimated_natural_search_uv > 0 else None
brand_name = brand_map.get(video.brand_id, video.brand_id) if video.brand_id else ""
result.append({
"item_id": video.item_id,
"star_nickname": video.star_nickname or "",
"title": video.title or "",
"video_url": video.video_url or "",
"create_date": video.publish_time.isoformat() if video.publish_time else None,
"hot_type": video.viral_type or "",
"industry_id": video.industry_id or "",
"brand_id": video.brand_id or "",
"brand_name": brand_name,
"total_new_a3_cnt": a3_increase_cnt,
"heated_new_a3_cnt": ad_a3_increase_cnt,
"natural_new_a3_cnt": natural_a3_increase_cnt,
"estimated_natural_cpm": estimated_natural_cpm,
"estimated_cp_a3": estimated_cp_a3,
"estimated_natural_cp_a3": estimated_natural_cp_a3,
"estimated_cp_search": estimated_cp_search,
"estimated_natural_cp_search": estimated_natural_cp_search,
})
return result
# 过滤 None不应发生防御性编程
return [r for r in results if r is not None]

View File

@ -17,6 +17,7 @@ from app.services.session_pool import (
session_pool,
get_session_with_retry,
get_random_config,
get_distinct_configs,
)
@ -571,3 +572,130 @@ class TestSessionPoolIntegration:
result = await pool.refresh()
assert result is False
def _make_configs(count: int) -> list[CookieConfig]:
"""创建 count 个不同的 CookieConfig 用于测试。"""
return [
CookieConfig(
brand_id=f"brand_{i}",
aadvid=f"aadvid_{i}",
auth_token=f"sessionid=session_{i}",
industry_id=20 + i,
brand_name=f"Brand{i}",
)
for i in range(count)
]
class TestGetDistinctConfigs:
"""Tests for SessionPool.get_distinct_configs and module-level get_distinct_configs."""
def test_enough_configs_returns_distinct(self):
"""池中配置 >= count → 返回不重复的"""
pool = SessionPool()
pool._configs = _make_configs(5)
result = pool.get_distinct_configs(3)
assert len(result) == 3
tokens = [r["auth_token"] for r in result]
assert len(set(tokens)) == 3
def test_exact_count(self):
"""池中配置 == count → 全部返回"""
pool = SessionPool()
pool._configs = _make_configs(3)
result = pool.get_distinct_configs(3)
assert len(result) == 3
tokens = {r["auth_token"] for r in result}
assert len(tokens) == 3
def test_fewer_configs_wraps_around(self):
"""池中配置 < count → 循环复用补足"""
pool = SessionPool()
pool._configs = _make_configs(2)
result = pool.get_distinct_configs(5)
assert len(result) == 5
# 前 2 个一定不重复
first_two_tokens = {result[0]["auth_token"], result[1]["auth_token"]}
assert len(first_two_tokens) == 2
def test_empty_pool_returns_empty(self):
"""空池 → 返回空列表"""
pool = SessionPool()
result = pool.get_distinct_configs(3)
assert result == []
def test_zero_count_returns_empty(self):
"""count=0 → 返回空列表"""
pool = SessionPool()
pool._configs = _make_configs(3)
result = pool.get_distinct_configs(0)
assert result == []
def test_result_contains_all_fields(self):
"""验证返回的 dict 包含所有必要字段"""
pool = SessionPool()
pool._configs = _make_configs(1)
result = pool.get_distinct_configs(1)
assert len(result) == 1
item = result[0]
assert "brand_id" in item
assert "aadvid" in item
assert "auth_token" in item
assert "industry_id" in item
assert "brand_name" in item
@pytest.mark.asyncio
async def test_module_level_get_distinct_configs(self):
"""测试模块级 get_distinct_configs 异步函数"""
pool = SessionPool()
pool._configs = _make_configs(3)
with patch("app.services.session_pool.session_pool", pool):
result = await get_distinct_configs(2)
assert len(result) == 2
@pytest.mark.asyncio
async def test_module_level_refreshes_on_empty(self):
"""池为空时自动刷新"""
pool = SessionPool()
with patch("app.services.session_pool.session_pool", pool):
with patch.object(pool, "refresh") as mock_refresh:
async def refresh_side_effect():
pool._configs = _make_configs(3)
return True
mock_refresh.side_effect = refresh_side_effect
result = await get_distinct_configs(2)
assert mock_refresh.called
assert len(result) == 2
@pytest.mark.asyncio
async def test_module_level_returns_empty_on_refresh_failure(self):
"""刷新失败 → 返回空列表"""
pool = SessionPool()
with patch("app.services.session_pool.session_pool", pool):
with patch.object(pool, "refresh") as mock_refresh:
mock_refresh.return_value = False
result = await get_distinct_configs(2, max_retries=2)
assert result == []
assert mock_refresh.call_count == 2

View File

@ -1,5 +1,12 @@
"""
Tests for Video Analysis Service (T-024)
覆盖
- calculate_cost_metrics 计算
- _needs_api_call 缓存判断
- get_video_analysis_data 详情页缓存命中 / API 调用 / API 失败降级
- update_video_a3_metrics heated_cost
- get_video_list_with_a3 列表页混合缓存 + 并发 API
"""
import pytest
@ -7,20 +14,98 @@ from datetime import datetime
from unittest.mock import AsyncMock, patch, MagicMock
from app.services.video_analysis import (
_build_video_list_item,
_needs_api_call,
calculate_cost_metrics,
get_video_base_info,
get_video_analysis_data,
get_video_list_with_a3,
update_video_a3_metrics,
get_and_update_video_analysis,
)
from app.services.yuntu_api import YuntuAPIError
def _make_mock_video(**overrides):
"""创建标准 mock video 对象,带合理默认值。"""
defaults = {
"item_id": "video_123",
"title": "测试视频",
"video_url": "https://example.com/video",
"vid": "vid_123",
"star_id": "star_001",
"star_unique_id": "unique_001",
"star_nickname": "测试达人",
"star_uid": "uid_001",
"star_fans_cnt": 100000,
"star_mcn": "MCN1",
"publish_time": datetime(2025, 1, 15),
"create_date": datetime(2025, 1, 15),
"industry_name": "母婴",
"industry_id": "20",
"brand_id": "brand_001",
"hot_type": "爆款",
"viral_type": "爆款",
"is_hot": True,
"has_cart": False,
"total_play_cnt": 50000,
"natural_play_cnt": 40000,
"heated_play_cnt": 10000,
"total_interaction_cnt": 5000,
"total_interact": 5000,
"natural_interaction_cnt": 3000,
"heated_interaction_cnt": 2000,
"digg_cnt": 3000,
"like_cnt": 3000,
"share_cnt": 1000,
"comment_cnt": 1000,
"play_over_cnt": 20000,
"play_over_rate": 0.4,
"after_view_search_uv": 1000,
"after_view_search_cnt": 1200,
"after_view_search_rate": 0.02,
"back_search_cnt": 50,
"back_search_uv": 50,
"return_search_cnt": 50,
"new_a3_rate": 0.05,
"total_new_a3_cnt": 0,
"heated_new_a3_cnt": 0,
"natural_new_a3_cnt": 0,
"total_cost": 0.0,
"heated_cost": 0.0,
"star_task_cost": 0.0,
"search_cost": 0.0,
"ad_hot_roi": 0.0,
"estimated_video_cost": 10000.0,
"order_id": None,
"content_type": None,
"industry_tags": None,
"ad_hot_type": None,
"trend": None,
"trend_daily": None,
"trend_total": None,
"component_metric_list": None,
"key_word_after_search_infos": None,
"index_map": None,
"search_keywords": None,
"keywords": None,
"price_under_20s": None,
"price_20_60s": None,
"price_over_60s": None,
"video_duration": None,
"data_date": None,
"created_at": None,
"updated_at": None,
}
defaults.update(overrides)
mock = MagicMock()
for k, v in defaults.items():
setattr(mock, k, v)
return mock
class TestCalculateCostMetrics:
"""Tests for calculate_cost_metrics function."""
def test_all_metrics_calculated(self):
"""Test calculation of all cost metrics."""
result = calculate_cost_metrics(
cost=10000,
natural_play_cnt=40000,
@ -30,29 +115,15 @@ class TestCalculateCostMetrics:
total_play_cnt=50000,
)
# CPM = 10000 / 50000 * 1000 = 200
assert result["cpm"] == 200.0
# 自然CPM = 10000 / 40000 * 1000 = 250
assert result["natural_cpm"] == 250.0
# CPA3 = 10000 / 500 = 20
assert result["cpa3"] == 20.0
# 自然CPA3 = 10000 / 400 = 25
assert result["natural_cpa3"] == 25.0
# CPsearch = 10000 / 1000 = 10
assert result["cp_search"] == 10.0
# 预估自然看后搜人数 = 40000 / 50000 * 1000 = 800
assert result["estimated_natural_search_uv"] == 800.0
# 自然CPsearch = 10000 / 800 = 12.5
assert result["natural_cp_search"] == 12.5
def test_zero_total_play_cnt(self):
"""Test with zero total_play_cnt (division by zero)."""
result = calculate_cost_metrics(
cost=10000,
natural_play_cnt=0,
@ -68,7 +139,6 @@ class TestCalculateCostMetrics:
assert result["natural_cp_search"] is None
def test_zero_a3_counts(self):
"""Test with zero A3 counts."""
result = calculate_cost_metrics(
cost=10000,
natural_play_cnt=40000,
@ -80,11 +150,9 @@ class TestCalculateCostMetrics:
assert result["cpa3"] is None
assert result["natural_cpa3"] is None
# 其他指标应该正常计算
assert result["cpm"] == 200.0
def test_zero_search_uv(self):
"""Test with zero after_view_search_uv."""
result = calculate_cost_metrics(
cost=10000,
natural_play_cnt=40000,
@ -95,12 +163,10 @@ class TestCalculateCostMetrics:
)
assert result["cp_search"] is None
# 当 after_view_search_uv=0 时,预估自然看后搜人数也应为 None无意义
assert result["estimated_natural_search_uv"] is None
assert result["natural_cp_search"] is None
def test_all_zeros(self):
"""Test with all zero values."""
result = calculate_cost_metrics(
cost=0,
natural_play_cnt=0,
@ -119,7 +185,6 @@ class TestCalculateCostMetrics:
assert result["natural_cp_search"] is None
def test_decimal_precision(self):
"""Test that results are rounded to 2 decimal places."""
result = calculate_cost_metrics(
cost=10000,
natural_play_cnt=30000,
@ -129,104 +194,143 @@ class TestCalculateCostMetrics:
total_play_cnt=70000,
)
# 验证都是2位小数
assert isinstance(result["cpm"], float)
assert len(str(result["cpm"]).split(".")[-1]) <= 2
class TestNeedsApiCall:
"""Tests for _needs_api_call helper."""
def test_needs_call_when_no_data(self):
"""A3=0 且 cost=0 → 需要调 API"""
video = _make_mock_video(total_new_a3_cnt=0, total_cost=0.0)
assert _needs_api_call(video) is True
def test_needs_call_when_none_values(self):
"""A3=None 且 cost=None → 需要调 API"""
video = _make_mock_video(total_new_a3_cnt=None, total_cost=None)
assert _needs_api_call(video) is True
def test_no_call_when_a3_exists(self):
"""有 A3 数据 → 不需要调 API"""
video = _make_mock_video(total_new_a3_cnt=500, total_cost=0.0)
assert _needs_api_call(video) is False
def test_no_call_when_cost_exists(self):
"""有 cost 数据 → 不需要调 API"""
video = _make_mock_video(total_new_a3_cnt=0, total_cost=10000.0)
assert _needs_api_call(video) is False
def test_no_call_when_both_exist(self):
"""A3 和 cost 都有 → 不需要调 API"""
video = _make_mock_video(total_new_a3_cnt=500, total_cost=10000.0)
assert _needs_api_call(video) is False
class TestGetVideoAnalysisData:
"""Tests for get_video_analysis_data function."""
async def test_success_with_api_data(self):
"""Test successful data retrieval with API data."""
# Mock database video
mock_video = MagicMock()
mock_video.item_id = "video_123"
mock_video.title = "测试视频"
mock_video.video_url = "https://example.com/video"
mock_video.star_id = "star_001"
mock_video.star_unique_id = "unique_001"
mock_video.star_nickname = "测试达人"
mock_video.publish_time = datetime(2025, 1, 15)
mock_video.industry_name = "母婴"
mock_video.industry_id = "20"
mock_video.total_play_cnt = 50000
mock_video.natural_play_cnt = 40000
mock_video.heated_play_cnt = 10000
mock_video.after_view_search_uv = 1000
mock_video.return_search_cnt = 50
mock_video.estimated_video_cost = 10000
@pytest.mark.asyncio
async def test_uses_db_when_cached(self):
"""数据库已有 A3/Cost → 直接使用,不调 API"""
mock_video = _make_mock_video(
total_new_a3_cnt=500,
heated_new_a3_cnt=100,
natural_new_a3_cnt=400,
total_cost=10000.0,
heated_cost=5000.0,
)
# Mock session
mock_session = AsyncMock()
mock_result = MagicMock()
mock_result.scalar_one_or_none.return_value = mock_video
mock_session.execute.return_value = mock_result
# Mock API response
with patch("app.services.brand_api.get_brand_names", new_callable=AsyncMock) as mock_brands:
mock_brands.return_value = {"brand_001": "品牌A"}
with patch("app.services.video_analysis.fetch_yuntu_analysis") as mock_api:
result = await get_video_analysis_data(mock_session, "video_123")
# API 不应被调用
mock_api.assert_not_called()
# 验证使用了数据库数据
assert result["a3_metrics"]["total_new_a3_cnt"] == 500
assert result["a3_metrics"]["heated_new_a3_cnt"] == 100
assert result["a3_metrics"]["natural_new_a3_cnt"] == 400
assert result["cost_metrics"]["total_cost"] == 10000.0
assert result["cost_metrics"]["heated_cost"] == 5000.0
@pytest.mark.asyncio
async def test_calls_api_and_saves_to_db(self):
"""数据库无数据 → 调 API → 写回 DB"""
mock_video = _make_mock_video(
total_new_a3_cnt=0,
total_cost=0.0,
heated_cost=0.0,
)
mock_session = AsyncMock()
mock_select_result = MagicMock()
mock_select_result.scalar_one_or_none.return_value = mock_video
mock_update_result = MagicMock()
mock_update_result.rowcount = 1
call_count = [0]
async def mock_execute(stmt):
stmt_str = str(stmt)
if "SELECT" in stmt_str.upper() or call_count[0] == 0:
call_count[0] += 1
return mock_select_result
return mock_update_result
mock_session.execute.side_effect = mock_execute
api_response = {
"code": 0,
"data": {
"total_show_cnt": 100000,
"natural_show_cnt": 80000,
"ad_show_cnt": 20000,
"total_play_cnt": 50000,
"natural_play_cnt": 40000,
"ad_play_cnt": 10000,
"effective_play_cnt": 30000,
"a3_increase_cnt": 500,
"ad_a3_increase_cnt": 100,
"natural_a3_increase_cnt": 400,
"after_view_search_uv": 1000,
"after_view_search_pv": 1500,
"brand_search_uv": 200,
"product_search_uv": 300,
"return_search_cnt": 50,
"cost": 10000,
"a3_increase_cnt": "500",
"ad_a3_increase_cnt": "100",
"natural_a3_increase_cnt": "400",
"cost": 15000,
"ad_cost": 8000,
"natural_cost": 0,
"ad_cost": 10000,
},
}
with patch(
"app.services.video_analysis.fetch_yuntu_analysis"
) as mock_api:
mock_api.return_value = api_response
with patch("app.services.brand_api.get_brand_names", new_callable=AsyncMock) as mock_brands:
mock_brands.return_value = {"brand_001": "品牌A"}
result = await get_video_analysis_data(mock_session, "video_123")
with patch("app.services.video_analysis.fetch_yuntu_analysis") as mock_api:
mock_api.return_value = api_response
# T-027: 验证使用 industry_id 而不是 brand_id 调用 API
mock_api.assert_called_once_with(
item_id="video_123",
publish_time=datetime(2025, 1, 15),
industry_id="20",
)
result = await get_video_analysis_data(mock_session, "video_123")
# 验证基础信息
assert result["base_info"]["item_id"] == "video_123"
assert result["base_info"]["title"] == "测试视频"
assert result["base_info"]["star_nickname"] == "测试达人"
# API 应被调用
mock_api.assert_called_once_with(
item_id="video_123",
publish_time=datetime(2025, 1, 15),
industry_id="20",
)
# 验证触达指标
assert result["reach_metrics"]["total_show_cnt"] == 100000
assert result["reach_metrics"]["natural_play_cnt"] == 40000
# 验证 A3 数据
assert result["a3_metrics"]["total_new_a3_cnt"] == 500
assert result["a3_metrics"]["heated_new_a3_cnt"] == 100
assert result["a3_metrics"]["natural_new_a3_cnt"] == 400
# 验证A3指标
assert result["a3_metrics"]["a3_increase_cnt"] == 500
assert result["a3_metrics"]["natural_a3_increase_cnt"] == 400
# 验证 cost
assert result["cost_metrics"]["total_cost"] == 15000
assert result["cost_metrics"]["heated_cost"] == 8000
# 验证搜索指标
assert result["search_metrics"]["after_view_search_uv"] == 1000
# 验证费用指标
assert result["cost_metrics_raw"]["cost"] == 10000
# 验证计算指标
assert result["cost_metrics_calculated"]["cpm"] is not None
assert result["cost_metrics_calculated"]["cpa3"] is not None
# 验证计算指标存在
assert "estimated_cpm" in result["calculated_metrics"]
assert "estimated_natural_cpm" in result["calculated_metrics"]
@pytest.mark.asyncio
async def test_video_not_found(self):
"""Test error when video is not found."""
mock_session = AsyncMock()
mock_result = MagicMock()
mock_result.scalar_one_or_none.return_value = None
@ -237,87 +341,103 @@ class TestGetVideoAnalysisData:
assert "not found" in str(exc_info.value).lower()
@pytest.mark.asyncio
async def test_fallback_on_api_failure(self):
"""Test fallback to database data when API fails."""
# Mock database video
mock_video = MagicMock()
mock_video.item_id = "video_123"
mock_video.title = "测试视频"
mock_video.video_url = None
mock_video.star_id = "star_001"
mock_video.star_unique_id = "unique_001"
mock_video.star_nickname = "测试达人"
mock_video.publish_time = datetime(2025, 1, 15)
mock_video.industry_name = "母婴"
mock_video.industry_id = "20"
mock_video.total_play_cnt = 50000
mock_video.natural_play_cnt = 40000
mock_video.heated_play_cnt = 10000
mock_video.after_view_search_uv = 1000
mock_video.return_search_cnt = 50
mock_video.estimated_video_cost = 10000
mock_video.total_new_a3_cnt = 500
mock_video.heated_new_a3_cnt = 100
mock_video.natural_new_a3_cnt = 400
mock_video.total_cost = 10000
"""API 失败 → 降级使用数据库数据"""
mock_video = _make_mock_video(
total_new_a3_cnt=0,
heated_new_a3_cnt=0,
natural_new_a3_cnt=0,
total_cost=0.0,
heated_cost=0.0,
)
# Mock session
mock_session = AsyncMock()
mock_result = MagicMock()
mock_result.scalar_one_or_none.return_value = mock_video
mock_session.execute.return_value = mock_result
with patch(
"app.services.video_analysis.fetch_yuntu_analysis"
) as mock_api:
mock_api.side_effect = YuntuAPIError("API Error")
with patch("app.services.brand_api.get_brand_names", new_callable=AsyncMock) as mock_brands:
mock_brands.return_value = {}
result = await get_video_analysis_data(mock_session, "video_123")
with patch("app.services.video_analysis.fetch_yuntu_analysis") as mock_api:
mock_api.side_effect = YuntuAPIError("API Error")
# 应该使用数据库数据
assert result["reach_metrics"]["total_play_cnt"] == 50000
assert result["reach_metrics"]["natural_play_cnt"] == 40000
assert result["search_metrics"]["after_view_search_uv"] == 1000
result = await get_video_analysis_data(mock_session, "video_123")
# 降级使用 DB 数据(都是 0
assert result["a3_metrics"]["total_new_a3_cnt"] == 0
assert result["cost_metrics"]["total_cost"] == 0.0
# 基础信息仍然正常
assert result["base_info"]["vid"] == "video_123"
assert result["reach_metrics"]["total_play_cnt"] == 50000
@pytest.mark.asyncio
async def test_null_publish_time(self):
"""Test handling of null publish_time."""
mock_video = MagicMock()
mock_video.item_id = "video_123"
mock_video.title = "测试视频"
mock_video.video_url = None
mock_video.star_id = "star_001"
mock_video.star_unique_id = "unique_001"
mock_video.star_nickname = "测试达人"
mock_video.publish_time = None # NULL
mock_video.industry_name = None
mock_video.industry_id = None
mock_video.total_play_cnt = 0
mock_video.natural_play_cnt = 0
mock_video.heated_play_cnt = 0
mock_video.after_view_search_uv = 0
mock_video.return_search_cnt = 0
mock_video.estimated_video_cost = 0
mock_video = _make_mock_video(
publish_time=None,
create_date=None,
total_new_a3_cnt=0,
total_cost=0.0,
total_play_cnt=0,
natural_play_cnt=0,
heated_play_cnt=0,
after_view_search_uv=0,
)
mock_session = AsyncMock()
mock_result = MagicMock()
mock_result.scalar_one_or_none.return_value = mock_video
mock_session.execute.return_value = mock_result
with patch(
"app.services.video_analysis.fetch_yuntu_analysis"
) as mock_api:
mock_api.return_value = {"code": 0, "data": {}}
with patch("app.services.brand_api.get_brand_names", new_callable=AsyncMock) as mock_brands:
mock_brands.return_value = {}
with patch("app.services.video_analysis.fetch_yuntu_analysis") as mock_api:
mock_api.return_value = {"code": 0, "data": {}}
result = await get_video_analysis_data(mock_session, "video_123")
assert result["base_info"]["create_date"] is None
@pytest.mark.asyncio
async def test_response_structure(self):
"""验证返回数据包含所有 6 大类"""
mock_video = _make_mock_video(total_new_a3_cnt=500, total_cost=10000.0)
mock_session = AsyncMock()
mock_result = MagicMock()
mock_result.scalar_one_or_none.return_value = mock_video
mock_session.execute.return_value = mock_result
with patch("app.services.brand_api.get_brand_names", new_callable=AsyncMock) as mock_brands:
mock_brands.return_value = {}
result = await get_video_analysis_data(mock_session, "video_123")
assert result["base_info"]["publish_time"] is None
assert "base_info" in result
assert "reach_metrics" in result
assert "a3_metrics" in result
assert "search_metrics" in result
assert "cost_metrics" in result
assert "calculated_metrics" in result
# base_info 关键字段
assert "star_nickname" in result["base_info"]
assert "vid" in result["base_info"]
assert "brand_name" in result["base_info"]
# reach_metrics 关键字段
assert "total_play_cnt" in result["reach_metrics"]
assert "natural_play_cnt" in result["reach_metrics"]
class TestUpdateVideoA3Metrics:
"""Tests for update_video_a3_metrics function (T-025)."""
@pytest.mark.asyncio
async def test_update_success(self):
"""Test successful A3 metrics update."""
mock_session = AsyncMock()
mock_result = MagicMock()
mock_result.rowcount = 1
@ -335,8 +455,29 @@ class TestUpdateVideoA3Metrics:
assert result is True
mock_session.commit.assert_called_once()
@pytest.mark.asyncio
async def test_update_with_heated_cost(self):
"""验证 heated_cost 参数正常传递"""
mock_session = AsyncMock()
mock_result = MagicMock()
mock_result.rowcount = 1
mock_session.execute.return_value = mock_result
result = await update_video_a3_metrics(
session=mock_session,
item_id="video_123",
total_new_a3_cnt=500,
heated_new_a3_cnt=100,
natural_new_a3_cnt=400,
total_cost=15000.0,
heated_cost=8000.0,
)
assert result is True
mock_session.commit.assert_called_once()
@pytest.mark.asyncio
async def test_update_video_not_found(self):
"""Test update when video not found."""
mock_session = AsyncMock()
mock_result = MagicMock()
mock_result.rowcount = 0
@ -353,8 +494,8 @@ class TestUpdateVideoA3Metrics:
assert result is False
@pytest.mark.asyncio
async def test_update_database_error(self):
"""Test update with database error."""
mock_session = AsyncMock()
mock_session.execute.side_effect = Exception("Database error")
@ -371,64 +512,249 @@ class TestUpdateVideoA3Metrics:
mock_session.rollback.assert_called_once()
class TestGetAndUpdateVideoAnalysis:
"""Tests for get_and_update_video_analysis function (T-024 + T-025)."""
class TestBuildVideoListItem:
"""Tests for _build_video_list_item helper."""
async def test_get_and_update_success(self):
"""Test successful get and update."""
# Mock database video
mock_video = MagicMock()
mock_video.item_id = "video_123"
mock_video.title = "测试视频"
mock_video.video_url = None
mock_video.star_id = "star_001"
mock_video.star_unique_id = "unique_001"
mock_video.star_nickname = "测试达人"
mock_video.publish_time = datetime(2025, 1, 15)
mock_video.industry_name = "母婴"
mock_video.industry_id = "20"
mock_video.total_play_cnt = 50000
mock_video.natural_play_cnt = 40000
mock_video.heated_play_cnt = 10000
mock_video.after_view_search_uv = 1000
mock_video.return_search_cnt = 50
mock_video.estimated_video_cost = 10000
def test_build_item_with_full_data(self):
video = _make_mock_video(
total_play_cnt=50000,
natural_play_cnt=40000,
after_view_search_uv=1000,
estimated_video_cost=10000.0,
)
result = _build_video_list_item(
video=video,
a3_increase_cnt=500,
ad_a3_increase_cnt=100,
natural_a3_increase_cnt=400,
api_cost=15000.0,
brand_name="品牌A",
)
assert result["item_id"] == "video_123"
assert result["brand_name"] == "品牌A"
assert result["total_new_a3_cnt"] == 500
assert result["estimated_natural_cpm"] is not None
assert result["estimated_cp_a3"] == 30.0 # 15000/500
def test_build_item_zero_division(self):
"""分母为 0 时应返回 None"""
video = _make_mock_video(
total_play_cnt=0,
natural_play_cnt=0,
after_view_search_uv=0,
estimated_video_cost=0.0,
)
result = _build_video_list_item(
video=video,
a3_increase_cnt=0,
ad_a3_increase_cnt=0,
natural_a3_increase_cnt=0,
api_cost=0.0,
brand_name="",
)
assert result["estimated_natural_cpm"] is None
assert result["estimated_cp_a3"] is None
assert result["estimated_natural_cp_a3"] is None
assert result["estimated_cp_search"] is None
assert result["estimated_natural_cp_search"] is None
class TestGetVideoListWithA3:
"""Tests for get_video_list_with_a3 function."""
@pytest.mark.asyncio
async def test_all_cached(self):
"""所有视频都有缓存 → 不调 API"""
videos = [
_make_mock_video(
item_id="v1", total_new_a3_cnt=500, total_cost=10000.0, brand_id="b1"
),
_make_mock_video(
item_id="v2", total_new_a3_cnt=300, total_cost=8000.0, brand_id="b2"
),
]
# Mock session
mock_session = AsyncMock()
mock_select_result = MagicMock()
mock_select_result.scalar_one_or_none.return_value = mock_video
with patch("app.services.brand_api.get_brand_names", new_callable=AsyncMock) as mock_brands:
mock_brands.return_value = {"b1": "品牌1", "b2": "品牌2"}
with patch("app.services.video_analysis.call_yuntu_api") as mock_api:
result = await get_video_list_with_a3(mock_session, videos)
mock_api.assert_not_called()
assert len(result) == 2
assert result[0]["item_id"] == "v1"
assert result[0]["total_new_a3_cnt"] == 500
assert result[1]["item_id"] == "v2"
assert result[1]["total_new_a3_cnt"] == 300
@pytest.mark.asyncio
async def test_all_need_api(self):
"""所有视频都需要 API → 并发调用 → 首次即返回正确数据 → gather 后顺序写 DB"""
videos = [
_make_mock_video(
item_id="v1", total_new_a3_cnt=0, total_cost=0.0, brand_id="b1"
),
_make_mock_video(
item_id="v2", total_new_a3_cnt=0, total_cost=0.0, brand_id="b2"
),
]
mock_session = AsyncMock()
mock_update_result = MagicMock()
mock_update_result.rowcount = 1
mock_session.execute.return_value = mock_update_result
# 根据不同的SQL语句返回不同的结果
async def mock_execute(stmt):
# 简单判断:如果是 SELECT 返回视频,如果是 UPDATE 返回更新结果
stmt_str = str(stmt)
if "SELECT" in stmt_str.upper():
return mock_select_result
return mock_update_result
mock_session.execute.side_effect = mock_execute
with patch(
"app.services.video_analysis.fetch_yuntu_analysis"
) as mock_api:
mock_api.return_value = {
"code": 0,
"data": {
"a3_increase_cnt": 500,
"ad_a3_increase_cnt": 100,
"natural_a3_increase_cnt": 400,
"cost": 10000,
},
api_response = {
"data": {
"a3_increase_cnt": "200",
"ad_a3_increase_cnt": "50",
"natural_a3_increase_cnt": "150",
"cost": 5000,
"ad_cost": 3000,
}
}
result = await get_and_update_video_analysis(mock_session, "video_123")
with patch("app.services.brand_api.get_brand_names", new_callable=AsyncMock) as mock_brands:
mock_brands.return_value = {}
# 验证返回数据
assert result["a3_metrics"]["a3_increase_cnt"] == 500
with patch("app.services.video_analysis.call_yuntu_api", new_callable=AsyncMock) as mock_api:
mock_api.return_value = api_response
# 验证数据库更新被调用
mock_session.commit.assert_called()
with patch("app.services.video_analysis.get_distinct_configs", new_callable=AsyncMock) as mock_configs:
mock_configs.return_value = [
{"aadvid": "aad1", "auth_token": "tok1"},
{"aadvid": "aad2", "auth_token": "tok2"},
]
with patch("app.services.video_analysis.update_video_a3_metrics", new_callable=AsyncMock) as mock_update:
mock_update.return_value = True
result = await get_video_list_with_a3(mock_session, videos)
assert len(result) == 2
assert mock_api.call_count == 2
# 首次查询即返回正确 API 数据(核心:不依赖 DB 写入成功)
assert result[0]["total_new_a3_cnt"] == 200
assert result[1]["total_new_a3_cnt"] == 200
# 验证两个视频用了不同 config
api_calls = mock_api.call_args_list
tokens = {c.kwargs["auth_token"] for c in api_calls}
assert len(tokens) == 2
# DB 写入在 gather 之后顺序执行
assert mock_update.call_count == 2
update_item_ids = [c.kwargs["item_id"] for c in mock_update.call_args_list]
assert "v1" in update_item_ids
assert "v2" in update_item_ids
@pytest.mark.asyncio
async def test_mixed_cached_and_api(self):
"""混合场景:部分缓存,部分需 API → 只对 API 成功的写 DB"""
videos = [
_make_mock_video(
item_id="v1", total_new_a3_cnt=500, total_cost=10000.0, brand_id="b1"
),
_make_mock_video(
item_id="v2", total_new_a3_cnt=0, total_cost=0.0, brand_id="b2"
),
_make_mock_video(
item_id="v3", total_new_a3_cnt=300, total_cost=5000.0, brand_id="b3"
),
]
mock_session = AsyncMock()
api_response = {
"data": {
"a3_increase_cnt": "200",
"ad_a3_increase_cnt": "50",
"natural_a3_increase_cnt": "150",
"cost": 5000,
"ad_cost": 3000,
}
}
with patch("app.services.brand_api.get_brand_names", new_callable=AsyncMock) as mock_brands:
mock_brands.return_value = {}
with patch("app.services.video_analysis.call_yuntu_api", new_callable=AsyncMock) as mock_api:
mock_api.return_value = api_response
with patch("app.services.video_analysis.get_distinct_configs", new_callable=AsyncMock) as mock_configs:
mock_configs.return_value = [
{"aadvid": "aad1", "auth_token": "tok1"},
]
with patch("app.services.video_analysis.update_video_a3_metrics", new_callable=AsyncMock) as mock_update:
mock_update.return_value = True
result = await get_video_list_with_a3(mock_session, videos)
# 保持原始排序
assert len(result) == 3
assert result[0]["item_id"] == "v1"
assert result[0]["total_new_a3_cnt"] == 500 # from DB
assert result[1]["item_id"] == "v2"
assert result[1]["total_new_a3_cnt"] == 200 # from API
assert result[2]["item_id"] == "v3"
assert result[2]["total_new_a3_cnt"] == 300 # from DB
# 只有 v2 调了 API
assert mock_api.call_count == 1
# 只对 v2 写回 DB
assert mock_update.call_count == 1
assert mock_update.call_args.kwargs["item_id"] == "v2"
assert mock_update.call_args.kwargs["total_new_a3_cnt"] == 200
@pytest.mark.asyncio
async def test_empty_list(self):
"""空列表 → 返回空"""
mock_session = AsyncMock()
with patch("app.services.brand_api.get_brand_names", new_callable=AsyncMock) as mock_brands:
mock_brands.return_value = {}
result = await get_video_list_with_a3(mock_session, [])
assert result == []
@pytest.mark.asyncio
async def test_api_failure_fallback(self):
"""API 调用失败 → 降级使用 DB 数据 → 不写回 DB"""
videos = [
_make_mock_video(
item_id="v1", total_new_a3_cnt=0, total_cost=0.0, brand_id="b1"
),
]
mock_session = AsyncMock()
with patch("app.services.brand_api.get_brand_names", new_callable=AsyncMock) as mock_brands:
mock_brands.return_value = {}
with patch("app.services.video_analysis.call_yuntu_api", new_callable=AsyncMock) as mock_api:
mock_api.side_effect = YuntuAPIError("API Error")
with patch("app.services.video_analysis.get_distinct_configs", new_callable=AsyncMock) as mock_configs:
mock_configs.return_value = [
{"aadvid": "aad1", "auth_token": "tok1"},
]
with patch("app.services.video_analysis.update_video_a3_metrics", new_callable=AsyncMock) as mock_update:
result = await get_video_list_with_a3(mock_session, videos)
# 降级到 DB 数据
assert len(result) == 1
assert result[0]["total_new_a3_cnt"] == 0
# API 失败不应写回 DB
mock_update.assert_not_called()