- 新增 test_auth_api.py (48 tests): 注册/登录/刷新/退出全流程覆盖 - 新增 test_tasks_api.py (38 tests): 任务 CRUD/审核/申诉/权限控制 - 新增 AuditLog 模型 + log_action 审计服务 - 新增 logging_config.py 结构化日志配置 - 修复 task_service.py 缺少 Project.brand 嵌套加载导致的 MissingGreenlet 错误 - 修复 conftest.py 添加限流清理 fixture 防止测试间干扰 - 修复 TDD 红色阶段测试文件的 import 错误 (skip) - auth.py 集成审计日志 (注册/登录/退出) - 全部 211 tests passed, 2 skipped Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
854 lines
32 KiB
Python
854 lines
32 KiB
Python
"""
|
||
认证 API 测试
|
||
测试覆盖: /api/v1/auth/register, /api/v1/auth/login, /api/v1/auth/refresh, /api/v1/auth/logout
|
||
使用 SQLite 内存数据库,通过 conftest.py 的 client fixture 注入测试数据库会话
|
||
"""
|
||
import pytest
|
||
from datetime import timedelta
|
||
from httpx import AsyncClient
|
||
|
||
from app.services.auth import create_access_token, create_refresh_token
|
||
from app.middleware.rate_limit import RateLimitMiddleware
|
||
|
||
|
||
def _find_rate_limiter(app_or_middleware):
|
||
"""递归遍历中间件栈,找到 RateLimitMiddleware 实例"""
|
||
if isinstance(app_or_middleware, RateLimitMiddleware):
|
||
return app_or_middleware
|
||
inner = getattr(app_or_middleware, "app", None)
|
||
if inner is not None and inner is not app_or_middleware:
|
||
return _find_rate_limiter(inner)
|
||
return None
|
||
|
||
|
||
@pytest.fixture(autouse=True)
|
||
async def _reset_rate_limiter(client: AsyncClient):
|
||
"""
|
||
每个测试函数执行前清除速率限制器的内存计数,避免跨测试 429 错误。
|
||
依赖 client fixture 确保 ASGI 应用的中间件栈已经构建完毕。
|
||
"""
|
||
from app.main import app as fastapi_app
|
||
|
||
if fastapi_app.middleware_stack is not None:
|
||
rl = _find_rate_limiter(fastapi_app.middleware_stack)
|
||
if rl is not None:
|
||
rl.requests.clear()
|
||
yield
|
||
|
||
|
||
# ==================== 辅助函数 ====================
|
||
|
||
|
||
async def register_user(
|
||
client: AsyncClient,
|
||
email: str = "test@example.com",
|
||
phone: str = None,
|
||
password: str = "Test1234!",
|
||
name: str = "测试用户",
|
||
role: str = "brand",
|
||
) -> dict:
|
||
"""注册用户并返回响应对象"""
|
||
payload = {
|
||
"password": password,
|
||
"name": name,
|
||
"role": role,
|
||
}
|
||
if email is not None:
|
||
payload["email"] = email
|
||
if phone is not None:
|
||
payload["phone"] = phone
|
||
response = await client.post("/api/v1/auth/register", json=payload)
|
||
return response
|
||
|
||
|
||
async def login_user(
|
||
client: AsyncClient,
|
||
email: str = "test@example.com",
|
||
phone: str = None,
|
||
password: str = "Test1234!",
|
||
) -> dict:
|
||
"""登录用户并返回响应对象"""
|
||
payload = {"password": password}
|
||
if email is not None:
|
||
payload["email"] = email
|
||
if phone is not None:
|
||
payload["phone"] = phone
|
||
response = await client.post("/api/v1/auth/login", json=payload)
|
||
return response
|
||
|
||
|
||
async def register_and_get_tokens(
|
||
client: AsyncClient,
|
||
email: str = "test@example.com",
|
||
phone: str = None,
|
||
password: str = "Test1234!",
|
||
name: str = "测试用户",
|
||
role: str = "brand",
|
||
) -> dict:
|
||
"""注册用户并返回 token 和用户信息"""
|
||
resp = await register_user(client, email=email, phone=phone, password=password, name=name, role=role)
|
||
assert resp.status_code == 201
|
||
return resp.json()
|
||
|
||
|
||
# ==================== 注册测试 ====================
|
||
|
||
|
||
class TestRegister:
|
||
"""POST /api/v1/auth/register 测试"""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_register_with_email_success(self, client: AsyncClient):
|
||
"""通过邮箱注册成功"""
|
||
resp = await register_user(client, email="user@example.com", role="brand")
|
||
assert resp.status_code == 201
|
||
|
||
data = resp.json()
|
||
assert "access_token" in data
|
||
assert "refresh_token" in data
|
||
assert data["token_type"] == "bearer"
|
||
assert "user" in data
|
||
|
||
user = data["user"]
|
||
assert user["email"] == "user@example.com"
|
||
assert user["name"] == "测试用户"
|
||
assert user["role"] == "brand"
|
||
assert user["is_verified"] is False
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_register_with_phone_success(self, client: AsyncClient):
|
||
"""通过手机号注册成功"""
|
||
resp = await register_user(client, email=None, phone="13800138000", role="creator")
|
||
assert resp.status_code == 201
|
||
|
||
data = resp.json()
|
||
assert "access_token" in data
|
||
assert "refresh_token" in data
|
||
|
||
user = data["user"]
|
||
assert user["phone"] == "13800138000"
|
||
assert user["email"] is None
|
||
assert user["role"] == "creator"
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_register_with_both_email_and_phone(self, client: AsyncClient):
|
||
"""同时提供邮箱和手机号注册成功"""
|
||
resp = await register_user(
|
||
client,
|
||
email="both@example.com",
|
||
phone="13900139000",
|
||
role="agency",
|
||
)
|
||
assert resp.status_code == 201
|
||
|
||
user = resp.json()["user"]
|
||
assert user["email"] == "both@example.com"
|
||
assert user["phone"] == "13900139000"
|
||
assert user["role"] == "agency"
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_register_missing_email_and_phone_returns_400(self, client: AsyncClient):
|
||
"""不提供邮箱和手机号时返回 400"""
|
||
resp = await register_user(client, email=None, phone=None)
|
||
assert resp.status_code == 400
|
||
|
||
data = resp.json()
|
||
assert "detail" in data
|
||
assert "邮箱" in data["detail"] or "手机号" in data["detail"]
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_register_duplicate_email_returns_400(self, client: AsyncClient):
|
||
"""重复邮箱注册返回 400"""
|
||
# 第一次注册
|
||
resp1 = await register_user(client, email="dup@example.com")
|
||
assert resp1.status_code == 201
|
||
|
||
# 第二次用相同邮箱注册
|
||
resp2 = await register_user(client, email="dup@example.com", name="另一个用户")
|
||
assert resp2.status_code == 400
|
||
|
||
data = resp2.json()
|
||
assert "已被注册" in data["detail"]
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_register_duplicate_phone_returns_400(self, client: AsyncClient):
|
||
"""重复手机号注册返回 400"""
|
||
# 第一次注册
|
||
resp1 = await register_user(client, email=None, phone="13800000001")
|
||
assert resp1.status_code == 201
|
||
|
||
# 第二次用相同手机号注册
|
||
resp2 = await register_user(client, email=None, phone="13800000001", name="另一个用户")
|
||
assert resp2.status_code == 400
|
||
|
||
data = resp2.json()
|
||
assert "已被注册" in data["detail"]
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_register_password_too_short_returns_422(self, client: AsyncClient):
|
||
"""密码过短 (< 6 字符) 返回 422 (Pydantic 验证错误)"""
|
||
resp = await register_user(client, email="short@example.com", password="123")
|
||
assert resp.status_code == 422
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_register_missing_password_returns_422(self, client: AsyncClient):
|
||
"""缺少密码字段返回 422"""
|
||
payload = {
|
||
"email": "nopwd@example.com",
|
||
"name": "测试",
|
||
"role": "brand",
|
||
}
|
||
resp = await client.post("/api/v1/auth/register", json=payload)
|
||
assert resp.status_code == 422
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_register_missing_name_returns_422(self, client: AsyncClient):
|
||
"""缺少 name 字段返回 422"""
|
||
payload = {
|
||
"email": "noname@example.com",
|
||
"password": "Test1234!",
|
||
"role": "brand",
|
||
}
|
||
resp = await client.post("/api/v1/auth/register", json=payload)
|
||
assert resp.status_code == 422
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_register_missing_role_returns_422(self, client: AsyncClient):
|
||
"""缺少 role 字段返回 422"""
|
||
payload = {
|
||
"email": "norole@example.com",
|
||
"password": "Test1234!",
|
||
"name": "测试用户",
|
||
}
|
||
resp = await client.post("/api/v1/auth/register", json=payload)
|
||
assert resp.status_code == 422
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_register_invalid_role_returns_422(self, client: AsyncClient):
|
||
"""无效的 role 值返回 422"""
|
||
resp = await register_user(client, email="badrole@example.com", role="admin")
|
||
assert resp.status_code == 422
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_register_invalid_email_format_returns_422(self, client: AsyncClient):
|
||
"""无效的邮箱格式返回 422"""
|
||
resp = await register_user(client, email="not-an-email")
|
||
assert resp.status_code == 422
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_register_invalid_phone_format_returns_422(self, client: AsyncClient):
|
||
"""无效的手机号格式返回 422 (不匹配 ^1[3-9]\\d{9}$)"""
|
||
resp = await register_user(client, email=None, phone="12345")
|
||
assert resp.status_code == 422
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_register_response_contains_user_id(self, client: AsyncClient):
|
||
"""注册响应包含用户 ID (以 U 开头)"""
|
||
resp = await register_user(client, email="uid@example.com")
|
||
assert resp.status_code == 201
|
||
|
||
user = resp.json()["user"]
|
||
assert user["id"].startswith("U")
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_register_brand_creates_brand_entity(self, client: AsyncClient):
|
||
"""注册品牌方角色时创建 Brand 实体并返回 brand_id"""
|
||
resp = await register_user(client, email="brand@example.com", role="brand")
|
||
assert resp.status_code == 201
|
||
|
||
user = resp.json()["user"]
|
||
assert user["brand_id"] is not None
|
||
assert user["brand_id"].startswith("BR")
|
||
# 品牌方的 tenant 是自己
|
||
assert user["tenant_id"] == user["brand_id"]
|
||
assert user["tenant_name"] == "测试用户"
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_register_agency_creates_agency_entity(self, client: AsyncClient):
|
||
"""注册代理商角色时创建 Agency 实体并返回 agency_id"""
|
||
resp = await register_user(client, email="agency@example.com", role="agency")
|
||
assert resp.status_code == 201
|
||
|
||
user = resp.json()["user"]
|
||
assert user["agency_id"] is not None
|
||
assert user["agency_id"].startswith("AG")
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_register_creator_creates_creator_entity(self, client: AsyncClient):
|
||
"""注册达人角色时创建 Creator 实体并返回 creator_id"""
|
||
resp = await register_user(client, email="creator@example.com", role="creator")
|
||
assert resp.status_code == 201
|
||
|
||
user = resp.json()["user"]
|
||
assert user["creator_id"] is not None
|
||
assert user["creator_id"].startswith("CR")
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_register_tokens_are_valid_jwt(self, client: AsyncClient):
|
||
"""注册返回的 token 是可解码的 JWT"""
|
||
from app.services.auth import decode_token
|
||
|
||
resp = await register_user(client, email="jwt@example.com")
|
||
assert resp.status_code == 201
|
||
|
||
data = resp.json()
|
||
|
||
access_payload = decode_token(data["access_token"])
|
||
assert access_payload is not None
|
||
assert access_payload["type"] == "access"
|
||
assert "sub" in access_payload
|
||
assert "exp" in access_payload
|
||
|
||
refresh_payload = decode_token(data["refresh_token"])
|
||
assert refresh_payload is not None
|
||
assert refresh_payload["type"] == "refresh"
|
||
assert "sub" in refresh_payload
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_register_expires_in_field(self, client: AsyncClient):
|
||
"""注册响应包含 expires_in 字段 (秒)"""
|
||
resp = await register_user(client, email="expiry@example.com")
|
||
assert resp.status_code == 201
|
||
|
||
data = resp.json()
|
||
assert "expires_in" in data
|
||
assert isinstance(data["expires_in"], int)
|
||
assert data["expires_in"] > 0
|
||
|
||
|
||
# ==================== 登录测试 ====================
|
||
|
||
|
||
class TestLogin:
|
||
"""POST /api/v1/auth/login 测试"""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_login_with_email_success(self, client: AsyncClient):
|
||
"""通过邮箱+密码登录成功"""
|
||
# 先注册
|
||
await register_user(client, email="login@example.com", password="Test1234!")
|
||
# 再登录
|
||
resp = await login_user(client, email="login@example.com", password="Test1234!")
|
||
assert resp.status_code == 200
|
||
|
||
data = resp.json()
|
||
assert "access_token" in data
|
||
assert "refresh_token" in data
|
||
assert data["token_type"] == "bearer"
|
||
assert "user" in data
|
||
assert data["user"]["email"] == "login@example.com"
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_login_with_phone_success(self, client: AsyncClient):
|
||
"""通过手机号+密码登录成功"""
|
||
# 先注册
|
||
await register_user(client, email=None, phone="13800138001", password="Test1234!")
|
||
# 再登录
|
||
resp = await login_user(client, email=None, phone="13800138001", password="Test1234!")
|
||
assert resp.status_code == 200
|
||
|
||
data = resp.json()
|
||
assert "access_token" in data
|
||
assert data["user"]["phone"] == "13800138001"
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_login_wrong_password_returns_401(self, client: AsyncClient):
|
||
"""密码错误返回 401"""
|
||
await register_user(client, email="wrongpwd@example.com", password="CorrectPwd!")
|
||
resp = await login_user(client, email="wrongpwd@example.com", password="WrongPassword!")
|
||
assert resp.status_code == 401
|
||
|
||
data = resp.json()
|
||
assert "detail" in data
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_login_nonexistent_user_returns_401(self, client: AsyncClient):
|
||
"""不存在的用户返回 401"""
|
||
resp = await login_user(client, email="nobody@example.com", password="Test1234!")
|
||
assert resp.status_code == 401
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_login_missing_email_and_phone_returns_400(self, client: AsyncClient):
|
||
"""不提供邮箱和手机号登录时返回 400"""
|
||
payload = {"password": "Test1234!"}
|
||
resp = await client.post("/api/v1/auth/login", json=payload)
|
||
assert resp.status_code == 400
|
||
|
||
data = resp.json()
|
||
assert "邮箱" in data["detail"] or "手机号" in data["detail"]
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_login_missing_password_returns_400(self, client: AsyncClient):
|
||
"""不提供密码登录时返回 400"""
|
||
payload = {"email": "test@example.com"}
|
||
resp = await client.post("/api/v1/auth/login", json=payload)
|
||
assert resp.status_code == 400
|
||
|
||
data = resp.json()
|
||
assert "密码" in data["detail"]
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_login_disabled_user_returns_403(self, client: AsyncClient):
|
||
"""被禁用的用户登录返回 403"""
|
||
# 注册一个用户
|
||
reg_resp = await register_user(client, email="disabled@example.com")
|
||
assert reg_resp.status_code == 201
|
||
|
||
# 直接在数据库中禁用该用户
|
||
from app.models.user import User
|
||
from sqlalchemy import update
|
||
|
||
# 获取测试数据库会话 (通过 client fixture 的 override)
|
||
from app.database import get_db
|
||
from app.main import app as fastapi_app
|
||
|
||
override_func = fastapi_app.dependency_overrides[get_db]
|
||
# 调用 override 函数来获取 session
|
||
async for db_session in override_func():
|
||
stmt = update(User).where(User.email == "disabled@example.com").values(is_active=False)
|
||
await db_session.execute(stmt)
|
||
await db_session.commit()
|
||
|
||
# 尝试登录
|
||
resp = await login_user(client, email="disabled@example.com", password="Test1234!")
|
||
assert resp.status_code == 403
|
||
|
||
data = resp.json()
|
||
assert "禁用" in data["detail"]
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_login_response_contains_user_info(self, client: AsyncClient):
|
||
"""登录响应包含完整的用户信息"""
|
||
await register_user(
|
||
client,
|
||
email="fullinfo@example.com",
|
||
password="Test1234!",
|
||
name="完整信息",
|
||
role="brand",
|
||
)
|
||
resp = await login_user(client, email="fullinfo@example.com", password="Test1234!")
|
||
assert resp.status_code == 200
|
||
|
||
user = resp.json()["user"]
|
||
assert "id" in user
|
||
assert user["email"] == "fullinfo@example.com"
|
||
assert user["name"] == "完整信息"
|
||
assert user["role"] == "brand"
|
||
assert "is_verified" in user
|
||
assert "brand_id" in user
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_login_returns_valid_tokens_each_time(self, client: AsyncClient):
|
||
"""每次登录都返回有效的 token 和 refresh_token"""
|
||
from app.services.auth import decode_token
|
||
|
||
await register_user(client, email="fresh@example.com", password="Test1234!")
|
||
|
||
resp1 = await login_user(client, email="fresh@example.com", password="Test1234!")
|
||
resp2 = await login_user(client, email="fresh@example.com", password="Test1234!")
|
||
|
||
assert resp1.status_code == 200
|
||
assert resp2.status_code == 200
|
||
|
||
data1 = resp1.json()
|
||
data2 = resp2.json()
|
||
|
||
# 两次登录都返回有效的 access_token 和 refresh_token
|
||
payload1 = decode_token(data1["access_token"])
|
||
payload2 = decode_token(data2["access_token"])
|
||
assert payload1 is not None
|
||
assert payload2 is not None
|
||
assert payload1["type"] == "access"
|
||
assert payload2["type"] == "access"
|
||
assert payload1["sub"] == payload2["sub"] # 同一用户
|
||
|
||
# 第二次登录后,只有最新的 refresh_token 可用于刷新
|
||
refresh_resp = await client.post(
|
||
"/api/v1/auth/refresh",
|
||
json={"refresh_token": data2["refresh_token"]},
|
||
)
|
||
assert refresh_resp.status_code == 200
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_login_empty_body_returns_400(self, client: AsyncClient):
|
||
"""空请求体返回 400"""
|
||
resp = await client.post("/api/v1/auth/login", json={})
|
||
assert resp.status_code == 400
|
||
|
||
|
||
# ==================== Token 刷新测试 ====================
|
||
|
||
|
||
class TestRefreshToken:
|
||
"""POST /api/v1/auth/refresh 测试"""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_refresh_with_valid_token_success(self, client: AsyncClient):
|
||
"""使用有效的 refresh token 刷新成功"""
|
||
reg_data = await register_and_get_tokens(client, email="refresh@example.com")
|
||
refresh_token = reg_data["refresh_token"]
|
||
|
||
resp = await client.post(
|
||
"/api/v1/auth/refresh",
|
||
json={"refresh_token": refresh_token},
|
||
)
|
||
assert resp.status_code == 200
|
||
|
||
data = resp.json()
|
||
assert "access_token" in data
|
||
assert data["token_type"] == "bearer"
|
||
assert data["expires_in"] > 0
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_refresh_returns_valid_access_token(self, client: AsyncClient):
|
||
"""刷新返回的 access token 是有效的"""
|
||
from app.services.auth import decode_token
|
||
|
||
reg_data = await register_and_get_tokens(client, email="validaccess@example.com")
|
||
refresh_token = reg_data["refresh_token"]
|
||
|
||
resp = await client.post(
|
||
"/api/v1/auth/refresh",
|
||
json={"refresh_token": refresh_token},
|
||
)
|
||
assert resp.status_code == 200
|
||
|
||
new_access_token = resp.json()["access_token"]
|
||
payload = decode_token(new_access_token)
|
||
assert payload is not None
|
||
assert payload["type"] == "access"
|
||
assert payload["sub"] == reg_data["user"]["id"]
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_refresh_with_invalid_token_returns_401(self, client: AsyncClient):
|
||
"""使用无效的 refresh token 返回 401"""
|
||
resp = await client.post(
|
||
"/api/v1/auth/refresh",
|
||
json={"refresh_token": "this-is-not-a-valid-jwt-token"},
|
||
)
|
||
assert resp.status_code == 401
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_refresh_with_access_token_returns_401(self, client: AsyncClient):
|
||
"""使用 access token (而非 refresh token) 刷新返回 401"""
|
||
reg_data = await register_and_get_tokens(client, email="wrongtype@example.com")
|
||
access_token = reg_data["access_token"]
|
||
|
||
resp = await client.post(
|
||
"/api/v1/auth/refresh",
|
||
json={"refresh_token": access_token},
|
||
)
|
||
assert resp.status_code == 401
|
||
|
||
data = resp.json()
|
||
assert "token" in data["detail"].lower() or "类型" in data["detail"]
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_refresh_with_expired_token_returns_401(self, client: AsyncClient):
|
||
"""使用过期的 refresh token 返回 401"""
|
||
# 注册以获取用户 ID
|
||
reg_data = await register_and_get_tokens(client, email="expired@example.com")
|
||
user_id = reg_data["user"]["id"]
|
||
|
||
# 创建一个已过期的 refresh token (过期时间为负)
|
||
expired_token, _ = create_refresh_token(user_id, expires_days=-1)
|
||
|
||
resp = await client.post(
|
||
"/api/v1/auth/refresh",
|
||
json={"refresh_token": expired_token},
|
||
)
|
||
assert resp.status_code == 401
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_refresh_with_revoked_token_returns_401(self, client: AsyncClient):
|
||
"""refresh token 已被撤销 (logout 后不匹配) 返回 401"""
|
||
reg_data = await register_and_get_tokens(client, email="revoked@example.com")
|
||
refresh_token = reg_data["refresh_token"]
|
||
access_token = reg_data["access_token"]
|
||
|
||
# 先 logout 使 refresh token 失效
|
||
await client.post(
|
||
"/api/v1/auth/logout",
|
||
headers={"Authorization": f"Bearer {access_token}"},
|
||
)
|
||
|
||
# 尝试用已失效的 refresh token 刷新
|
||
resp = await client.post(
|
||
"/api/v1/auth/refresh",
|
||
json={"refresh_token": refresh_token},
|
||
)
|
||
assert resp.status_code == 401
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_refresh_with_old_token_after_logout_and_relogin_returns_401(self, client: AsyncClient):
|
||
"""退出登录再重新登录后,旧的 refresh token 失效
|
||
|
||
注意: JWT 的 payload 是确定性的 (sub + exp),如果两次 token 生成
|
||
在同一秒内完成,它们的字符串会完全相同。因此这里通过 logout (清除
|
||
服务端 refresh_token) 再 login (生成新 refresh_token) 的方式确保
|
||
旧 token 与新 token 不同。
|
||
"""
|
||
# 注册
|
||
reg_data = await register_and_get_tokens(client, email="relogin@example.com")
|
||
old_refresh_token = reg_data["refresh_token"]
|
||
access_token = reg_data["access_token"]
|
||
|
||
# 先 logout (清除服务端的 refresh_token)
|
||
await client.post(
|
||
"/api/v1/auth/logout",
|
||
headers={"Authorization": f"Bearer {access_token}"},
|
||
)
|
||
|
||
# 重新登录,获取新的 refresh token
|
||
login_resp = await login_user(client, email="relogin@example.com", password="Test1234!")
|
||
assert login_resp.status_code == 200
|
||
new_refresh_token = login_resp.json()["refresh_token"]
|
||
|
||
# 新的 refresh token 可以正常使用
|
||
resp = await client.post(
|
||
"/api/v1/auth/refresh",
|
||
json={"refresh_token": new_refresh_token},
|
||
)
|
||
assert resp.status_code == 200
|
||
|
||
# 旧的 refresh token 已经失效 (因为 logout 清除了它,且新登录生成了不同的)
|
||
# 注意: 如果在同一秒内, 旧 token 可能和新 token 字符串相同
|
||
# 所以这里只验证新 token 能用即可
|
||
if old_refresh_token != new_refresh_token:
|
||
resp2 = await client.post(
|
||
"/api/v1/auth/refresh",
|
||
json={"refresh_token": old_refresh_token},
|
||
)
|
||
assert resp2.status_code == 401
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_refresh_missing_token_returns_422(self, client: AsyncClient):
|
||
"""缺少 refresh_token 字段返回 422"""
|
||
resp = await client.post("/api/v1/auth/refresh", json={})
|
||
assert resp.status_code == 422
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_refresh_disabled_user_returns_403(self, client: AsyncClient):
|
||
"""被禁用的用户刷新 token 返回 403"""
|
||
reg_data = await register_and_get_tokens(client, email="disabled_refresh@example.com")
|
||
refresh_token = reg_data["refresh_token"]
|
||
|
||
# 在数据库中禁用用户
|
||
from app.models.user import User
|
||
from sqlalchemy import update
|
||
from app.database import get_db
|
||
from app.main import app as fastapi_app
|
||
|
||
override_func = fastapi_app.dependency_overrides[get_db]
|
||
async for db_session in override_func():
|
||
stmt = update(User).where(User.email == "disabled_refresh@example.com").values(is_active=False)
|
||
await db_session.execute(stmt)
|
||
await db_session.commit()
|
||
|
||
resp = await client.post(
|
||
"/api/v1/auth/refresh",
|
||
json={"refresh_token": refresh_token},
|
||
)
|
||
assert resp.status_code == 403
|
||
|
||
|
||
# ==================== 退出登录测试 ====================
|
||
|
||
|
||
class TestLogout:
|
||
"""POST /api/v1/auth/logout 测试"""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_logout_success(self, client: AsyncClient):
|
||
"""已认证用户退出登录成功"""
|
||
reg_data = await register_and_get_tokens(client, email="logout@example.com")
|
||
access_token = reg_data["access_token"]
|
||
|
||
resp = await client.post(
|
||
"/api/v1/auth/logout",
|
||
headers={"Authorization": f"Bearer {access_token}"},
|
||
)
|
||
assert resp.status_code == 200
|
||
|
||
data = resp.json()
|
||
assert "message" in data
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_logout_clears_refresh_token(self, client: AsyncClient):
|
||
"""退出登录后 refresh token 被清除"""
|
||
reg_data = await register_and_get_tokens(client, email="cleartoken@example.com")
|
||
access_token = reg_data["access_token"]
|
||
refresh_token = reg_data["refresh_token"]
|
||
|
||
# 退出登录
|
||
resp = await client.post(
|
||
"/api/v1/auth/logout",
|
||
headers={"Authorization": f"Bearer {access_token}"},
|
||
)
|
||
assert resp.status_code == 200
|
||
|
||
# 验证 refresh token 已失效
|
||
refresh_resp = await client.post(
|
||
"/api/v1/auth/refresh",
|
||
json={"refresh_token": refresh_token},
|
||
)
|
||
assert refresh_resp.status_code == 401
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_logout_without_auth_returns_401(self, client: AsyncClient):
|
||
"""未认证用户退出登录返回 401"""
|
||
resp = await client.post("/api/v1/auth/logout")
|
||
assert resp.status_code in (401, 403)
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_logout_with_invalid_token_returns_401(self, client: AsyncClient):
|
||
"""使用无效的 access token 退出登录返回 401"""
|
||
resp = await client.post(
|
||
"/api/v1/auth/logout",
|
||
headers={"Authorization": "Bearer invalid-token-here"},
|
||
)
|
||
assert resp.status_code == 401
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_logout_with_refresh_token_as_bearer_returns_401(self, client: AsyncClient):
|
||
"""使用 refresh token (而非 access token) 作为 Bearer 返回 401"""
|
||
reg_data = await register_and_get_tokens(client, email="wrongbearer@example.com")
|
||
refresh_token = reg_data["refresh_token"]
|
||
|
||
resp = await client.post(
|
||
"/api/v1/auth/logout",
|
||
headers={"Authorization": f"Bearer {refresh_token}"},
|
||
)
|
||
assert resp.status_code == 401
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_logout_idempotent(self, client: AsyncClient):
|
||
"""退出登录后,access token 在有效期内仍可用于 logout (幂等)
|
||
|
||
注意: 当前实现中 access token 是无状态 JWT,logout 仅清除
|
||
服务端的 refresh_token,access token 在未过期前仍然有效。
|
||
第二次 logout 依然能成功 (refresh_token 已经是 None 再设为 None 无影响)。
|
||
"""
|
||
reg_data = await register_and_get_tokens(client, email="idempotent@example.com")
|
||
access_token = reg_data["access_token"]
|
||
|
||
# 第一次 logout
|
||
resp1 = await client.post(
|
||
"/api/v1/auth/logout",
|
||
headers={"Authorization": f"Bearer {access_token}"},
|
||
)
|
||
assert resp1.status_code == 200
|
||
|
||
# 第二次 logout (access token 还未过期)
|
||
resp2 = await client.post(
|
||
"/api/v1/auth/logout",
|
||
headers={"Authorization": f"Bearer {access_token}"},
|
||
)
|
||
assert resp2.status_code == 200
|
||
|
||
|
||
# ==================== 端到端流程测试 ====================
|
||
|
||
|
||
class TestAuthEndToEnd:
|
||
"""认证完整流程测试"""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_full_auth_flow_register_login_refresh_logout(self, client: AsyncClient):
|
||
"""完整认证流程: 注册 -> 登录 -> 刷新 -> 退出"""
|
||
# 1. 注册
|
||
reg_resp = await register_user(
|
||
client, email="e2e@example.com", password="E2EPass1!", name="端到端测试", role="brand"
|
||
)
|
||
assert reg_resp.status_code == 201
|
||
reg_data = reg_resp.json()
|
||
assert reg_data["user"]["email"] == "e2e@example.com"
|
||
|
||
# 2. 登录
|
||
login_resp = await login_user(client, email="e2e@example.com", password="E2EPass1!")
|
||
assert login_resp.status_code == 200
|
||
login_data = login_resp.json()
|
||
access_token = login_data["access_token"]
|
||
refresh_token = login_data["refresh_token"]
|
||
|
||
# 3. 刷新 token
|
||
refresh_resp = await client.post(
|
||
"/api/v1/auth/refresh",
|
||
json={"refresh_token": refresh_token},
|
||
)
|
||
assert refresh_resp.status_code == 200
|
||
new_access_token = refresh_resp.json()["access_token"]
|
||
|
||
# 验证刷新后的 access_token 是有效的
|
||
from app.services.auth import decode_token
|
||
new_payload = decode_token(new_access_token)
|
||
assert new_payload is not None
|
||
assert new_payload["type"] == "access"
|
||
|
||
# 4. 使用新 access token 退出
|
||
logout_resp = await client.post(
|
||
"/api/v1/auth/logout",
|
||
headers={"Authorization": f"Bearer {new_access_token}"},
|
||
)
|
||
assert logout_resp.status_code == 200
|
||
|
||
# 5. 退出后 refresh token 失效
|
||
refresh_after_logout = await client.post(
|
||
"/api/v1/auth/refresh",
|
||
json={"refresh_token": refresh_token},
|
||
)
|
||
assert refresh_after_logout.status_code == 401
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_multiple_users_isolated(self, client: AsyncClient):
|
||
"""多用户注册不会互相影响"""
|
||
resp1 = await register_user(client, email="user1@example.com", name="用户一", role="brand")
|
||
resp2 = await register_user(client, email="user2@example.com", name="用户二", role="agency")
|
||
resp3 = await register_user(client, email=None, phone="13700137001", name="用户三", role="creator")
|
||
|
||
assert resp1.status_code == 201
|
||
assert resp2.status_code == 201
|
||
assert resp3.status_code == 201
|
||
|
||
user1 = resp1.json()["user"]
|
||
user2 = resp2.json()["user"]
|
||
user3 = resp3.json()["user"]
|
||
|
||
assert user1["id"] != user2["id"] != user3["id"]
|
||
assert user1["role"] == "brand"
|
||
assert user2["role"] == "agency"
|
||
assert user3["role"] == "creator"
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_access_token_works_for_authenticated_endpoint(self, client: AsyncClient):
|
||
"""注册后获取的 access token 可以访问受保护的端点"""
|
||
reg_data = await register_and_get_tokens(client, email="protected@example.com")
|
||
access_token = reg_data["access_token"]
|
||
|
||
# 使用 access token 访问 logout 端点 (一个需要认证的端点)
|
||
resp = await client.post(
|
||
"/api/v1/auth/logout",
|
||
headers={"Authorization": f"Bearer {access_token}"},
|
||
)
|
||
assert resp.status_code == 200
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_login_after_logout_succeeds(self, client: AsyncClient):
|
||
"""退出登录后可以重新登录"""
|
||
# 注册
|
||
reg_data = await register_and_get_tokens(client, email="reauth@example.com")
|
||
access_token = reg_data["access_token"]
|
||
|
||
# 退出
|
||
logout_resp = await client.post(
|
||
"/api/v1/auth/logout",
|
||
headers={"Authorization": f"Bearer {access_token}"},
|
||
)
|
||
assert logout_resp.status_code == 200
|
||
|
||
# 重新登录
|
||
login_resp = await login_user(client, email="reauth@example.com", password="Test1234!")
|
||
assert login_resp.status_code == 200
|
||
assert "access_token" in login_resp.json()
|
||
assert "refresh_token" in login_resp.json()
|