feat: 添加后端核心模块
用户认证: - User 模型(支持邮箱/手机号登录) - 双 Token JWT 认证(accessToken + refreshToken) - 注册/登录/刷新 Token API 组织模型: - Brand(品牌方)、Agency(代理商)、Creator(达人) - 多对多关系:品牌方↔代理商、代理商↔达人 项目与任务: - Project 模型(品牌方发布) - Task 模型(完整审核流程追踪) - Brief 模型(解析后的结构化内容) 文件上传: - 阿里云 OSS 直传签名服务 - 支持分片上传,最大 500MB 数据库迁移: - 003_user_org_project_task.py Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
parent
4b8809efe2
commit
4caafdb50f
240
backend/alembic/versions/003_user_org_project_task.py
Normal file
240
backend/alembic/versions/003_user_org_project_task.py
Normal file
@ -0,0 +1,240 @@
|
||||
"""添加用户、组织、项目、任务表
|
||||
|
||||
Revision ID: 003
|
||||
Revises: 002
|
||||
Create Date: 2026-02-09
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '003'
|
||||
down_revision: Union[str, None] = '002'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# 创建枚举类型
|
||||
user_role_enum = postgresql.ENUM(
|
||||
'brand', 'agency', 'creator',
|
||||
name='user_role_enum'
|
||||
)
|
||||
user_role_enum.create(op.get_bind(), checkfirst=True)
|
||||
|
||||
task_stage_enum = postgresql.ENUM(
|
||||
'script_upload', 'script_ai_review', 'script_agency_review', 'script_brand_review',
|
||||
'video_upload', 'video_ai_review', 'video_agency_review', 'video_brand_review',
|
||||
'completed', 'rejected',
|
||||
name='task_stage_enum'
|
||||
)
|
||||
task_stage_enum.create(op.get_bind(), checkfirst=True)
|
||||
|
||||
# 用户表
|
||||
op.create_table(
|
||||
'users',
|
||||
sa.Column('id', sa.String(64), primary_key=True),
|
||||
sa.Column('email', sa.String(255), unique=True, nullable=True, index=True),
|
||||
sa.Column('phone', sa.String(20), unique=True, nullable=True, index=True),
|
||||
sa.Column('password_hash', sa.String(255), nullable=False),
|
||||
sa.Column('name', sa.String(100), nullable=False),
|
||||
sa.Column('avatar', sa.String(2048), nullable=True),
|
||||
sa.Column('role', postgresql.ENUM('brand', 'agency', 'creator', name='user_role_enum', create_type=False), nullable=False, index=True),
|
||||
sa.Column('is_active', sa.Boolean(), default=True, nullable=False),
|
||||
sa.Column('is_verified', sa.Boolean(), default=False, nullable=False),
|
||||
sa.Column('last_login_at', sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column('refresh_token', sa.String(512), nullable=True),
|
||||
sa.Column('refresh_token_expires_at', sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.func.now(), onupdate=sa.func.now(), nullable=False),
|
||||
)
|
||||
|
||||
# 品牌方表
|
||||
op.create_table(
|
||||
'brands',
|
||||
sa.Column('id', sa.String(64), primary_key=True),
|
||||
sa.Column('user_id', sa.String(64), sa.ForeignKey('users.id', ondelete='CASCADE'), unique=True, nullable=False),
|
||||
sa.Column('name', sa.String(255), nullable=False),
|
||||
sa.Column('logo', sa.String(2048), nullable=True),
|
||||
sa.Column('description', sa.Text(), nullable=True),
|
||||
sa.Column('contact_name', sa.String(100), nullable=True),
|
||||
sa.Column('contact_phone', sa.String(20), nullable=True),
|
||||
sa.Column('contact_email', sa.String(255), nullable=True),
|
||||
sa.Column('final_review_enabled', sa.Boolean(), default=True, nullable=False),
|
||||
sa.Column('is_active', sa.Boolean(), default=True, nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.func.now(), onupdate=sa.func.now(), nullable=False),
|
||||
)
|
||||
|
||||
# 代理商表
|
||||
op.create_table(
|
||||
'agencies',
|
||||
sa.Column('id', sa.String(64), primary_key=True),
|
||||
sa.Column('user_id', sa.String(64), sa.ForeignKey('users.id', ondelete='CASCADE'), unique=True, nullable=False),
|
||||
sa.Column('name', sa.String(255), nullable=False),
|
||||
sa.Column('logo', sa.String(2048), nullable=True),
|
||||
sa.Column('description', sa.Text(), nullable=True),
|
||||
sa.Column('contact_name', sa.String(100), nullable=True),
|
||||
sa.Column('contact_phone', sa.String(20), nullable=True),
|
||||
sa.Column('contact_email', sa.String(255), nullable=True),
|
||||
sa.Column('force_pass_enabled', sa.Boolean(), default=True, nullable=False),
|
||||
sa.Column('is_active', sa.Boolean(), default=True, nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.func.now(), onupdate=sa.func.now(), nullable=False),
|
||||
)
|
||||
|
||||
# 达人表
|
||||
op.create_table(
|
||||
'creators',
|
||||
sa.Column('id', sa.String(64), primary_key=True),
|
||||
sa.Column('user_id', sa.String(64), sa.ForeignKey('users.id', ondelete='CASCADE'), unique=True, nullable=False),
|
||||
sa.Column('name', sa.String(255), nullable=False),
|
||||
sa.Column('avatar', sa.String(2048), nullable=True),
|
||||
sa.Column('bio', sa.Text(), nullable=True),
|
||||
sa.Column('douyin_account', sa.String(100), nullable=True),
|
||||
sa.Column('xiaohongshu_account', sa.String(100), nullable=True),
|
||||
sa.Column('bilibili_account', sa.String(100), nullable=True),
|
||||
sa.Column('is_active', sa.Boolean(), default=True, nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.func.now(), onupdate=sa.func.now(), nullable=False),
|
||||
)
|
||||
|
||||
# 品牌方-代理商关联表
|
||||
op.create_table(
|
||||
'brand_agency',
|
||||
sa.Column('brand_id', sa.String(64), sa.ForeignKey('brands.id', ondelete='CASCADE'), primary_key=True),
|
||||
sa.Column('agency_id', sa.String(64), sa.ForeignKey('agencies.id', ondelete='CASCADE'), primary_key=True),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.func.now()),
|
||||
sa.Column('is_active', sa.Boolean(), default=True),
|
||||
)
|
||||
|
||||
# 代理商-达人关联表
|
||||
op.create_table(
|
||||
'agency_creator',
|
||||
sa.Column('agency_id', sa.String(64), sa.ForeignKey('agencies.id', ondelete='CASCADE'), primary_key=True),
|
||||
sa.Column('creator_id', sa.String(64), sa.ForeignKey('creators.id', ondelete='CASCADE'), primary_key=True),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.func.now()),
|
||||
sa.Column('is_active', sa.Boolean(), default=True),
|
||||
)
|
||||
|
||||
# 项目表
|
||||
op.create_table(
|
||||
'projects',
|
||||
sa.Column('id', sa.String(64), primary_key=True),
|
||||
sa.Column('brand_id', sa.String(64), sa.ForeignKey('brands.id', ondelete='CASCADE'), nullable=False, index=True),
|
||||
sa.Column('name', sa.String(255), nullable=False),
|
||||
sa.Column('description', sa.Text(), nullable=True),
|
||||
sa.Column('start_date', sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column('deadline', sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column('status', sa.String(20), default='active', nullable=False, index=True),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.func.now(), onupdate=sa.func.now(), nullable=False),
|
||||
)
|
||||
|
||||
# 项目-代理商关联表
|
||||
op.create_table(
|
||||
'project_agency',
|
||||
sa.Column('project_id', sa.String(64), sa.ForeignKey('projects.id', ondelete='CASCADE'), primary_key=True),
|
||||
sa.Column('agency_id', sa.String(64), sa.ForeignKey('agencies.id', ondelete='CASCADE'), primary_key=True),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.func.now()),
|
||||
sa.Column('is_active', sa.Boolean(), default=True),
|
||||
)
|
||||
|
||||
# Brief 表
|
||||
op.create_table(
|
||||
'briefs',
|
||||
sa.Column('id', sa.String(64), primary_key=True),
|
||||
sa.Column('project_id', sa.String(64), sa.ForeignKey('projects.id', ondelete='CASCADE'), unique=True, nullable=False, index=True),
|
||||
sa.Column('file_url', sa.String(2048), nullable=True),
|
||||
sa.Column('file_name', sa.String(255), nullable=True),
|
||||
sa.Column('selling_points', postgresql.JSON(), nullable=True),
|
||||
sa.Column('blacklist_words', postgresql.JSON(), nullable=True),
|
||||
sa.Column('competitors', postgresql.JSON(), nullable=True),
|
||||
sa.Column('brand_tone', sa.Text(), nullable=True),
|
||||
sa.Column('min_duration', sa.Integer(), nullable=True),
|
||||
sa.Column('max_duration', sa.Integer(), nullable=True),
|
||||
sa.Column('other_requirements', sa.Text(), nullable=True),
|
||||
sa.Column('attachments', postgresql.JSON(), nullable=True),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.func.now(), onupdate=sa.func.now(), nullable=False),
|
||||
)
|
||||
|
||||
# 任务表
|
||||
op.create_table(
|
||||
'tasks',
|
||||
sa.Column('id', sa.String(64), primary_key=True),
|
||||
sa.Column('project_id', sa.String(64), sa.ForeignKey('projects.id', ondelete='CASCADE'), nullable=False, index=True),
|
||||
sa.Column('agency_id', sa.String(64), sa.ForeignKey('agencies.id', ondelete='CASCADE'), nullable=False, index=True),
|
||||
sa.Column('creator_id', sa.String(64), sa.ForeignKey('creators.id', ondelete='CASCADE'), nullable=False, index=True),
|
||||
sa.Column('name', sa.String(255), nullable=False),
|
||||
sa.Column('sequence', sa.Integer(), default=1, nullable=False),
|
||||
sa.Column('stage', postgresql.ENUM(
|
||||
'script_upload', 'script_ai_review', 'script_agency_review', 'script_brand_review',
|
||||
'video_upload', 'video_ai_review', 'video_agency_review', 'video_brand_review',
|
||||
'completed', 'rejected',
|
||||
name='task_stage_enum', create_type=False
|
||||
), default='script_upload', nullable=False, index=True),
|
||||
|
||||
# 脚本相关
|
||||
sa.Column('script_file_url', sa.String(2048), nullable=True),
|
||||
sa.Column('script_file_name', sa.String(255), nullable=True),
|
||||
sa.Column('script_uploaded_at', sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column('script_ai_score', sa.Integer(), nullable=True),
|
||||
sa.Column('script_ai_result', postgresql.JSON(), nullable=True),
|
||||
sa.Column('script_ai_reviewed_at', sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column('script_agency_status', postgresql.ENUM('pending', 'processing', 'passed', 'rejected', 'force_passed', name='task_status_enum', create_type=False), nullable=True),
|
||||
sa.Column('script_agency_comment', sa.Text(), nullable=True),
|
||||
sa.Column('script_agency_reviewer_id', sa.String(64), nullable=True),
|
||||
sa.Column('script_agency_reviewed_at', sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column('script_brand_status', postgresql.ENUM('pending', 'processing', 'passed', 'rejected', 'force_passed', name='task_status_enum', create_type=False), nullable=True),
|
||||
sa.Column('script_brand_comment', sa.Text(), nullable=True),
|
||||
sa.Column('script_brand_reviewer_id', sa.String(64), nullable=True),
|
||||
sa.Column('script_brand_reviewed_at', sa.DateTime(timezone=True), nullable=True),
|
||||
|
||||
# 视频相关
|
||||
sa.Column('video_file_url', sa.String(2048), nullable=True),
|
||||
sa.Column('video_file_name', sa.String(255), nullable=True),
|
||||
sa.Column('video_duration', sa.Integer(), nullable=True),
|
||||
sa.Column('video_thumbnail_url', sa.String(2048), nullable=True),
|
||||
sa.Column('video_uploaded_at', sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column('video_ai_score', sa.Integer(), nullable=True),
|
||||
sa.Column('video_ai_result', postgresql.JSON(), nullable=True),
|
||||
sa.Column('video_ai_reviewed_at', sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column('video_agency_status', postgresql.ENUM('pending', 'processing', 'passed', 'rejected', 'force_passed', name='task_status_enum', create_type=False), nullable=True),
|
||||
sa.Column('video_agency_comment', sa.Text(), nullable=True),
|
||||
sa.Column('video_agency_reviewer_id', sa.String(64), nullable=True),
|
||||
sa.Column('video_agency_reviewed_at', sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column('video_brand_status', postgresql.ENUM('pending', 'processing', 'passed', 'rejected', 'force_passed', name='task_status_enum', create_type=False), nullable=True),
|
||||
sa.Column('video_brand_comment', sa.Text(), nullable=True),
|
||||
sa.Column('video_brand_reviewer_id', sa.String(64), nullable=True),
|
||||
sa.Column('video_brand_reviewed_at', sa.DateTime(timezone=True), nullable=True),
|
||||
|
||||
# 申诉相关
|
||||
sa.Column('appeal_count', sa.Integer(), default=1, nullable=False),
|
||||
sa.Column('is_appeal', sa.Boolean(), default=False, nullable=False),
|
||||
sa.Column('appeal_reason', sa.Text(), nullable=True),
|
||||
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.func.now(), onupdate=sa.func.now(), nullable=False),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table('tasks')
|
||||
op.drop_table('briefs')
|
||||
op.drop_table('project_agency')
|
||||
op.drop_table('projects')
|
||||
op.drop_table('agency_creator')
|
||||
op.drop_table('brand_agency')
|
||||
op.drop_table('creators')
|
||||
op.drop_table('agencies')
|
||||
op.drop_table('brands')
|
||||
op.drop_table('users')
|
||||
|
||||
# 删除枚举类型
|
||||
op.execute("DROP TYPE IF EXISTS task_stage_enum")
|
||||
op.execute("DROP TYPE IF EXISTS user_role_enum")
|
||||
246
backend/app/api/auth.py
Normal file
246
backend/app/api/auth.py
Normal file
@ -0,0 +1,246 @@
|
||||
"""
|
||||
认证 API
|
||||
"""
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.database import get_db
|
||||
from app.schemas.auth import (
|
||||
RegisterRequest,
|
||||
LoginRequest,
|
||||
LoginResponse,
|
||||
RefreshTokenRequest,
|
||||
RefreshTokenResponse,
|
||||
UserResponse,
|
||||
)
|
||||
from app.services.auth import (
|
||||
get_user_by_email,
|
||||
get_user_by_phone,
|
||||
get_user_by_id,
|
||||
create_user,
|
||||
authenticate_user,
|
||||
create_access_token,
|
||||
create_refresh_token,
|
||||
update_refresh_token,
|
||||
decode_token,
|
||||
get_user_organization_info,
|
||||
)
|
||||
|
||||
router = APIRouter(prefix="/auth", tags=["认证"])
|
||||
|
||||
|
||||
@router.post("/register", response_model=LoginResponse, status_code=status.HTTP_201_CREATED)
|
||||
async def register(
|
||||
request: RegisterRequest,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
用户注册
|
||||
|
||||
- 支持邮箱或手机号注册(至少提供一个)
|
||||
- 注册后自动登录,返回 Token
|
||||
"""
|
||||
# 验证至少提供邮箱或手机号
|
||||
if not request.email and not request.phone:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="请提供邮箱或手机号",
|
||||
)
|
||||
|
||||
# 检查邮箱是否已存在
|
||||
if request.email:
|
||||
existing = await get_user_by_email(db, request.email)
|
||||
if existing:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="该邮箱已被注册",
|
||||
)
|
||||
|
||||
# 检查手机号是否已存在
|
||||
if request.phone:
|
||||
existing = await get_user_by_phone(db, request.phone)
|
||||
if existing:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="该手机号已被注册",
|
||||
)
|
||||
|
||||
# 创建用户
|
||||
user = await create_user(
|
||||
db=db,
|
||||
email=request.email,
|
||||
phone=request.phone,
|
||||
password=request.password,
|
||||
name=request.name,
|
||||
role=request.role,
|
||||
)
|
||||
|
||||
# 生成 Token
|
||||
access_token = create_access_token(user.id)
|
||||
refresh_token, refresh_expires_at = create_refresh_token(user.id)
|
||||
|
||||
# 保存 refresh token
|
||||
await update_refresh_token(db, user, refresh_token, refresh_expires_at)
|
||||
await db.commit()
|
||||
|
||||
# 获取组织信息
|
||||
org_info = await get_user_organization_info(db, user)
|
||||
|
||||
return LoginResponse(
|
||||
access_token=access_token,
|
||||
refresh_token=refresh_token,
|
||||
user=UserResponse(
|
||||
id=user.id,
|
||||
email=user.email,
|
||||
phone=user.phone,
|
||||
name=user.name,
|
||||
avatar=user.avatar,
|
||||
role=user.role,
|
||||
is_verified=user.is_verified,
|
||||
**org_info,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@router.post("/login", response_model=LoginResponse)
|
||||
async def login(
|
||||
request: LoginRequest,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
用户登录
|
||||
|
||||
- 支持邮箱+密码 或 手机号+密码 登录
|
||||
- 返回 accessToken 和 refreshToken
|
||||
"""
|
||||
# 验证请求参数
|
||||
if not request.email and not request.phone:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="请提供邮箱或手机号",
|
||||
)
|
||||
|
||||
if not request.password:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="请提供密码",
|
||||
)
|
||||
|
||||
# 验证用户
|
||||
user = await authenticate_user(
|
||||
db=db,
|
||||
email=request.email,
|
||||
phone=request.phone,
|
||||
password=request.password,
|
||||
)
|
||||
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="邮箱/手机号或密码错误",
|
||||
)
|
||||
|
||||
if not user.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="账号已被禁用",
|
||||
)
|
||||
|
||||
# 生成 Token
|
||||
access_token = create_access_token(user.id)
|
||||
refresh_token, refresh_expires_at = create_refresh_token(user.id)
|
||||
|
||||
# 保存 refresh token
|
||||
await update_refresh_token(db, user, refresh_token, refresh_expires_at)
|
||||
await db.commit()
|
||||
|
||||
# 获取组织信息
|
||||
org_info = await get_user_organization_info(db, user)
|
||||
|
||||
return LoginResponse(
|
||||
access_token=access_token,
|
||||
refresh_token=refresh_token,
|
||||
user=UserResponse(
|
||||
id=user.id,
|
||||
email=user.email,
|
||||
phone=user.phone,
|
||||
name=user.name,
|
||||
avatar=user.avatar,
|
||||
role=user.role,
|
||||
is_verified=user.is_verified,
|
||||
**org_info,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@router.post("/refresh", response_model=RefreshTokenResponse)
|
||||
async def refresh_token(
|
||||
request: RefreshTokenRequest,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
刷新 Access Token
|
||||
|
||||
- 使用 refreshToken 获取新的 accessToken
|
||||
- refreshToken 有效期 7 天
|
||||
"""
|
||||
# 解码 refresh token
|
||||
payload = decode_token(request.refresh_token)
|
||||
if not payload:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="无效的 refresh token",
|
||||
)
|
||||
|
||||
if payload.get("type") != "refresh":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="无效的 token 类型",
|
||||
)
|
||||
|
||||
user_id = payload.get("sub")
|
||||
if not user_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="无效的 token",
|
||||
)
|
||||
|
||||
# 获取用户
|
||||
user = await get_user_by_id(db, user_id)
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="用户不存在",
|
||||
)
|
||||
|
||||
# 验证 refresh token 是否匹配
|
||||
if user.refresh_token != request.refresh_token:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="refresh token 已失效",
|
||||
)
|
||||
|
||||
if not user.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="账号已被禁用",
|
||||
)
|
||||
|
||||
# 生成新的 access token
|
||||
access_token = create_access_token(user.id)
|
||||
|
||||
return RefreshTokenResponse(access_token=access_token)
|
||||
|
||||
|
||||
@router.post("/logout")
|
||||
async def logout(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
# TODO: 添加认证依赖
|
||||
):
|
||||
"""
|
||||
退出登录
|
||||
|
||||
- 清除 refresh token
|
||||
"""
|
||||
# TODO: 实现退出登录
|
||||
return {"message": "已退出登录"}
|
||||
117
backend/app/api/upload.py
Normal file
117
backend/app/api/upload.py
Normal file
@ -0,0 +1,117 @@
|
||||
"""
|
||||
文件上传 API
|
||||
"""
|
||||
from fastapi import APIRouter, HTTPException, status
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional
|
||||
from datetime import datetime
|
||||
|
||||
from app.services.oss import generate_upload_policy, get_file_url
|
||||
from app.config import settings
|
||||
|
||||
router = APIRouter(prefix="/upload", tags=["文件上传"])
|
||||
|
||||
|
||||
class UploadPolicyRequest(BaseModel):
|
||||
"""获取上传凭证请求"""
|
||||
file_type: str = "general" # script, video, image, general
|
||||
file_name: Optional[str] = None
|
||||
|
||||
|
||||
class UploadPolicyResponse(BaseModel):
|
||||
"""上传凭证响应"""
|
||||
access_key_id: str
|
||||
policy: str
|
||||
signature: str
|
||||
host: str
|
||||
dir: str
|
||||
expire: int
|
||||
max_size_mb: int
|
||||
|
||||
|
||||
class FileUploadedRequest(BaseModel):
|
||||
"""文件上传完成回调"""
|
||||
file_key: str
|
||||
file_name: str
|
||||
file_size: int
|
||||
file_type: str
|
||||
|
||||
|
||||
class FileUploadedResponse(BaseModel):
|
||||
"""文件上传完成响应"""
|
||||
url: str
|
||||
file_key: str
|
||||
file_name: str
|
||||
file_size: int
|
||||
file_type: str
|
||||
|
||||
|
||||
@router.post("/policy", response_model=UploadPolicyResponse)
|
||||
async def get_upload_policy(
|
||||
request: UploadPolicyRequest,
|
||||
):
|
||||
"""
|
||||
获取 OSS 直传凭证
|
||||
|
||||
前端使用此凭证直接上传文件到阿里云 OSS,无需经过后端。
|
||||
|
||||
文件类型说明:
|
||||
- script: 脚本文档 (docx, pdf, xlsx, txt, pptx)
|
||||
- video: 视频文件 (mp4, mov, webm)
|
||||
- image: 图片文件 (jpg, png, gif)
|
||||
- general: 通用文件
|
||||
"""
|
||||
# 根据文件类型设置上传目录
|
||||
now = datetime.now()
|
||||
base_dir = f"uploads/{now.year}/{now.month:02d}"
|
||||
|
||||
if request.file_type == "script":
|
||||
upload_dir = f"{base_dir}/scripts/"
|
||||
elif request.file_type == "video":
|
||||
upload_dir = f"{base_dir}/videos/"
|
||||
elif request.file_type == "image":
|
||||
upload_dir = f"{base_dir}/images/"
|
||||
else:
|
||||
upload_dir = f"{base_dir}/files/"
|
||||
|
||||
try:
|
||||
policy = generate_upload_policy(
|
||||
max_size_mb=settings.MAX_FILE_SIZE_MB,
|
||||
expire_seconds=3600,
|
||||
upload_dir=upload_dir,
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=str(e),
|
||||
)
|
||||
|
||||
return UploadPolicyResponse(
|
||||
access_key_id=policy["accessKeyId"],
|
||||
policy=policy["policy"],
|
||||
signature=policy["signature"],
|
||||
host=policy["host"],
|
||||
dir=policy["dir"],
|
||||
expire=policy["expire"],
|
||||
max_size_mb=settings.MAX_FILE_SIZE_MB,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/complete", response_model=FileUploadedResponse)
|
||||
async def file_uploaded(
|
||||
request: FileUploadedRequest,
|
||||
):
|
||||
"""
|
||||
文件上传完成回调
|
||||
|
||||
前端上传完成后调用此接口,获取文件的完整 URL。
|
||||
"""
|
||||
url = get_file_url(request.file_key)
|
||||
|
||||
return FileUploadedResponse(
|
||||
url=url,
|
||||
file_key=request.file_key,
|
||||
file_name=request.file_name,
|
||||
file_size=request.file_size,
|
||||
file_type=request.file_type,
|
||||
)
|
||||
@ -27,6 +27,16 @@ class Settings(BaseSettings):
|
||||
AI_API_KEY: str = "" # 中转服务商的 API Key
|
||||
AI_API_BASE_URL: str = "" # 中转服务商的 Base URL,如 https://api.oneinall.ai/v1
|
||||
|
||||
# 阿里云 OSS 配置
|
||||
OSS_ACCESS_KEY_ID: str = ""
|
||||
OSS_ACCESS_KEY_SECRET: str = ""
|
||||
OSS_ENDPOINT: str = "oss-cn-hangzhou.aliyuncs.com"
|
||||
OSS_BUCKET_NAME: str = "miaosi-files"
|
||||
OSS_BUCKET_DOMAIN: str = "" # 公开访问域名,如 https://miaosi-files.oss-cn-hangzhou.aliyuncs.com
|
||||
|
||||
# 文件上传限制
|
||||
MAX_FILE_SIZE_MB: int = 500 # 最大文件大小 500MB
|
||||
|
||||
class Config:
|
||||
env_file = ".env"
|
||||
case_sensitive = True
|
||||
|
||||
@ -7,14 +7,30 @@ from app.config import settings
|
||||
# 导入所有模型,确保在创建表时被注册
|
||||
from app.models.base import Base
|
||||
from app.models import (
|
||||
Tenant,
|
||||
# 用户与组织
|
||||
User,
|
||||
UserRole,
|
||||
Brand,
|
||||
Agency,
|
||||
Creator,
|
||||
# 项目与任务
|
||||
Project,
|
||||
Task,
|
||||
TaskStage,
|
||||
TaskStatus,
|
||||
Brief,
|
||||
# AI 配置
|
||||
AIConfig,
|
||||
# 审核(旧模型)
|
||||
ReviewTask,
|
||||
ManualTask,
|
||||
# 规则
|
||||
ForbiddenWord,
|
||||
WhitelistItem,
|
||||
Competitor,
|
||||
RiskException,
|
||||
# 兼容
|
||||
Tenant,
|
||||
)
|
||||
|
||||
# 创建异步引擎
|
||||
@ -65,12 +81,28 @@ __all__ = [
|
||||
"get_db",
|
||||
"init_db",
|
||||
"drop_db",
|
||||
"Tenant",
|
||||
# 用户与组织
|
||||
"User",
|
||||
"UserRole",
|
||||
"Brand",
|
||||
"Agency",
|
||||
"Creator",
|
||||
# 项目与任务
|
||||
"Project",
|
||||
"Task",
|
||||
"TaskStage",
|
||||
"TaskStatus",
|
||||
"Brief",
|
||||
# AI 配置
|
||||
"AIConfig",
|
||||
# 审核
|
||||
"ReviewTask",
|
||||
"ManualTask",
|
||||
# 规则
|
||||
"ForbiddenWord",
|
||||
"WhitelistItem",
|
||||
"Competitor",
|
||||
"RiskException",
|
||||
# 兼容
|
||||
"Tenant",
|
||||
]
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from app.config import settings
|
||||
from app.api import health, scripts, videos, tasks, rules, ai_config, risk_exceptions, metrics
|
||||
from app.api import health, auth, upload, scripts, videos, tasks, rules, ai_config, risk_exceptions, metrics
|
||||
|
||||
# 创建应用
|
||||
app = FastAPI(
|
||||
@ -24,6 +24,8 @@ app.add_middleware(
|
||||
|
||||
# 注册路由
|
||||
app.include_router(health.router, prefix="/api/v1")
|
||||
app.include_router(auth.router, prefix="/api/v1")
|
||||
app.include_router(upload.router, prefix="/api/v1")
|
||||
app.include_router(scripts.router, prefix="/api/v1")
|
||||
app.include_router(videos.router, prefix="/api/v1")
|
||||
app.include_router(tasks.router, prefix="/api/v1")
|
||||
|
||||
@ -3,21 +3,49 @@
|
||||
导出所有 ORM 模型
|
||||
"""
|
||||
from app.models.base import Base, TimestampMixin
|
||||
from app.models.tenant import Tenant
|
||||
from app.models.user import User, UserRole
|
||||
from app.models.organization import Brand, Agency, Creator, brand_agency_association, agency_creator_association
|
||||
from app.models.project import Project, project_agency_association
|
||||
from app.models.task import Task, TaskStage, TaskStatus
|
||||
from app.models.brief import Brief
|
||||
from app.models.ai_config import AIConfig
|
||||
from app.models.review import ReviewTask, ManualTask
|
||||
from app.models.review import ReviewTask, ManualTask, Platform
|
||||
from app.models.rule import ForbiddenWord, WhitelistItem, Competitor
|
||||
from app.models.risk_exception import RiskException
|
||||
|
||||
# 保留 Tenant 兼容旧代码,但新代码应使用 Brand
|
||||
from app.models.tenant import Tenant
|
||||
|
||||
__all__ = [
|
||||
# Base
|
||||
"Base",
|
||||
"TimestampMixin",
|
||||
"Tenant",
|
||||
# 用户与组织
|
||||
"User",
|
||||
"UserRole",
|
||||
"Brand",
|
||||
"Agency",
|
||||
"Creator",
|
||||
"brand_agency_association",
|
||||
"agency_creator_association",
|
||||
# 项目与任务
|
||||
"Project",
|
||||
"project_agency_association",
|
||||
"Task",
|
||||
"TaskStage",
|
||||
"TaskStatus",
|
||||
"Brief",
|
||||
# AI 配置
|
||||
"AIConfig",
|
||||
# 审核(旧模型,保留兼容)
|
||||
"ReviewTask",
|
||||
"ManualTask",
|
||||
"Platform",
|
||||
# 规则
|
||||
"ForbiddenWord",
|
||||
"WhitelistItem",
|
||||
"Competitor",
|
||||
"RiskException",
|
||||
# 兼容
|
||||
"Tenant",
|
||||
]
|
||||
|
||||
60
backend/app/models/brief.py
Normal file
60
backend/app/models/brief.py
Normal file
@ -0,0 +1,60 @@
|
||||
"""
|
||||
Brief 模型
|
||||
"""
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
from sqlalchemy import String, Text, ForeignKey
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from app.models.base import Base, TimestampMixin
|
||||
from app.models.types import JSONType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.models.project import Project
|
||||
|
||||
|
||||
class Brief(Base, TimestampMixin):
|
||||
"""Brief 文档表"""
|
||||
__tablename__ = "briefs"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(64), primary_key=True)
|
||||
project_id: Mapped[str] = mapped_column(
|
||||
String(64),
|
||||
ForeignKey("projects.id", ondelete="CASCADE"),
|
||||
unique=True,
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
|
||||
# 原始文件
|
||||
file_url: Mapped[Optional[str]] = mapped_column(String(2048), nullable=True)
|
||||
file_name: Mapped[Optional[str]] = mapped_column(String(255), nullable=True)
|
||||
|
||||
# 解析后的结构化内容
|
||||
# 卖点要求: [{"content": "SPF50+", "required": true}, ...]
|
||||
selling_points: Mapped[Optional[list]] = mapped_column(JSONType, nullable=True)
|
||||
|
||||
# 违禁词: [{"word": "最好", "reason": "绝对化用语"}, ...]
|
||||
blacklist_words: Mapped[Optional[list]] = mapped_column(JSONType, nullable=True)
|
||||
|
||||
# 竞品: ["竞品A", "竞品B", ...]
|
||||
competitors: Mapped[Optional[list]] = mapped_column(JSONType, nullable=True)
|
||||
|
||||
# 品牌调性要求
|
||||
brand_tone: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
|
||||
# 时长要求(秒)
|
||||
min_duration: Mapped[Optional[int]] = mapped_column(nullable=True)
|
||||
max_duration: Mapped[Optional[int]] = mapped_column(nullable=True)
|
||||
|
||||
# 其他要求(自由文本)
|
||||
other_requirements: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
|
||||
# 附件文档(代理商上传的参考资料)
|
||||
# [{"id": "af1", "name": "达人拍摄指南.pdf", "url": "...", "size": "1.5MB"}, ...]
|
||||
attachments: Mapped[Optional[list]] = mapped_column(JSONType, nullable=True)
|
||||
|
||||
# 关联
|
||||
project: Mapped["Project"] = relationship("Project", back_populates="brief")
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<Brief(id={self.id}, project_id={self.project_id})>"
|
||||
160
backend/app/models/organization.py
Normal file
160
backend/app/models/organization.py
Normal file
@ -0,0 +1,160 @@
|
||||
"""
|
||||
组织模型:品牌方、代理商、达人
|
||||
"""
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
from datetime import datetime
|
||||
from sqlalchemy import String, Boolean, Text, ForeignKey, DateTime, Table, Column
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from app.models.base import Base, TimestampMixin
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.models.user import User
|
||||
from app.models.project import Project
|
||||
|
||||
|
||||
# 品牌方-代理商 关联表(多对多)
|
||||
brand_agency_association = Table(
|
||||
"brand_agency",
|
||||
Base.metadata,
|
||||
Column("brand_id", String(64), ForeignKey("brands.id", ondelete="CASCADE"), primary_key=True),
|
||||
Column("agency_id", String(64), ForeignKey("agencies.id", ondelete="CASCADE"), primary_key=True),
|
||||
Column("created_at", DateTime(timezone=True), default=datetime.utcnow),
|
||||
Column("is_active", Boolean, default=True),
|
||||
)
|
||||
|
||||
# 代理商-达人 关联表(多对多)
|
||||
agency_creator_association = Table(
|
||||
"agency_creator",
|
||||
Base.metadata,
|
||||
Column("agency_id", String(64), ForeignKey("agencies.id", ondelete="CASCADE"), primary_key=True),
|
||||
Column("creator_id", String(64), ForeignKey("creators.id", ondelete="CASCADE"), primary_key=True),
|
||||
Column("created_at", DateTime(timezone=True), default=datetime.utcnow),
|
||||
Column("is_active", Boolean, default=True),
|
||||
)
|
||||
|
||||
|
||||
class Brand(Base, TimestampMixin):
|
||||
"""品牌方表(即租户)"""
|
||||
__tablename__ = "brands"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(64), primary_key=True) # 格式: BR123456
|
||||
user_id: Mapped[str] = mapped_column(
|
||||
String(64),
|
||||
ForeignKey("users.id", ondelete="CASCADE"),
|
||||
unique=True,
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
# 品牌信息
|
||||
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
logo: Mapped[Optional[str]] = mapped_column(String(2048), nullable=True)
|
||||
description: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
|
||||
# 联系信息
|
||||
contact_name: Mapped[Optional[str]] = mapped_column(String(100), nullable=True)
|
||||
contact_phone: Mapped[Optional[str]] = mapped_column(String(20), nullable=True)
|
||||
contact_email: Mapped[Optional[str]] = mapped_column(String(255), nullable=True)
|
||||
|
||||
# 设置
|
||||
final_review_enabled: Mapped[bool] = mapped_column(Boolean, default=True, nullable=False) # 终审开关
|
||||
|
||||
# 状态
|
||||
is_active: Mapped[bool] = mapped_column(Boolean, default=True, nullable=False)
|
||||
|
||||
# 关联
|
||||
user: Mapped["User"] = relationship("User", back_populates="brand")
|
||||
agencies: Mapped[list["Agency"]] = relationship(
|
||||
"Agency",
|
||||
secondary=brand_agency_association,
|
||||
back_populates="brands",
|
||||
)
|
||||
projects: Mapped[list["Project"]] = relationship(
|
||||
"Project",
|
||||
back_populates="brand",
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<Brand(id={self.id}, name={self.name})>"
|
||||
|
||||
|
||||
class Agency(Base, TimestampMixin):
|
||||
"""代理商表"""
|
||||
__tablename__ = "agencies"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(64), primary_key=True) # 格式: AG123456
|
||||
user_id: Mapped[str] = mapped_column(
|
||||
String(64),
|
||||
ForeignKey("users.id", ondelete="CASCADE"),
|
||||
unique=True,
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
# 代理商信息
|
||||
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
logo: Mapped[Optional[str]] = mapped_column(String(2048), nullable=True)
|
||||
description: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
|
||||
# 联系信息
|
||||
contact_name: Mapped[Optional[str]] = mapped_column(String(100), nullable=True)
|
||||
contact_phone: Mapped[Optional[str]] = mapped_column(String(20), nullable=True)
|
||||
contact_email: Mapped[Optional[str]] = mapped_column(String(255), nullable=True)
|
||||
|
||||
# 权限设置(可被品牌方覆盖)
|
||||
force_pass_enabled: Mapped[bool] = mapped_column(Boolean, default=True, nullable=False) # 强制通过权
|
||||
|
||||
# 状态
|
||||
is_active: Mapped[bool] = mapped_column(Boolean, default=True, nullable=False)
|
||||
|
||||
# 关联
|
||||
user: Mapped["User"] = relationship("User", back_populates="agency")
|
||||
brands: Mapped[list["Brand"]] = relationship(
|
||||
"Brand",
|
||||
secondary=brand_agency_association,
|
||||
back_populates="agencies",
|
||||
)
|
||||
creators: Mapped[list["Creator"]] = relationship(
|
||||
"Creator",
|
||||
secondary=agency_creator_association,
|
||||
back_populates="agencies",
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<Agency(id={self.id}, name={self.name})>"
|
||||
|
||||
|
||||
class Creator(Base, TimestampMixin):
|
||||
"""达人表"""
|
||||
__tablename__ = "creators"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(64), primary_key=True) # 格式: CR123456
|
||||
user_id: Mapped[str] = mapped_column(
|
||||
String(64),
|
||||
ForeignKey("users.id", ondelete="CASCADE"),
|
||||
unique=True,
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
# 达人信息
|
||||
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
avatar: Mapped[Optional[str]] = mapped_column(String(2048), nullable=True)
|
||||
bio: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
|
||||
# 社交账号
|
||||
douyin_account: Mapped[Optional[str]] = mapped_column(String(100), nullable=True)
|
||||
xiaohongshu_account: Mapped[Optional[str]] = mapped_column(String(100), nullable=True)
|
||||
bilibili_account: Mapped[Optional[str]] = mapped_column(String(100), nullable=True)
|
||||
|
||||
# 状态
|
||||
is_active: Mapped[bool] = mapped_column(Boolean, default=True, nullable=False)
|
||||
|
||||
# 关联
|
||||
user: Mapped["User"] = relationship("User", back_populates="creator")
|
||||
agencies: Mapped[list["Agency"]] = relationship(
|
||||
"Agency",
|
||||
secondary=agency_creator_association,
|
||||
back_populates="creators",
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<Creator(id={self.id}, name={self.name})>"
|
||||
74
backend/app/models/project.py
Normal file
74
backend/app/models/project.py
Normal file
@ -0,0 +1,74 @@
|
||||
"""
|
||||
项目模型
|
||||
"""
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
from datetime import datetime
|
||||
from sqlalchemy import String, Text, ForeignKey, DateTime, Table, Column, Boolean
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from app.models.base import Base, TimestampMixin
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.models.organization import Brand, Agency
|
||||
from app.models.task import Task
|
||||
from app.models.brief import Brief
|
||||
|
||||
|
||||
# 项目-代理商 关联表(一个项目可以分配给多个代理商)
|
||||
project_agency_association = Table(
|
||||
"project_agency",
|
||||
Base.metadata,
|
||||
Column("project_id", String(64), ForeignKey("projects.id", ondelete="CASCADE"), primary_key=True),
|
||||
Column("agency_id", String(64), ForeignKey("agencies.id", ondelete="CASCADE"), primary_key=True),
|
||||
Column("created_at", DateTime(timezone=True), default=datetime.utcnow),
|
||||
Column("is_active", Boolean, default=True),
|
||||
)
|
||||
|
||||
|
||||
class Project(Base, TimestampMixin):
|
||||
"""项目表"""
|
||||
__tablename__ = "projects"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(64), primary_key=True)
|
||||
brand_id: Mapped[str] = mapped_column(
|
||||
String(64),
|
||||
ForeignKey("brands.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
|
||||
# 项目信息
|
||||
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
description: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
|
||||
# 时间
|
||||
start_date: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||
deadline: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||
|
||||
# 状态
|
||||
status: Mapped[str] = mapped_column(
|
||||
String(20),
|
||||
default="active", # active, completed, archived
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
|
||||
# 关联
|
||||
brand: Mapped["Brand"] = relationship("Brand", back_populates="projects")
|
||||
agencies: Mapped[list["Agency"]] = relationship(
|
||||
"Agency",
|
||||
secondary=project_agency_association,
|
||||
backref="projects",
|
||||
)
|
||||
tasks: Mapped[list["Task"]] = relationship(
|
||||
"Task",
|
||||
back_populates="project",
|
||||
)
|
||||
brief: Mapped[Optional["Brief"]] = relationship(
|
||||
"Brief",
|
||||
back_populates="project",
|
||||
uselist=False,
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<Project(id={self.id}, name={self.name})>"
|
||||
148
backend/app/models/task.py
Normal file
148
backend/app/models/task.py
Normal file
@ -0,0 +1,148 @@
|
||||
"""
|
||||
任务模型
|
||||
"""
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
from datetime import datetime
|
||||
from sqlalchemy import String, Integer, Text, ForeignKey, DateTime, Enum as SQLEnum, Boolean
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
import enum
|
||||
|
||||
from app.models.base import Base, TimestampMixin
|
||||
from app.models.types import JSONType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.models.project import Project
|
||||
from app.models.organization import Agency, Creator
|
||||
|
||||
|
||||
class TaskStage(str, enum.Enum):
|
||||
"""任务阶段"""
|
||||
SCRIPT_UPLOAD = "script_upload" # 待上传脚本
|
||||
SCRIPT_AI_REVIEW = "script_ai_review" # 脚本 AI 审核中
|
||||
SCRIPT_AGENCY_REVIEW = "script_agency_review" # 脚本代理商审核中
|
||||
SCRIPT_BRAND_REVIEW = "script_brand_review" # 脚本品牌方终审中
|
||||
VIDEO_UPLOAD = "video_upload" # 待上传视频
|
||||
VIDEO_AI_REVIEW = "video_ai_review" # 视频 AI 审核中
|
||||
VIDEO_AGENCY_REVIEW = "video_agency_review" # 视频代理商审核中
|
||||
VIDEO_BRAND_REVIEW = "video_brand_review" # 视频品牌方终审中
|
||||
COMPLETED = "completed" # 已完成
|
||||
REJECTED = "rejected" # 已驳回
|
||||
|
||||
|
||||
class TaskStatus(str, enum.Enum):
|
||||
"""任务状态"""
|
||||
PENDING = "pending" # 待处理
|
||||
PROCESSING = "processing" # 处理中
|
||||
PASSED = "passed" # 通过
|
||||
REJECTED = "rejected" # 驳回
|
||||
FORCE_PASSED = "force_passed" # 强制通过
|
||||
|
||||
|
||||
class Task(Base, TimestampMixin):
|
||||
"""任务表"""
|
||||
__tablename__ = "tasks"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(64), primary_key=True)
|
||||
|
||||
# 关联
|
||||
project_id: Mapped[str] = mapped_column(
|
||||
String(64),
|
||||
ForeignKey("projects.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
agency_id: Mapped[str] = mapped_column(
|
||||
String(64),
|
||||
ForeignKey("agencies.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
creator_id: Mapped[str] = mapped_column(
|
||||
String(64),
|
||||
ForeignKey("creators.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
|
||||
# 任务信息
|
||||
name: Mapped[str] = mapped_column(String(255), nullable=False) # 如 "宣传任务(1)"
|
||||
sequence: Mapped[int] = mapped_column(Integer, default=1, nullable=False) # 序号
|
||||
|
||||
# 当前阶段
|
||||
stage: Mapped[TaskStage] = mapped_column(
|
||||
SQLEnum(TaskStage, name="task_stage_enum"),
|
||||
default=TaskStage.SCRIPT_UPLOAD,
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
|
||||
# ===== 脚本相关 =====
|
||||
script_file_url: Mapped[Optional[str]] = mapped_column(String(2048), nullable=True)
|
||||
script_file_name: Mapped[Optional[str]] = mapped_column(String(255), nullable=True)
|
||||
script_uploaded_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||
|
||||
# 脚本 AI 审核结果
|
||||
script_ai_score: Mapped[Optional[int]] = mapped_column(Integer, nullable=True)
|
||||
script_ai_result: Mapped[Optional[dict]] = mapped_column(JSONType, nullable=True)
|
||||
script_ai_reviewed_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||
|
||||
# 脚本代理商审核
|
||||
script_agency_status: Mapped[Optional[TaskStatus]] = mapped_column(
|
||||
SQLEnum(TaskStatus, name="task_status_enum"),
|
||||
nullable=True,
|
||||
)
|
||||
script_agency_comment: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
script_agency_reviewer_id: Mapped[Optional[str]] = mapped_column(String(64), nullable=True)
|
||||
script_agency_reviewed_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||
|
||||
# 脚本品牌方终审
|
||||
script_brand_status: Mapped[Optional[TaskStatus]] = mapped_column(
|
||||
SQLEnum(TaskStatus, name="task_status_enum", create_type=False),
|
||||
nullable=True,
|
||||
)
|
||||
script_brand_comment: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
script_brand_reviewer_id: Mapped[Optional[str]] = mapped_column(String(64), nullable=True)
|
||||
script_brand_reviewed_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||
|
||||
# ===== 视频相关 =====
|
||||
video_file_url: Mapped[Optional[str]] = mapped_column(String(2048), nullable=True)
|
||||
video_file_name: Mapped[Optional[str]] = mapped_column(String(255), nullable=True)
|
||||
video_duration: Mapped[Optional[int]] = mapped_column(Integer, nullable=True) # 秒
|
||||
video_thumbnail_url: Mapped[Optional[str]] = mapped_column(String(2048), nullable=True)
|
||||
video_uploaded_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||
|
||||
# 视频 AI 审核结果
|
||||
video_ai_score: Mapped[Optional[int]] = mapped_column(Integer, nullable=True)
|
||||
video_ai_result: Mapped[Optional[dict]] = mapped_column(JSONType, nullable=True)
|
||||
video_ai_reviewed_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||
|
||||
# 视频代理商审核
|
||||
video_agency_status: Mapped[Optional[TaskStatus]] = mapped_column(
|
||||
SQLEnum(TaskStatus, name="task_status_enum", create_type=False),
|
||||
nullable=True,
|
||||
)
|
||||
video_agency_comment: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
video_agency_reviewer_id: Mapped[Optional[str]] = mapped_column(String(64), nullable=True)
|
||||
video_agency_reviewed_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||
|
||||
# 视频品牌方终审
|
||||
video_brand_status: Mapped[Optional[TaskStatus]] = mapped_column(
|
||||
SQLEnum(TaskStatus, name="task_status_enum", create_type=False),
|
||||
nullable=True,
|
||||
)
|
||||
video_brand_comment: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
video_brand_reviewer_id: Mapped[Optional[str]] = mapped_column(String(64), nullable=True)
|
||||
video_brand_reviewed_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||
|
||||
# ===== 申诉相关 =====
|
||||
appeal_count: Mapped[int] = mapped_column(Integer, default=1, nullable=False) # 剩余申诉次数
|
||||
is_appeal: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False) # 是否为申诉
|
||||
appeal_reason: Mapped[Optional[str]] = mapped_column(Text, nullable=True) # 申诉理由
|
||||
|
||||
# 关联
|
||||
project: Mapped["Project"] = relationship("Project", back_populates="tasks")
|
||||
agency: Mapped["Agency"] = relationship("Agency", foreign_keys=[agency_id])
|
||||
creator: Mapped["Creator"] = relationship("Creator", foreign_keys=[creator_id])
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<Task(id={self.id}, name={self.name}, stage={self.stage})>"
|
||||
80
backend/app/models/user.py
Normal file
80
backend/app/models/user.py
Normal file
@ -0,0 +1,80 @@
|
||||
"""
|
||||
用户模型
|
||||
"""
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
from datetime import datetime
|
||||
from sqlalchemy import String, Boolean, DateTime, Enum as SQLEnum
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
import enum
|
||||
|
||||
from app.models.base import Base, TimestampMixin
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.models.organization import Brand, Agency, Creator
|
||||
|
||||
|
||||
class UserRole(str, enum.Enum):
|
||||
"""用户角色"""
|
||||
BRAND = "brand" # 品牌方
|
||||
AGENCY = "agency" # 代理商
|
||||
CREATOR = "creator" # 达人
|
||||
|
||||
|
||||
class User(Base, TimestampMixin):
|
||||
"""用户表"""
|
||||
__tablename__ = "users"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(64), primary_key=True)
|
||||
|
||||
# 登录凭证(邮箱和手机号都可以登录)
|
||||
email: Mapped[Optional[str]] = mapped_column(String(255), unique=True, nullable=True, index=True)
|
||||
phone: Mapped[Optional[str]] = mapped_column(String(20), unique=True, nullable=True, index=True)
|
||||
password_hash: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
|
||||
# 用户信息
|
||||
name: Mapped[str] = mapped_column(String(100), nullable=False)
|
||||
avatar: Mapped[Optional[str]] = mapped_column(String(2048), nullable=True)
|
||||
|
||||
# 角色
|
||||
role: Mapped[UserRole] = mapped_column(
|
||||
SQLEnum(UserRole, name="user_role_enum"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
|
||||
# 状态
|
||||
is_active: Mapped[bool] = mapped_column(Boolean, default=True, nullable=False)
|
||||
is_verified: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False)
|
||||
|
||||
# 最后登录
|
||||
last_login_at: Mapped[Optional[datetime]] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
nullable=True,
|
||||
)
|
||||
|
||||
# Refresh Token(用于 JWT 刷新)
|
||||
refresh_token: Mapped[Optional[str]] = mapped_column(String(512), nullable=True)
|
||||
refresh_token_expires_at: Mapped[Optional[datetime]] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
nullable=True,
|
||||
)
|
||||
|
||||
# 关联的组织(根据角色不同,关联到不同的组织)
|
||||
brand: Mapped[Optional["Brand"]] = relationship(
|
||||
"Brand",
|
||||
back_populates="user",
|
||||
uselist=False,
|
||||
)
|
||||
agency: Mapped[Optional["Agency"]] = relationship(
|
||||
"Agency",
|
||||
back_populates="user",
|
||||
uselist=False,
|
||||
)
|
||||
creator: Mapped[Optional["Creator"]] = relationship(
|
||||
"Creator",
|
||||
back_populates="user",
|
||||
uselist=False,
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<User(id={self.id}, email={self.email}, role={self.role})>"
|
||||
120
backend/app/schemas/auth.py
Normal file
120
backend/app/schemas/auth.py
Normal file
@ -0,0 +1,120 @@
|
||||
"""
|
||||
认证相关 Schema
|
||||
"""
|
||||
from typing import Optional
|
||||
from pydantic import BaseModel, EmailStr, Field
|
||||
from app.models.user import UserRole
|
||||
|
||||
|
||||
# ===== 请求 =====
|
||||
|
||||
class RegisterRequest(BaseModel):
|
||||
"""注册请求"""
|
||||
email: Optional[EmailStr] = None
|
||||
phone: Optional[str] = Field(None, pattern=r"^1[3-9]\d{9}$")
|
||||
password: str = Field(..., min_length=6, max_length=128)
|
||||
name: str = Field(..., min_length=1, max_length=100)
|
||||
role: UserRole
|
||||
|
||||
class Config:
|
||||
json_schema_extra = {
|
||||
"example": {
|
||||
"email": "user@example.com",
|
||||
"password": "password123",
|
||||
"name": "张三",
|
||||
"role": "creator"
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class LoginRequest(BaseModel):
|
||||
"""登录请求(邮箱或手机号)"""
|
||||
email: Optional[EmailStr] = None
|
||||
phone: Optional[str] = None
|
||||
password: Optional[str] = None
|
||||
sms_code: Optional[str] = None # 短信验证码
|
||||
|
||||
class Config:
|
||||
json_schema_extra = {
|
||||
"example": {
|
||||
"email": "user@example.com",
|
||||
"password": "password123"
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class RefreshTokenRequest(BaseModel):
|
||||
"""刷新 Token 请求"""
|
||||
refresh_token: str
|
||||
|
||||
|
||||
class SendSmsCodeRequest(BaseModel):
|
||||
"""发送短信验证码请求"""
|
||||
phone: str = Field(..., pattern=r"^1[3-9]\d{9}$")
|
||||
|
||||
|
||||
class BindPhoneRequest(BaseModel):
|
||||
"""绑定手机号请求"""
|
||||
phone: str = Field(..., pattern=r"^1[3-9]\d{9}$")
|
||||
sms_code: str
|
||||
|
||||
|
||||
class BindEmailRequest(BaseModel):
|
||||
"""绑定邮箱请求"""
|
||||
email: EmailStr
|
||||
password: str = Field(..., min_length=6, max_length=128)
|
||||
|
||||
|
||||
class ChangePasswordRequest(BaseModel):
|
||||
"""修改密码请求"""
|
||||
old_password: str
|
||||
new_password: str = Field(..., min_length=6, max_length=128)
|
||||
|
||||
|
||||
# ===== 响应 =====
|
||||
|
||||
class UserResponse(BaseModel):
|
||||
"""用户信息响应"""
|
||||
id: str
|
||||
email: Optional[str] = None
|
||||
phone: Optional[str] = None
|
||||
name: str
|
||||
avatar: Optional[str] = None
|
||||
role: UserRole
|
||||
is_verified: bool
|
||||
|
||||
# 根据角色返回对应的组织 ID
|
||||
brand_id: Optional[str] = None
|
||||
agency_id: Optional[str] = None
|
||||
creator_id: Optional[str] = None
|
||||
|
||||
# 当前所属租户(品牌方)- 用于数据隔离
|
||||
tenant_id: Optional[str] = None
|
||||
tenant_name: Optional[str] = None
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class TokenResponse(BaseModel):
|
||||
"""Token 响应"""
|
||||
access_token: str
|
||||
refresh_token: str
|
||||
token_type: str = "bearer"
|
||||
expires_in: int = 900 # 15 分钟 = 900 秒
|
||||
|
||||
|
||||
class LoginResponse(BaseModel):
|
||||
"""登录响应"""
|
||||
access_token: str
|
||||
refresh_token: str
|
||||
token_type: str = "bearer"
|
||||
expires_in: int = 900
|
||||
user: UserResponse
|
||||
|
||||
|
||||
class RefreshTokenResponse(BaseModel):
|
||||
"""刷新 Token 响应"""
|
||||
access_token: str
|
||||
token_type: str = "bearer"
|
||||
expires_in: int = 900
|
||||
215
backend/app/services/auth.py
Normal file
215
backend/app/services/auth.py
Normal file
@ -0,0 +1,215 @@
|
||||
"""
|
||||
认证服务
|
||||
"""
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional
|
||||
import secrets
|
||||
from jose import jwt, JWTError
|
||||
from passlib.context import CryptContext
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.config import settings
|
||||
from app.models.user import User, UserRole
|
||||
from app.models.organization import Brand, Agency, Creator
|
||||
|
||||
# 密码加密上下文
|
||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||
|
||||
|
||||
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
||||
"""验证密码"""
|
||||
return pwd_context.verify(plain_password, hashed_password)
|
||||
|
||||
|
||||
def hash_password(password: str) -> str:
|
||||
"""哈希密码"""
|
||||
return pwd_context.hash(password)
|
||||
|
||||
|
||||
def generate_id(prefix: str) -> str:
|
||||
"""生成语义化 ID"""
|
||||
# 格式: BR123456, AG123456, CR123456
|
||||
random_part = secrets.randbelow(900000) + 100000 # 100000-999999
|
||||
return f"{prefix}{random_part}"
|
||||
|
||||
|
||||
def create_access_token(user_id: str, expires_delta: Optional[timedelta] = None) -> str:
|
||||
"""创建访问 Token"""
|
||||
if expires_delta is None:
|
||||
expires_delta = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||
|
||||
expire = datetime.utcnow() + expires_delta
|
||||
to_encode = {
|
||||
"sub": user_id,
|
||||
"exp": expire,
|
||||
"type": "access",
|
||||
}
|
||||
return jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
|
||||
|
||||
|
||||
def create_refresh_token(user_id: str, expires_days: int = 7) -> tuple[str, datetime]:
|
||||
"""创建刷新 Token"""
|
||||
expire = datetime.utcnow() + timedelta(days=expires_days)
|
||||
to_encode = {
|
||||
"sub": user_id,
|
||||
"exp": expire,
|
||||
"type": "refresh",
|
||||
}
|
||||
token = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
|
||||
return token, expire
|
||||
|
||||
|
||||
def decode_token(token: str) -> Optional[dict]:
|
||||
"""解码 Token"""
|
||||
try:
|
||||
payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM])
|
||||
return payload
|
||||
except JWTError:
|
||||
return None
|
||||
|
||||
|
||||
async def get_user_by_email(db: AsyncSession, email: str) -> Optional[User]:
|
||||
"""通过邮箱获取用户"""
|
||||
result = await db.execute(
|
||||
select(User).where(User.email == email)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
|
||||
async def get_user_by_phone(db: AsyncSession, phone: str) -> Optional[User]:
|
||||
"""通过手机号获取用户"""
|
||||
result = await db.execute(
|
||||
select(User).where(User.phone == phone)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
|
||||
async def get_user_by_id(db: AsyncSession, user_id: str) -> Optional[User]:
|
||||
"""通过 ID 获取用户"""
|
||||
result = await db.execute(
|
||||
select(User).where(User.id == user_id)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
|
||||
async def create_user(
|
||||
db: AsyncSession,
|
||||
email: Optional[str],
|
||||
phone: Optional[str],
|
||||
password: str,
|
||||
name: str,
|
||||
role: UserRole,
|
||||
) -> User:
|
||||
"""创建用户"""
|
||||
user_id = generate_id("U")
|
||||
|
||||
user = User(
|
||||
id=user_id,
|
||||
email=email,
|
||||
phone=phone,
|
||||
password_hash=hash_password(password),
|
||||
name=name,
|
||||
role=role,
|
||||
is_active=True,
|
||||
is_verified=False,
|
||||
)
|
||||
db.add(user)
|
||||
|
||||
# 根据角色创建对应的组织实体
|
||||
if role == UserRole.BRAND:
|
||||
brand = Brand(
|
||||
id=generate_id("BR"),
|
||||
user_id=user_id,
|
||||
name=name,
|
||||
)
|
||||
db.add(brand)
|
||||
elif role == UserRole.AGENCY:
|
||||
agency = Agency(
|
||||
id=generate_id("AG"),
|
||||
user_id=user_id,
|
||||
name=name,
|
||||
)
|
||||
db.add(agency)
|
||||
elif role == UserRole.CREATOR:
|
||||
creator = Creator(
|
||||
id=generate_id("CR"),
|
||||
user_id=user_id,
|
||||
name=name,
|
||||
)
|
||||
db.add(creator)
|
||||
|
||||
await db.flush()
|
||||
return user
|
||||
|
||||
|
||||
async def authenticate_user(
|
||||
db: AsyncSession,
|
||||
email: Optional[str] = None,
|
||||
phone: Optional[str] = None,
|
||||
password: Optional[str] = None,
|
||||
) -> Optional[User]:
|
||||
"""验证用户登录"""
|
||||
user = None
|
||||
|
||||
if email:
|
||||
user = await get_user_by_email(db, email)
|
||||
elif phone:
|
||||
user = await get_user_by_phone(db, phone)
|
||||
|
||||
if not user:
|
||||
return None
|
||||
|
||||
if password and not verify_password(password, user.password_hash):
|
||||
return None
|
||||
|
||||
return user
|
||||
|
||||
|
||||
async def update_refresh_token(db: AsyncSession, user: User, refresh_token: str, expires_at: datetime) -> None:
|
||||
"""更新用户的刷新 Token"""
|
||||
user.refresh_token = refresh_token
|
||||
user.refresh_token_expires_at = expires_at
|
||||
user.last_login_at = datetime.utcnow()
|
||||
await db.flush()
|
||||
|
||||
|
||||
async def get_user_organization_info(db: AsyncSession, user: User) -> dict:
|
||||
"""获取用户的组织信息"""
|
||||
info = {
|
||||
"brand_id": None,
|
||||
"agency_id": None,
|
||||
"creator_id": None,
|
||||
"tenant_id": None,
|
||||
"tenant_name": None,
|
||||
}
|
||||
|
||||
if user.role == UserRole.BRAND:
|
||||
result = await db.execute(
|
||||
select(Brand).where(Brand.user_id == user.id)
|
||||
)
|
||||
brand = result.scalar_one_or_none()
|
||||
if brand:
|
||||
info["brand_id"] = brand.id
|
||||
info["tenant_id"] = brand.id
|
||||
info["tenant_name"] = brand.name
|
||||
|
||||
elif user.role == UserRole.AGENCY:
|
||||
result = await db.execute(
|
||||
select(Agency).where(Agency.user_id == user.id)
|
||||
)
|
||||
agency = result.scalar_one_or_none()
|
||||
if agency:
|
||||
info["agency_id"] = agency.id
|
||||
# 代理商可能服务多个品牌,这里暂时不设置 tenant
|
||||
|
||||
elif user.role == UserRole.CREATOR:
|
||||
result = await db.execute(
|
||||
select(Creator).where(Creator.user_id == user.id)
|
||||
)
|
||||
creator = result.scalar_one_or_none()
|
||||
if creator:
|
||||
info["creator_id"] = creator.id
|
||||
# 达人可能服务多个代理商,这里暂时不设置 tenant
|
||||
|
||||
return info
|
||||
152
backend/app/services/oss.py
Normal file
152
backend/app/services/oss.py
Normal file
@ -0,0 +1,152 @@
|
||||
"""
|
||||
阿里云 OSS 服务
|
||||
"""
|
||||
import time
|
||||
import hmac
|
||||
import base64
|
||||
import hashlib
|
||||
import json
|
||||
from typing import Optional
|
||||
from datetime import datetime
|
||||
from app.config import settings
|
||||
|
||||
|
||||
def generate_upload_policy(
|
||||
max_size_mb: int = 500,
|
||||
expire_seconds: int = 3600,
|
||||
upload_dir: Optional[str] = None,
|
||||
) -> dict:
|
||||
"""
|
||||
生成前端直传 OSS 所需的 Policy 和签名
|
||||
|
||||
Returns:
|
||||
{
|
||||
"accessKeyId": "...",
|
||||
"policy": "base64 encoded policy",
|
||||
"signature": "...",
|
||||
"host": "https://bucket.oss-cn-hangzhou.aliyuncs.com",
|
||||
"dir": "uploads/2026/02/",
|
||||
"expire": 1234567890
|
||||
}
|
||||
"""
|
||||
if not settings.OSS_ACCESS_KEY_ID or not settings.OSS_ACCESS_KEY_SECRET:
|
||||
raise ValueError("OSS 配置未设置")
|
||||
|
||||
# 计算过期时间
|
||||
expire_time = int(time.time()) + expire_seconds
|
||||
expire_date = datetime.utcfromtimestamp(expire_time).strftime("%Y-%m-%dT%H:%M:%SZ")
|
||||
|
||||
# 默认上传目录:uploads/年/月/
|
||||
if upload_dir is None:
|
||||
now = datetime.now()
|
||||
upload_dir = f"uploads/{now.year}/{now.month:02d}/"
|
||||
|
||||
# 构建 Policy
|
||||
policy_dict = {
|
||||
"expiration": expire_date,
|
||||
"conditions": [
|
||||
{"bucket": settings.OSS_BUCKET_NAME},
|
||||
["starts-with", "$key", upload_dir],
|
||||
["content-length-range", 0, max_size_mb * 1024 * 1024],
|
||||
]
|
||||
}
|
||||
|
||||
# Base64 编码 Policy
|
||||
policy_json = json.dumps(policy_dict)
|
||||
policy_base64 = base64.b64encode(policy_json.encode()).decode()
|
||||
|
||||
# 计算签名
|
||||
signature = base64.b64encode(
|
||||
hmac.new(
|
||||
settings.OSS_ACCESS_KEY_SECRET.encode(),
|
||||
policy_base64.encode(),
|
||||
hashlib.sha1
|
||||
).digest()
|
||||
).decode()
|
||||
|
||||
# 构建 Host
|
||||
host = settings.OSS_BUCKET_DOMAIN
|
||||
if not host:
|
||||
host = f"https://{settings.OSS_BUCKET_NAME}.{settings.OSS_ENDPOINT}"
|
||||
|
||||
return {
|
||||
"accessKeyId": settings.OSS_ACCESS_KEY_ID,
|
||||
"policy": policy_base64,
|
||||
"signature": signature,
|
||||
"host": host,
|
||||
"dir": upload_dir,
|
||||
"expire": expire_time,
|
||||
}
|
||||
|
||||
|
||||
def generate_sts_token(
|
||||
role_arn: str,
|
||||
session_name: str = "miaosi-upload",
|
||||
duration_seconds: int = 3600,
|
||||
) -> dict:
|
||||
"""
|
||||
生成 STS 临时凭证(需要配置 RAM 角色)
|
||||
|
||||
注意:此方法需要安装 aliyun-python-sdk-sts
|
||||
如果不使用 STS,可以使用上面的 generate_upload_policy 方法
|
||||
"""
|
||||
# TODO: 实现 STS 临时凭证生成
|
||||
# 需要安装 aliyun-python-sdk-core 和 aliyun-python-sdk-sts
|
||||
raise NotImplementedError("STS 临时凭证生成暂未实现,请使用 generate_upload_policy")
|
||||
|
||||
|
||||
def get_file_url(file_key: str) -> str:
|
||||
"""
|
||||
获取文件的公开访问 URL
|
||||
|
||||
Args:
|
||||
file_key: 文件在 OSS 中的 key,如 "uploads/2026/02/video.mp4"
|
||||
|
||||
Returns:
|
||||
完整的访问 URL
|
||||
"""
|
||||
host = settings.OSS_BUCKET_DOMAIN
|
||||
if not host:
|
||||
host = f"https://{settings.OSS_BUCKET_NAME}.{settings.OSS_ENDPOINT}"
|
||||
|
||||
# 确保 host 以 https:// 开头
|
||||
if not host.startswith("http"):
|
||||
host = f"https://{host}"
|
||||
|
||||
# 确保 host 不以 / 结尾
|
||||
host = host.rstrip("/")
|
||||
|
||||
# 确保 file_key 不以 / 开头
|
||||
file_key = file_key.lstrip("/")
|
||||
|
||||
return f"{host}/{file_key}"
|
||||
|
||||
|
||||
def parse_file_key_from_url(url: str) -> str:
|
||||
"""
|
||||
从完整 URL 解析出文件 key
|
||||
|
||||
Args:
|
||||
url: 完整的 OSS URL
|
||||
|
||||
Returns:
|
||||
文件 key
|
||||
"""
|
||||
host = settings.OSS_BUCKET_DOMAIN
|
||||
if not host:
|
||||
host = f"https://{settings.OSS_BUCKET_NAME}.{settings.OSS_ENDPOINT}"
|
||||
|
||||
# 移除 host 前缀
|
||||
if url.startswith(host):
|
||||
return url[len(host):].lstrip("/")
|
||||
|
||||
# 尝试其他格式
|
||||
if settings.OSS_BUCKET_NAME in url:
|
||||
# 格式: https://bucket.endpoint/key
|
||||
parts = url.split(settings.OSS_BUCKET_NAME + ".")
|
||||
if len(parts) > 1:
|
||||
key_part = parts[1].split("/", 1)
|
||||
if len(key_part) > 1:
|
||||
return key_part[1]
|
||||
|
||||
return url
|
||||
@ -12,7 +12,7 @@ dependencies = [
|
||||
"asyncpg>=0.29.0",
|
||||
"greenlet>=3.0.0",
|
||||
"httpx>=0.26.0",
|
||||
"pydantic>=2.5.0",
|
||||
"pydantic[email]>=2.5.0",
|
||||
"pydantic-settings>=2.0.0",
|
||||
"python-jose>=3.3.0",
|
||||
"passlib>=1.7.4",
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user