- 后端新增: Project CRUD / Brief CRUD / 组织关系管理 / 工作台统计 / SSE 推送 / 认证依赖注入 - 后端完善: 任务 API 全流程(创建/审核/申诉) + Task Service + Task Schema - 前端修复: login 页面 localStorage key 错误 (miaosi_auth -> miaosi_user) - 前端对齐: types/task.ts 与后端 TaskStage/TaskResponse 完全对齐 - 前端新增: project/brief/organization/dashboard 类型定义 - 前端补全: api.ts 新增 30+ API 方法覆盖所有后端接口 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
236 lines
6.2 KiB
Python
236 lines
6.2 KiB
Python
"""
|
|
SSE (Server-Sent Events) 实时推送 API
|
|
用于推送审核进度等实时通知
|
|
"""
|
|
import asyncio
|
|
import json
|
|
from typing import AsyncGenerator, Optional, Set
|
|
from datetime import datetime
|
|
from fastapi import APIRouter, Depends, HTTPException, status
|
|
from sse_starlette.sse import EventSourceResponse
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.database import get_db
|
|
from app.models.user import User, UserRole
|
|
from app.models.organization import Brand, Agency, Creator
|
|
from app.api.deps import get_current_user
|
|
from sqlalchemy import select
|
|
|
|
router = APIRouter(prefix="/sse", tags=["实时推送"])
|
|
|
|
# 存储活跃的客户端连接
|
|
# 结构: {user_id: set of AsyncGenerator}
|
|
active_connections: dict[str, Set[asyncio.Queue]] = {}
|
|
|
|
|
|
async def add_connection(user_id: str, queue: asyncio.Queue):
|
|
"""添加客户端连接"""
|
|
if user_id not in active_connections:
|
|
active_connections[user_id] = set()
|
|
active_connections[user_id].add(queue)
|
|
|
|
|
|
async def remove_connection(user_id: str, queue: asyncio.Queue):
|
|
"""移除客户端连接"""
|
|
if user_id in active_connections:
|
|
active_connections[user_id].discard(queue)
|
|
if not active_connections[user_id]:
|
|
del active_connections[user_id]
|
|
|
|
|
|
async def send_to_user(user_id: str, event: str, data: dict):
|
|
"""发送消息给指定用户的所有连接"""
|
|
if user_id in active_connections:
|
|
message = {
|
|
"event": event,
|
|
"data": data,
|
|
"timestamp": datetime.utcnow().isoformat(),
|
|
}
|
|
for queue in active_connections[user_id]:
|
|
await queue.put(message)
|
|
|
|
|
|
async def broadcast_to_role(role: UserRole, event: str, data: dict, db: AsyncSession):
|
|
"""广播消息给指定角色的所有用户"""
|
|
# 这里简化处理,实际应该批量查询
|
|
# 在生产环境中应该使用 Redis 等消息队列
|
|
pass
|
|
|
|
|
|
async def event_generator(user_id: str, queue: asyncio.Queue) -> AsyncGenerator[dict, None]:
|
|
"""SSE 事件生成器"""
|
|
try:
|
|
await add_connection(user_id, queue)
|
|
|
|
# 发送连接成功消息
|
|
yield {
|
|
"event": "connected",
|
|
"data": json.dumps({
|
|
"message": "连接成功",
|
|
"user_id": user_id,
|
|
}),
|
|
}
|
|
|
|
while True:
|
|
try:
|
|
# 等待消息,超时后发送心跳
|
|
message = await asyncio.wait_for(queue.get(), timeout=30.0)
|
|
yield {
|
|
"event": message["event"],
|
|
"data": json.dumps(message["data"]),
|
|
}
|
|
except asyncio.TimeoutError:
|
|
# 发送心跳保持连接
|
|
yield {
|
|
"event": "heartbeat",
|
|
"data": json.dumps({"timestamp": datetime.utcnow().isoformat()}),
|
|
}
|
|
|
|
except asyncio.CancelledError:
|
|
pass
|
|
finally:
|
|
await remove_connection(user_id, queue)
|
|
|
|
|
|
@router.get("/events")
|
|
async def sse_events(
|
|
current_user: User = Depends(get_current_user),
|
|
):
|
|
"""
|
|
SSE 事件流
|
|
|
|
- 客户端通过此端点订阅实时事件
|
|
- 支持的事件类型:
|
|
- connected: 连接成功
|
|
- heartbeat: 心跳
|
|
- task_updated: 任务状态更新
|
|
- review_progress: AI 审核进度
|
|
- review_completed: AI 审核完成
|
|
- new_task: 新任务分配
|
|
"""
|
|
queue = asyncio.Queue()
|
|
|
|
return EventSourceResponse(
|
|
event_generator(current_user.id, queue),
|
|
media_type="text/event-stream",
|
|
)
|
|
|
|
|
|
# ===== 推送工具函数(供其他模块调用) =====
|
|
|
|
|
|
async def notify_task_updated(task_id: str, user_ids: list[str], data: dict):
|
|
"""
|
|
通知任务状态更新
|
|
|
|
Args:
|
|
task_id: 任务 ID
|
|
user_ids: 需要通知的用户 ID 列表
|
|
data: 推送数据
|
|
"""
|
|
for user_id in user_ids:
|
|
await send_to_user(user_id, "task_updated", {
|
|
"task_id": task_id,
|
|
**data,
|
|
})
|
|
|
|
|
|
async def notify_review_progress(
|
|
task_id: str,
|
|
user_id: str,
|
|
progress: int,
|
|
current_step: str,
|
|
review_type: str, # "script" or "video"
|
|
):
|
|
"""
|
|
通知 AI 审核进度
|
|
|
|
Args:
|
|
task_id: 任务 ID
|
|
user_id: 达人用户 ID
|
|
progress: 进度百分比 (0-100)
|
|
current_step: 当前步骤描述
|
|
review_type: 审核类型
|
|
"""
|
|
await send_to_user(user_id, "review_progress", {
|
|
"task_id": task_id,
|
|
"review_type": review_type,
|
|
"progress": progress,
|
|
"current_step": current_step,
|
|
})
|
|
|
|
|
|
async def notify_review_completed(
|
|
task_id: str,
|
|
user_id: str,
|
|
review_type: str,
|
|
score: int,
|
|
violations_count: int,
|
|
):
|
|
"""
|
|
通知 AI 审核完成
|
|
|
|
Args:
|
|
task_id: 任务 ID
|
|
user_id: 达人用户 ID
|
|
review_type: 审核类型
|
|
score: 审核分数
|
|
violations_count: 违规数量
|
|
"""
|
|
await send_to_user(user_id, "review_completed", {
|
|
"task_id": task_id,
|
|
"review_type": review_type,
|
|
"score": score,
|
|
"violations_count": violations_count,
|
|
})
|
|
|
|
|
|
async def notify_new_task(
|
|
task_id: str,
|
|
creator_user_id: str,
|
|
task_name: str,
|
|
project_name: str,
|
|
):
|
|
"""
|
|
通知新任务分配
|
|
|
|
Args:
|
|
task_id: 任务 ID
|
|
creator_user_id: 达人用户 ID
|
|
task_name: 任务名称
|
|
project_name: 项目名称
|
|
"""
|
|
await send_to_user(creator_user_id, "new_task", {
|
|
"task_id": task_id,
|
|
"task_name": task_name,
|
|
"project_name": project_name,
|
|
})
|
|
|
|
|
|
async def notify_review_decision(
|
|
task_id: str,
|
|
creator_user_id: str,
|
|
review_type: str, # "script" or "video"
|
|
reviewer_type: str, # "agency" or "brand"
|
|
action: str, # "pass", "reject", "force_pass"
|
|
comment: Optional[str] = None,
|
|
):
|
|
"""
|
|
通知审核决策
|
|
|
|
Args:
|
|
task_id: 任务 ID
|
|
creator_user_id: 达人用户 ID
|
|
review_type: 审核类型
|
|
reviewer_type: 审核者类型
|
|
action: 审核动作
|
|
comment: 审核意见
|
|
"""
|
|
await send_to_user(creator_user_id, "review_decision", {
|
|
"task_id": task_id,
|
|
"review_type": review_type,
|
|
"reviewer_type": reviewer_type,
|
|
"action": action,
|
|
"comment": comment,
|
|
})
|