From eba9ce8e60f15f4789eb80e63a09ef9d86eabaa3 Mon Sep 17 00:00:00 2001 From: Your Name Date: Tue, 10 Feb 2026 10:45:33 +0800 Subject: [PATCH] =?UTF-8?q?test:=20=E8=A1=A5=E5=85=A8=20Profile=20?= =?UTF-8?q?=E5=92=8C=20Messages=20API=20=E6=B5=8B=E8=AF=95=E8=A6=86?= =?UTF-8?q?=E7=9B=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - test_profile_api.py: 16 个测试 (GET/PUT /profile + PUT /profile/password) - 三种角色资料获取、更新品牌方/代理商/达人字段、密码修改、权限校验 - test_messages_api.py: 16 个测试 (消息列表/分页/过滤/未读数/标记已读) - 分页、按类型/已读状态过滤、用户隔离、全部已读 32 个新测试全部通过 Co-Authored-By: Claude Opus 4.6 --- backend/tests/test_messages_api.py | 315 +++++++++++++++++++++++++++++ backend/tests/test_profile_api.py | 281 +++++++++++++++++++++++++ 2 files changed, 596 insertions(+) create mode 100644 backend/tests/test_messages_api.py create mode 100644 backend/tests/test_profile_api.py diff --git a/backend/tests/test_messages_api.py b/backend/tests/test_messages_api.py new file mode 100644 index 0000000..ac4b564 --- /dev/null +++ b/backend/tests/test_messages_api.py @@ -0,0 +1,315 @@ +""" +消息 API 测试 +覆盖: GET /messages, GET /messages/unread-count, PUT /messages/{id}/read, PUT /messages/read-all +""" +import pytest +from httpx import AsyncClient +from sqlalchemy.ext.asyncio import AsyncSession + +from app.services.message_service import create_message + + +MESSAGES_URL = "/api/v1/messages" +REGISTER_URL = "/api/v1/auth/register" + + +def _auth(token: str) -> dict: + return {"Authorization": f"Bearer {token}"} + + +async def _register(client: AsyncClient, role: str, name: str, email: str) -> dict: + resp = await client.post(REGISTER_URL, json={ + "email": email, + "password": "Test1234!", + "name": name, + "role": role, + "email_code": "000000", + }) + assert resp.status_code == 201, resp.text + return resp.json() + + +async def _seed_messages( + db: AsyncSession, + user_id: str, + count: int = 5, + msg_type: str = "system", + is_read: bool = False, +) -> list: + """在数据库中直接创建消息(绕过 API)""" + msgs = [] + for i in range(count): + m = await create_message( + db=db, + user_id=user_id, + type=msg_type, + title=f"测试消息 {i+1}", + content=f"消息内容 {i+1}", + related_task_id=f"TK{100000+i}", + sender_name="系统", + ) + if is_read: + m.is_read = True + msgs.append(m) + await db.commit() + return msgs + + +# ==================== GET /messages ==================== + + +class TestGetMessages: + """消息列表""" + + @pytest.mark.asyncio + async def test_empty_messages(self, client: AsyncClient): + data = await _register(client, "brand", "空消息", "empty-msg@test.com") + token = data["access_token"] + + resp = await client.get(MESSAGES_URL, headers=_auth(token)) + assert resp.status_code == 200 + + body = resp.json() + assert body["items"] == [] + assert body["total"] == 0 + assert body["page"] == 1 + + @pytest.mark.asyncio + async def test_list_messages(self, client: AsyncClient, test_db_session: AsyncSession): + data = await _register(client, "brand", "消息列表", "list-msg@test.com") + token = data["access_token"] + user_id = data["user"]["id"] + + await _seed_messages(test_db_session, user_id, count=3) + + resp = await client.get(MESSAGES_URL, headers=_auth(token)) + assert resp.status_code == 200 + + body = resp.json() + assert body["total"] == 3 + assert len(body["items"]) == 3 + + # 验证消息结构 + msg = body["items"][0] + assert "id" in msg + assert "type" in msg + assert "title" in msg + assert "content" in msg + assert "is_read" in msg + + @pytest.mark.asyncio + async def test_pagination(self, client: AsyncClient, test_db_session: AsyncSession): + data = await _register(client, "brand", "分页测试", "page-msg@test.com") + token = data["access_token"] + user_id = data["user"]["id"] + + await _seed_messages(test_db_session, user_id, count=15) + + # 第 1 页 + resp1 = await client.get( + MESSAGES_URL, params={"page": 1, "page_size": 10}, headers=_auth(token), + ) + assert resp1.status_code == 200 + body1 = resp1.json() + assert len(body1["items"]) == 10 + assert body1["total"] == 15 + assert body1["page"] == 1 + + # 第 2 页 + resp2 = await client.get( + MESSAGES_URL, params={"page": 2, "page_size": 10}, headers=_auth(token), + ) + assert resp2.status_code == 200 + body2 = resp2.json() + assert len(body2["items"]) == 5 + + @pytest.mark.asyncio + async def test_filter_by_read_status(self, client: AsyncClient, test_db_session: AsyncSession): + data = await _register(client, "brand", "已读过滤", "read-filter@test.com") + token = data["access_token"] + user_id = data["user"]["id"] + + await _seed_messages(test_db_session, user_id, count=3, is_read=False) + await _seed_messages(test_db_session, user_id, count=2, is_read=True) + + # 只看未读 + resp = await client.get( + MESSAGES_URL, params={"is_read": False}, headers=_auth(token), + ) + assert resp.status_code == 200 + body = resp.json() + assert body["total"] == 3 + for m in body["items"]: + assert m["is_read"] is False + + # 只看已读 + resp2 = await client.get( + MESSAGES_URL, params={"is_read": True}, headers=_auth(token), + ) + assert resp2.status_code == 200 + assert resp2.json()["total"] == 2 + + @pytest.mark.asyncio + async def test_filter_by_type(self, client: AsyncClient, test_db_session: AsyncSession): + data = await _register(client, "brand", "类型过滤", "type-filter@test.com") + token = data["access_token"] + user_id = data["user"]["id"] + + await _seed_messages(test_db_session, user_id, count=3, msg_type="new_task") + await _seed_messages(test_db_session, user_id, count=2, msg_type="system") + + resp = await client.get( + MESSAGES_URL, params={"type": "new_task"}, headers=_auth(token), + ) + assert resp.status_code == 200 + body = resp.json() + assert body["total"] == 3 + for m in body["items"]: + assert m["type"] == "new_task" + + @pytest.mark.asyncio + async def test_messages_isolation(self, client: AsyncClient, test_db_session: AsyncSession): + """用户只能看到自己的消息""" + data_a = await _register(client, "brand", "用户A", "user-a@test.com") + data_b = await _register(client, "agency", "用户B", "user-b@test.com") + + await _seed_messages(test_db_session, data_a["user"]["id"], count=3) + await _seed_messages(test_db_session, data_b["user"]["id"], count=5) + + resp = await client.get(MESSAGES_URL, headers=_auth(data_a["access_token"])) + assert resp.json()["total"] == 3 + + resp2 = await client.get(MESSAGES_URL, headers=_auth(data_b["access_token"])) + assert resp2.json()["total"] == 5 + + @pytest.mark.asyncio + async def test_messages_unauthenticated(self, client: AsyncClient): + resp = await client.get(MESSAGES_URL) + assert resp.status_code in (401, 403) + + +# ==================== GET /messages/unread-count ==================== + + +class TestUnreadCount: + """未读消息数""" + + @pytest.mark.asyncio + async def test_unread_count_zero(self, client: AsyncClient): + data = await _register(client, "brand", "零未读", "zero-unread@test.com") + token = data["access_token"] + + resp = await client.get(f"{MESSAGES_URL}/unread-count", headers=_auth(token)) + assert resp.status_code == 200 + assert resp.json()["count"] == 0 + + @pytest.mark.asyncio + async def test_unread_count(self, client: AsyncClient, test_db_session: AsyncSession): + data = await _register(client, "brand", "未读计数", "unread-count@test.com") + token = data["access_token"] + user_id = data["user"]["id"] + + await _seed_messages(test_db_session, user_id, count=5, is_read=False) + await _seed_messages(test_db_session, user_id, count=3, is_read=True) + + resp = await client.get(f"{MESSAGES_URL}/unread-count", headers=_auth(token)) + assert resp.status_code == 200 + assert resp.json()["count"] == 5 + + +# ==================== PUT /messages/{id}/read ==================== + + +class TestMarkAsRead: + """标记单条消息已读""" + + @pytest.mark.asyncio + async def test_mark_as_read(self, client: AsyncClient, test_db_session: AsyncSession): + data = await _register(client, "brand", "标记已读", "mark-read@test.com") + token = data["access_token"] + user_id = data["user"]["id"] + + msgs = await _seed_messages(test_db_session, user_id, count=1) + msg_id = msgs[0].id + + resp = await client.put(f"{MESSAGES_URL}/{msg_id}/read", headers=_auth(token)) + assert resp.status_code == 200 + + # 验证未读数减少 + count_resp = await client.get(f"{MESSAGES_URL}/unread-count", headers=_auth(token)) + assert count_resp.json()["count"] == 0 + + @pytest.mark.asyncio + async def test_mark_nonexistent_message(self, client: AsyncClient): + data = await _register(client, "brand", "不存在", "nonexist-msg@test.com") + token = data["access_token"] + + resp = await client.put(f"{MESSAGES_URL}/MSG999999/read", headers=_auth(token)) + assert resp.status_code == 404 + + @pytest.mark.asyncio + async def test_mark_other_users_message(self, client: AsyncClient, test_db_session: AsyncSession): + """不能标记别人的消息""" + data_a = await _register(client, "brand", "用户A标记", "mark-a@test.com") + data_b = await _register(client, "agency", "用户B标记", "mark-b@test.com") + + msgs = await _seed_messages(test_db_session, data_a["user"]["id"], count=1) + msg_id = msgs[0].id + + # 用户B尝试标记用户A的消息 + resp = await client.put( + f"{MESSAGES_URL}/{msg_id}/read", + headers=_auth(data_b["access_token"]), + ) + assert resp.status_code == 404 # 看不到别人的消息,返回 404 + + +# ==================== PUT /messages/read-all ==================== + + +class TestMarkAllAsRead: + """标记所有消息已读""" + + @pytest.mark.asyncio + async def test_mark_all_as_read(self, client: AsyncClient, test_db_session: AsyncSession): + data = await _register(client, "brand", "全部已读", "all-read@test.com") + token = data["access_token"] + user_id = data["user"]["id"] + + await _seed_messages(test_db_session, user_id, count=5) + + resp = await client.put(f"{MESSAGES_URL}/read-all", headers=_auth(token)) + assert resp.status_code == 200 + body = resp.json() + assert body["count"] == 5 + + # 验证未读数为 0 + count_resp = await client.get(f"{MESSAGES_URL}/unread-count", headers=_auth(token)) + assert count_resp.json()["count"] == 0 + + @pytest.mark.asyncio + async def test_mark_all_no_messages(self, client: AsyncClient): + data = await _register(client, "brand", "无消息全读", "no-msg-all@test.com") + token = data["access_token"] + + resp = await client.put(f"{MESSAGES_URL}/read-all", headers=_auth(token)) + assert resp.status_code == 200 + assert resp.json()["count"] == 0 + + @pytest.mark.asyncio + async def test_mark_all_only_affects_own(self, client: AsyncClient, test_db_session: AsyncSession): + """全部已读只影响自己的消息""" + data_a = await _register(client, "brand", "全读A", "all-own-a@test.com") + data_b = await _register(client, "agency", "全读B", "all-own-b@test.com") + + await _seed_messages(test_db_session, data_a["user"]["id"], count=3) + await _seed_messages(test_db_session, data_b["user"]["id"], count=4) + + # A 全部已读 + await client.put(f"{MESSAGES_URL}/read-all", headers=_auth(data_a["access_token"])) + + # B 的未读数不受影响 + count_resp = await client.get( + f"{MESSAGES_URL}/unread-count", headers=_auth(data_b["access_token"]), + ) + assert count_resp.json()["count"] == 4 diff --git a/backend/tests/test_profile_api.py b/backend/tests/test_profile_api.py new file mode 100644 index 0000000..dd92a1c --- /dev/null +++ b/backend/tests/test_profile_api.py @@ -0,0 +1,281 @@ +""" +用户资料 API 测试 +覆盖: GET /profile, PUT /profile, PUT /profile/password +""" +import pytest +from httpx import AsyncClient + + +PROFILE_URL = "/api/v1/profile" +REGISTER_URL = "/api/v1/auth/register" + + +def _auth(token: str) -> dict: + return {"Authorization": f"Bearer {token}"} + + +async def _register(client: AsyncClient, role: str, name: str, email: str) -> dict: + resp = await client.post(REGISTER_URL, json={ + "email": email, + "password": "Test1234!", + "name": name, + "role": role, + "email_code": "000000", + }) + assert resp.status_code == 201, resp.text + return resp.json() + + +# ==================== GET /profile ==================== + + +class TestGetProfile: + """获取用户资料""" + + @pytest.mark.asyncio + async def test_get_brand_profile(self, client: AsyncClient): + data = await _register(client, "brand", "测试品牌", "brand-profile@test.com") + token = data["access_token"] + + resp = await client.get(PROFILE_URL, headers=_auth(token)) + assert resp.status_code == 200 + + body = resp.json() + assert body["name"] == "测试品牌" + assert body["role"] == "brand" + assert body["email"] == "brand-profile@test.com" + assert body["brand"] is not None + assert body["brand"]["name"] == "测试品牌" + + @pytest.mark.asyncio + async def test_get_agency_profile(self, client: AsyncClient): + data = await _register(client, "agency", "测试代理商", "agency-profile@test.com") + token = data["access_token"] + + resp = await client.get(PROFILE_URL, headers=_auth(token)) + assert resp.status_code == 200 + + body = resp.json() + assert body["role"] == "agency" + assert body["agency"] is not None + assert body["agency"]["name"] == "测试代理商" + + @pytest.mark.asyncio + async def test_get_creator_profile(self, client: AsyncClient): + data = await _register(client, "creator", "测试达人", "creator-profile@test.com") + token = data["access_token"] + + resp = await client.get(PROFILE_URL, headers=_auth(token)) + assert resp.status_code == 200 + + body = resp.json() + assert body["role"] == "creator" + assert body["creator"] is not None + assert body["creator"]["name"] == "测试达人" + + @pytest.mark.asyncio + async def test_get_profile_unauthenticated(self, client: AsyncClient): + resp = await client.get(PROFILE_URL) + assert resp.status_code in (401, 403) + + @pytest.mark.asyncio + async def test_get_profile_invalid_token(self, client: AsyncClient): + resp = await client.get(PROFILE_URL, headers=_auth("invalid-token")) + assert resp.status_code in (401, 403) + + +# ==================== PUT /profile ==================== + + +class TestUpdateProfile: + """更新用户资料""" + + @pytest.mark.asyncio + async def test_update_brand_name(self, client: AsyncClient): + data = await _register(client, "brand", "原始品牌", "brand-update@test.com") + token = data["access_token"] + + resp = await client.put( + PROFILE_URL, + json={"name": "新品牌名称"}, + headers=_auth(token), + ) + assert resp.status_code == 200 + body = resp.json() + assert body["name"] == "新品牌名称" + assert body["brand"]["name"] == "新品牌名称" + + @pytest.mark.asyncio + async def test_update_brand_contact(self, client: AsyncClient): + data = await _register(client, "brand", "品牌联系人", "brand-contact@test.com") + token = data["access_token"] + + resp = await client.put( + PROFILE_URL, + json={ + "description": "品牌描述", + "contact_name": "张三", + "contact_phone": "13800000001", + "contact_email": "zhangsan@brand.com", + }, + headers=_auth(token), + ) + assert resp.status_code == 200 + body = resp.json() + assert body["brand"]["description"] == "品牌描述" + assert body["brand"]["contact_name"] == "张三" + assert body["brand"]["contact_phone"] == "13800000001" + assert body["brand"]["contact_email"] == "zhangsan@brand.com" + + @pytest.mark.asyncio + async def test_update_agency_profile(self, client: AsyncClient): + data = await _register(client, "agency", "代理商", "agency-update@test.com") + token = data["access_token"] + + resp = await client.put( + PROFILE_URL, + json={ + "name": "新代理商名", + "description": "专业MCN机构", + "contact_name": "李四", + }, + headers=_auth(token), + ) + assert resp.status_code == 200 + body = resp.json() + assert body["name"] == "新代理商名" + assert body["agency"]["name"] == "新代理商名" + assert body["agency"]["description"] == "专业MCN机构" + + @pytest.mark.asyncio + async def test_update_creator_profile(self, client: AsyncClient): + data = await _register(client, "creator", "达人", "creator-update@test.com") + token = data["access_token"] + + resp = await client.put( + PROFILE_URL, + json={ + "name": "新达人名", + "bio": "美食博主", + "douyin_account": "douyin123", + "xiaohongshu_account": "xhs456", + "bilibili_account": "bili789", + }, + headers=_auth(token), + ) + assert resp.status_code == 200 + body = resp.json() + assert body["name"] == "新达人名" + assert body["creator"]["name"] == "新达人名" + assert body["creator"]["bio"] == "美食博主" + assert body["creator"]["douyin_account"] == "douyin123" + assert body["creator"]["xiaohongshu_account"] == "xhs456" + assert body["creator"]["bilibili_account"] == "bili789" + + @pytest.mark.asyncio + async def test_update_phone_and_avatar(self, client: AsyncClient): + data = await _register(client, "brand", "头像测试", "avatar-test@test.com") + token = data["access_token"] + + resp = await client.put( + PROFILE_URL, + json={ + "phone": "13900000000", + "avatar": "https://example.com/avatar.png", + }, + headers=_auth(token), + ) + assert resp.status_code == 200 + body = resp.json() + assert body["phone"] == "13900000000" + assert body["avatar"] == "https://example.com/avatar.png" + + @pytest.mark.asyncio + async def test_update_empty_body(self, client: AsyncClient): + """空请求体不应报错""" + data = await _register(client, "brand", "空更新测试", "empty-update@test.com") + token = data["access_token"] + + resp = await client.put(PROFILE_URL, json={}, headers=_auth(token)) + assert resp.status_code == 200 + + @pytest.mark.asyncio + async def test_update_profile_unauthenticated(self, client: AsyncClient): + resp = await client.put(PROFILE_URL, json={"name": "hack"}) + assert resp.status_code in (401, 403) + + @pytest.mark.asyncio + async def test_update_persists(self, client: AsyncClient): + """更新后重新 GET 应返回最新数据""" + data = await _register(client, "creator", "持久化测试", "persist@test.com") + token = data["access_token"] + + await client.put( + PROFILE_URL, + json={"bio": "更新后的简介"}, + headers=_auth(token), + ) + + resp = await client.get(PROFILE_URL, headers=_auth(token)) + assert resp.status_code == 200 + assert resp.json()["creator"]["bio"] == "更新后的简介" + + +# ==================== PUT /profile/password ==================== + + +class TestChangePassword: + """修改密码""" + + @pytest.mark.asyncio + async def test_change_password_success(self, client: AsyncClient): + data = await _register(client, "brand", "密码测试", "pwd-change@test.com") + token = data["access_token"] + + resp = await client.put( + f"{PROFILE_URL}/password", + json={"old_password": "Test1234!", "new_password": "NewPass5678!"}, + headers=_auth(token), + ) + assert resp.status_code == 200 + assert "密码修改成功" in resp.json()["message"] + + # 用新密码登录 + login_resp = await client.post("/api/v1/auth/login", json={ + "email": "pwd-change@test.com", + "password": "NewPass5678!", + }) + assert login_resp.status_code == 200 + + @pytest.mark.asyncio + async def test_change_password_wrong_old(self, client: AsyncClient): + data = await _register(client, "brand", "错误密码", "wrong-pwd@test.com") + token = data["access_token"] + + resp = await client.put( + f"{PROFILE_URL}/password", + json={"old_password": "WrongPassword!", "new_password": "NewPass!"}, + headers=_auth(token), + ) + assert resp.status_code == 400 + assert "原密码" in resp.json()["detail"] + + @pytest.mark.asyncio + async def test_change_password_too_short(self, client: AsyncClient): + data = await _register(client, "brand", "短密码", "short-pwd@test.com") + token = data["access_token"] + + resp = await client.put( + f"{PROFILE_URL}/password", + json={"old_password": "Test1234!", "new_password": "12345"}, + headers=_auth(token), + ) + assert resp.status_code == 422 # Pydantic validation + + @pytest.mark.asyncio + async def test_change_password_unauthenticated(self, client: AsyncClient): + resp = await client.put( + f"{PROFILE_URL}/password", + json={"old_password": "a", "new_password": "b"}, + ) + assert resp.status_code in (401, 403)