diff --git a/backend/alembic/versions/003_user_org_project_task.py b/backend/alembic/versions/003_user_org_project_task.py new file mode 100644 index 0000000..655bc73 --- /dev/null +++ b/backend/alembic/versions/003_user_org_project_task.py @@ -0,0 +1,240 @@ +"""添加用户、组织、项目、任务表 + +Revision ID: 003 +Revises: 002 +Create Date: 2026-02-09 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision: str = '003' +down_revision: Union[str, None] = '002' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # 创建枚举类型 + user_role_enum = postgresql.ENUM( + 'brand', 'agency', 'creator', + name='user_role_enum' + ) + user_role_enum.create(op.get_bind(), checkfirst=True) + + task_stage_enum = postgresql.ENUM( + 'script_upload', 'script_ai_review', 'script_agency_review', 'script_brand_review', + 'video_upload', 'video_ai_review', 'video_agency_review', 'video_brand_review', + 'completed', 'rejected', + name='task_stage_enum' + ) + task_stage_enum.create(op.get_bind(), checkfirst=True) + + # 用户表 + op.create_table( + 'users', + sa.Column('id', sa.String(64), primary_key=True), + sa.Column('email', sa.String(255), unique=True, nullable=True, index=True), + sa.Column('phone', sa.String(20), unique=True, nullable=True, index=True), + sa.Column('password_hash', sa.String(255), nullable=False), + sa.Column('name', sa.String(100), nullable=False), + sa.Column('avatar', sa.String(2048), nullable=True), + sa.Column('role', postgresql.ENUM('brand', 'agency', 'creator', name='user_role_enum', create_type=False), nullable=False, index=True), + sa.Column('is_active', sa.Boolean(), default=True, nullable=False), + sa.Column('is_verified', sa.Boolean(), default=False, nullable=False), + sa.Column('last_login_at', sa.DateTime(timezone=True), nullable=True), + sa.Column('refresh_token', sa.String(512), nullable=True), + sa.Column('refresh_token_expires_at', sa.DateTime(timezone=True), nullable=True), + sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False), + sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.func.now(), onupdate=sa.func.now(), nullable=False), + ) + + # 品牌方表 + op.create_table( + 'brands', + sa.Column('id', sa.String(64), primary_key=True), + sa.Column('user_id', sa.String(64), sa.ForeignKey('users.id', ondelete='CASCADE'), unique=True, nullable=False), + sa.Column('name', sa.String(255), nullable=False), + sa.Column('logo', sa.String(2048), nullable=True), + sa.Column('description', sa.Text(), nullable=True), + sa.Column('contact_name', sa.String(100), nullable=True), + sa.Column('contact_phone', sa.String(20), nullable=True), + sa.Column('contact_email', sa.String(255), nullable=True), + sa.Column('final_review_enabled', sa.Boolean(), default=True, nullable=False), + sa.Column('is_active', sa.Boolean(), default=True, nullable=False), + sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False), + sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.func.now(), onupdate=sa.func.now(), nullable=False), + ) + + # 代理商表 + op.create_table( + 'agencies', + sa.Column('id', sa.String(64), primary_key=True), + sa.Column('user_id', sa.String(64), sa.ForeignKey('users.id', ondelete='CASCADE'), unique=True, nullable=False), + sa.Column('name', sa.String(255), nullable=False), + sa.Column('logo', sa.String(2048), nullable=True), + sa.Column('description', sa.Text(), nullable=True), + sa.Column('contact_name', sa.String(100), nullable=True), + sa.Column('contact_phone', sa.String(20), nullable=True), + sa.Column('contact_email', sa.String(255), nullable=True), + sa.Column('force_pass_enabled', sa.Boolean(), default=True, nullable=False), + sa.Column('is_active', sa.Boolean(), default=True, nullable=False), + sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False), + sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.func.now(), onupdate=sa.func.now(), nullable=False), + ) + + # 达人表 + op.create_table( + 'creators', + sa.Column('id', sa.String(64), primary_key=True), + sa.Column('user_id', sa.String(64), sa.ForeignKey('users.id', ondelete='CASCADE'), unique=True, nullable=False), + sa.Column('name', sa.String(255), nullable=False), + sa.Column('avatar', sa.String(2048), nullable=True), + sa.Column('bio', sa.Text(), nullable=True), + sa.Column('douyin_account', sa.String(100), nullable=True), + sa.Column('xiaohongshu_account', sa.String(100), nullable=True), + sa.Column('bilibili_account', sa.String(100), nullable=True), + sa.Column('is_active', sa.Boolean(), default=True, nullable=False), + sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False), + sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.func.now(), onupdate=sa.func.now(), nullable=False), + ) + + # 品牌方-代理商关联表 + op.create_table( + 'brand_agency', + sa.Column('brand_id', sa.String(64), sa.ForeignKey('brands.id', ondelete='CASCADE'), primary_key=True), + sa.Column('agency_id', sa.String(64), sa.ForeignKey('agencies.id', ondelete='CASCADE'), primary_key=True), + sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.func.now()), + sa.Column('is_active', sa.Boolean(), default=True), + ) + + # 代理商-达人关联表 + op.create_table( + 'agency_creator', + sa.Column('agency_id', sa.String(64), sa.ForeignKey('agencies.id', ondelete='CASCADE'), primary_key=True), + sa.Column('creator_id', sa.String(64), sa.ForeignKey('creators.id', ondelete='CASCADE'), primary_key=True), + sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.func.now()), + sa.Column('is_active', sa.Boolean(), default=True), + ) + + # 项目表 + op.create_table( + 'projects', + sa.Column('id', sa.String(64), primary_key=True), + sa.Column('brand_id', sa.String(64), sa.ForeignKey('brands.id', ondelete='CASCADE'), nullable=False, index=True), + sa.Column('name', sa.String(255), nullable=False), + sa.Column('description', sa.Text(), nullable=True), + sa.Column('start_date', sa.DateTime(timezone=True), nullable=True), + sa.Column('deadline', sa.DateTime(timezone=True), nullable=True), + sa.Column('status', sa.String(20), default='active', nullable=False, index=True), + sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False), + sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.func.now(), onupdate=sa.func.now(), nullable=False), + ) + + # 项目-代理商关联表 + op.create_table( + 'project_agency', + sa.Column('project_id', sa.String(64), sa.ForeignKey('projects.id', ondelete='CASCADE'), primary_key=True), + sa.Column('agency_id', sa.String(64), sa.ForeignKey('agencies.id', ondelete='CASCADE'), primary_key=True), + sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.func.now()), + sa.Column('is_active', sa.Boolean(), default=True), + ) + + # Brief 表 + op.create_table( + 'briefs', + sa.Column('id', sa.String(64), primary_key=True), + sa.Column('project_id', sa.String(64), sa.ForeignKey('projects.id', ondelete='CASCADE'), unique=True, nullable=False, index=True), + sa.Column('file_url', sa.String(2048), nullable=True), + sa.Column('file_name', sa.String(255), nullable=True), + sa.Column('selling_points', postgresql.JSON(), nullable=True), + sa.Column('blacklist_words', postgresql.JSON(), nullable=True), + sa.Column('competitors', postgresql.JSON(), nullable=True), + sa.Column('brand_tone', sa.Text(), nullable=True), + sa.Column('min_duration', sa.Integer(), nullable=True), + sa.Column('max_duration', sa.Integer(), nullable=True), + sa.Column('other_requirements', sa.Text(), nullable=True), + sa.Column('attachments', postgresql.JSON(), nullable=True), + sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False), + sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.func.now(), onupdate=sa.func.now(), nullable=False), + ) + + # 任务表 + op.create_table( + 'tasks', + sa.Column('id', sa.String(64), primary_key=True), + sa.Column('project_id', sa.String(64), sa.ForeignKey('projects.id', ondelete='CASCADE'), nullable=False, index=True), + sa.Column('agency_id', sa.String(64), sa.ForeignKey('agencies.id', ondelete='CASCADE'), nullable=False, index=True), + sa.Column('creator_id', sa.String(64), sa.ForeignKey('creators.id', ondelete='CASCADE'), nullable=False, index=True), + sa.Column('name', sa.String(255), nullable=False), + sa.Column('sequence', sa.Integer(), default=1, nullable=False), + sa.Column('stage', postgresql.ENUM( + 'script_upload', 'script_ai_review', 'script_agency_review', 'script_brand_review', + 'video_upload', 'video_ai_review', 'video_agency_review', 'video_brand_review', + 'completed', 'rejected', + name='task_stage_enum', create_type=False + ), default='script_upload', nullable=False, index=True), + + # 脚本相关 + sa.Column('script_file_url', sa.String(2048), nullable=True), + sa.Column('script_file_name', sa.String(255), nullable=True), + sa.Column('script_uploaded_at', sa.DateTime(timezone=True), nullable=True), + sa.Column('script_ai_score', sa.Integer(), nullable=True), + sa.Column('script_ai_result', postgresql.JSON(), nullable=True), + sa.Column('script_ai_reviewed_at', sa.DateTime(timezone=True), nullable=True), + sa.Column('script_agency_status', postgresql.ENUM('pending', 'processing', 'passed', 'rejected', 'force_passed', name='task_status_enum', create_type=False), nullable=True), + sa.Column('script_agency_comment', sa.Text(), nullable=True), + sa.Column('script_agency_reviewer_id', sa.String(64), nullable=True), + sa.Column('script_agency_reviewed_at', sa.DateTime(timezone=True), nullable=True), + sa.Column('script_brand_status', postgresql.ENUM('pending', 'processing', 'passed', 'rejected', 'force_passed', name='task_status_enum', create_type=False), nullable=True), + sa.Column('script_brand_comment', sa.Text(), nullable=True), + sa.Column('script_brand_reviewer_id', sa.String(64), nullable=True), + sa.Column('script_brand_reviewed_at', sa.DateTime(timezone=True), nullable=True), + + # 视频相关 + sa.Column('video_file_url', sa.String(2048), nullable=True), + sa.Column('video_file_name', sa.String(255), nullable=True), + sa.Column('video_duration', sa.Integer(), nullable=True), + sa.Column('video_thumbnail_url', sa.String(2048), nullable=True), + sa.Column('video_uploaded_at', sa.DateTime(timezone=True), nullable=True), + sa.Column('video_ai_score', sa.Integer(), nullable=True), + sa.Column('video_ai_result', postgresql.JSON(), nullable=True), + sa.Column('video_ai_reviewed_at', sa.DateTime(timezone=True), nullable=True), + sa.Column('video_agency_status', postgresql.ENUM('pending', 'processing', 'passed', 'rejected', 'force_passed', name='task_status_enum', create_type=False), nullable=True), + sa.Column('video_agency_comment', sa.Text(), nullable=True), + sa.Column('video_agency_reviewer_id', sa.String(64), nullable=True), + sa.Column('video_agency_reviewed_at', sa.DateTime(timezone=True), nullable=True), + sa.Column('video_brand_status', postgresql.ENUM('pending', 'processing', 'passed', 'rejected', 'force_passed', name='task_status_enum', create_type=False), nullable=True), + sa.Column('video_brand_comment', sa.Text(), nullable=True), + sa.Column('video_brand_reviewer_id', sa.String(64), nullable=True), + sa.Column('video_brand_reviewed_at', sa.DateTime(timezone=True), nullable=True), + + # 申诉相关 + sa.Column('appeal_count', sa.Integer(), default=1, nullable=False), + sa.Column('is_appeal', sa.Boolean(), default=False, nullable=False), + sa.Column('appeal_reason', sa.Text(), nullable=True), + + sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False), + sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.func.now(), onupdate=sa.func.now(), nullable=False), + ) + + +def downgrade() -> None: + op.drop_table('tasks') + op.drop_table('briefs') + op.drop_table('project_agency') + op.drop_table('projects') + op.drop_table('agency_creator') + op.drop_table('brand_agency') + op.drop_table('creators') + op.drop_table('agencies') + op.drop_table('brands') + op.drop_table('users') + + # 删除枚举类型 + op.execute("DROP TYPE IF EXISTS task_stage_enum") + op.execute("DROP TYPE IF EXISTS user_role_enum") diff --git a/backend/app/api/auth.py b/backend/app/api/auth.py new file mode 100644 index 0000000..4bf1b83 --- /dev/null +++ b/backend/app/api/auth.py @@ -0,0 +1,246 @@ +""" +认证 API +""" +from fastapi import APIRouter, Depends, HTTPException, status +from sqlalchemy.ext.asyncio import AsyncSession + +from app.database import get_db +from app.schemas.auth import ( + RegisterRequest, + LoginRequest, + LoginResponse, + RefreshTokenRequest, + RefreshTokenResponse, + UserResponse, +) +from app.services.auth import ( + get_user_by_email, + get_user_by_phone, + get_user_by_id, + create_user, + authenticate_user, + create_access_token, + create_refresh_token, + update_refresh_token, + decode_token, + get_user_organization_info, +) + +router = APIRouter(prefix="/auth", tags=["认证"]) + + +@router.post("/register", response_model=LoginResponse, status_code=status.HTTP_201_CREATED) +async def register( + request: RegisterRequest, + db: AsyncSession = Depends(get_db), +): + """ + 用户注册 + + - 支持邮箱或手机号注册(至少提供一个) + - 注册后自动登录,返回 Token + """ + # 验证至少提供邮箱或手机号 + if not request.email and not request.phone: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="请提供邮箱或手机号", + ) + + # 检查邮箱是否已存在 + if request.email: + existing = await get_user_by_email(db, request.email) + if existing: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="该邮箱已被注册", + ) + + # 检查手机号是否已存在 + if request.phone: + existing = await get_user_by_phone(db, request.phone) + if existing: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="该手机号已被注册", + ) + + # 创建用户 + user = await create_user( + db=db, + email=request.email, + phone=request.phone, + password=request.password, + name=request.name, + role=request.role, + ) + + # 生成 Token + access_token = create_access_token(user.id) + refresh_token, refresh_expires_at = create_refresh_token(user.id) + + # 保存 refresh token + await update_refresh_token(db, user, refresh_token, refresh_expires_at) + await db.commit() + + # 获取组织信息 + org_info = await get_user_organization_info(db, user) + + return LoginResponse( + access_token=access_token, + refresh_token=refresh_token, + user=UserResponse( + id=user.id, + email=user.email, + phone=user.phone, + name=user.name, + avatar=user.avatar, + role=user.role, + is_verified=user.is_verified, + **org_info, + ), + ) + + +@router.post("/login", response_model=LoginResponse) +async def login( + request: LoginRequest, + db: AsyncSession = Depends(get_db), +): + """ + 用户登录 + + - 支持邮箱+密码 或 手机号+密码 登录 + - 返回 accessToken 和 refreshToken + """ + # 验证请求参数 + if not request.email and not request.phone: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="请提供邮箱或手机号", + ) + + if not request.password: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="请提供密码", + ) + + # 验证用户 + user = await authenticate_user( + db=db, + email=request.email, + phone=request.phone, + password=request.password, + ) + + if not user: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="邮箱/手机号或密码错误", + ) + + if not user.is_active: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="账号已被禁用", + ) + + # 生成 Token + access_token = create_access_token(user.id) + refresh_token, refresh_expires_at = create_refresh_token(user.id) + + # 保存 refresh token + await update_refresh_token(db, user, refresh_token, refresh_expires_at) + await db.commit() + + # 获取组织信息 + org_info = await get_user_organization_info(db, user) + + return LoginResponse( + access_token=access_token, + refresh_token=refresh_token, + user=UserResponse( + id=user.id, + email=user.email, + phone=user.phone, + name=user.name, + avatar=user.avatar, + role=user.role, + is_verified=user.is_verified, + **org_info, + ), + ) + + +@router.post("/refresh", response_model=RefreshTokenResponse) +async def refresh_token( + request: RefreshTokenRequest, + db: AsyncSession = Depends(get_db), +): + """ + 刷新 Access Token + + - 使用 refreshToken 获取新的 accessToken + - refreshToken 有效期 7 天 + """ + # 解码 refresh token + payload = decode_token(request.refresh_token) + if not payload: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="无效的 refresh token", + ) + + if payload.get("type") != "refresh": + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="无效的 token 类型", + ) + + user_id = payload.get("sub") + if not user_id: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="无效的 token", + ) + + # 获取用户 + user = await get_user_by_id(db, user_id) + if not user: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="用户不存在", + ) + + # 验证 refresh token 是否匹配 + if user.refresh_token != request.refresh_token: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="refresh token 已失效", + ) + + if not user.is_active: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="账号已被禁用", + ) + + # 生成新的 access token + access_token = create_access_token(user.id) + + return RefreshTokenResponse(access_token=access_token) + + +@router.post("/logout") +async def logout( + db: AsyncSession = Depends(get_db), + # TODO: 添加认证依赖 +): + """ + 退出登录 + + - 清除 refresh token + """ + # TODO: 实现退出登录 + return {"message": "已退出登录"} diff --git a/backend/app/api/upload.py b/backend/app/api/upload.py new file mode 100644 index 0000000..d78ab6a --- /dev/null +++ b/backend/app/api/upload.py @@ -0,0 +1,117 @@ +""" +文件上传 API +""" +from fastapi import APIRouter, HTTPException, status +from pydantic import BaseModel +from typing import Optional +from datetime import datetime + +from app.services.oss import generate_upload_policy, get_file_url +from app.config import settings + +router = APIRouter(prefix="/upload", tags=["文件上传"]) + + +class UploadPolicyRequest(BaseModel): + """获取上传凭证请求""" + file_type: str = "general" # script, video, image, general + file_name: Optional[str] = None + + +class UploadPolicyResponse(BaseModel): + """上传凭证响应""" + access_key_id: str + policy: str + signature: str + host: str + dir: str + expire: int + max_size_mb: int + + +class FileUploadedRequest(BaseModel): + """文件上传完成回调""" + file_key: str + file_name: str + file_size: int + file_type: str + + +class FileUploadedResponse(BaseModel): + """文件上传完成响应""" + url: str + file_key: str + file_name: str + file_size: int + file_type: str + + +@router.post("/policy", response_model=UploadPolicyResponse) +async def get_upload_policy( + request: UploadPolicyRequest, +): + """ + 获取 OSS 直传凭证 + + 前端使用此凭证直接上传文件到阿里云 OSS,无需经过后端。 + + 文件类型说明: + - script: 脚本文档 (docx, pdf, xlsx, txt, pptx) + - video: 视频文件 (mp4, mov, webm) + - image: 图片文件 (jpg, png, gif) + - general: 通用文件 + """ + # 根据文件类型设置上传目录 + now = datetime.now() + base_dir = f"uploads/{now.year}/{now.month:02d}" + + if request.file_type == "script": + upload_dir = f"{base_dir}/scripts/" + elif request.file_type == "video": + upload_dir = f"{base_dir}/videos/" + elif request.file_type == "image": + upload_dir = f"{base_dir}/images/" + else: + upload_dir = f"{base_dir}/files/" + + try: + policy = generate_upload_policy( + max_size_mb=settings.MAX_FILE_SIZE_MB, + expire_seconds=3600, + upload_dir=upload_dir, + ) + except ValueError as e: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=str(e), + ) + + return UploadPolicyResponse( + access_key_id=policy["accessKeyId"], + policy=policy["policy"], + signature=policy["signature"], + host=policy["host"], + dir=policy["dir"], + expire=policy["expire"], + max_size_mb=settings.MAX_FILE_SIZE_MB, + ) + + +@router.post("/complete", response_model=FileUploadedResponse) +async def file_uploaded( + request: FileUploadedRequest, +): + """ + 文件上传完成回调 + + 前端上传完成后调用此接口,获取文件的完整 URL。 + """ + url = get_file_url(request.file_key) + + return FileUploadedResponse( + url=url, + file_key=request.file_key, + file_name=request.file_name, + file_size=request.file_size, + file_type=request.file_type, + ) diff --git a/backend/app/config.py b/backend/app/config.py index 6641122..18e34e9 100644 --- a/backend/app/config.py +++ b/backend/app/config.py @@ -27,6 +27,16 @@ class Settings(BaseSettings): AI_API_KEY: str = "" # 中转服务商的 API Key AI_API_BASE_URL: str = "" # 中转服务商的 Base URL,如 https://api.oneinall.ai/v1 + # 阿里云 OSS 配置 + OSS_ACCESS_KEY_ID: str = "" + OSS_ACCESS_KEY_SECRET: str = "" + OSS_ENDPOINT: str = "oss-cn-hangzhou.aliyuncs.com" + OSS_BUCKET_NAME: str = "miaosi-files" + OSS_BUCKET_DOMAIN: str = "" # 公开访问域名,如 https://miaosi-files.oss-cn-hangzhou.aliyuncs.com + + # 文件上传限制 + MAX_FILE_SIZE_MB: int = 500 # 最大文件大小 500MB + class Config: env_file = ".env" case_sensitive = True diff --git a/backend/app/database.py b/backend/app/database.py index cffbb16..1696732 100644 --- a/backend/app/database.py +++ b/backend/app/database.py @@ -7,14 +7,30 @@ from app.config import settings # 导入所有模型,确保在创建表时被注册 from app.models.base import Base from app.models import ( - Tenant, + # 用户与组织 + User, + UserRole, + Brand, + Agency, + Creator, + # 项目与任务 + Project, + Task, + TaskStage, + TaskStatus, + Brief, + # AI 配置 AIConfig, + # 审核(旧模型) ReviewTask, ManualTask, + # 规则 ForbiddenWord, WhitelistItem, Competitor, RiskException, + # 兼容 + Tenant, ) # 创建异步引擎 @@ -65,12 +81,28 @@ __all__ = [ "get_db", "init_db", "drop_db", - "Tenant", + # 用户与组织 + "User", + "UserRole", + "Brand", + "Agency", + "Creator", + # 项目与任务 + "Project", + "Task", + "TaskStage", + "TaskStatus", + "Brief", + # AI 配置 "AIConfig", + # 审核 "ReviewTask", "ManualTask", + # 规则 "ForbiddenWord", "WhitelistItem", "Competitor", "RiskException", + # 兼容 + "Tenant", ] diff --git a/backend/app/main.py b/backend/app/main.py index c76e66c..88013f9 100644 --- a/backend/app/main.py +++ b/backend/app/main.py @@ -2,7 +2,7 @@ from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from app.config import settings -from app.api import health, scripts, videos, tasks, rules, ai_config, risk_exceptions, metrics +from app.api import health, auth, upload, scripts, videos, tasks, rules, ai_config, risk_exceptions, metrics # 创建应用 app = FastAPI( @@ -24,6 +24,8 @@ app.add_middleware( # 注册路由 app.include_router(health.router, prefix="/api/v1") +app.include_router(auth.router, prefix="/api/v1") +app.include_router(upload.router, prefix="/api/v1") app.include_router(scripts.router, prefix="/api/v1") app.include_router(videos.router, prefix="/api/v1") app.include_router(tasks.router, prefix="/api/v1") diff --git a/backend/app/models/__init__.py b/backend/app/models/__init__.py index 6a9afba..1e17452 100644 --- a/backend/app/models/__init__.py +++ b/backend/app/models/__init__.py @@ -3,21 +3,49 @@ 导出所有 ORM 模型 """ from app.models.base import Base, TimestampMixin -from app.models.tenant import Tenant +from app.models.user import User, UserRole +from app.models.organization import Brand, Agency, Creator, brand_agency_association, agency_creator_association +from app.models.project import Project, project_agency_association +from app.models.task import Task, TaskStage, TaskStatus +from app.models.brief import Brief from app.models.ai_config import AIConfig -from app.models.review import ReviewTask, ManualTask +from app.models.review import ReviewTask, ManualTask, Platform from app.models.rule import ForbiddenWord, WhitelistItem, Competitor from app.models.risk_exception import RiskException +# 保留 Tenant 兼容旧代码,但新代码应使用 Brand +from app.models.tenant import Tenant + __all__ = [ + # Base "Base", "TimestampMixin", - "Tenant", + # 用户与组织 + "User", + "UserRole", + "Brand", + "Agency", + "Creator", + "brand_agency_association", + "agency_creator_association", + # 项目与任务 + "Project", + "project_agency_association", + "Task", + "TaskStage", + "TaskStatus", + "Brief", + # AI 配置 "AIConfig", + # 审核(旧模型,保留兼容) "ReviewTask", "ManualTask", + "Platform", + # 规则 "ForbiddenWord", "WhitelistItem", "Competitor", "RiskException", + # 兼容 + "Tenant", ] diff --git a/backend/app/models/brief.py b/backend/app/models/brief.py new file mode 100644 index 0000000..3f11ffd --- /dev/null +++ b/backend/app/models/brief.py @@ -0,0 +1,60 @@ +""" +Brief 模型 +""" +from typing import TYPE_CHECKING, Optional +from sqlalchemy import String, Text, ForeignKey +from sqlalchemy.orm import Mapped, mapped_column, relationship + +from app.models.base import Base, TimestampMixin +from app.models.types import JSONType + +if TYPE_CHECKING: + from app.models.project import Project + + +class Brief(Base, TimestampMixin): + """Brief 文档表""" + __tablename__ = "briefs" + + id: Mapped[str] = mapped_column(String(64), primary_key=True) + project_id: Mapped[str] = mapped_column( + String(64), + ForeignKey("projects.id", ondelete="CASCADE"), + unique=True, + nullable=False, + index=True, + ) + + # 原始文件 + file_url: Mapped[Optional[str]] = mapped_column(String(2048), nullable=True) + file_name: Mapped[Optional[str]] = mapped_column(String(255), nullable=True) + + # 解析后的结构化内容 + # 卖点要求: [{"content": "SPF50+", "required": true}, ...] + selling_points: Mapped[Optional[list]] = mapped_column(JSONType, nullable=True) + + # 违禁词: [{"word": "最好", "reason": "绝对化用语"}, ...] + blacklist_words: Mapped[Optional[list]] = mapped_column(JSONType, nullable=True) + + # 竞品: ["竞品A", "竞品B", ...] + competitors: Mapped[Optional[list]] = mapped_column(JSONType, nullable=True) + + # 品牌调性要求 + brand_tone: Mapped[Optional[str]] = mapped_column(Text, nullable=True) + + # 时长要求(秒) + min_duration: Mapped[Optional[int]] = mapped_column(nullable=True) + max_duration: Mapped[Optional[int]] = mapped_column(nullable=True) + + # 其他要求(自由文本) + other_requirements: Mapped[Optional[str]] = mapped_column(Text, nullable=True) + + # 附件文档(代理商上传的参考资料) + # [{"id": "af1", "name": "达人拍摄指南.pdf", "url": "...", "size": "1.5MB"}, ...] + attachments: Mapped[Optional[list]] = mapped_column(JSONType, nullable=True) + + # 关联 + project: Mapped["Project"] = relationship("Project", back_populates="brief") + + def __repr__(self) -> str: + return f"" diff --git a/backend/app/models/organization.py b/backend/app/models/organization.py new file mode 100644 index 0000000..0079f4c --- /dev/null +++ b/backend/app/models/organization.py @@ -0,0 +1,160 @@ +""" +组织模型:品牌方、代理商、达人 +""" +from typing import TYPE_CHECKING, Optional +from datetime import datetime +from sqlalchemy import String, Boolean, Text, ForeignKey, DateTime, Table, Column +from sqlalchemy.orm import Mapped, mapped_column, relationship + +from app.models.base import Base, TimestampMixin + +if TYPE_CHECKING: + from app.models.user import User + from app.models.project import Project + + +# 品牌方-代理商 关联表(多对多) +brand_agency_association = Table( + "brand_agency", + Base.metadata, + Column("brand_id", String(64), ForeignKey("brands.id", ondelete="CASCADE"), primary_key=True), + Column("agency_id", String(64), ForeignKey("agencies.id", ondelete="CASCADE"), primary_key=True), + Column("created_at", DateTime(timezone=True), default=datetime.utcnow), + Column("is_active", Boolean, default=True), +) + +# 代理商-达人 关联表(多对多) +agency_creator_association = Table( + "agency_creator", + Base.metadata, + Column("agency_id", String(64), ForeignKey("agencies.id", ondelete="CASCADE"), primary_key=True), + Column("creator_id", String(64), ForeignKey("creators.id", ondelete="CASCADE"), primary_key=True), + Column("created_at", DateTime(timezone=True), default=datetime.utcnow), + Column("is_active", Boolean, default=True), +) + + +class Brand(Base, TimestampMixin): + """品牌方表(即租户)""" + __tablename__ = "brands" + + id: Mapped[str] = mapped_column(String(64), primary_key=True) # 格式: BR123456 + user_id: Mapped[str] = mapped_column( + String(64), + ForeignKey("users.id", ondelete="CASCADE"), + unique=True, + nullable=False, + ) + + # 品牌信息 + name: Mapped[str] = mapped_column(String(255), nullable=False) + logo: Mapped[Optional[str]] = mapped_column(String(2048), nullable=True) + description: Mapped[Optional[str]] = mapped_column(Text, nullable=True) + + # 联系信息 + contact_name: Mapped[Optional[str]] = mapped_column(String(100), nullable=True) + contact_phone: Mapped[Optional[str]] = mapped_column(String(20), nullable=True) + contact_email: Mapped[Optional[str]] = mapped_column(String(255), nullable=True) + + # 设置 + final_review_enabled: Mapped[bool] = mapped_column(Boolean, default=True, nullable=False) # 终审开关 + + # 状态 + is_active: Mapped[bool] = mapped_column(Boolean, default=True, nullable=False) + + # 关联 + user: Mapped["User"] = relationship("User", back_populates="brand") + agencies: Mapped[list["Agency"]] = relationship( + "Agency", + secondary=brand_agency_association, + back_populates="brands", + ) + projects: Mapped[list["Project"]] = relationship( + "Project", + back_populates="brand", + ) + + def __repr__(self) -> str: + return f"" + + +class Agency(Base, TimestampMixin): + """代理商表""" + __tablename__ = "agencies" + + id: Mapped[str] = mapped_column(String(64), primary_key=True) # 格式: AG123456 + user_id: Mapped[str] = mapped_column( + String(64), + ForeignKey("users.id", ondelete="CASCADE"), + unique=True, + nullable=False, + ) + + # 代理商信息 + name: Mapped[str] = mapped_column(String(255), nullable=False) + logo: Mapped[Optional[str]] = mapped_column(String(2048), nullable=True) + description: Mapped[Optional[str]] = mapped_column(Text, nullable=True) + + # 联系信息 + contact_name: Mapped[Optional[str]] = mapped_column(String(100), nullable=True) + contact_phone: Mapped[Optional[str]] = mapped_column(String(20), nullable=True) + contact_email: Mapped[Optional[str]] = mapped_column(String(255), nullable=True) + + # 权限设置(可被品牌方覆盖) + force_pass_enabled: Mapped[bool] = mapped_column(Boolean, default=True, nullable=False) # 强制通过权 + + # 状态 + is_active: Mapped[bool] = mapped_column(Boolean, default=True, nullable=False) + + # 关联 + user: Mapped["User"] = relationship("User", back_populates="agency") + brands: Mapped[list["Brand"]] = relationship( + "Brand", + secondary=brand_agency_association, + back_populates="agencies", + ) + creators: Mapped[list["Creator"]] = relationship( + "Creator", + secondary=agency_creator_association, + back_populates="agencies", + ) + + def __repr__(self) -> str: + return f"" + + +class Creator(Base, TimestampMixin): + """达人表""" + __tablename__ = "creators" + + id: Mapped[str] = mapped_column(String(64), primary_key=True) # 格式: CR123456 + user_id: Mapped[str] = mapped_column( + String(64), + ForeignKey("users.id", ondelete="CASCADE"), + unique=True, + nullable=False, + ) + + # 达人信息 + name: Mapped[str] = mapped_column(String(255), nullable=False) + avatar: Mapped[Optional[str]] = mapped_column(String(2048), nullable=True) + bio: Mapped[Optional[str]] = mapped_column(Text, nullable=True) + + # 社交账号 + douyin_account: Mapped[Optional[str]] = mapped_column(String(100), nullable=True) + xiaohongshu_account: Mapped[Optional[str]] = mapped_column(String(100), nullable=True) + bilibili_account: Mapped[Optional[str]] = mapped_column(String(100), nullable=True) + + # 状态 + is_active: Mapped[bool] = mapped_column(Boolean, default=True, nullable=False) + + # 关联 + user: Mapped["User"] = relationship("User", back_populates="creator") + agencies: Mapped[list["Agency"]] = relationship( + "Agency", + secondary=agency_creator_association, + back_populates="creators", + ) + + def __repr__(self) -> str: + return f"" diff --git a/backend/app/models/project.py b/backend/app/models/project.py new file mode 100644 index 0000000..0e8ff83 --- /dev/null +++ b/backend/app/models/project.py @@ -0,0 +1,74 @@ +""" +项目模型 +""" +from typing import TYPE_CHECKING, Optional +from datetime import datetime +from sqlalchemy import String, Text, ForeignKey, DateTime, Table, Column, Boolean +from sqlalchemy.orm import Mapped, mapped_column, relationship + +from app.models.base import Base, TimestampMixin + +if TYPE_CHECKING: + from app.models.organization import Brand, Agency + from app.models.task import Task + from app.models.brief import Brief + + +# 项目-代理商 关联表(一个项目可以分配给多个代理商) +project_agency_association = Table( + "project_agency", + Base.metadata, + Column("project_id", String(64), ForeignKey("projects.id", ondelete="CASCADE"), primary_key=True), + Column("agency_id", String(64), ForeignKey("agencies.id", ondelete="CASCADE"), primary_key=True), + Column("created_at", DateTime(timezone=True), default=datetime.utcnow), + Column("is_active", Boolean, default=True), +) + + +class Project(Base, TimestampMixin): + """项目表""" + __tablename__ = "projects" + + id: Mapped[str] = mapped_column(String(64), primary_key=True) + brand_id: Mapped[str] = mapped_column( + String(64), + ForeignKey("brands.id", ondelete="CASCADE"), + nullable=False, + index=True, + ) + + # 项目信息 + name: Mapped[str] = mapped_column(String(255), nullable=False) + description: Mapped[Optional[str]] = mapped_column(Text, nullable=True) + + # 时间 + start_date: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True) + deadline: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True) + + # 状态 + status: Mapped[str] = mapped_column( + String(20), + default="active", # active, completed, archived + nullable=False, + index=True, + ) + + # 关联 + brand: Mapped["Brand"] = relationship("Brand", back_populates="projects") + agencies: Mapped[list["Agency"]] = relationship( + "Agency", + secondary=project_agency_association, + backref="projects", + ) + tasks: Mapped[list["Task"]] = relationship( + "Task", + back_populates="project", + ) + brief: Mapped[Optional["Brief"]] = relationship( + "Brief", + back_populates="project", + uselist=False, + ) + + def __repr__(self) -> str: + return f"" diff --git a/backend/app/models/task.py b/backend/app/models/task.py new file mode 100644 index 0000000..4473466 --- /dev/null +++ b/backend/app/models/task.py @@ -0,0 +1,148 @@ +""" +任务模型 +""" +from typing import TYPE_CHECKING, Optional +from datetime import datetime +from sqlalchemy import String, Integer, Text, ForeignKey, DateTime, Enum as SQLEnum, Boolean +from sqlalchemy.orm import Mapped, mapped_column, relationship +import enum + +from app.models.base import Base, TimestampMixin +from app.models.types import JSONType + +if TYPE_CHECKING: + from app.models.project import Project + from app.models.organization import Agency, Creator + + +class TaskStage(str, enum.Enum): + """任务阶段""" + SCRIPT_UPLOAD = "script_upload" # 待上传脚本 + SCRIPT_AI_REVIEW = "script_ai_review" # 脚本 AI 审核中 + SCRIPT_AGENCY_REVIEW = "script_agency_review" # 脚本代理商审核中 + SCRIPT_BRAND_REVIEW = "script_brand_review" # 脚本品牌方终审中 + VIDEO_UPLOAD = "video_upload" # 待上传视频 + VIDEO_AI_REVIEW = "video_ai_review" # 视频 AI 审核中 + VIDEO_AGENCY_REVIEW = "video_agency_review" # 视频代理商审核中 + VIDEO_BRAND_REVIEW = "video_brand_review" # 视频品牌方终审中 + COMPLETED = "completed" # 已完成 + REJECTED = "rejected" # 已驳回 + + +class TaskStatus(str, enum.Enum): + """任务状态""" + PENDING = "pending" # 待处理 + PROCESSING = "processing" # 处理中 + PASSED = "passed" # 通过 + REJECTED = "rejected" # 驳回 + FORCE_PASSED = "force_passed" # 强制通过 + + +class Task(Base, TimestampMixin): + """任务表""" + __tablename__ = "tasks" + + id: Mapped[str] = mapped_column(String(64), primary_key=True) + + # 关联 + project_id: Mapped[str] = mapped_column( + String(64), + ForeignKey("projects.id", ondelete="CASCADE"), + nullable=False, + index=True, + ) + agency_id: Mapped[str] = mapped_column( + String(64), + ForeignKey("agencies.id", ondelete="CASCADE"), + nullable=False, + index=True, + ) + creator_id: Mapped[str] = mapped_column( + String(64), + ForeignKey("creators.id", ondelete="CASCADE"), + nullable=False, + index=True, + ) + + # 任务信息 + name: Mapped[str] = mapped_column(String(255), nullable=False) # 如 "宣传任务(1)" + sequence: Mapped[int] = mapped_column(Integer, default=1, nullable=False) # 序号 + + # 当前阶段 + stage: Mapped[TaskStage] = mapped_column( + SQLEnum(TaskStage, name="task_stage_enum"), + default=TaskStage.SCRIPT_UPLOAD, + nullable=False, + index=True, + ) + + # ===== 脚本相关 ===== + script_file_url: Mapped[Optional[str]] = mapped_column(String(2048), nullable=True) + script_file_name: Mapped[Optional[str]] = mapped_column(String(255), nullable=True) + script_uploaded_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True) + + # 脚本 AI 审核结果 + script_ai_score: Mapped[Optional[int]] = mapped_column(Integer, nullable=True) + script_ai_result: Mapped[Optional[dict]] = mapped_column(JSONType, nullable=True) + script_ai_reviewed_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True) + + # 脚本代理商审核 + script_agency_status: Mapped[Optional[TaskStatus]] = mapped_column( + SQLEnum(TaskStatus, name="task_status_enum"), + nullable=True, + ) + script_agency_comment: Mapped[Optional[str]] = mapped_column(Text, nullable=True) + script_agency_reviewer_id: Mapped[Optional[str]] = mapped_column(String(64), nullable=True) + script_agency_reviewed_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True) + + # 脚本品牌方终审 + script_brand_status: Mapped[Optional[TaskStatus]] = mapped_column( + SQLEnum(TaskStatus, name="task_status_enum", create_type=False), + nullable=True, + ) + script_brand_comment: Mapped[Optional[str]] = mapped_column(Text, nullable=True) + script_brand_reviewer_id: Mapped[Optional[str]] = mapped_column(String(64), nullable=True) + script_brand_reviewed_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True) + + # ===== 视频相关 ===== + video_file_url: Mapped[Optional[str]] = mapped_column(String(2048), nullable=True) + video_file_name: Mapped[Optional[str]] = mapped_column(String(255), nullable=True) + video_duration: Mapped[Optional[int]] = mapped_column(Integer, nullable=True) # 秒 + video_thumbnail_url: Mapped[Optional[str]] = mapped_column(String(2048), nullable=True) + video_uploaded_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True) + + # 视频 AI 审核结果 + video_ai_score: Mapped[Optional[int]] = mapped_column(Integer, nullable=True) + video_ai_result: Mapped[Optional[dict]] = mapped_column(JSONType, nullable=True) + video_ai_reviewed_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True) + + # 视频代理商审核 + video_agency_status: Mapped[Optional[TaskStatus]] = mapped_column( + SQLEnum(TaskStatus, name="task_status_enum", create_type=False), + nullable=True, + ) + video_agency_comment: Mapped[Optional[str]] = mapped_column(Text, nullable=True) + video_agency_reviewer_id: Mapped[Optional[str]] = mapped_column(String(64), nullable=True) + video_agency_reviewed_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True) + + # 视频品牌方终审 + video_brand_status: Mapped[Optional[TaskStatus]] = mapped_column( + SQLEnum(TaskStatus, name="task_status_enum", create_type=False), + nullable=True, + ) + video_brand_comment: Mapped[Optional[str]] = mapped_column(Text, nullable=True) + video_brand_reviewer_id: Mapped[Optional[str]] = mapped_column(String(64), nullable=True) + video_brand_reviewed_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True) + + # ===== 申诉相关 ===== + appeal_count: Mapped[int] = mapped_column(Integer, default=1, nullable=False) # 剩余申诉次数 + is_appeal: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False) # 是否为申诉 + appeal_reason: Mapped[Optional[str]] = mapped_column(Text, nullable=True) # 申诉理由 + + # 关联 + project: Mapped["Project"] = relationship("Project", back_populates="tasks") + agency: Mapped["Agency"] = relationship("Agency", foreign_keys=[agency_id]) + creator: Mapped["Creator"] = relationship("Creator", foreign_keys=[creator_id]) + + def __repr__(self) -> str: + return f"" diff --git a/backend/app/models/user.py b/backend/app/models/user.py new file mode 100644 index 0000000..c2e39d7 --- /dev/null +++ b/backend/app/models/user.py @@ -0,0 +1,80 @@ +""" +用户模型 +""" +from typing import TYPE_CHECKING, Optional +from datetime import datetime +from sqlalchemy import String, Boolean, DateTime, Enum as SQLEnum +from sqlalchemy.orm import Mapped, mapped_column, relationship +import enum + +from app.models.base import Base, TimestampMixin + +if TYPE_CHECKING: + from app.models.organization import Brand, Agency, Creator + + +class UserRole(str, enum.Enum): + """用户角色""" + BRAND = "brand" # 品牌方 + AGENCY = "agency" # 代理商 + CREATOR = "creator" # 达人 + + +class User(Base, TimestampMixin): + """用户表""" + __tablename__ = "users" + + id: Mapped[str] = mapped_column(String(64), primary_key=True) + + # 登录凭证(邮箱和手机号都可以登录) + email: Mapped[Optional[str]] = mapped_column(String(255), unique=True, nullable=True, index=True) + phone: Mapped[Optional[str]] = mapped_column(String(20), unique=True, nullable=True, index=True) + password_hash: Mapped[str] = mapped_column(String(255), nullable=False) + + # 用户信息 + name: Mapped[str] = mapped_column(String(100), nullable=False) + avatar: Mapped[Optional[str]] = mapped_column(String(2048), nullable=True) + + # 角色 + role: Mapped[UserRole] = mapped_column( + SQLEnum(UserRole, name="user_role_enum"), + nullable=False, + index=True, + ) + + # 状态 + is_active: Mapped[bool] = mapped_column(Boolean, default=True, nullable=False) + is_verified: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False) + + # 最后登录 + last_login_at: Mapped[Optional[datetime]] = mapped_column( + DateTime(timezone=True), + nullable=True, + ) + + # Refresh Token(用于 JWT 刷新) + refresh_token: Mapped[Optional[str]] = mapped_column(String(512), nullable=True) + refresh_token_expires_at: Mapped[Optional[datetime]] = mapped_column( + DateTime(timezone=True), + nullable=True, + ) + + # 关联的组织(根据角色不同,关联到不同的组织) + brand: Mapped[Optional["Brand"]] = relationship( + "Brand", + back_populates="user", + uselist=False, + ) + agency: Mapped[Optional["Agency"]] = relationship( + "Agency", + back_populates="user", + uselist=False, + ) + creator: Mapped[Optional["Creator"]] = relationship( + "Creator", + back_populates="user", + uselist=False, + ) + + def __repr__(self) -> str: + return f"" diff --git a/backend/app/schemas/auth.py b/backend/app/schemas/auth.py new file mode 100644 index 0000000..56abbca --- /dev/null +++ b/backend/app/schemas/auth.py @@ -0,0 +1,120 @@ +""" +认证相关 Schema +""" +from typing import Optional +from pydantic import BaseModel, EmailStr, Field +from app.models.user import UserRole + + +# ===== 请求 ===== + +class RegisterRequest(BaseModel): + """注册请求""" + email: Optional[EmailStr] = None + phone: Optional[str] = Field(None, pattern=r"^1[3-9]\d{9}$") + password: str = Field(..., min_length=6, max_length=128) + name: str = Field(..., min_length=1, max_length=100) + role: UserRole + + class Config: + json_schema_extra = { + "example": { + "email": "user@example.com", + "password": "password123", + "name": "张三", + "role": "creator" + } + } + + +class LoginRequest(BaseModel): + """登录请求(邮箱或手机号)""" + email: Optional[EmailStr] = None + phone: Optional[str] = None + password: Optional[str] = None + sms_code: Optional[str] = None # 短信验证码 + + class Config: + json_schema_extra = { + "example": { + "email": "user@example.com", + "password": "password123" + } + } + + +class RefreshTokenRequest(BaseModel): + """刷新 Token 请求""" + refresh_token: str + + +class SendSmsCodeRequest(BaseModel): + """发送短信验证码请求""" + phone: str = Field(..., pattern=r"^1[3-9]\d{9}$") + + +class BindPhoneRequest(BaseModel): + """绑定手机号请求""" + phone: str = Field(..., pattern=r"^1[3-9]\d{9}$") + sms_code: str + + +class BindEmailRequest(BaseModel): + """绑定邮箱请求""" + email: EmailStr + password: str = Field(..., min_length=6, max_length=128) + + +class ChangePasswordRequest(BaseModel): + """修改密码请求""" + old_password: str + new_password: str = Field(..., min_length=6, max_length=128) + + +# ===== 响应 ===== + +class UserResponse(BaseModel): + """用户信息响应""" + id: str + email: Optional[str] = None + phone: Optional[str] = None + name: str + avatar: Optional[str] = None + role: UserRole + is_verified: bool + + # 根据角色返回对应的组织 ID + brand_id: Optional[str] = None + agency_id: Optional[str] = None + creator_id: Optional[str] = None + + # 当前所属租户(品牌方)- 用于数据隔离 + tenant_id: Optional[str] = None + tenant_name: Optional[str] = None + + class Config: + from_attributes = True + + +class TokenResponse(BaseModel): + """Token 响应""" + access_token: str + refresh_token: str + token_type: str = "bearer" + expires_in: int = 900 # 15 分钟 = 900 秒 + + +class LoginResponse(BaseModel): + """登录响应""" + access_token: str + refresh_token: str + token_type: str = "bearer" + expires_in: int = 900 + user: UserResponse + + +class RefreshTokenResponse(BaseModel): + """刷新 Token 响应""" + access_token: str + token_type: str = "bearer" + expires_in: int = 900 diff --git a/backend/app/services/auth.py b/backend/app/services/auth.py new file mode 100644 index 0000000..2a7ecd5 --- /dev/null +++ b/backend/app/services/auth.py @@ -0,0 +1,215 @@ +""" +认证服务 +""" +from datetime import datetime, timedelta +from typing import Optional +import secrets +from jose import jwt, JWTError +from passlib.context import CryptContext +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy import select + +from app.config import settings +from app.models.user import User, UserRole +from app.models.organization import Brand, Agency, Creator + +# 密码加密上下文 +pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") + + +def verify_password(plain_password: str, hashed_password: str) -> bool: + """验证密码""" + return pwd_context.verify(plain_password, hashed_password) + + +def hash_password(password: str) -> str: + """哈希密码""" + return pwd_context.hash(password) + + +def generate_id(prefix: str) -> str: + """生成语义化 ID""" + # 格式: BR123456, AG123456, CR123456 + random_part = secrets.randbelow(900000) + 100000 # 100000-999999 + return f"{prefix}{random_part}" + + +def create_access_token(user_id: str, expires_delta: Optional[timedelta] = None) -> str: + """创建访问 Token""" + if expires_delta is None: + expires_delta = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES) + + expire = datetime.utcnow() + expires_delta + to_encode = { + "sub": user_id, + "exp": expire, + "type": "access", + } + return jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM) + + +def create_refresh_token(user_id: str, expires_days: int = 7) -> tuple[str, datetime]: + """创建刷新 Token""" + expire = datetime.utcnow() + timedelta(days=expires_days) + to_encode = { + "sub": user_id, + "exp": expire, + "type": "refresh", + } + token = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM) + return token, expire + + +def decode_token(token: str) -> Optional[dict]: + """解码 Token""" + try: + payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]) + return payload + except JWTError: + return None + + +async def get_user_by_email(db: AsyncSession, email: str) -> Optional[User]: + """通过邮箱获取用户""" + result = await db.execute( + select(User).where(User.email == email) + ) + return result.scalar_one_or_none() + + +async def get_user_by_phone(db: AsyncSession, phone: str) -> Optional[User]: + """通过手机号获取用户""" + result = await db.execute( + select(User).where(User.phone == phone) + ) + return result.scalar_one_or_none() + + +async def get_user_by_id(db: AsyncSession, user_id: str) -> Optional[User]: + """通过 ID 获取用户""" + result = await db.execute( + select(User).where(User.id == user_id) + ) + return result.scalar_one_or_none() + + +async def create_user( + db: AsyncSession, + email: Optional[str], + phone: Optional[str], + password: str, + name: str, + role: UserRole, +) -> User: + """创建用户""" + user_id = generate_id("U") + + user = User( + id=user_id, + email=email, + phone=phone, + password_hash=hash_password(password), + name=name, + role=role, + is_active=True, + is_verified=False, + ) + db.add(user) + + # 根据角色创建对应的组织实体 + if role == UserRole.BRAND: + brand = Brand( + id=generate_id("BR"), + user_id=user_id, + name=name, + ) + db.add(brand) + elif role == UserRole.AGENCY: + agency = Agency( + id=generate_id("AG"), + user_id=user_id, + name=name, + ) + db.add(agency) + elif role == UserRole.CREATOR: + creator = Creator( + id=generate_id("CR"), + user_id=user_id, + name=name, + ) + db.add(creator) + + await db.flush() + return user + + +async def authenticate_user( + db: AsyncSession, + email: Optional[str] = None, + phone: Optional[str] = None, + password: Optional[str] = None, +) -> Optional[User]: + """验证用户登录""" + user = None + + if email: + user = await get_user_by_email(db, email) + elif phone: + user = await get_user_by_phone(db, phone) + + if not user: + return None + + if password and not verify_password(password, user.password_hash): + return None + + return user + + +async def update_refresh_token(db: AsyncSession, user: User, refresh_token: str, expires_at: datetime) -> None: + """更新用户的刷新 Token""" + user.refresh_token = refresh_token + user.refresh_token_expires_at = expires_at + user.last_login_at = datetime.utcnow() + await db.flush() + + +async def get_user_organization_info(db: AsyncSession, user: User) -> dict: + """获取用户的组织信息""" + info = { + "brand_id": None, + "agency_id": None, + "creator_id": None, + "tenant_id": None, + "tenant_name": None, + } + + if user.role == UserRole.BRAND: + result = await db.execute( + select(Brand).where(Brand.user_id == user.id) + ) + brand = result.scalar_one_or_none() + if brand: + info["brand_id"] = brand.id + info["tenant_id"] = brand.id + info["tenant_name"] = brand.name + + elif user.role == UserRole.AGENCY: + result = await db.execute( + select(Agency).where(Agency.user_id == user.id) + ) + agency = result.scalar_one_or_none() + if agency: + info["agency_id"] = agency.id + # 代理商可能服务多个品牌,这里暂时不设置 tenant + + elif user.role == UserRole.CREATOR: + result = await db.execute( + select(Creator).where(Creator.user_id == user.id) + ) + creator = result.scalar_one_or_none() + if creator: + info["creator_id"] = creator.id + # 达人可能服务多个代理商,这里暂时不设置 tenant + + return info diff --git a/backend/app/services/oss.py b/backend/app/services/oss.py new file mode 100644 index 0000000..e07a285 --- /dev/null +++ b/backend/app/services/oss.py @@ -0,0 +1,152 @@ +""" +阿里云 OSS 服务 +""" +import time +import hmac +import base64 +import hashlib +import json +from typing import Optional +from datetime import datetime +from app.config import settings + + +def generate_upload_policy( + max_size_mb: int = 500, + expire_seconds: int = 3600, + upload_dir: Optional[str] = None, +) -> dict: + """ + 生成前端直传 OSS 所需的 Policy 和签名 + + Returns: + { + "accessKeyId": "...", + "policy": "base64 encoded policy", + "signature": "...", + "host": "https://bucket.oss-cn-hangzhou.aliyuncs.com", + "dir": "uploads/2026/02/", + "expire": 1234567890 + } + """ + if not settings.OSS_ACCESS_KEY_ID or not settings.OSS_ACCESS_KEY_SECRET: + raise ValueError("OSS 配置未设置") + + # 计算过期时间 + expire_time = int(time.time()) + expire_seconds + expire_date = datetime.utcfromtimestamp(expire_time).strftime("%Y-%m-%dT%H:%M:%SZ") + + # 默认上传目录:uploads/年/月/ + if upload_dir is None: + now = datetime.now() + upload_dir = f"uploads/{now.year}/{now.month:02d}/" + + # 构建 Policy + policy_dict = { + "expiration": expire_date, + "conditions": [ + {"bucket": settings.OSS_BUCKET_NAME}, + ["starts-with", "$key", upload_dir], + ["content-length-range", 0, max_size_mb * 1024 * 1024], + ] + } + + # Base64 编码 Policy + policy_json = json.dumps(policy_dict) + policy_base64 = base64.b64encode(policy_json.encode()).decode() + + # 计算签名 + signature = base64.b64encode( + hmac.new( + settings.OSS_ACCESS_KEY_SECRET.encode(), + policy_base64.encode(), + hashlib.sha1 + ).digest() + ).decode() + + # 构建 Host + host = settings.OSS_BUCKET_DOMAIN + if not host: + host = f"https://{settings.OSS_BUCKET_NAME}.{settings.OSS_ENDPOINT}" + + return { + "accessKeyId": settings.OSS_ACCESS_KEY_ID, + "policy": policy_base64, + "signature": signature, + "host": host, + "dir": upload_dir, + "expire": expire_time, + } + + +def generate_sts_token( + role_arn: str, + session_name: str = "miaosi-upload", + duration_seconds: int = 3600, +) -> dict: + """ + 生成 STS 临时凭证(需要配置 RAM 角色) + + 注意:此方法需要安装 aliyun-python-sdk-sts + 如果不使用 STS,可以使用上面的 generate_upload_policy 方法 + """ + # TODO: 实现 STS 临时凭证生成 + # 需要安装 aliyun-python-sdk-core 和 aliyun-python-sdk-sts + raise NotImplementedError("STS 临时凭证生成暂未实现,请使用 generate_upload_policy") + + +def get_file_url(file_key: str) -> str: + """ + 获取文件的公开访问 URL + + Args: + file_key: 文件在 OSS 中的 key,如 "uploads/2026/02/video.mp4" + + Returns: + 完整的访问 URL + """ + host = settings.OSS_BUCKET_DOMAIN + if not host: + host = f"https://{settings.OSS_BUCKET_NAME}.{settings.OSS_ENDPOINT}" + + # 确保 host 以 https:// 开头 + if not host.startswith("http"): + host = f"https://{host}" + + # 确保 host 不以 / 结尾 + host = host.rstrip("/") + + # 确保 file_key 不以 / 开头 + file_key = file_key.lstrip("/") + + return f"{host}/{file_key}" + + +def parse_file_key_from_url(url: str) -> str: + """ + 从完整 URL 解析出文件 key + + Args: + url: 完整的 OSS URL + + Returns: + 文件 key + """ + host = settings.OSS_BUCKET_DOMAIN + if not host: + host = f"https://{settings.OSS_BUCKET_NAME}.{settings.OSS_ENDPOINT}" + + # 移除 host 前缀 + if url.startswith(host): + return url[len(host):].lstrip("/") + + # 尝试其他格式 + if settings.OSS_BUCKET_NAME in url: + # 格式: https://bucket.endpoint/key + parts = url.split(settings.OSS_BUCKET_NAME + ".") + if len(parts) > 1: + key_part = parts[1].split("/", 1) + if len(key_part) > 1: + return key_part[1] + + return url diff --git a/backend/pyproject.toml b/backend/pyproject.toml index e13b027..e69440d 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -12,7 +12,7 @@ dependencies = [ "asyncpg>=0.29.0", "greenlet>=3.0.0", "httpx>=0.26.0", - "pydantic>=2.5.0", + "pydantic[email]>=2.5.0", "pydantic-settings>=2.0.0", "python-jose>=3.3.0", "passlib>=1.7.4",