- 新增 .gitlab-ci.yml (lint/test/build 三阶段) - 新增前端测试: taskStageMapper (109), api.ts (36), AuthContext (16) - 修复旧测试: Sidebar 导航文案、MobileLayout padding 值 - python-jose → PyJWT 消除 ecdsa CVE 漏洞 - 限流中间件增加 5 个敏感端点精细限流 + 标准限流头 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
218 lines
5.8 KiB
Python
218 lines
5.8 KiB
Python
"""
|
|
认证服务
|
|
"""
|
|
from datetime import datetime, timedelta
|
|
from typing import Optional
|
|
import secrets
|
|
import jwt
|
|
from jwt.exceptions import PyJWTError as JWTError
|
|
from passlib.context import CryptContext
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from sqlalchemy import select
|
|
|
|
from app.config import settings
|
|
from app.models.user import User, UserRole
|
|
from app.models.organization import Brand, Agency, Creator
|
|
|
|
# 密码加密上下文
|
|
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
|
|
|
|
|
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
|
"""验证密码"""
|
|
return pwd_context.verify(plain_password, hashed_password)
|
|
|
|
|
|
def hash_password(password: str) -> str:
|
|
"""哈希密码"""
|
|
return pwd_context.hash(password)
|
|
|
|
|
|
def generate_id(prefix: str) -> str:
|
|
"""生成语义化 ID"""
|
|
# 格式: BR123456, AG123456, CR123456
|
|
random_part = secrets.randbelow(900000) + 100000 # 100000-999999
|
|
return f"{prefix}{random_part}"
|
|
|
|
|
|
def create_access_token(user_id: str, expires_delta: Optional[timedelta] = None) -> str:
|
|
"""创建访问 Token"""
|
|
if expires_delta is None:
|
|
expires_delta = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
|
|
|
|
expire = datetime.utcnow() + expires_delta
|
|
to_encode = {
|
|
"sub": user_id,
|
|
"exp": expire,
|
|
"type": "access",
|
|
}
|
|
return jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
|
|
|
|
|
|
def create_refresh_token(user_id: str, expires_days: int = 7) -> tuple[str, datetime]:
|
|
"""创建刷新 Token"""
|
|
expire = datetime.utcnow() + timedelta(days=expires_days)
|
|
to_encode = {
|
|
"sub": user_id,
|
|
"exp": expire,
|
|
"type": "refresh",
|
|
}
|
|
token = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
|
|
return token, expire
|
|
|
|
|
|
def decode_token(token: str) -> Optional[dict]:
|
|
"""解码 Token"""
|
|
try:
|
|
payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM])
|
|
return payload
|
|
except JWTError:
|
|
return None
|
|
|
|
|
|
async def get_user_by_email(db: AsyncSession, email: str) -> Optional[User]:
|
|
"""通过邮箱获取用户"""
|
|
result = await db.execute(
|
|
select(User).where(User.email == email)
|
|
)
|
|
return result.scalar_one_or_none()
|
|
|
|
|
|
async def get_user_by_phone(db: AsyncSession, phone: str) -> Optional[User]:
|
|
"""通过手机号获取用户"""
|
|
result = await db.execute(
|
|
select(User).where(User.phone == phone)
|
|
)
|
|
return result.scalar_one_or_none()
|
|
|
|
|
|
async def get_user_by_id(db: AsyncSession, user_id: str) -> Optional[User]:
|
|
"""通过 ID 获取用户"""
|
|
result = await db.execute(
|
|
select(User).where(User.id == user_id)
|
|
)
|
|
return result.scalar_one_or_none()
|
|
|
|
|
|
async def create_user(
|
|
db: AsyncSession,
|
|
email: Optional[str],
|
|
phone: Optional[str],
|
|
password: str,
|
|
name: str,
|
|
role: UserRole,
|
|
is_verified: bool = False,
|
|
) -> User:
|
|
"""创建用户"""
|
|
user_id = generate_id("U")
|
|
|
|
user = User(
|
|
id=user_id,
|
|
email=email,
|
|
phone=phone,
|
|
password_hash=hash_password(password),
|
|
name=name,
|
|
role=role,
|
|
is_active=True,
|
|
is_verified=is_verified,
|
|
)
|
|
db.add(user)
|
|
|
|
# 根据角色创建对应的组织实体
|
|
if role == UserRole.BRAND:
|
|
brand = Brand(
|
|
id=generate_id("BR"),
|
|
user_id=user_id,
|
|
name=name,
|
|
)
|
|
db.add(brand)
|
|
elif role == UserRole.AGENCY:
|
|
agency = Agency(
|
|
id=generate_id("AG"),
|
|
user_id=user_id,
|
|
name=name,
|
|
)
|
|
db.add(agency)
|
|
elif role == UserRole.CREATOR:
|
|
creator = Creator(
|
|
id=generate_id("CR"),
|
|
user_id=user_id,
|
|
name=name,
|
|
)
|
|
db.add(creator)
|
|
|
|
await db.flush()
|
|
return user
|
|
|
|
|
|
async def authenticate_user(
|
|
db: AsyncSession,
|
|
email: Optional[str] = None,
|
|
phone: Optional[str] = None,
|
|
password: Optional[str] = None,
|
|
) -> Optional[User]:
|
|
"""验证用户登录"""
|
|
user = None
|
|
|
|
if email:
|
|
user = await get_user_by_email(db, email)
|
|
elif phone:
|
|
user = await get_user_by_phone(db, phone)
|
|
|
|
if not user:
|
|
return None
|
|
|
|
if password and not verify_password(password, user.password_hash):
|
|
return None
|
|
|
|
return user
|
|
|
|
|
|
async def update_refresh_token(db: AsyncSession, user: User, refresh_token: str, expires_at: datetime) -> None:
|
|
"""更新用户的刷新 Token"""
|
|
user.refresh_token = refresh_token
|
|
user.refresh_token_expires_at = expires_at
|
|
user.last_login_at = datetime.utcnow()
|
|
await db.flush()
|
|
|
|
|
|
async def get_user_organization_info(db: AsyncSession, user: User) -> dict:
|
|
"""获取用户的组织信息"""
|
|
info = {
|
|
"brand_id": None,
|
|
"agency_id": None,
|
|
"creator_id": None,
|
|
"tenant_id": None,
|
|
"tenant_name": None,
|
|
}
|
|
|
|
if user.role == UserRole.BRAND:
|
|
result = await db.execute(
|
|
select(Brand).where(Brand.user_id == user.id)
|
|
)
|
|
brand = result.scalar_one_or_none()
|
|
if brand:
|
|
info["brand_id"] = brand.id
|
|
info["tenant_id"] = brand.id
|
|
info["tenant_name"] = brand.name
|
|
|
|
elif user.role == UserRole.AGENCY:
|
|
result = await db.execute(
|
|
select(Agency).where(Agency.user_id == user.id)
|
|
)
|
|
agency = result.scalar_one_or_none()
|
|
if agency:
|
|
info["agency_id"] = agency.id
|
|
# 代理商可能服务多个品牌,这里暂时不设置 tenant
|
|
|
|
elif user.role == UserRole.CREATOR:
|
|
result = await db.execute(
|
|
select(Creator).where(Creator.user_id == user.id)
|
|
)
|
|
creator = result.scalar_one_or_none()
|
|
if creator:
|
|
info["creator_id"] = creator.id
|
|
# 达人可能服务多个代理商,这里暂时不设置 tenant
|
|
|
|
return info
|