diff --git a/backend/app/services/session_pool.py b/backend/app/services/session_pool.py index db1b09d..8a66368 100644 --- a/backend/app/services/session_pool.py +++ b/backend/app/services/session_pool.py @@ -11,8 +11,8 @@ T-027 修复: import asyncio import logging import random -from typing import Dict, Optional, Any, List from dataclasses import dataclass +from typing import Any, Dict, List, Optional import httpx @@ -150,6 +150,46 @@ class SessionPool: """检查池是否为空""" return len(self._configs) == 0 + def get_distinct_configs(self, count: int) -> List[Dict[str, Any]]: + """ + 获取 count 个不同的配置,用于并发调用。 + + - 池中配置 >= count:随机抽样 count 个不重复的 + - 池中配置 < count:全部取出,循环复用补足 + - 池为空:返回空列表 + + Args: + count: 需要的配置数量 + + Returns: + List[Dict]: 配置字典列表 + """ + if not self._configs or count <= 0: + return [] + + def _to_dict(config: CookieConfig) -> Dict[str, Any]: + return { + "brand_id": config.brand_id, + "aadvid": config.aadvid, + "auth_token": config.auth_token, + "industry_id": config.industry_id, + "brand_name": config.brand_name, + } + + if len(self._configs) >= count: + sampled = random.sample(self._configs, count) + return [_to_dict(c) for c in sampled] + + # 池中配置不足,全部取出后循环复用 + result = [_to_dict(c) for c in self._configs] + shuffled = list(self._configs) + random.shuffle(shuffled) + idx = 0 + while len(result) < count: + result.append(_to_dict(shuffled[idx % len(shuffled)])) + idx += 1 + return result + # 兼容旧接口 def get_random(self) -> Optional[str]: """兼容旧接口:随机获取一个 SessionID""" @@ -218,6 +258,32 @@ async def get_session_with_retry(max_retries: int = 3) -> Optional[str]: return None +async def get_distinct_configs(count: int, max_retries: int = 3) -> List[Dict[str, Any]]: + """ + 获取 count 个不同的配置,必要时刷新池。 + + Args: + count: 需要的配置数量 + max_retries: 最大重试次数 + + Returns: + List[Dict]: 配置字典列表 + """ + for attempt in range(max_retries): + if session_pool.is_empty: + success = await session_pool.refresh() + if not success: + logger.warning(f"Session pool refresh failed, attempt {attempt + 1}") + continue + + configs = session_pool.get_distinct_configs(count) + if configs: + return configs + + logger.error("Failed to get distinct configs after all retries") + return [] + + async def get_config_for_brand(brand_id: str, max_retries: int = 3) -> Optional[Any]: """ 兼容旧接口:获取品牌对应的配置。 diff --git a/backend/app/services/video_analysis.py b/backend/app/services/video_analysis.py index b6cf419..bc01240 100644 --- a/backend/app/services/video_analysis.py +++ b/backend/app/services/video_analysis.py @@ -4,25 +4,43 @@ 实现视频分析数据获取和成本指标计算。 """ +import asyncio import logging from datetime import datetime -from typing import Dict, Optional, Any +from typing import Any, Dict, Optional +from sqlalchemy import select, update from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy import select - -from sqlalchemy import update from app.models.kol_video import KolVideo +from app.services.session_pool import ( + get_distinct_configs, + get_random_config, + session_pool, +) +from app.services.yuntu_api import ( + SessionInvalidError, + call_yuntu_api, + parse_analysis_response, +) from app.services.yuntu_api import ( get_video_analysis as fetch_yuntu_analysis, - parse_analysis_response, - YuntuAPIError, ) logger = logging.getLogger(__name__) +def _needs_api_call(video: KolVideo) -> bool: + """ + 判断是否需要调用 Yuntu API 获取 A3/Cost 数据。 + + 如果数据库中已有 A3 或 Cost 数据,直接使用数据库数据,不调 API。 + """ + has_a3 = (video.total_new_a3_cnt or 0) > 0 + has_cost = (video.total_cost or 0) > 0 + return not (has_a3 or has_cost) + + def calculate_cost_metrics( cost: float, natural_play_cnt: int, @@ -151,33 +169,58 @@ async def get_video_analysis_data( brand_map = await get_brand_names([video.brand_id]) brand_name = brand_map.get(video.brand_id, video.brand_id) - # 3. 调用巨量云图API获取实时 A3 数据和 cost + # 3. 获取 A3 数据和 cost(缓存优先策略) a3_increase_cnt = 0 ad_a3_increase_cnt = 0 natural_a3_increase_cnt = 0 api_cost = 0.0 + ad_cost = 0.0 - try: - publish_time = video.publish_time or datetime.now() - industry_id = video.industry_id or "" - - api_response = await fetch_yuntu_analysis( - item_id=item_id, - publish_time=publish_time, - industry_id=industry_id, - ) - analysis_data = parse_analysis_response(api_response) - a3_increase_cnt = analysis_data.get("a3_increase_cnt", 0) - ad_a3_increase_cnt = analysis_data.get("ad_a3_increase_cnt", 0) - natural_a3_increase_cnt = analysis_data.get("natural_a3_increase_cnt", 0) - api_cost = analysis_data.get("cost", 0) - - except Exception as e: - logger.warning(f"API failed for {item_id}: {e}, using DB data") + if not _needs_api_call(video): + # 数据库已有数据,直接使用 + logger.info(f"Using DB data for {item_id} (A3/Cost already cached)") a3_increase_cnt = video.total_new_a3_cnt or 0 ad_a3_increase_cnt = video.heated_new_a3_cnt or 0 natural_a3_increase_cnt = video.natural_new_a3_cnt or 0 api_cost = video.total_cost or 0.0 + ad_cost = video.heated_cost or 0.0 + else: + # 需要调用 API 获取数据 + try: + publish_time = video.publish_time or datetime.now() + industry_id = video.industry_id or "" + + api_response = await fetch_yuntu_analysis( + item_id=item_id, + publish_time=publish_time, + industry_id=industry_id, + ) + analysis_data = parse_analysis_response(api_response) + a3_increase_cnt = analysis_data.get("a3_increase_cnt", 0) + ad_a3_increase_cnt = analysis_data.get("ad_a3_increase_cnt", 0) + natural_a3_increase_cnt = analysis_data.get("natural_a3_increase_cnt", 0) + api_cost = analysis_data.get("cost", 0) + ad_cost = analysis_data.get("ad_cost", 0) + + # 写回数据库 + await update_video_a3_metrics( + session=session, + item_id=item_id, + total_new_a3_cnt=int(a3_increase_cnt), + heated_new_a3_cnt=int(ad_a3_increase_cnt), + natural_new_a3_cnt=int(natural_a3_increase_cnt), + total_cost=float(api_cost), + heated_cost=float(ad_cost), + ) + logger.info(f"API data fetched and saved to DB for {item_id}") + + except Exception as e: + logger.warning(f"API failed for {item_id}: {e}, using DB data") + a3_increase_cnt = video.total_new_a3_cnt or 0 + ad_a3_increase_cnt = video.heated_new_a3_cnt or 0 + natural_a3_increase_cnt = video.natural_new_a3_cnt or 0 + api_cost = video.total_cost or 0.0 + ad_cost = video.heated_cost or 0.0 # 4. 数据库字段 estimated_video_cost = video.estimated_video_cost or 0.0 @@ -187,8 +230,7 @@ async def get_video_analysis_data( after_view_search_uv = video.after_view_search_uv or 0 # 5. 计算成本指标 - # 预估加热费用 = max(total_cost - estimated_video_cost, 0) - heated_cost = max(api_cost - estimated_video_cost, 0) if api_cost > estimated_video_cost else 0 + heated_cost = ad_cost # 预估自然看后搜人数 estimated_natural_search_uv = None @@ -271,9 +313,10 @@ async def update_video_a3_metrics( heated_new_a3_cnt: int, natural_new_a3_cnt: int, total_cost: float, + heated_cost: float = 0.0, ) -> bool: """ - 更新数据库中的A3指标 (T-025)。 + 更新数据库中的A3指标和费用数据 (T-025)。 Args: session: 数据库会话 @@ -281,7 +324,8 @@ async def update_video_a3_metrics( total_new_a3_cnt: 总新增A3 heated_new_a3_cnt: 加热新增A3 natural_new_a3_cnt: 自然新增A3 - total_cost: 总花费 + total_cost: 预估总费用 + heated_cost: 预估加热费用 Returns: bool: 更新是否成功 @@ -295,6 +339,7 @@ async def update_video_a3_metrics( heated_new_a3_cnt=heated_new_a3_cnt, natural_new_a3_cnt=natural_new_a3_cnt, total_cost=total_cost, + heated_cost=heated_cost, ) ) result = await session.execute(stmt) @@ -313,39 +358,6 @@ async def update_video_a3_metrics( return False -async def get_and_update_video_analysis( - session: AsyncSession, item_id: str -) -> Dict[str, Any]: - """ - 获取视频分析数据并更新数据库中的A3指标 (T-024 + T-025 组合)。 - - Args: - session: 数据库会话 - item_id: 视频ID - - Returns: - Dict: 完整的视频分析数据 - """ - # 获取分析数据 - result = await get_video_analysis_data(session, item_id) - - # 提取A3指标 - a3_metrics = result.get("a3_metrics", {}) - cost_raw = result.get("cost_metrics_raw", {}) - - # 更新数据库 - await update_video_a3_metrics( - session=session, - item_id=item_id, - total_new_a3_cnt=a3_metrics.get("a3_increase_cnt", 0), - heated_new_a3_cnt=a3_metrics.get("ad_a3_increase_cnt", 0), - natural_new_a3_cnt=a3_metrics.get("natural_a3_increase_cnt", 0), - total_cost=cost_raw.get("cost", 0), - ) - - return result - - async def search_videos_by_star_id( session: AsyncSession, star_id: str ) -> list[KolVideo]: @@ -373,11 +385,60 @@ async def search_videos_by_nickname( return list(result.scalars().all()) +def _build_video_list_item( + video: KolVideo, + a3_increase_cnt: int, + ad_a3_increase_cnt: int, + natural_a3_increase_cnt: int, + api_cost: float, + brand_name: str, +) -> Dict[str, Any]: + """构建视频列表项的结果字典。""" + estimated_video_cost = video.estimated_video_cost or 0.0 + natural_play_cnt = video.natural_play_cnt or 0 + total_play_cnt = video.total_play_cnt or 0 + after_view_search_uv = video.after_view_search_uv or 0 + + estimated_natural_search_uv = None + if total_play_cnt > 0 and after_view_search_uv > 0: + estimated_natural_search_uv = (natural_play_cnt / total_play_cnt) * after_view_search_uv + + estimated_natural_cpm = round((estimated_video_cost / natural_play_cnt) * 1000, 2) if natural_play_cnt > 0 else None + estimated_cp_a3 = round(api_cost / a3_increase_cnt, 2) if a3_increase_cnt > 0 else None + estimated_natural_cp_a3 = round(estimated_video_cost / natural_a3_increase_cnt, 2) if natural_a3_increase_cnt > 0 else None + estimated_cp_search = round(api_cost / after_view_search_uv, 2) if after_view_search_uv > 0 else None + estimated_natural_cp_search = round(estimated_video_cost / estimated_natural_search_uv, 2) if estimated_natural_search_uv and estimated_natural_search_uv > 0 else None + + return { + "item_id": video.item_id, + "star_nickname": video.star_nickname or "", + "title": video.title or "", + "video_url": video.video_url or "", + "create_date": video.publish_time.isoformat() if video.publish_time else None, + "hot_type": video.viral_type or "", + "industry_id": video.industry_id or "", + "brand_id": video.brand_id or "", + "brand_name": brand_name, + "total_new_a3_cnt": a3_increase_cnt, + "heated_new_a3_cnt": ad_a3_increase_cnt, + "natural_new_a3_cnt": natural_a3_increase_cnt, + "estimated_natural_cpm": estimated_natural_cpm, + "estimated_cp_a3": estimated_cp_a3, + "estimated_natural_cp_a3": estimated_natural_cp_a3, + "estimated_cp_search": estimated_cp_search, + "estimated_natural_cp_search": estimated_natural_cp_search, + } + + async def get_video_list_with_a3( session: AsyncSession, videos: list[KolVideo] ) -> list[Dict[str, Any]]: """ - 获取视频列表的摘要数据(实时调用云图API获取A3数据)。 + 获取视频列表的摘要数据。 + + 缓存优先策略: + - 数据库有 A3/Cost 数据 → 直接使用 + - 数据库无数据 → 并发调用云图 API(预分配不同 cookie)→ 写回数据库 """ from app.services.brand_api import get_brand_names @@ -385,73 +446,149 @@ async def get_video_list_with_a3( brand_ids = [video.brand_id for video in videos if video.brand_id] brand_map = await get_brand_names(brand_ids) if brand_ids else {} - result = [] - for video in videos: - # 实时调用云图 API 获取 A3 数据和 cost - a3_increase_cnt = 0 - ad_a3_increase_cnt = 0 - natural_a3_increase_cnt = 0 - api_cost = 0.0 + # 分组:已有数据 vs 需要 API 调用 + cached_videos: list[tuple[int, KolVideo]] = [] # (原始索引, video) + api_videos: list[tuple[int, KolVideo]] = [] # (原始索引, video) - try: - publish_time = video.publish_time or datetime.now() - industry_id = video.industry_id or "" + for idx, video in enumerate(videos): + if _needs_api_call(video): + api_videos.append((idx, video)) + else: + cached_videos.append((idx, video)) - api_response = await fetch_yuntu_analysis( - item_id=video.item_id, - publish_time=publish_time, - industry_id=industry_id, + logger.info( + f"Video list: {len(cached_videos)} cached, {len(api_videos)} need API" + ) + + # 结果数组(按原始索引填充) + results: list[Optional[Dict[str, Any]]] = [None] * len(videos) + + # 组 A:直接用数据库数据 + for idx, video in cached_videos: + brand_name = brand_map.get(video.brand_id, video.brand_id or "") if video.brand_id else "" + results[idx] = _build_video_list_item( + video=video, + a3_increase_cnt=video.total_new_a3_cnt or 0, + ad_a3_increase_cnt=video.heated_new_a3_cnt or 0, + natural_a3_increase_cnt=video.natural_new_a3_cnt or 0, + api_cost=video.total_cost or 0.0, + brand_name=brand_name, + ) + + # 组 B:并发调用 API(预分配不同 cookie) + if api_videos: + configs = await get_distinct_configs(len(api_videos)) + semaphore = asyncio.Semaphore(5) + # 收集需要写回 DB 的数据(避免并发 session 操作) + pending_updates: list[Dict[str, Any]] = [] + + async def _fetch_single( + idx: int, video: KolVideo, config: Dict[str, Any] + ) -> None: + a3_increase_cnt = 0 + ad_a3_increase_cnt = 0 + natural_a3_increase_cnt = 0 + api_cost = 0.0 + ad_cost_val = 0.0 + api_success = False + + async with semaphore: + try: + publish_time = video.publish_time or datetime.now() + industry_id = video.industry_id or "" + + api_response = await call_yuntu_api( + item_id=video.item_id, + publish_time=publish_time, + industry_id=industry_id, + aadvid=config["aadvid"], + auth_token=config["auth_token"], + ) + api_data = parse_analysis_response(api_response) + a3_increase_cnt = api_data.get("a3_increase_cnt", 0) + ad_a3_increase_cnt = api_data.get("ad_a3_increase_cnt", 0) + natural_a3_increase_cnt = api_data.get("natural_a3_increase_cnt", 0) + api_cost = api_data.get("cost", 0) + ad_cost_val = api_data.get("ad_cost", 0) + api_success = True + + except SessionInvalidError: + # Session 失效,从池中移除,重新获取随机 config 重试 + session_pool.remove_by_auth_token(config["auth_token"]) + logger.warning(f"Session invalid for {video.item_id}, retrying") + retry_config = await get_random_config() + if retry_config: + try: + publish_time = video.publish_time or datetime.now() + industry_id = video.industry_id or "" + api_response = await call_yuntu_api( + item_id=video.item_id, + publish_time=publish_time, + industry_id=industry_id, + aadvid=retry_config["aadvid"], + auth_token=retry_config["auth_token"], + ) + api_data = parse_analysis_response(api_response) + a3_increase_cnt = api_data.get("a3_increase_cnt", 0) + ad_a3_increase_cnt = api_data.get("ad_a3_increase_cnt", 0) + natural_a3_increase_cnt = api_data.get("natural_a3_increase_cnt", 0) + api_cost = api_data.get("cost", 0) + ad_cost_val = api_data.get("ad_cost", 0) + api_success = True + except Exception as e2: + logger.warning(f"Retry failed for {video.item_id}: {e2}") + a3_increase_cnt = video.total_new_a3_cnt or 0 + ad_a3_increase_cnt = video.heated_new_a3_cnt or 0 + natural_a3_increase_cnt = video.natural_new_a3_cnt or 0 + api_cost = video.total_cost or 0.0 + + except Exception as e: + logger.warning(f"API failed for {video.item_id}: {e}") + a3_increase_cnt = video.total_new_a3_cnt or 0 + ad_a3_increase_cnt = video.heated_new_a3_cnt or 0 + natural_a3_increase_cnt = video.natural_new_a3_cnt or 0 + api_cost = video.total_cost or 0.0 + + # 收集待写回 DB 的数据(不在并发中操作 session) + if api_success: + pending_updates.append({ + "item_id": video.item_id, + "total_new_a3_cnt": int(a3_increase_cnt), + "heated_new_a3_cnt": int(ad_a3_increase_cnt), + "natural_new_a3_cnt": int(natural_a3_increase_cnt), + "total_cost": float(api_cost), + "heated_cost": float(ad_cost_val), + }) + + brand_name = brand_map.get(video.brand_id, video.brand_id or "") if video.brand_id else "" + results[idx] = _build_video_list_item( + video=video, + a3_increase_cnt=a3_increase_cnt, + ad_a3_increase_cnt=ad_a3_increase_cnt, + natural_a3_increase_cnt=natural_a3_increase_cnt, + api_cost=api_cost, + brand_name=brand_name, ) - api_data = parse_analysis_response(api_response) - a3_increase_cnt = api_data.get("a3_increase_cnt", 0) - ad_a3_increase_cnt = api_data.get("ad_a3_increase_cnt", 0) - natural_a3_increase_cnt = api_data.get("natural_a3_increase_cnt", 0) - api_cost = api_data.get("cost", 0) - except Exception as e: - logger.warning(f"API failed for {video.item_id}: {e}") - a3_increase_cnt = video.total_new_a3_cnt or 0 - ad_a3_increase_cnt = video.heated_new_a3_cnt or 0 - natural_a3_increase_cnt = video.natural_new_a3_cnt or 0 - api_cost = video.total_cost or 0.0 + # 为每个视频分配一个独立的 config,并发执行 + tasks = [] + for i, (idx, video) in enumerate(api_videos): + config = configs[i] if i < len(configs) else configs[i % len(configs)] if configs else {} + tasks.append(_fetch_single(idx, video, config)) - # 数据库字段 - estimated_video_cost = video.estimated_video_cost or 0.0 - natural_play_cnt = video.natural_play_cnt or 0 - total_play_cnt = video.total_play_cnt or 0 - after_view_search_uv = video.after_view_search_uv or 0 + await asyncio.gather(*tasks) - # 计算成本指标 - estimated_natural_search_uv = None - if total_play_cnt > 0 and after_view_search_uv > 0: - estimated_natural_search_uv = (natural_play_cnt / total_play_cnt) * after_view_search_uv + # 顺序写回 DB(避免并发 session 操作导致状态损坏) + for upd in pending_updates: + await update_video_a3_metrics( + session=session, + item_id=upd["item_id"], + total_new_a3_cnt=upd["total_new_a3_cnt"], + heated_new_a3_cnt=upd["heated_new_a3_cnt"], + natural_new_a3_cnt=upd["natural_new_a3_cnt"], + total_cost=upd["total_cost"], + heated_cost=upd["heated_cost"], + ) - estimated_natural_cpm = round((estimated_video_cost / natural_play_cnt) * 1000, 2) if natural_play_cnt > 0 else None - estimated_cp_a3 = round(api_cost / a3_increase_cnt, 2) if a3_increase_cnt > 0 else None - estimated_natural_cp_a3 = round(estimated_video_cost / natural_a3_increase_cnt, 2) if natural_a3_increase_cnt > 0 else None - estimated_cp_search = round(api_cost / after_view_search_uv, 2) if after_view_search_uv > 0 else None - estimated_natural_cp_search = round(estimated_video_cost / estimated_natural_search_uv, 2) if estimated_natural_search_uv and estimated_natural_search_uv > 0 else None - - brand_name = brand_map.get(video.brand_id, video.brand_id) if video.brand_id else "" - - result.append({ - "item_id": video.item_id, - "star_nickname": video.star_nickname or "", - "title": video.title or "", - "video_url": video.video_url or "", - "create_date": video.publish_time.isoformat() if video.publish_time else None, - "hot_type": video.viral_type or "", - "industry_id": video.industry_id or "", - "brand_id": video.brand_id or "", - "brand_name": brand_name, - "total_new_a3_cnt": a3_increase_cnt, - "heated_new_a3_cnt": ad_a3_increase_cnt, - "natural_new_a3_cnt": natural_a3_increase_cnt, - "estimated_natural_cpm": estimated_natural_cpm, - "estimated_cp_a3": estimated_cp_a3, - "estimated_natural_cp_a3": estimated_natural_cp_a3, - "estimated_cp_search": estimated_cp_search, - "estimated_natural_cp_search": estimated_natural_cp_search, - }) - - return result + # 过滤 None(不应发生,防御性编程) + return [r for r in results if r is not None] diff --git a/backend/tests/test_session_pool.py b/backend/tests/test_session_pool.py index a756367..e81dd24 100644 --- a/backend/tests/test_session_pool.py +++ b/backend/tests/test_session_pool.py @@ -17,6 +17,7 @@ from app.services.session_pool import ( session_pool, get_session_with_retry, get_random_config, + get_distinct_configs, ) @@ -571,3 +572,130 @@ class TestSessionPoolIntegration: result = await pool.refresh() assert result is False + + +def _make_configs(count: int) -> list[CookieConfig]: + """创建 count 个不同的 CookieConfig 用于测试。""" + return [ + CookieConfig( + brand_id=f"brand_{i}", + aadvid=f"aadvid_{i}", + auth_token=f"sessionid=session_{i}", + industry_id=20 + i, + brand_name=f"Brand{i}", + ) + for i in range(count) + ] + + +class TestGetDistinctConfigs: + """Tests for SessionPool.get_distinct_configs and module-level get_distinct_configs.""" + + def test_enough_configs_returns_distinct(self): + """池中配置 >= count → 返回不重复的""" + pool = SessionPool() + pool._configs = _make_configs(5) + + result = pool.get_distinct_configs(3) + + assert len(result) == 3 + tokens = [r["auth_token"] for r in result] + assert len(set(tokens)) == 3 + + def test_exact_count(self): + """池中配置 == count → 全部返回""" + pool = SessionPool() + pool._configs = _make_configs(3) + + result = pool.get_distinct_configs(3) + + assert len(result) == 3 + tokens = {r["auth_token"] for r in result} + assert len(tokens) == 3 + + def test_fewer_configs_wraps_around(self): + """池中配置 < count → 循环复用补足""" + pool = SessionPool() + pool._configs = _make_configs(2) + + result = pool.get_distinct_configs(5) + + assert len(result) == 5 + # 前 2 个一定不重复 + first_two_tokens = {result[0]["auth_token"], result[1]["auth_token"]} + assert len(first_two_tokens) == 2 + + def test_empty_pool_returns_empty(self): + """空池 → 返回空列表""" + pool = SessionPool() + + result = pool.get_distinct_configs(3) + + assert result == [] + + def test_zero_count_returns_empty(self): + """count=0 → 返回空列表""" + pool = SessionPool() + pool._configs = _make_configs(3) + + result = pool.get_distinct_configs(0) + + assert result == [] + + def test_result_contains_all_fields(self): + """验证返回的 dict 包含所有必要字段""" + pool = SessionPool() + pool._configs = _make_configs(1) + + result = pool.get_distinct_configs(1) + + assert len(result) == 1 + item = result[0] + assert "brand_id" in item + assert "aadvid" in item + assert "auth_token" in item + assert "industry_id" in item + assert "brand_name" in item + + @pytest.mark.asyncio + async def test_module_level_get_distinct_configs(self): + """测试模块级 get_distinct_configs 异步函数""" + pool = SessionPool() + pool._configs = _make_configs(3) + + with patch("app.services.session_pool.session_pool", pool): + result = await get_distinct_configs(2) + + assert len(result) == 2 + + @pytest.mark.asyncio + async def test_module_level_refreshes_on_empty(self): + """池为空时自动刷新""" + pool = SessionPool() + + with patch("app.services.session_pool.session_pool", pool): + with patch.object(pool, "refresh") as mock_refresh: + async def refresh_side_effect(): + pool._configs = _make_configs(3) + return True + + mock_refresh.side_effect = refresh_side_effect + + result = await get_distinct_configs(2) + + assert mock_refresh.called + assert len(result) == 2 + + @pytest.mark.asyncio + async def test_module_level_returns_empty_on_refresh_failure(self): + """刷新失败 → 返回空列表""" + pool = SessionPool() + + with patch("app.services.session_pool.session_pool", pool): + with patch.object(pool, "refresh") as mock_refresh: + mock_refresh.return_value = False + + result = await get_distinct_configs(2, max_retries=2) + + assert result == [] + assert mock_refresh.call_count == 2 diff --git a/backend/tests/test_video_analysis.py b/backend/tests/test_video_analysis.py index 0b6c01b..6007205 100644 --- a/backend/tests/test_video_analysis.py +++ b/backend/tests/test_video_analysis.py @@ -1,5 +1,12 @@ """ Tests for Video Analysis Service (T-024) + +覆盖: +- calculate_cost_metrics 计算 +- _needs_api_call 缓存判断 +- get_video_analysis_data 详情页(缓存命中 / API 调用 / API 失败降级) +- update_video_a3_metrics(含 heated_cost) +- get_video_list_with_a3 列表页(混合缓存 + 并发 API) """ import pytest @@ -7,20 +14,98 @@ from datetime import datetime from unittest.mock import AsyncMock, patch, MagicMock from app.services.video_analysis import ( + _build_video_list_item, + _needs_api_call, calculate_cost_metrics, - get_video_base_info, get_video_analysis_data, + get_video_list_with_a3, update_video_a3_metrics, - get_and_update_video_analysis, ) from app.services.yuntu_api import YuntuAPIError +def _make_mock_video(**overrides): + """创建标准 mock video 对象,带合理默认值。""" + defaults = { + "item_id": "video_123", + "title": "测试视频", + "video_url": "https://example.com/video", + "vid": "vid_123", + "star_id": "star_001", + "star_unique_id": "unique_001", + "star_nickname": "测试达人", + "star_uid": "uid_001", + "star_fans_cnt": 100000, + "star_mcn": "MCN1", + "publish_time": datetime(2025, 1, 15), + "create_date": datetime(2025, 1, 15), + "industry_name": "母婴", + "industry_id": "20", + "brand_id": "brand_001", + "hot_type": "爆款", + "viral_type": "爆款", + "is_hot": True, + "has_cart": False, + "total_play_cnt": 50000, + "natural_play_cnt": 40000, + "heated_play_cnt": 10000, + "total_interaction_cnt": 5000, + "total_interact": 5000, + "natural_interaction_cnt": 3000, + "heated_interaction_cnt": 2000, + "digg_cnt": 3000, + "like_cnt": 3000, + "share_cnt": 1000, + "comment_cnt": 1000, + "play_over_cnt": 20000, + "play_over_rate": 0.4, + "after_view_search_uv": 1000, + "after_view_search_cnt": 1200, + "after_view_search_rate": 0.02, + "back_search_cnt": 50, + "back_search_uv": 50, + "return_search_cnt": 50, + "new_a3_rate": 0.05, + "total_new_a3_cnt": 0, + "heated_new_a3_cnt": 0, + "natural_new_a3_cnt": 0, + "total_cost": 0.0, + "heated_cost": 0.0, + "star_task_cost": 0.0, + "search_cost": 0.0, + "ad_hot_roi": 0.0, + "estimated_video_cost": 10000.0, + "order_id": None, + "content_type": None, + "industry_tags": None, + "ad_hot_type": None, + "trend": None, + "trend_daily": None, + "trend_total": None, + "component_metric_list": None, + "key_word_after_search_infos": None, + "index_map": None, + "search_keywords": None, + "keywords": None, + "price_under_20s": None, + "price_20_60s": None, + "price_over_60s": None, + "video_duration": None, + "data_date": None, + "created_at": None, + "updated_at": None, + } + defaults.update(overrides) + mock = MagicMock() + for k, v in defaults.items(): + setattr(mock, k, v) + return mock + + class TestCalculateCostMetrics: """Tests for calculate_cost_metrics function.""" def test_all_metrics_calculated(self): - """Test calculation of all cost metrics.""" result = calculate_cost_metrics( cost=10000, natural_play_cnt=40000, @@ -30,29 +115,15 @@ class TestCalculateCostMetrics: total_play_cnt=50000, ) - # CPM = 10000 / 50000 * 1000 = 200 assert result["cpm"] == 200.0 - - # 自然CPM = 10000 / 40000 * 1000 = 250 assert result["natural_cpm"] == 250.0 - - # CPA3 = 10000 / 500 = 20 assert result["cpa3"] == 20.0 - - # 自然CPA3 = 10000 / 400 = 25 assert result["natural_cpa3"] == 25.0 - - # CPsearch = 10000 / 1000 = 10 assert result["cp_search"] == 10.0 - - # 预估自然看后搜人数 = 40000 / 50000 * 1000 = 800 assert result["estimated_natural_search_uv"] == 800.0 - - # 自然CPsearch = 10000 / 800 = 12.5 assert result["natural_cp_search"] == 12.5 def test_zero_total_play_cnt(self): - """Test with zero total_play_cnt (division by zero).""" result = calculate_cost_metrics( cost=10000, natural_play_cnt=0, @@ -68,7 +139,6 @@ class TestCalculateCostMetrics: assert result["natural_cp_search"] is None def test_zero_a3_counts(self): - """Test with zero A3 counts.""" result = calculate_cost_metrics( cost=10000, natural_play_cnt=40000, @@ -80,11 +150,9 @@ class TestCalculateCostMetrics: assert result["cpa3"] is None assert result["natural_cpa3"] is None - # 其他指标应该正常计算 assert result["cpm"] == 200.0 def test_zero_search_uv(self): - """Test with zero after_view_search_uv.""" result = calculate_cost_metrics( cost=10000, natural_play_cnt=40000, @@ -95,12 +163,10 @@ class TestCalculateCostMetrics: ) assert result["cp_search"] is None - # 当 after_view_search_uv=0 时,预估自然看后搜人数也应为 None(无意义) assert result["estimated_natural_search_uv"] is None assert result["natural_cp_search"] is None def test_all_zeros(self): - """Test with all zero values.""" result = calculate_cost_metrics( cost=0, natural_play_cnt=0, @@ -119,7 +185,6 @@ class TestCalculateCostMetrics: assert result["natural_cp_search"] is None def test_decimal_precision(self): - """Test that results are rounded to 2 decimal places.""" result = calculate_cost_metrics( cost=10000, natural_play_cnt=30000, @@ -129,104 +194,143 @@ class TestCalculateCostMetrics: total_play_cnt=70000, ) - # 验证都是2位小数 assert isinstance(result["cpm"], float) assert len(str(result["cpm"]).split(".")[-1]) <= 2 +class TestNeedsApiCall: + """Tests for _needs_api_call helper.""" + + def test_needs_call_when_no_data(self): + """A3=0 且 cost=0 → 需要调 API""" + video = _make_mock_video(total_new_a3_cnt=0, total_cost=0.0) + assert _needs_api_call(video) is True + + def test_needs_call_when_none_values(self): + """A3=None 且 cost=None → 需要调 API""" + video = _make_mock_video(total_new_a3_cnt=None, total_cost=None) + assert _needs_api_call(video) is True + + def test_no_call_when_a3_exists(self): + """有 A3 数据 → 不需要调 API""" + video = _make_mock_video(total_new_a3_cnt=500, total_cost=0.0) + assert _needs_api_call(video) is False + + def test_no_call_when_cost_exists(self): + """有 cost 数据 → 不需要调 API""" + video = _make_mock_video(total_new_a3_cnt=0, total_cost=10000.0) + assert _needs_api_call(video) is False + + def test_no_call_when_both_exist(self): + """A3 和 cost 都有 → 不需要调 API""" + video = _make_mock_video(total_new_a3_cnt=500, total_cost=10000.0) + assert _needs_api_call(video) is False + + class TestGetVideoAnalysisData: """Tests for get_video_analysis_data function.""" - async def test_success_with_api_data(self): - """Test successful data retrieval with API data.""" - # Mock database video - mock_video = MagicMock() - mock_video.item_id = "video_123" - mock_video.title = "测试视频" - mock_video.video_url = "https://example.com/video" - mock_video.star_id = "star_001" - mock_video.star_unique_id = "unique_001" - mock_video.star_nickname = "测试达人" - mock_video.publish_time = datetime(2025, 1, 15) - mock_video.industry_name = "母婴" - mock_video.industry_id = "20" - mock_video.total_play_cnt = 50000 - mock_video.natural_play_cnt = 40000 - mock_video.heated_play_cnt = 10000 - mock_video.after_view_search_uv = 1000 - mock_video.return_search_cnt = 50 - mock_video.estimated_video_cost = 10000 + @pytest.mark.asyncio + async def test_uses_db_when_cached(self): + """数据库已有 A3/Cost → 直接使用,不调 API""" + mock_video = _make_mock_video( + total_new_a3_cnt=500, + heated_new_a3_cnt=100, + natural_new_a3_cnt=400, + total_cost=10000.0, + heated_cost=5000.0, + ) - # Mock session mock_session = AsyncMock() mock_result = MagicMock() mock_result.scalar_one_or_none.return_value = mock_video mock_session.execute.return_value = mock_result - # Mock API response + with patch("app.services.brand_api.get_brand_names", new_callable=AsyncMock) as mock_brands: + mock_brands.return_value = {"brand_001": "品牌A"} + + with patch("app.services.video_analysis.fetch_yuntu_analysis") as mock_api: + result = await get_video_analysis_data(mock_session, "video_123") + + # API 不应被调用 + mock_api.assert_not_called() + + # 验证使用了数据库数据 + assert result["a3_metrics"]["total_new_a3_cnt"] == 500 + assert result["a3_metrics"]["heated_new_a3_cnt"] == 100 + assert result["a3_metrics"]["natural_new_a3_cnt"] == 400 + assert result["cost_metrics"]["total_cost"] == 10000.0 + assert result["cost_metrics"]["heated_cost"] == 5000.0 + + @pytest.mark.asyncio + async def test_calls_api_and_saves_to_db(self): + """数据库无数据 → 调 API → 写回 DB""" + mock_video = _make_mock_video( + total_new_a3_cnt=0, + total_cost=0.0, + heated_cost=0.0, + ) + + mock_session = AsyncMock() + mock_select_result = MagicMock() + mock_select_result.scalar_one_or_none.return_value = mock_video + + mock_update_result = MagicMock() + mock_update_result.rowcount = 1 + + call_count = [0] + + async def mock_execute(stmt): + stmt_str = str(stmt) + if "SELECT" in stmt_str.upper() or call_count[0] == 0: + call_count[0] += 1 + return mock_select_result + return mock_update_result + + mock_session.execute.side_effect = mock_execute + api_response = { "code": 0, "data": { - "total_show_cnt": 100000, - "natural_show_cnt": 80000, - "ad_show_cnt": 20000, - "total_play_cnt": 50000, - "natural_play_cnt": 40000, - "ad_play_cnt": 10000, - "effective_play_cnt": 30000, - "a3_increase_cnt": 500, - "ad_a3_increase_cnt": 100, - "natural_a3_increase_cnt": 400, - "after_view_search_uv": 1000, - "after_view_search_pv": 1500, - "brand_search_uv": 200, - "product_search_uv": 300, - "return_search_cnt": 50, - "cost": 10000, + "a3_increase_cnt": "500", + "ad_a3_increase_cnt": "100", + "natural_a3_increase_cnt": "400", + "cost": 15000, + "ad_cost": 8000, "natural_cost": 0, - "ad_cost": 10000, }, } - with patch( - "app.services.video_analysis.fetch_yuntu_analysis" - ) as mock_api: - mock_api.return_value = api_response + with patch("app.services.brand_api.get_brand_names", new_callable=AsyncMock) as mock_brands: + mock_brands.return_value = {"brand_001": "品牌A"} - result = await get_video_analysis_data(mock_session, "video_123") + with patch("app.services.video_analysis.fetch_yuntu_analysis") as mock_api: + mock_api.return_value = api_response - # T-027: 验证使用 industry_id 而不是 brand_id 调用 API - mock_api.assert_called_once_with( - item_id="video_123", - publish_time=datetime(2025, 1, 15), - industry_id="20", - ) + result = await get_video_analysis_data(mock_session, "video_123") - # 验证基础信息 - assert result["base_info"]["item_id"] == "video_123" - assert result["base_info"]["title"] == "测试视频" - assert result["base_info"]["star_nickname"] == "测试达人" + # API 应被调用 + mock_api.assert_called_once_with( + item_id="video_123", + publish_time=datetime(2025, 1, 15), + industry_id="20", + ) - # 验证触达指标 - assert result["reach_metrics"]["total_show_cnt"] == 100000 - assert result["reach_metrics"]["natural_play_cnt"] == 40000 + # 验证 A3 数据 + assert result["a3_metrics"]["total_new_a3_cnt"] == 500 + assert result["a3_metrics"]["heated_new_a3_cnt"] == 100 + assert result["a3_metrics"]["natural_new_a3_cnt"] == 400 - # 验证A3指标 - assert result["a3_metrics"]["a3_increase_cnt"] == 500 - assert result["a3_metrics"]["natural_a3_increase_cnt"] == 400 + # 验证 cost + assert result["cost_metrics"]["total_cost"] == 15000 + assert result["cost_metrics"]["heated_cost"] == 8000 - # 验证搜索指标 - assert result["search_metrics"]["after_view_search_uv"] == 1000 - - # 验证费用指标 - assert result["cost_metrics_raw"]["cost"] == 10000 - - # 验证计算指标 - assert result["cost_metrics_calculated"]["cpm"] is not None - assert result["cost_metrics_calculated"]["cpa3"] is not None + # 验证计算指标存在 + assert "estimated_cpm" in result["calculated_metrics"] + assert "estimated_natural_cpm" in result["calculated_metrics"] + @pytest.mark.asyncio async def test_video_not_found(self): - """Test error when video is not found.""" mock_session = AsyncMock() mock_result = MagicMock() mock_result.scalar_one_or_none.return_value = None @@ -237,87 +341,103 @@ class TestGetVideoAnalysisData: assert "not found" in str(exc_info.value).lower() + @pytest.mark.asyncio async def test_fallback_on_api_failure(self): - """Test fallback to database data when API fails.""" - # Mock database video - mock_video = MagicMock() - mock_video.item_id = "video_123" - mock_video.title = "测试视频" - mock_video.video_url = None - mock_video.star_id = "star_001" - mock_video.star_unique_id = "unique_001" - mock_video.star_nickname = "测试达人" - mock_video.publish_time = datetime(2025, 1, 15) - mock_video.industry_name = "母婴" - mock_video.industry_id = "20" - mock_video.total_play_cnt = 50000 - mock_video.natural_play_cnt = 40000 - mock_video.heated_play_cnt = 10000 - mock_video.after_view_search_uv = 1000 - mock_video.return_search_cnt = 50 - mock_video.estimated_video_cost = 10000 - mock_video.total_new_a3_cnt = 500 - mock_video.heated_new_a3_cnt = 100 - mock_video.natural_new_a3_cnt = 400 - mock_video.total_cost = 10000 + """API 失败 → 降级使用数据库数据""" + mock_video = _make_mock_video( + total_new_a3_cnt=0, + heated_new_a3_cnt=0, + natural_new_a3_cnt=0, + total_cost=0.0, + heated_cost=0.0, + ) - # Mock session mock_session = AsyncMock() mock_result = MagicMock() mock_result.scalar_one_or_none.return_value = mock_video mock_session.execute.return_value = mock_result - with patch( - "app.services.video_analysis.fetch_yuntu_analysis" - ) as mock_api: - mock_api.side_effect = YuntuAPIError("API Error") + with patch("app.services.brand_api.get_brand_names", new_callable=AsyncMock) as mock_brands: + mock_brands.return_value = {} - result = await get_video_analysis_data(mock_session, "video_123") + with patch("app.services.video_analysis.fetch_yuntu_analysis") as mock_api: + mock_api.side_effect = YuntuAPIError("API Error") - # 应该使用数据库数据 - assert result["reach_metrics"]["total_play_cnt"] == 50000 - assert result["reach_metrics"]["natural_play_cnt"] == 40000 - assert result["search_metrics"]["after_view_search_uv"] == 1000 + result = await get_video_analysis_data(mock_session, "video_123") + # 降级使用 DB 数据(都是 0) + assert result["a3_metrics"]["total_new_a3_cnt"] == 0 + assert result["cost_metrics"]["total_cost"] == 0.0 + + # 基础信息仍然正常 + assert result["base_info"]["vid"] == "video_123" + assert result["reach_metrics"]["total_play_cnt"] == 50000 + + @pytest.mark.asyncio async def test_null_publish_time(self): - """Test handling of null publish_time.""" - mock_video = MagicMock() - mock_video.item_id = "video_123" - mock_video.title = "测试视频" - mock_video.video_url = None - mock_video.star_id = "star_001" - mock_video.star_unique_id = "unique_001" - mock_video.star_nickname = "测试达人" - mock_video.publish_time = None # NULL - mock_video.industry_name = None - mock_video.industry_id = None - mock_video.total_play_cnt = 0 - mock_video.natural_play_cnt = 0 - mock_video.heated_play_cnt = 0 - mock_video.after_view_search_uv = 0 - mock_video.return_search_cnt = 0 - mock_video.estimated_video_cost = 0 + mock_video = _make_mock_video( + publish_time=None, + create_date=None, + total_new_a3_cnt=0, + total_cost=0.0, + total_play_cnt=0, + natural_play_cnt=0, + heated_play_cnt=0, + after_view_search_uv=0, + ) mock_session = AsyncMock() mock_result = MagicMock() mock_result.scalar_one_or_none.return_value = mock_video mock_session.execute.return_value = mock_result - with patch( - "app.services.video_analysis.fetch_yuntu_analysis" - ) as mock_api: - mock_api.return_value = {"code": 0, "data": {}} + with patch("app.services.brand_api.get_brand_names", new_callable=AsyncMock) as mock_brands: + mock_brands.return_value = {} + + with patch("app.services.video_analysis.fetch_yuntu_analysis") as mock_api: + mock_api.return_value = {"code": 0, "data": {}} + + result = await get_video_analysis_data(mock_session, "video_123") + + assert result["base_info"]["create_date"] is None + + @pytest.mark.asyncio + async def test_response_structure(self): + """验证返回数据包含所有 6 大类""" + mock_video = _make_mock_video(total_new_a3_cnt=500, total_cost=10000.0) + + mock_session = AsyncMock() + mock_result = MagicMock() + mock_result.scalar_one_or_none.return_value = mock_video + mock_session.execute.return_value = mock_result + + with patch("app.services.brand_api.get_brand_names", new_callable=AsyncMock) as mock_brands: + mock_brands.return_value = {} result = await get_video_analysis_data(mock_session, "video_123") - assert result["base_info"]["publish_time"] is None + assert "base_info" in result + assert "reach_metrics" in result + assert "a3_metrics" in result + assert "search_metrics" in result + assert "cost_metrics" in result + assert "calculated_metrics" in result + + # base_info 关键字段 + assert "star_nickname" in result["base_info"] + assert "vid" in result["base_info"] + assert "brand_name" in result["base_info"] + + # reach_metrics 关键字段 + assert "total_play_cnt" in result["reach_metrics"] + assert "natural_play_cnt" in result["reach_metrics"] class TestUpdateVideoA3Metrics: """Tests for update_video_a3_metrics function (T-025).""" + @pytest.mark.asyncio async def test_update_success(self): - """Test successful A3 metrics update.""" mock_session = AsyncMock() mock_result = MagicMock() mock_result.rowcount = 1 @@ -335,8 +455,29 @@ class TestUpdateVideoA3Metrics: assert result is True mock_session.commit.assert_called_once() + @pytest.mark.asyncio + async def test_update_with_heated_cost(self): + """验证 heated_cost 参数正常传递""" + mock_session = AsyncMock() + mock_result = MagicMock() + mock_result.rowcount = 1 + mock_session.execute.return_value = mock_result + + result = await update_video_a3_metrics( + session=mock_session, + item_id="video_123", + total_new_a3_cnt=500, + heated_new_a3_cnt=100, + natural_new_a3_cnt=400, + total_cost=15000.0, + heated_cost=8000.0, + ) + + assert result is True + mock_session.commit.assert_called_once() + + @pytest.mark.asyncio async def test_update_video_not_found(self): - """Test update when video not found.""" mock_session = AsyncMock() mock_result = MagicMock() mock_result.rowcount = 0 @@ -353,8 +494,8 @@ class TestUpdateVideoA3Metrics: assert result is False + @pytest.mark.asyncio async def test_update_database_error(self): - """Test update with database error.""" mock_session = AsyncMock() mock_session.execute.side_effect = Exception("Database error") @@ -371,64 +512,249 @@ class TestUpdateVideoA3Metrics: mock_session.rollback.assert_called_once() -class TestGetAndUpdateVideoAnalysis: - """Tests for get_and_update_video_analysis function (T-024 + T-025).""" +class TestBuildVideoListItem: + """Tests for _build_video_list_item helper.""" - async def test_get_and_update_success(self): - """Test successful get and update.""" - # Mock database video - mock_video = MagicMock() - mock_video.item_id = "video_123" - mock_video.title = "测试视频" - mock_video.video_url = None - mock_video.star_id = "star_001" - mock_video.star_unique_id = "unique_001" - mock_video.star_nickname = "测试达人" - mock_video.publish_time = datetime(2025, 1, 15) - mock_video.industry_name = "母婴" - mock_video.industry_id = "20" - mock_video.total_play_cnt = 50000 - mock_video.natural_play_cnt = 40000 - mock_video.heated_play_cnt = 10000 - mock_video.after_view_search_uv = 1000 - mock_video.return_search_cnt = 50 - mock_video.estimated_video_cost = 10000 + def test_build_item_with_full_data(self): + video = _make_mock_video( + total_play_cnt=50000, + natural_play_cnt=40000, + after_view_search_uv=1000, + estimated_video_cost=10000.0, + ) + + result = _build_video_list_item( + video=video, + a3_increase_cnt=500, + ad_a3_increase_cnt=100, + natural_a3_increase_cnt=400, + api_cost=15000.0, + brand_name="品牌A", + ) + + assert result["item_id"] == "video_123" + assert result["brand_name"] == "品牌A" + assert result["total_new_a3_cnt"] == 500 + assert result["estimated_natural_cpm"] is not None + assert result["estimated_cp_a3"] == 30.0 # 15000/500 + + def test_build_item_zero_division(self): + """分母为 0 时应返回 None""" + video = _make_mock_video( + total_play_cnt=0, + natural_play_cnt=0, + after_view_search_uv=0, + estimated_video_cost=0.0, + ) + + result = _build_video_list_item( + video=video, + a3_increase_cnt=0, + ad_a3_increase_cnt=0, + natural_a3_increase_cnt=0, + api_cost=0.0, + brand_name="", + ) + + assert result["estimated_natural_cpm"] is None + assert result["estimated_cp_a3"] is None + assert result["estimated_natural_cp_a3"] is None + assert result["estimated_cp_search"] is None + assert result["estimated_natural_cp_search"] is None + + +class TestGetVideoListWithA3: + """Tests for get_video_list_with_a3 function.""" + + @pytest.mark.asyncio + async def test_all_cached(self): + """所有视频都有缓存 → 不调 API""" + videos = [ + _make_mock_video( + item_id="v1", total_new_a3_cnt=500, total_cost=10000.0, brand_id="b1" + ), + _make_mock_video( + item_id="v2", total_new_a3_cnt=300, total_cost=8000.0, brand_id="b2" + ), + ] - # Mock session mock_session = AsyncMock() - mock_select_result = MagicMock() - mock_select_result.scalar_one_or_none.return_value = mock_video + with patch("app.services.brand_api.get_brand_names", new_callable=AsyncMock) as mock_brands: + mock_brands.return_value = {"b1": "品牌1", "b2": "品牌2"} + + with patch("app.services.video_analysis.call_yuntu_api") as mock_api: + result = await get_video_list_with_a3(mock_session, videos) + + mock_api.assert_not_called() + assert len(result) == 2 + assert result[0]["item_id"] == "v1" + assert result[0]["total_new_a3_cnt"] == 500 + assert result[1]["item_id"] == "v2" + assert result[1]["total_new_a3_cnt"] == 300 + + @pytest.mark.asyncio + async def test_all_need_api(self): + """所有视频都需要 API → 并发调用 → 首次即返回正确数据 → gather 后顺序写 DB""" + videos = [ + _make_mock_video( + item_id="v1", total_new_a3_cnt=0, total_cost=0.0, brand_id="b1" + ), + _make_mock_video( + item_id="v2", total_new_a3_cnt=0, total_cost=0.0, brand_id="b2" + ), + ] + + mock_session = AsyncMock() mock_update_result = MagicMock() mock_update_result.rowcount = 1 + mock_session.execute.return_value = mock_update_result - # 根据不同的SQL语句返回不同的结果 - async def mock_execute(stmt): - # 简单判断:如果是 SELECT 返回视频,如果是 UPDATE 返回更新结果 - stmt_str = str(stmt) - if "SELECT" in stmt_str.upper(): - return mock_select_result - return mock_update_result - - mock_session.execute.side_effect = mock_execute - - with patch( - "app.services.video_analysis.fetch_yuntu_analysis" - ) as mock_api: - mock_api.return_value = { - "code": 0, - "data": { - "a3_increase_cnt": 500, - "ad_a3_increase_cnt": 100, - "natural_a3_increase_cnt": 400, - "cost": 10000, - }, + api_response = { + "data": { + "a3_increase_cnt": "200", + "ad_a3_increase_cnt": "50", + "natural_a3_increase_cnt": "150", + "cost": 5000, + "ad_cost": 3000, } + } - result = await get_and_update_video_analysis(mock_session, "video_123") + with patch("app.services.brand_api.get_brand_names", new_callable=AsyncMock) as mock_brands: + mock_brands.return_value = {} - # 验证返回数据 - assert result["a3_metrics"]["a3_increase_cnt"] == 500 + with patch("app.services.video_analysis.call_yuntu_api", new_callable=AsyncMock) as mock_api: + mock_api.return_value = api_response - # 验证数据库更新被调用 - mock_session.commit.assert_called() + with patch("app.services.video_analysis.get_distinct_configs", new_callable=AsyncMock) as mock_configs: + mock_configs.return_value = [ + {"aadvid": "aad1", "auth_token": "tok1"}, + {"aadvid": "aad2", "auth_token": "tok2"}, + ] + + with patch("app.services.video_analysis.update_video_a3_metrics", new_callable=AsyncMock) as mock_update: + mock_update.return_value = True + + result = await get_video_list_with_a3(mock_session, videos) + + assert len(result) == 2 + assert mock_api.call_count == 2 + + # 首次查询即返回正确 API 数据(核心:不依赖 DB 写入成功) + assert result[0]["total_new_a3_cnt"] == 200 + assert result[1]["total_new_a3_cnt"] == 200 + + # 验证两个视频用了不同 config + api_calls = mock_api.call_args_list + tokens = {c.kwargs["auth_token"] for c in api_calls} + assert len(tokens) == 2 + + # DB 写入在 gather 之后顺序执行 + assert mock_update.call_count == 2 + update_item_ids = [c.kwargs["item_id"] for c in mock_update.call_args_list] + assert "v1" in update_item_ids + assert "v2" in update_item_ids + + @pytest.mark.asyncio + async def test_mixed_cached_and_api(self): + """混合场景:部分缓存,部分需 API → 只对 API 成功的写 DB""" + videos = [ + _make_mock_video( + item_id="v1", total_new_a3_cnt=500, total_cost=10000.0, brand_id="b1" + ), + _make_mock_video( + item_id="v2", total_new_a3_cnt=0, total_cost=0.0, brand_id="b2" + ), + _make_mock_video( + item_id="v3", total_new_a3_cnt=300, total_cost=5000.0, brand_id="b3" + ), + ] + + mock_session = AsyncMock() + + api_response = { + "data": { + "a3_increase_cnt": "200", + "ad_a3_increase_cnt": "50", + "natural_a3_increase_cnt": "150", + "cost": 5000, + "ad_cost": 3000, + } + } + + with patch("app.services.brand_api.get_brand_names", new_callable=AsyncMock) as mock_brands: + mock_brands.return_value = {} + + with patch("app.services.video_analysis.call_yuntu_api", new_callable=AsyncMock) as mock_api: + mock_api.return_value = api_response + + with patch("app.services.video_analysis.get_distinct_configs", new_callable=AsyncMock) as mock_configs: + mock_configs.return_value = [ + {"aadvid": "aad1", "auth_token": "tok1"}, + ] + + with patch("app.services.video_analysis.update_video_a3_metrics", new_callable=AsyncMock) as mock_update: + mock_update.return_value = True + + result = await get_video_list_with_a3(mock_session, videos) + + # 保持原始排序 + assert len(result) == 3 + assert result[0]["item_id"] == "v1" + assert result[0]["total_new_a3_cnt"] == 500 # from DB + assert result[1]["item_id"] == "v2" + assert result[1]["total_new_a3_cnt"] == 200 # from API + assert result[2]["item_id"] == "v3" + assert result[2]["total_new_a3_cnt"] == 300 # from DB + + # 只有 v2 调了 API + assert mock_api.call_count == 1 + + # 只对 v2 写回 DB + assert mock_update.call_count == 1 + assert mock_update.call_args.kwargs["item_id"] == "v2" + assert mock_update.call_args.kwargs["total_new_a3_cnt"] == 200 + + @pytest.mark.asyncio + async def test_empty_list(self): + """空列表 → 返回空""" + mock_session = AsyncMock() + + with patch("app.services.brand_api.get_brand_names", new_callable=AsyncMock) as mock_brands: + mock_brands.return_value = {} + + result = await get_video_list_with_a3(mock_session, []) + + assert result == [] + + @pytest.mark.asyncio + async def test_api_failure_fallback(self): + """API 调用失败 → 降级使用 DB 数据 → 不写回 DB""" + videos = [ + _make_mock_video( + item_id="v1", total_new_a3_cnt=0, total_cost=0.0, brand_id="b1" + ), + ] + + mock_session = AsyncMock() + + with patch("app.services.brand_api.get_brand_names", new_callable=AsyncMock) as mock_brands: + mock_brands.return_value = {} + + with patch("app.services.video_analysis.call_yuntu_api", new_callable=AsyncMock) as mock_api: + mock_api.side_effect = YuntuAPIError("API Error") + + with patch("app.services.video_analysis.get_distinct_configs", new_callable=AsyncMock) as mock_configs: + mock_configs.return_value = [ + {"aadvid": "aad1", "auth_token": "tok1"}, + ] + + with patch("app.services.video_analysis.update_video_a3_metrics", new_callable=AsyncMock) as mock_update: + result = await get_video_list_with_a3(mock_session, videos) + + # 降级到 DB 数据 + assert len(result) == 1 + assert result[0]["total_new_a3_cnt"] == 0 + + # API 失败不应写回 DB + mock_update.assert_not_called()