From c17c64cd11552c631dde5abdaf439d3e4e8722b2 Mon Sep 17 00:00:00 2001 From: Your Name Date: Tue, 10 Feb 2026 13:51:40 +0800 Subject: [PATCH] =?UTF-8?q?test:=20=E8=A1=A5=E5=85=A8=E5=93=81=E7=89=8C?= =?UTF-8?q?=E6=96=B9=E5=B9=B3=E5=8F=B0=E8=A7=84=E5=88=99=20API=20=E6=B5=8B?= =?UTF-8?q?=E8=AF=95=E8=A6=86=E7=9B=96=20(22=20=E4=B8=AA=E6=96=B0=E6=B5=8B?= =?UTF-8?q?=E8=AF=95)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - TestBrandPlatformRuleParse: 文档解析 6 个用例 (201/400/降级) - TestBrandPlatformRuleConfirm: 确认生效 5 个用例 (active/编辑/停旧/404/隔离) - TestBrandPlatformRuleList: 列表查询 5 个用例 (空/筛选/隔离) - TestBrandPlatformRuleDelete: 删除 4 个用例 (204/实际删除/404/隔离) - TestBrandPlatformRuleLifecycle: 完整生命周期 1 个用例 - 修复 confirm 端点 flush 后缺少 refresh 导致的 MissingGreenlet 错误 Co-Authored-By: Claude Opus 4.6 --- backend/app/api/rules.py | 1 + backend/tests/test_rules_api.py | 488 +++++++++++++++++++++++++++++++- 2 files changed, 487 insertions(+), 2 deletions(-) diff --git a/backend/app/api/rules.py b/backend/app/api/rules.py index 745ac9d..82b47dc 100644 --- a/backend/app/api/rules.py +++ b/backend/app/api/rules.py @@ -611,6 +611,7 @@ async def confirm_platform_rule( rule.parsed_rules = request.parsed_rules.model_dump() rule.status = RuleStatus.ACTIVE.value await db.flush() + await db.refresh(rule) return _format_platform_rule(rule) diff --git a/backend/tests/test_rules_api.py b/backend/tests/test_rules_api.py index ed10433..cce77f8 100644 --- a/backend/tests/test_rules_api.py +++ b/backend/tests/test_rules_api.py @@ -1,8 +1,10 @@ """ -规则管理 API 测试 (TDD - 红色阶段) -测试覆盖: 违禁词库、白名单、竞品库、平台规则 +规则管理 API 测试 +测试覆盖: 违禁词库、白名单、竞品库、平台规则、品牌方平台规则 CRUD """ +import json import pytest +from unittest.mock import AsyncMock, MagicMock, patch from httpx import AsyncClient from app.schemas.review import ScriptReviewResponse, ViolationType @@ -383,3 +385,485 @@ class TestRuleConflictDetection: assert "brief_rule" in conflict assert "platform_rule" in conflict assert "suggestion" in conflict + + +# ==================== 品牌方平台规则(文档上传 + AI 解析) ==================== + +# Mock AI 解析返回的规则数据 +MOCK_PARSED_RULES = { + "forbidden_words": ["绝对有效", "最强", "第一"], + "restricted_words": [ + {"word": "推荐", "condition": "不能用于医疗产品", "suggestion": "建议改为'供参考'"} + ], + "duration": {"min_seconds": 7, "max_seconds": 60}, + "content_requirements": ["必须展示产品正面", "需口播品牌名"], + "other_rules": [ + {"rule": "字幕要求", "description": "视频必须添加中文字幕"} + ], +} + +MOCK_AI_JSON_RESPONSE = json.dumps(MOCK_PARSED_RULES, ensure_ascii=False) + + +def _mock_ai_client_for_parse(): + """创建用于文档解析的 mock AI 客户端""" + client = MagicMock() + client.chat_completion = AsyncMock(return_value=MagicMock( + content=MOCK_AI_JSON_RESPONSE, + )) + client.close = AsyncMock() + return client + + +async def _create_platform_rule( + client: AsyncClient, + tenant_id: str, + brand_id: str, + platform: str = "douyin", + document_name: str = "规则文档.pdf", +) -> dict: + """辅助函数:创建一条 draft 平台规则""" + with patch( + "app.api.rules.DocumentParser.download_and_parse", + new_callable=AsyncMock, + return_value="这是平台规则文档内容...", + ), patch( + "app.api.rules.AIServiceFactory.get_client", + new_callable=AsyncMock, + return_value=_mock_ai_client_for_parse(), + ), patch( + "app.api.rules.AIServiceFactory.get_config", + new_callable=AsyncMock, + return_value=MagicMock(models={"text": "gpt-4o"}), + ): + resp = await client.post( + "/api/v1/rules/platform-rules/parse", + headers={"X-Tenant-ID": tenant_id}, + json={ + "document_url": "https://tos.example.com/rules.pdf", + "document_name": document_name, + "platform": platform, + "brand_id": brand_id, + }, + ) + return resp + + +class TestBrandPlatformRuleParse: + """品牌方平台规则 — 上传文档 + AI 解析""" + + @pytest.mark.asyncio + async def test_parse_returns_201_draft(self, client: AsyncClient, tenant_id: str, brand_id: str): + """上传文档解析返回 201,状态为 draft""" + resp = await _create_platform_rule(client, tenant_id, brand_id) + assert resp.status_code == 201 + + data = resp.json() + assert data["status"] == "draft" + assert data["platform"] == "douyin" + assert data["brand_id"] == brand_id + assert data["id"].startswith("pr-") + assert data["document_name"] == "规则文档.pdf" + + @pytest.mark.asyncio + async def test_parse_returns_parsed_rules(self, client: AsyncClient, tenant_id: str, brand_id: str): + """解析后返回结构化规则""" + resp = await _create_platform_rule(client, tenant_id, brand_id) + data = resp.json() + + rules = data["parsed_rules"] + assert "forbidden_words" in rules + assert "restricted_words" in rules + assert "duration" in rules + assert "content_requirements" in rules + assert "other_rules" in rules + assert len(rules["forbidden_words"]) == 3 + assert "绝对有效" in rules["forbidden_words"] + + @pytest.mark.asyncio + async def test_parse_empty_document_returns_400(self, client: AsyncClient, tenant_id: str, brand_id: str): + """空文档返回 400""" + with patch( + "app.api.rules.DocumentParser.download_and_parse", + new_callable=AsyncMock, + return_value=" ", + ): + resp = await client.post( + "/api/v1/rules/platform-rules/parse", + headers={"X-Tenant-ID": tenant_id}, + json={ + "document_url": "https://tos.example.com/empty.pdf", + "document_name": "empty.pdf", + "platform": "douyin", + "brand_id": brand_id, + }, + ) + assert resp.status_code == 400 + assert "内容为空" in resp.json()["detail"] + + @pytest.mark.asyncio + async def test_parse_unsupported_format_returns_400(self, client: AsyncClient, tenant_id: str, brand_id: str): + """不支持的文件格式返回 400""" + with patch( + "app.api.rules.DocumentParser.download_and_parse", + new_callable=AsyncMock, + side_effect=ValueError("不支持的文件格式: zip"), + ): + resp = await client.post( + "/api/v1/rules/platform-rules/parse", + headers={"X-Tenant-ID": tenant_id}, + json={ + "document_url": "https://tos.example.com/file.zip", + "document_name": "file.zip", + "platform": "douyin", + "brand_id": brand_id, + }, + ) + assert resp.status_code == 400 + + @pytest.mark.asyncio + async def test_parse_ai_failure_returns_empty_rules(self, client: AsyncClient, tenant_id: str, brand_id: str): + """AI 解析失败时返回空规则结构(降级处理)""" + with patch( + "app.api.rules.DocumentParser.download_and_parse", + new_callable=AsyncMock, + return_value="文档内容...", + ), patch( + "app.api.rules.AIServiceFactory.get_client", + new_callable=AsyncMock, + return_value=None, + ): + resp = await client.post( + "/api/v1/rules/platform-rules/parse", + headers={"X-Tenant-ID": tenant_id}, + json={ + "document_url": "https://tos.example.com/rules.pdf", + "document_name": "rules.pdf", + "platform": "douyin", + "brand_id": brand_id, + }, + ) + assert resp.status_code == 201 + rules = resp.json()["parsed_rules"] + assert rules["forbidden_words"] == [] + assert rules["content_requirements"] == [] + assert rules["duration"] is None + + @pytest.mark.asyncio + async def test_parse_multiple_platforms(self, client: AsyncClient, tenant_id: str, brand_id: str): + """同一品牌方可以上传不同平台的规则""" + r1 = await _create_platform_rule(client, tenant_id, brand_id, platform="douyin") + r2 = await _create_platform_rule(client, tenant_id, brand_id, platform="xiaohongshu") + + assert r1.status_code == 201 + assert r2.status_code == 201 + assert r1.json()["platform"] == "douyin" + assert r2.json()["platform"] == "xiaohongshu" + + +class TestBrandPlatformRuleConfirm: + """品牌方平台规则 — 确认/生效""" + + @pytest.mark.asyncio + async def test_confirm_sets_active(self, client: AsyncClient, tenant_id: str, brand_id: str): + """确认规则后状态变为 active""" + # 先创建 draft + create_resp = await _create_platform_rule(client, tenant_id, brand_id) + rule_id = create_resp.json()["id"] + + # 确认 + confirm_resp = await client.put( + f"/api/v1/rules/platform-rules/{rule_id}/confirm", + headers={"X-Tenant-ID": tenant_id}, + json={ + "parsed_rules": MOCK_PARSED_RULES, + }, + ) + assert confirm_resp.status_code == 200 + data = confirm_resp.json() + assert data["status"] == "active" + assert data["id"] == rule_id + + @pytest.mark.asyncio + async def test_confirm_with_edited_rules(self, client: AsyncClient, tenant_id: str, brand_id: str): + """品牌方修改后确认""" + create_resp = await _create_platform_rule(client, tenant_id, brand_id) + rule_id = create_resp.json()["id"] + + edited_rules = { + "forbidden_words": ["绝对有效", "最强", "第一", "新增的违禁词"], + "restricted_words": [], + "duration": {"min_seconds": 10, "max_seconds": 120}, + "content_requirements": ["必须展示产品"], + "other_rules": [], + } + + confirm_resp = await client.put( + f"/api/v1/rules/platform-rules/{rule_id}/confirm", + headers={"X-Tenant-ID": tenant_id}, + json={"parsed_rules": edited_rules}, + ) + assert confirm_resp.status_code == 200 + data = confirm_resp.json() + assert "新增的违禁词" in data["parsed_rules"]["forbidden_words"] + assert data["parsed_rules"]["duration"]["min_seconds"] == 10 + + @pytest.mark.asyncio + async def test_confirm_deactivates_old_rule(self, client: AsyncClient, tenant_id: str, brand_id: str): + """确认新规则后旧的 active 规则变 inactive""" + # 创建并确认第一条规则 + r1 = await _create_platform_rule(client, tenant_id, brand_id, platform="douyin") + rule1_id = r1.json()["id"] + await client.put( + f"/api/v1/rules/platform-rules/{rule1_id}/confirm", + headers={"X-Tenant-ID": tenant_id}, + json={"parsed_rules": MOCK_PARSED_RULES}, + ) + + # 创建并确认第二条规则(同品牌同平台) + r2 = await _create_platform_rule(client, tenant_id, brand_id, platform="douyin") + rule2_id = r2.json()["id"] + await client.put( + f"/api/v1/rules/platform-rules/{rule2_id}/confirm", + headers={"X-Tenant-ID": tenant_id}, + json={"parsed_rules": MOCK_PARSED_RULES}, + ) + + # 查询所有规则 — rule1 应该变 inactive,rule2 应该 active + list_resp = await client.get( + f"/api/v1/rules/platform-rules?brand_id={brand_id}&platform=douyin", + headers={"X-Tenant-ID": tenant_id}, + ) + rules = list_resp.json()["items"] + rule1 = next(r for r in rules if r["id"] == rule1_id) + rule2 = next(r for r in rules if r["id"] == rule2_id) + + assert rule1["status"] == "inactive" + assert rule2["status"] == "active" + + @pytest.mark.asyncio + async def test_confirm_nonexistent_rule_returns_404(self, client: AsyncClient, tenant_id: str): + """确认不存在的规则返回 404""" + resp = await client.put( + "/api/v1/rules/platform-rules/pr-nonexist/confirm", + headers={"X-Tenant-ID": tenant_id}, + json={"parsed_rules": MOCK_PARSED_RULES}, + ) + assert resp.status_code == 404 + + @pytest.mark.asyncio + async def test_confirm_cross_tenant_returns_404(self, client: AsyncClient, tenant_id: str, brand_id: str): + """不同租户确认规则返回 404(租户隔离)""" + create_resp = await _create_platform_rule(client, tenant_id, brand_id) + rule_id = create_resp.json()["id"] + + resp = await client.put( + f"/api/v1/rules/platform-rules/{rule_id}/confirm", + headers={"X-Tenant-ID": "other-tenant-xxx"}, + json={"parsed_rules": MOCK_PARSED_RULES}, + ) + assert resp.status_code == 404 + + +class TestBrandPlatformRuleList: + """品牌方平台规则 — 列表查询""" + + @pytest.mark.asyncio + async def test_list_empty_returns_200(self, client: AsyncClient, tenant_id: str): + """没有规则时返回空列表""" + resp = await client.get( + "/api/v1/rules/platform-rules", + headers={"X-Tenant-ID": tenant_id}, + ) + assert resp.status_code == 200 + data = resp.json() + assert data["items"] == [] + assert data["total"] == 0 + + @pytest.mark.asyncio + async def test_list_returns_created_rules(self, client: AsyncClient, tenant_id: str, brand_id: str): + """创建规则后列表包含该规则""" + await _create_platform_rule(client, tenant_id, brand_id, platform="douyin") + await _create_platform_rule(client, tenant_id, brand_id, platform="xiaohongshu") + + resp = await client.get( + f"/api/v1/rules/platform-rules?brand_id={brand_id}", + headers={"X-Tenant-ID": tenant_id}, + ) + data = resp.json() + assert data["total"] == 2 + platforms = {r["platform"] for r in data["items"]} + assert platforms == {"douyin", "xiaohongshu"} + + @pytest.mark.asyncio + async def test_list_filter_by_platform(self, client: AsyncClient, tenant_id: str, brand_id: str): + """按平台筛选""" + await _create_platform_rule(client, tenant_id, brand_id, platform="douyin") + await _create_platform_rule(client, tenant_id, brand_id, platform="xiaohongshu") + + resp = await client.get( + f"/api/v1/rules/platform-rules?brand_id={brand_id}&platform=douyin", + headers={"X-Tenant-ID": tenant_id}, + ) + data = resp.json() + assert data["total"] == 1 + assert data["items"][0]["platform"] == "douyin" + + @pytest.mark.asyncio + async def test_list_filter_by_status(self, client: AsyncClient, tenant_id: str, brand_id: str): + """按状态筛选""" + r = await _create_platform_rule(client, tenant_id, brand_id) + rule_id = r.json()["id"] + + # 确认一条 + await client.put( + f"/api/v1/rules/platform-rules/{rule_id}/confirm", + headers={"X-Tenant-ID": tenant_id}, + json={"parsed_rules": MOCK_PARSED_RULES}, + ) + + # 再创建一条 draft + await _create_platform_rule(client, tenant_id, brand_id, platform="douyin") + + # 只查 active + resp = await client.get( + f"/api/v1/rules/platform-rules?brand_id={brand_id}&status=active", + headers={"X-Tenant-ID": tenant_id}, + ) + active_rules = resp.json()["items"] + assert all(r["status"] == "active" for r in active_rules) + + # 只查 draft + resp2 = await client.get( + f"/api/v1/rules/platform-rules?brand_id={brand_id}&status=draft", + headers={"X-Tenant-ID": tenant_id}, + ) + draft_rules = resp2.json()["items"] + assert all(r["status"] == "draft" for r in draft_rules) + + @pytest.mark.asyncio + async def test_list_tenant_isolation(self, client: AsyncClient, tenant_id: str, brand_id: str): + """租户隔离:不同租户看不到彼此的规则""" + await _create_platform_rule(client, tenant_id, brand_id) + + resp = await client.get( + "/api/v1/rules/platform-rules", + headers={"X-Tenant-ID": "another-tenant-yyy"}, + ) + assert resp.json()["total"] == 0 + + +class TestBrandPlatformRuleDelete: + """品牌方平台规则 — 删除""" + + @pytest.mark.asyncio + async def test_delete_returns_204(self, client: AsyncClient, tenant_id: str, brand_id: str): + """删除规则返回 204""" + r = await _create_platform_rule(client, tenant_id, brand_id) + rule_id = r.json()["id"] + + resp = await client.delete( + f"/api/v1/rules/platform-rules/{rule_id}", + headers={"X-Tenant-ID": tenant_id}, + ) + assert resp.status_code == 204 + + @pytest.mark.asyncio + async def test_delete_actually_removes(self, client: AsyncClient, tenant_id: str, brand_id: str): + """删除后列表中不再包含该规则""" + r = await _create_platform_rule(client, tenant_id, brand_id) + rule_id = r.json()["id"] + + await client.delete( + f"/api/v1/rules/platform-rules/{rule_id}", + headers={"X-Tenant-ID": tenant_id}, + ) + + resp = await client.get( + f"/api/v1/rules/platform-rules?brand_id={brand_id}", + headers={"X-Tenant-ID": tenant_id}, + ) + ids = [r["id"] for r in resp.json()["items"]] + assert rule_id not in ids + + @pytest.mark.asyncio + async def test_delete_nonexistent_returns_404(self, client: AsyncClient, tenant_id: str): + """删除不存在的规则返回 404""" + resp = await client.delete( + "/api/v1/rules/platform-rules/pr-nonexist", + headers={"X-Tenant-ID": tenant_id}, + ) + assert resp.status_code == 404 + + @pytest.mark.asyncio + async def test_delete_cross_tenant_returns_404(self, client: AsyncClient, tenant_id: str, brand_id: str): + """不同租户删除规则返回 404(租户隔离)""" + r = await _create_platform_rule(client, tenant_id, brand_id) + rule_id = r.json()["id"] + + resp = await client.delete( + f"/api/v1/rules/platform-rules/{rule_id}", + headers={"X-Tenant-ID": "other-tenant-zzz"}, + ) + assert resp.status_code == 404 + + +class TestBrandPlatformRuleLifecycle: + """品牌方平台规则 — 完整生命周期""" + + @pytest.mark.asyncio + async def test_full_lifecycle(self, client: AsyncClient, tenant_id: str, brand_id: str): + """完整流程: 上传解析 → 确认生效 → 重新上传 → 旧规则停用""" + headers = {"X-Tenant-ID": tenant_id} + + # 1. 上传并解析 + r1 = await _create_platform_rule(client, tenant_id, brand_id, platform="douyin") + assert r1.status_code == 201 + rule1_id = r1.json()["id"] + assert r1.json()["status"] == "draft" + + # 2. 确认生效 + confirm_resp = await client.put( + f"/api/v1/rules/platform-rules/{rule1_id}/confirm", + headers=headers, + json={"parsed_rules": MOCK_PARSED_RULES}, + ) + assert confirm_resp.json()["status"] == "active" + + # 3. 重新上传新规则 + r2 = await _create_platform_rule(client, tenant_id, brand_id, platform="douyin") + rule2_id = r2.json()["id"] + assert r2.json()["status"] == "draft" + + # 4. 确认新规则 + await client.put( + f"/api/v1/rules/platform-rules/{rule2_id}/confirm", + headers=headers, + json={"parsed_rules": MOCK_PARSED_RULES}, + ) + + # 5. 验证旧规则自动停用 + list_resp = await client.get( + f"/api/v1/rules/platform-rules?brand_id={brand_id}&platform=douyin", + headers=headers, + ) + rules = list_resp.json()["items"] + rule1 = next(r for r in rules if r["id"] == rule1_id) + rule2 = next(r for r in rules if r["id"] == rule2_id) + assert rule1["status"] == "inactive" + assert rule2["status"] == "active" + + # 6. 删除旧规则 + del_resp = await client.delete( + f"/api/v1/rules/platform-rules/{rule1_id}", + headers=headers, + ) + assert del_resp.status_code == 204 + + # 7. 验证只剩新规则 + final_resp = await client.get( + f"/api/v1/rules/platform-rules?brand_id={brand_id}&platform=douyin", + headers=headers, + ) + assert final_resp.json()["total"] == 1 + assert final_resp.json()["items"][0]["id"] == rule2_id