video-compliance-ai/backend/tests/test_messages_api.py
Your Name eba9ce8e60 test: 补全 Profile 和 Messages API 测试覆盖
- test_profile_api.py: 16 个测试 (GET/PUT /profile + PUT /profile/password)
  - 三种角色资料获取、更新品牌方/代理商/达人字段、密码修改、权限校验
- test_messages_api.py: 16 个测试 (消息列表/分页/过滤/未读数/标记已读)
  - 分页、按类型/已读状态过滤、用户隔离、全部已读

32 个新测试全部通过

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-10 10:45:33 +08:00

316 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
消息 API 测试
覆盖: GET /messages, GET /messages/unread-count, PUT /messages/{id}/read, PUT /messages/read-all
"""
import pytest
from httpx import AsyncClient
from sqlalchemy.ext.asyncio import AsyncSession
from app.services.message_service import create_message
MESSAGES_URL = "/api/v1/messages"
REGISTER_URL = "/api/v1/auth/register"
def _auth(token: str) -> dict:
return {"Authorization": f"Bearer {token}"}
async def _register(client: AsyncClient, role: str, name: str, email: str) -> dict:
resp = await client.post(REGISTER_URL, json={
"email": email,
"password": "Test1234!",
"name": name,
"role": role,
"email_code": "000000",
})
assert resp.status_code == 201, resp.text
return resp.json()
async def _seed_messages(
db: AsyncSession,
user_id: str,
count: int = 5,
msg_type: str = "system",
is_read: bool = False,
) -> list:
"""在数据库中直接创建消息(绕过 API"""
msgs = []
for i in range(count):
m = await create_message(
db=db,
user_id=user_id,
type=msg_type,
title=f"测试消息 {i+1}",
content=f"消息内容 {i+1}",
related_task_id=f"TK{100000+i}",
sender_name="系统",
)
if is_read:
m.is_read = True
msgs.append(m)
await db.commit()
return msgs
# ==================== GET /messages ====================
class TestGetMessages:
"""消息列表"""
@pytest.mark.asyncio
async def test_empty_messages(self, client: AsyncClient):
data = await _register(client, "brand", "空消息", "empty-msg@test.com")
token = data["access_token"]
resp = await client.get(MESSAGES_URL, headers=_auth(token))
assert resp.status_code == 200
body = resp.json()
assert body["items"] == []
assert body["total"] == 0
assert body["page"] == 1
@pytest.mark.asyncio
async def test_list_messages(self, client: AsyncClient, test_db_session: AsyncSession):
data = await _register(client, "brand", "消息列表", "list-msg@test.com")
token = data["access_token"]
user_id = data["user"]["id"]
await _seed_messages(test_db_session, user_id, count=3)
resp = await client.get(MESSAGES_URL, headers=_auth(token))
assert resp.status_code == 200
body = resp.json()
assert body["total"] == 3
assert len(body["items"]) == 3
# 验证消息结构
msg = body["items"][0]
assert "id" in msg
assert "type" in msg
assert "title" in msg
assert "content" in msg
assert "is_read" in msg
@pytest.mark.asyncio
async def test_pagination(self, client: AsyncClient, test_db_session: AsyncSession):
data = await _register(client, "brand", "分页测试", "page-msg@test.com")
token = data["access_token"]
user_id = data["user"]["id"]
await _seed_messages(test_db_session, user_id, count=15)
# 第 1 页
resp1 = await client.get(
MESSAGES_URL, params={"page": 1, "page_size": 10}, headers=_auth(token),
)
assert resp1.status_code == 200
body1 = resp1.json()
assert len(body1["items"]) == 10
assert body1["total"] == 15
assert body1["page"] == 1
# 第 2 页
resp2 = await client.get(
MESSAGES_URL, params={"page": 2, "page_size": 10}, headers=_auth(token),
)
assert resp2.status_code == 200
body2 = resp2.json()
assert len(body2["items"]) == 5
@pytest.mark.asyncio
async def test_filter_by_read_status(self, client: AsyncClient, test_db_session: AsyncSession):
data = await _register(client, "brand", "已读过滤", "read-filter@test.com")
token = data["access_token"]
user_id = data["user"]["id"]
await _seed_messages(test_db_session, user_id, count=3, is_read=False)
await _seed_messages(test_db_session, user_id, count=2, is_read=True)
# 只看未读
resp = await client.get(
MESSAGES_URL, params={"is_read": False}, headers=_auth(token),
)
assert resp.status_code == 200
body = resp.json()
assert body["total"] == 3
for m in body["items"]:
assert m["is_read"] is False
# 只看已读
resp2 = await client.get(
MESSAGES_URL, params={"is_read": True}, headers=_auth(token),
)
assert resp2.status_code == 200
assert resp2.json()["total"] == 2
@pytest.mark.asyncio
async def test_filter_by_type(self, client: AsyncClient, test_db_session: AsyncSession):
data = await _register(client, "brand", "类型过滤", "type-filter@test.com")
token = data["access_token"]
user_id = data["user"]["id"]
await _seed_messages(test_db_session, user_id, count=3, msg_type="new_task")
await _seed_messages(test_db_session, user_id, count=2, msg_type="system")
resp = await client.get(
MESSAGES_URL, params={"type": "new_task"}, headers=_auth(token),
)
assert resp.status_code == 200
body = resp.json()
assert body["total"] == 3
for m in body["items"]:
assert m["type"] == "new_task"
@pytest.mark.asyncio
async def test_messages_isolation(self, client: AsyncClient, test_db_session: AsyncSession):
"""用户只能看到自己的消息"""
data_a = await _register(client, "brand", "用户A", "user-a@test.com")
data_b = await _register(client, "agency", "用户B", "user-b@test.com")
await _seed_messages(test_db_session, data_a["user"]["id"], count=3)
await _seed_messages(test_db_session, data_b["user"]["id"], count=5)
resp = await client.get(MESSAGES_URL, headers=_auth(data_a["access_token"]))
assert resp.json()["total"] == 3
resp2 = await client.get(MESSAGES_URL, headers=_auth(data_b["access_token"]))
assert resp2.json()["total"] == 5
@pytest.mark.asyncio
async def test_messages_unauthenticated(self, client: AsyncClient):
resp = await client.get(MESSAGES_URL)
assert resp.status_code in (401, 403)
# ==================== GET /messages/unread-count ====================
class TestUnreadCount:
"""未读消息数"""
@pytest.mark.asyncio
async def test_unread_count_zero(self, client: AsyncClient):
data = await _register(client, "brand", "零未读", "zero-unread@test.com")
token = data["access_token"]
resp = await client.get(f"{MESSAGES_URL}/unread-count", headers=_auth(token))
assert resp.status_code == 200
assert resp.json()["count"] == 0
@pytest.mark.asyncio
async def test_unread_count(self, client: AsyncClient, test_db_session: AsyncSession):
data = await _register(client, "brand", "未读计数", "unread-count@test.com")
token = data["access_token"]
user_id = data["user"]["id"]
await _seed_messages(test_db_session, user_id, count=5, is_read=False)
await _seed_messages(test_db_session, user_id, count=3, is_read=True)
resp = await client.get(f"{MESSAGES_URL}/unread-count", headers=_auth(token))
assert resp.status_code == 200
assert resp.json()["count"] == 5
# ==================== PUT /messages/{id}/read ====================
class TestMarkAsRead:
"""标记单条消息已读"""
@pytest.mark.asyncio
async def test_mark_as_read(self, client: AsyncClient, test_db_session: AsyncSession):
data = await _register(client, "brand", "标记已读", "mark-read@test.com")
token = data["access_token"]
user_id = data["user"]["id"]
msgs = await _seed_messages(test_db_session, user_id, count=1)
msg_id = msgs[0].id
resp = await client.put(f"{MESSAGES_URL}/{msg_id}/read", headers=_auth(token))
assert resp.status_code == 200
# 验证未读数减少
count_resp = await client.get(f"{MESSAGES_URL}/unread-count", headers=_auth(token))
assert count_resp.json()["count"] == 0
@pytest.mark.asyncio
async def test_mark_nonexistent_message(self, client: AsyncClient):
data = await _register(client, "brand", "不存在", "nonexist-msg@test.com")
token = data["access_token"]
resp = await client.put(f"{MESSAGES_URL}/MSG999999/read", headers=_auth(token))
assert resp.status_code == 404
@pytest.mark.asyncio
async def test_mark_other_users_message(self, client: AsyncClient, test_db_session: AsyncSession):
"""不能标记别人的消息"""
data_a = await _register(client, "brand", "用户A标记", "mark-a@test.com")
data_b = await _register(client, "agency", "用户B标记", "mark-b@test.com")
msgs = await _seed_messages(test_db_session, data_a["user"]["id"], count=1)
msg_id = msgs[0].id
# 用户B尝试标记用户A的消息
resp = await client.put(
f"{MESSAGES_URL}/{msg_id}/read",
headers=_auth(data_b["access_token"]),
)
assert resp.status_code == 404 # 看不到别人的消息,返回 404
# ==================== PUT /messages/read-all ====================
class TestMarkAllAsRead:
"""标记所有消息已读"""
@pytest.mark.asyncio
async def test_mark_all_as_read(self, client: AsyncClient, test_db_session: AsyncSession):
data = await _register(client, "brand", "全部已读", "all-read@test.com")
token = data["access_token"]
user_id = data["user"]["id"]
await _seed_messages(test_db_session, user_id, count=5)
resp = await client.put(f"{MESSAGES_URL}/read-all", headers=_auth(token))
assert resp.status_code == 200
body = resp.json()
assert body["count"] == 5
# 验证未读数为 0
count_resp = await client.get(f"{MESSAGES_URL}/unread-count", headers=_auth(token))
assert count_resp.json()["count"] == 0
@pytest.mark.asyncio
async def test_mark_all_no_messages(self, client: AsyncClient):
data = await _register(client, "brand", "无消息全读", "no-msg-all@test.com")
token = data["access_token"]
resp = await client.put(f"{MESSAGES_URL}/read-all", headers=_auth(token))
assert resp.status_code == 200
assert resp.json()["count"] == 0
@pytest.mark.asyncio
async def test_mark_all_only_affects_own(self, client: AsyncClient, test_db_session: AsyncSession):
"""全部已读只影响自己的消息"""
data_a = await _register(client, "brand", "全读A", "all-own-a@test.com")
data_b = await _register(client, "agency", "全读B", "all-own-b@test.com")
await _seed_messages(test_db_session, data_a["user"]["id"], count=3)
await _seed_messages(test_db_session, data_b["user"]["id"], count=4)
# A 全部已读
await client.put(f"{MESSAGES_URL}/read-all", headers=_auth(data_a["access_token"]))
# B 的未读数不受影响
count_resp = await client.get(
f"{MESSAGES_URL}/unread-count", headers=_auth(data_b["access_token"]),
)
assert count_resp.json()["count"] == 4