- test_profile_api.py: 16 个测试 (GET/PUT /profile + PUT /profile/password) - 三种角色资料获取、更新品牌方/代理商/达人字段、密码修改、权限校验 - test_messages_api.py: 16 个测试 (消息列表/分页/过滤/未读数/标记已读) - 分页、按类型/已读状态过滤、用户隔离、全部已读 32 个新测试全部通过 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
316 lines
11 KiB
Python
316 lines
11 KiB
Python
"""
|
||
消息 API 测试
|
||
覆盖: GET /messages, GET /messages/unread-count, PUT /messages/{id}/read, PUT /messages/read-all
|
||
"""
|
||
import pytest
|
||
from httpx import AsyncClient
|
||
from sqlalchemy.ext.asyncio import AsyncSession
|
||
|
||
from app.services.message_service import create_message
|
||
|
||
|
||
MESSAGES_URL = "/api/v1/messages"
|
||
REGISTER_URL = "/api/v1/auth/register"
|
||
|
||
|
||
def _auth(token: str) -> dict:
|
||
return {"Authorization": f"Bearer {token}"}
|
||
|
||
|
||
async def _register(client: AsyncClient, role: str, name: str, email: str) -> dict:
|
||
resp = await client.post(REGISTER_URL, json={
|
||
"email": email,
|
||
"password": "Test1234!",
|
||
"name": name,
|
||
"role": role,
|
||
"email_code": "000000",
|
||
})
|
||
assert resp.status_code == 201, resp.text
|
||
return resp.json()
|
||
|
||
|
||
async def _seed_messages(
|
||
db: AsyncSession,
|
||
user_id: str,
|
||
count: int = 5,
|
||
msg_type: str = "system",
|
||
is_read: bool = False,
|
||
) -> list:
|
||
"""在数据库中直接创建消息(绕过 API)"""
|
||
msgs = []
|
||
for i in range(count):
|
||
m = await create_message(
|
||
db=db,
|
||
user_id=user_id,
|
||
type=msg_type,
|
||
title=f"测试消息 {i+1}",
|
||
content=f"消息内容 {i+1}",
|
||
related_task_id=f"TK{100000+i}",
|
||
sender_name="系统",
|
||
)
|
||
if is_read:
|
||
m.is_read = True
|
||
msgs.append(m)
|
||
await db.commit()
|
||
return msgs
|
||
|
||
|
||
# ==================== GET /messages ====================
|
||
|
||
|
||
class TestGetMessages:
|
||
"""消息列表"""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_empty_messages(self, client: AsyncClient):
|
||
data = await _register(client, "brand", "空消息", "empty-msg@test.com")
|
||
token = data["access_token"]
|
||
|
||
resp = await client.get(MESSAGES_URL, headers=_auth(token))
|
||
assert resp.status_code == 200
|
||
|
||
body = resp.json()
|
||
assert body["items"] == []
|
||
assert body["total"] == 0
|
||
assert body["page"] == 1
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_list_messages(self, client: AsyncClient, test_db_session: AsyncSession):
|
||
data = await _register(client, "brand", "消息列表", "list-msg@test.com")
|
||
token = data["access_token"]
|
||
user_id = data["user"]["id"]
|
||
|
||
await _seed_messages(test_db_session, user_id, count=3)
|
||
|
||
resp = await client.get(MESSAGES_URL, headers=_auth(token))
|
||
assert resp.status_code == 200
|
||
|
||
body = resp.json()
|
||
assert body["total"] == 3
|
||
assert len(body["items"]) == 3
|
||
|
||
# 验证消息结构
|
||
msg = body["items"][0]
|
||
assert "id" in msg
|
||
assert "type" in msg
|
||
assert "title" in msg
|
||
assert "content" in msg
|
||
assert "is_read" in msg
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_pagination(self, client: AsyncClient, test_db_session: AsyncSession):
|
||
data = await _register(client, "brand", "分页测试", "page-msg@test.com")
|
||
token = data["access_token"]
|
||
user_id = data["user"]["id"]
|
||
|
||
await _seed_messages(test_db_session, user_id, count=15)
|
||
|
||
# 第 1 页
|
||
resp1 = await client.get(
|
||
MESSAGES_URL, params={"page": 1, "page_size": 10}, headers=_auth(token),
|
||
)
|
||
assert resp1.status_code == 200
|
||
body1 = resp1.json()
|
||
assert len(body1["items"]) == 10
|
||
assert body1["total"] == 15
|
||
assert body1["page"] == 1
|
||
|
||
# 第 2 页
|
||
resp2 = await client.get(
|
||
MESSAGES_URL, params={"page": 2, "page_size": 10}, headers=_auth(token),
|
||
)
|
||
assert resp2.status_code == 200
|
||
body2 = resp2.json()
|
||
assert len(body2["items"]) == 5
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_filter_by_read_status(self, client: AsyncClient, test_db_session: AsyncSession):
|
||
data = await _register(client, "brand", "已读过滤", "read-filter@test.com")
|
||
token = data["access_token"]
|
||
user_id = data["user"]["id"]
|
||
|
||
await _seed_messages(test_db_session, user_id, count=3, is_read=False)
|
||
await _seed_messages(test_db_session, user_id, count=2, is_read=True)
|
||
|
||
# 只看未读
|
||
resp = await client.get(
|
||
MESSAGES_URL, params={"is_read": False}, headers=_auth(token),
|
||
)
|
||
assert resp.status_code == 200
|
||
body = resp.json()
|
||
assert body["total"] == 3
|
||
for m in body["items"]:
|
||
assert m["is_read"] is False
|
||
|
||
# 只看已读
|
||
resp2 = await client.get(
|
||
MESSAGES_URL, params={"is_read": True}, headers=_auth(token),
|
||
)
|
||
assert resp2.status_code == 200
|
||
assert resp2.json()["total"] == 2
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_filter_by_type(self, client: AsyncClient, test_db_session: AsyncSession):
|
||
data = await _register(client, "brand", "类型过滤", "type-filter@test.com")
|
||
token = data["access_token"]
|
||
user_id = data["user"]["id"]
|
||
|
||
await _seed_messages(test_db_session, user_id, count=3, msg_type="new_task")
|
||
await _seed_messages(test_db_session, user_id, count=2, msg_type="system")
|
||
|
||
resp = await client.get(
|
||
MESSAGES_URL, params={"type": "new_task"}, headers=_auth(token),
|
||
)
|
||
assert resp.status_code == 200
|
||
body = resp.json()
|
||
assert body["total"] == 3
|
||
for m in body["items"]:
|
||
assert m["type"] == "new_task"
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_messages_isolation(self, client: AsyncClient, test_db_session: AsyncSession):
|
||
"""用户只能看到自己的消息"""
|
||
data_a = await _register(client, "brand", "用户A", "user-a@test.com")
|
||
data_b = await _register(client, "agency", "用户B", "user-b@test.com")
|
||
|
||
await _seed_messages(test_db_session, data_a["user"]["id"], count=3)
|
||
await _seed_messages(test_db_session, data_b["user"]["id"], count=5)
|
||
|
||
resp = await client.get(MESSAGES_URL, headers=_auth(data_a["access_token"]))
|
||
assert resp.json()["total"] == 3
|
||
|
||
resp2 = await client.get(MESSAGES_URL, headers=_auth(data_b["access_token"]))
|
||
assert resp2.json()["total"] == 5
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_messages_unauthenticated(self, client: AsyncClient):
|
||
resp = await client.get(MESSAGES_URL)
|
||
assert resp.status_code in (401, 403)
|
||
|
||
|
||
# ==================== GET /messages/unread-count ====================
|
||
|
||
|
||
class TestUnreadCount:
|
||
"""未读消息数"""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_unread_count_zero(self, client: AsyncClient):
|
||
data = await _register(client, "brand", "零未读", "zero-unread@test.com")
|
||
token = data["access_token"]
|
||
|
||
resp = await client.get(f"{MESSAGES_URL}/unread-count", headers=_auth(token))
|
||
assert resp.status_code == 200
|
||
assert resp.json()["count"] == 0
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_unread_count(self, client: AsyncClient, test_db_session: AsyncSession):
|
||
data = await _register(client, "brand", "未读计数", "unread-count@test.com")
|
||
token = data["access_token"]
|
||
user_id = data["user"]["id"]
|
||
|
||
await _seed_messages(test_db_session, user_id, count=5, is_read=False)
|
||
await _seed_messages(test_db_session, user_id, count=3, is_read=True)
|
||
|
||
resp = await client.get(f"{MESSAGES_URL}/unread-count", headers=_auth(token))
|
||
assert resp.status_code == 200
|
||
assert resp.json()["count"] == 5
|
||
|
||
|
||
# ==================== PUT /messages/{id}/read ====================
|
||
|
||
|
||
class TestMarkAsRead:
|
||
"""标记单条消息已读"""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_mark_as_read(self, client: AsyncClient, test_db_session: AsyncSession):
|
||
data = await _register(client, "brand", "标记已读", "mark-read@test.com")
|
||
token = data["access_token"]
|
||
user_id = data["user"]["id"]
|
||
|
||
msgs = await _seed_messages(test_db_session, user_id, count=1)
|
||
msg_id = msgs[0].id
|
||
|
||
resp = await client.put(f"{MESSAGES_URL}/{msg_id}/read", headers=_auth(token))
|
||
assert resp.status_code == 200
|
||
|
||
# 验证未读数减少
|
||
count_resp = await client.get(f"{MESSAGES_URL}/unread-count", headers=_auth(token))
|
||
assert count_resp.json()["count"] == 0
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_mark_nonexistent_message(self, client: AsyncClient):
|
||
data = await _register(client, "brand", "不存在", "nonexist-msg@test.com")
|
||
token = data["access_token"]
|
||
|
||
resp = await client.put(f"{MESSAGES_URL}/MSG999999/read", headers=_auth(token))
|
||
assert resp.status_code == 404
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_mark_other_users_message(self, client: AsyncClient, test_db_session: AsyncSession):
|
||
"""不能标记别人的消息"""
|
||
data_a = await _register(client, "brand", "用户A标记", "mark-a@test.com")
|
||
data_b = await _register(client, "agency", "用户B标记", "mark-b@test.com")
|
||
|
||
msgs = await _seed_messages(test_db_session, data_a["user"]["id"], count=1)
|
||
msg_id = msgs[0].id
|
||
|
||
# 用户B尝试标记用户A的消息
|
||
resp = await client.put(
|
||
f"{MESSAGES_URL}/{msg_id}/read",
|
||
headers=_auth(data_b["access_token"]),
|
||
)
|
||
assert resp.status_code == 404 # 看不到别人的消息,返回 404
|
||
|
||
|
||
# ==================== PUT /messages/read-all ====================
|
||
|
||
|
||
class TestMarkAllAsRead:
|
||
"""标记所有消息已读"""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_mark_all_as_read(self, client: AsyncClient, test_db_session: AsyncSession):
|
||
data = await _register(client, "brand", "全部已读", "all-read@test.com")
|
||
token = data["access_token"]
|
||
user_id = data["user"]["id"]
|
||
|
||
await _seed_messages(test_db_session, user_id, count=5)
|
||
|
||
resp = await client.put(f"{MESSAGES_URL}/read-all", headers=_auth(token))
|
||
assert resp.status_code == 200
|
||
body = resp.json()
|
||
assert body["count"] == 5
|
||
|
||
# 验证未读数为 0
|
||
count_resp = await client.get(f"{MESSAGES_URL}/unread-count", headers=_auth(token))
|
||
assert count_resp.json()["count"] == 0
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_mark_all_no_messages(self, client: AsyncClient):
|
||
data = await _register(client, "brand", "无消息全读", "no-msg-all@test.com")
|
||
token = data["access_token"]
|
||
|
||
resp = await client.put(f"{MESSAGES_URL}/read-all", headers=_auth(token))
|
||
assert resp.status_code == 200
|
||
assert resp.json()["count"] == 0
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_mark_all_only_affects_own(self, client: AsyncClient, test_db_session: AsyncSession):
|
||
"""全部已读只影响自己的消息"""
|
||
data_a = await _register(client, "brand", "全读A", "all-own-a@test.com")
|
||
data_b = await _register(client, "agency", "全读B", "all-own-b@test.com")
|
||
|
||
await _seed_messages(test_db_session, data_a["user"]["id"], count=3)
|
||
await _seed_messages(test_db_session, data_b["user"]["id"], count=4)
|
||
|
||
# A 全部已读
|
||
await client.put(f"{MESSAGES_URL}/read-all", headers=_auth(data_a["access_token"]))
|
||
|
||
# B 的未读数不受影响
|
||
count_resp = await client.get(
|
||
f"{MESSAGES_URL}/unread-count", headers=_auth(data_b["access_token"]),
|
||
)
|
||
assert count_resp.json()["count"] == 4
|