""" 认证服务 """ from datetime import datetime, timedelta from typing import Optional import secrets from jose import jwt, JWTError from passlib.context import CryptContext from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import select from app.config import settings from app.models.user import User, UserRole from app.models.organization import Brand, Agency, Creator # 密码加密上下文 pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") def verify_password(plain_password: str, hashed_password: str) -> bool: """验证密码""" return pwd_context.verify(plain_password, hashed_password) def hash_password(password: str) -> str: """哈希密码""" return pwd_context.hash(password) def generate_id(prefix: str) -> str: """生成语义化 ID""" # 格式: BR123456, AG123456, CR123456 random_part = secrets.randbelow(900000) + 100000 # 100000-999999 return f"{prefix}{random_part}" def create_access_token(user_id: str, expires_delta: Optional[timedelta] = None) -> str: """创建访问 Token""" if expires_delta is None: expires_delta = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES) expire = datetime.utcnow() + expires_delta to_encode = { "sub": user_id, "exp": expire, "type": "access", } return jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM) def create_refresh_token(user_id: str, expires_days: int = 7) -> tuple[str, datetime]: """创建刷新 Token""" expire = datetime.utcnow() + timedelta(days=expires_days) to_encode = { "sub": user_id, "exp": expire, "type": "refresh", } token = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM) return token, expire def decode_token(token: str) -> Optional[dict]: """解码 Token""" try: payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]) return payload except JWTError: return None async def get_user_by_email(db: AsyncSession, email: str) -> Optional[User]: """通过邮箱获取用户""" result = await db.execute( select(User).where(User.email == email) ) return result.scalar_one_or_none() async def get_user_by_phone(db: AsyncSession, phone: str) -> Optional[User]: """通过手机号获取用户""" result = await db.execute( select(User).where(User.phone == phone) ) return result.scalar_one_or_none() async def get_user_by_id(db: AsyncSession, user_id: str) -> Optional[User]: """通过 ID 获取用户""" result = await db.execute( select(User).where(User.id == user_id) ) return result.scalar_one_or_none() async def create_user( db: AsyncSession, email: Optional[str], phone: Optional[str], password: str, name: str, role: UserRole, ) -> User: """创建用户""" user_id = generate_id("U") user = User( id=user_id, email=email, phone=phone, password_hash=hash_password(password), name=name, role=role, is_active=True, is_verified=False, ) db.add(user) # 根据角色创建对应的组织实体 if role == UserRole.BRAND: brand = Brand( id=generate_id("BR"), user_id=user_id, name=name, ) db.add(brand) elif role == UserRole.AGENCY: agency = Agency( id=generate_id("AG"), user_id=user_id, name=name, ) db.add(agency) elif role == UserRole.CREATOR: creator = Creator( id=generate_id("CR"), user_id=user_id, name=name, ) db.add(creator) await db.flush() return user async def authenticate_user( db: AsyncSession, email: Optional[str] = None, phone: Optional[str] = None, password: Optional[str] = None, ) -> Optional[User]: """验证用户登录""" user = None if email: user = await get_user_by_email(db, email) elif phone: user = await get_user_by_phone(db, phone) if not user: return None if password and not verify_password(password, user.password_hash): return None return user async def update_refresh_token(db: AsyncSession, user: User, refresh_token: str, expires_at: datetime) -> None: """更新用户的刷新 Token""" user.refresh_token = refresh_token user.refresh_token_expires_at = expires_at user.last_login_at = datetime.utcnow() await db.flush() async def get_user_organization_info(db: AsyncSession, user: User) -> dict: """获取用户的组织信息""" info = { "brand_id": None, "agency_id": None, "creator_id": None, "tenant_id": None, "tenant_name": None, } if user.role == UserRole.BRAND: result = await db.execute( select(Brand).where(Brand.user_id == user.id) ) brand = result.scalar_one_or_none() if brand: info["brand_id"] = brand.id info["tenant_id"] = brand.id info["tenant_name"] = brand.name elif user.role == UserRole.AGENCY: result = await db.execute( select(Agency).where(Agency.user_id == user.id) ) agency = result.scalar_one_or_none() if agency: info["agency_id"] = agency.id # 代理商可能服务多个品牌,这里暂时不设置 tenant elif user.role == UserRole.CREATOR: result = await db.execute( select(Creator).where(Creator.user_id == user.id) ) creator = result.scalar_one_or_none() if creator: info["creator_id"] = creator.id # 达人可能服务多个代理商,这里暂时不设置 tenant return info