diff --git a/backend/alembic/env.py b/backend/alembic/env.py index 86c228e..cbd2766 100644 --- a/backend/alembic/env.py +++ b/backend/alembic/env.py @@ -14,14 +14,8 @@ from alembic import context # 导入配置和模型 from app.config import settings from app.models.base import Base -from app.models import ( - Tenant, - AIConfig, - ReviewTask, - ForbiddenWord, - WhitelistItem, - Competitor, -) +# 导入所有模型,确保 autogenerate 能检测到全部表 +from app.models import * # noqa: F401,F403 # Alembic Config 对象 config = context.config diff --git a/backend/alembic/versions/004_add_audit_logs.py b/backend/alembic/versions/004_add_audit_logs.py new file mode 100644 index 0000000..978965c --- /dev/null +++ b/backend/alembic/versions/004_add_audit_logs.py @@ -0,0 +1,37 @@ +"""添加审计日志表 + +Revision ID: 004 +Revises: 003 +Create Date: 2026-02-09 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + +# revision identifiers, used by Alembic. +revision: str = '004' +down_revision: Union[str, None] = '003' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.create_table( + 'audit_logs', + sa.Column('id', sa.Integer(), primary_key=True, autoincrement=True), + sa.Column('action', sa.String(50), nullable=False, index=True), + sa.Column('resource_type', sa.String(50), nullable=False, index=True), + sa.Column('resource_id', sa.String(64), nullable=True, index=True), + sa.Column('user_id', sa.String(64), nullable=True, index=True), + sa.Column('user_name', sa.String(255), nullable=True), + sa.Column('user_role', sa.String(20), nullable=True), + sa.Column('detail', sa.Text(), nullable=True), + sa.Column('ip_address', sa.String(45), nullable=True), + sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False, index=True), + ) + + +def downgrade() -> None: + op.drop_table('audit_logs') diff --git a/backend/alembic/versions/005_add_messages.py b/backend/alembic/versions/005_add_messages.py new file mode 100644 index 0000000..8dd220c --- /dev/null +++ b/backend/alembic/versions/005_add_messages.py @@ -0,0 +1,42 @@ +"""添加消息表 + +Revision ID: 005 +Revises: 004 +Create Date: 2026-02-09 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + +# revision identifiers, used by Alembic. +revision: str = '005' +down_revision: Union[str, None] = '004' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.create_table( + 'messages', + sa.Column('id', sa.String(64), primary_key=True), + sa.Column('user_id', sa.String(64), sa.ForeignKey('users.id', ondelete='CASCADE'), nullable=False), + sa.Column('type', sa.String(50), nullable=False), + sa.Column('title', sa.String(255), nullable=False), + sa.Column('content', sa.Text(), nullable=False), + sa.Column('is_read', sa.Boolean(), nullable=False, server_default='false'), + sa.Column('related_task_id', sa.String(64), nullable=True), + sa.Column('related_project_id', sa.String(64), nullable=True), + sa.Column('sender_name', sa.String(100), nullable=True), + sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False), + sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False), + ) + op.create_index('idx_messages_user_id', 'messages', ['user_id']) + op.create_index('idx_messages_user_read', 'messages', ['user_id', 'is_read']) + + +def downgrade() -> None: + op.drop_index('idx_messages_user_read', table_name='messages') + op.drop_index('idx_messages_user_id', table_name='messages') + op.drop_table('messages') diff --git a/backend/app/api/messages.py b/backend/app/api/messages.py new file mode 100644 index 0000000..141f4a7 --- /dev/null +++ b/backend/app/api/messages.py @@ -0,0 +1,97 @@ +""" +消息/通知 API +""" +from typing import Optional +from fastapi import APIRouter, Depends, HTTPException, Query, status +from sqlalchemy.ext.asyncio import AsyncSession + +from app.database import get_db +from app.models.user import User +from app.api.deps import get_current_user +from app.schemas.message import MessageResponse, MessageListResponse, UnreadCountResponse +from app.services.message_service import ( + list_messages, + get_unread_count, + mark_as_read, + mark_all_as_read, +) + +router = APIRouter(prefix="/messages", tags=["消息"]) + + +@router.get("", response_model=MessageListResponse) +async def get_messages( + page: int = Query(1, ge=1), + page_size: int = Query(20, ge=1, le=100), + is_read: Optional[bool] = Query(None), + type: Optional[str] = Query(None), + current_user: User = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +): + """获取消息列表""" + messages, total = await list_messages( + db=db, + user_id=current_user.id, + page=page, + page_size=page_size, + is_read=is_read, + type=type, + ) + + return MessageListResponse( + items=[ + MessageResponse( + id=m.id, + type=m.type, + title=m.title, + content=m.content, + is_read=m.is_read, + related_task_id=m.related_task_id, + related_project_id=m.related_project_id, + sender_name=m.sender_name, + created_at=m.created_at, + ) + for m in messages + ], + total=total, + page=page, + page_size=page_size, + ) + + +@router.get("/unread-count", response_model=UnreadCountResponse) +async def get_message_unread_count( + current_user: User = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +): + """获取未读消息数""" + count = await get_unread_count(db, current_user.id) + return UnreadCountResponse(count=count) + + +@router.put("/{message_id}/read") +async def mark_message_as_read( + message_id: str, + current_user: User = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +): + """标记消息已读""" + success = await mark_as_read(db, message_id, current_user.id) + if not success: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="消息不存在", + ) + await db.commit() + return {"message": "已标记为已读"} + + +@router.put("/read-all") +async def mark_all_messages_as_read( + current_user: User = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +): + """标记所有消息已读""" + count = await mark_all_as_read(db, current_user.id) + await db.commit() + return {"message": f"已标记 {count} 条消息为已读", "count": count} diff --git a/backend/app/api/profile.py b/backend/app/api/profile.py new file mode 100644 index 0000000..b8a596a --- /dev/null +++ b/backend/app/api/profile.py @@ -0,0 +1,173 @@ +""" +用户资料 API +""" +from fastapi import APIRouter, Depends, HTTPException, status +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy import select + +from app.database import get_db +from app.models.user import User, UserRole +from app.models.organization import Brand, Agency, Creator +from app.api.deps import get_current_user +from app.services.auth import verify_password, hash_password +from app.schemas.profile import ( + ProfileResponse, + ProfileUpdateRequest, + ChangePasswordRequest, + BrandProfile, + AgencyProfile, + CreatorProfile, +) + +router = APIRouter(prefix="/profile", tags=["用户资料"]) + + +def _build_profile_response(user: User, brand=None, agency=None, creator=None) -> ProfileResponse: + """构建资料响应""" + resp = ProfileResponse( + id=user.id, + email=user.email, + phone=user.phone, + name=user.name, + avatar=user.avatar, + role=user.role.value, + is_verified=user.is_verified, + created_at=user.created_at, + ) + if brand: + resp.brand = BrandProfile( + id=brand.id, + name=brand.name, + logo=brand.logo, + description=brand.description, + contact_name=brand.contact_name, + contact_phone=brand.contact_phone, + contact_email=brand.contact_email, + ) + if agency: + resp.agency = AgencyProfile( + id=agency.id, + name=agency.name, + logo=agency.logo, + description=agency.description, + contact_name=agency.contact_name, + contact_phone=agency.contact_phone, + contact_email=agency.contact_email, + ) + if creator: + resp.creator = CreatorProfile( + id=creator.id, + name=creator.name, + avatar=creator.avatar, + bio=creator.bio, + douyin_account=creator.douyin_account, + xiaohongshu_account=creator.xiaohongshu_account, + bilibili_account=creator.bilibili_account, + ) + return resp + + +async def _get_role_entity(db: AsyncSession, user: User): + """根据角色获取对应实体""" + if user.role == UserRole.BRAND: + result = await db.execute(select(Brand).where(Brand.user_id == user.id)) + return result.scalar_one_or_none(), None, None + elif user.role == UserRole.AGENCY: + result = await db.execute(select(Agency).where(Agency.user_id == user.id)) + return None, result.scalar_one_or_none(), None + elif user.role == UserRole.CREATOR: + result = await db.execute(select(Creator).where(Creator.user_id == user.id)) + return None, None, result.scalar_one_or_none() + return None, None, None + + +@router.get("", response_model=ProfileResponse) +async def get_profile( + current_user: User = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +): + """获取当前用户资料""" + brand, agency, creator = await _get_role_entity(db, current_user) + return _build_profile_response(current_user, brand, agency, creator) + + +@router.put("", response_model=ProfileResponse) +async def update_profile( + request: ProfileUpdateRequest, + current_user: User = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +): + """更新当前用户资料""" + # 更新 User 表通用字段 + if request.name is not None: + current_user.name = request.name + if request.avatar is not None: + current_user.avatar = request.avatar + if request.phone is not None: + current_user.phone = request.phone + + # 更新角色表字段 + brand, agency, creator = await _get_role_entity(db, current_user) + + if current_user.role == UserRole.BRAND and brand: + if request.name is not None: + brand.name = request.name + if request.description is not None: + brand.description = request.description + if request.contact_name is not None: + brand.contact_name = request.contact_name + if request.contact_phone is not None: + brand.contact_phone = request.contact_phone + if request.contact_email is not None: + brand.contact_email = request.contact_email + + elif current_user.role == UserRole.AGENCY and agency: + if request.name is not None: + agency.name = request.name + if request.description is not None: + agency.description = request.description + if request.contact_name is not None: + agency.contact_name = request.contact_name + if request.contact_phone is not None: + agency.contact_phone = request.contact_phone + if request.contact_email is not None: + agency.contact_email = request.contact_email + + elif current_user.role == UserRole.CREATOR and creator: + if request.name is not None: + creator.name = request.name + if request.avatar is not None: + creator.avatar = request.avatar + if request.bio is not None: + creator.bio = request.bio + if request.douyin_account is not None: + creator.douyin_account = request.douyin_account + if request.xiaohongshu_account is not None: + creator.xiaohongshu_account = request.xiaohongshu_account + if request.bilibili_account is not None: + creator.bilibili_account = request.bilibili_account + + await db.commit() + + # 重新查询返回最新数据 + brand, agency, creator = await _get_role_entity(db, current_user) + return _build_profile_response(current_user, brand, agency, creator) + + +@router.put("/password") +async def change_password( + request: ChangePasswordRequest, + current_user: User = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +): + """修改密码""" + if not verify_password(request.old_password, current_user.password_hash): + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="原密码不正确", + ) + + current_user.password_hash = hash_password(request.new_password) + await db.commit() + + return {"message": "密码修改成功"} diff --git a/backend/app/api/tasks.py b/backend/app/api/tasks.py index 1bea47f..904e1c3 100644 --- a/backend/app/api/tasks.py +++ b/backend/app/api/tasks.py @@ -51,6 +51,8 @@ from app.services.task_service import ( list_pending_reviews_for_agency, list_pending_reviews_for_brand, ) +from app.api.sse import notify_new_task, notify_task_updated, notify_review_decision +from app.services.message_service import create_message router = APIRouter(prefix="/tasks", tags=["任务"]) @@ -172,6 +174,31 @@ async def create_new_task( # 重新加载关联 task = await get_task_by_id(db, task.id) + # 创建消息 + SSE 通知达人有新任务 + try: + await create_message( + db=db, + user_id=creator.user_id, + type="new_task", + title="新任务分配", + content=f"您有新的任务「{task.name}」,来自项目「{task.project.name}」", + related_task_id=task.id, + related_project_id=task.project.id, + sender_name=agency.name, + ) + await db.commit() + except Exception: + pass + try: + await notify_new_task( + task_id=task.id, + creator_user_id=creator.user_id, + task_name=task.name, + project_name=task.project.name, + ) + except Exception: + pass + return _task_to_response(task) @@ -367,6 +394,21 @@ async def upload_task_script( # 重新加载关联 task = await get_task_by_id(db, task.id) + # SSE 通知代理商脚本已上传 + try: + result = await db.execute( + select(Agency).where(Agency.id == task.agency_id) + ) + agency_obj = result.scalar_one_or_none() + if agency_obj: + await notify_task_updated( + task_id=task.id, + user_ids=[agency_obj.user_id], + data={"action": "script_uploaded", "stage": task.stage.value}, + ) + except Exception: + pass + return _task_to_response(task) @@ -415,6 +457,21 @@ async def upload_task_video( # 重新加载关联 task = await get_task_by_id(db, task.id) + # SSE 通知代理商视频已上传 + try: + result = await db.execute( + select(Agency).where(Agency.id == task.agency_id) + ) + agency_obj = result.scalar_one_or_none() + if agency_obj: + await notify_task_updated( + task_id=task.id, + user_ids=[agency_obj.user_id], + data={"action": "video_uploaded", "stage": task.stage.value}, + ) + except Exception: + pass + return _task_to_response(task) @@ -523,6 +580,41 @@ async def review_script( # 重新加载关联 task = await get_task_by_id(db, task.id) + # 创建消息 + SSE 通知达人脚本审核结果 + try: + result = await db.execute( + select(Creator).where(Creator.id == task.creator_id) + ) + creator_obj = result.scalar_one_or_none() + if creator_obj: + reviewer_type = "agency" if current_user.role == UserRole.AGENCY else "brand" + action_text = {"pass": "通过", "reject": "驳回", "force_pass": "强制通过"}.get(request.action, request.action) + await create_message( + db=db, + user_id=creator_obj.user_id, + type=request.action, + title=f"脚本审核{action_text}", + content=f"您的任务「{task.name}」脚本已被{action_text}" + (f",评语:{request.comment}" if request.comment else ""), + related_task_id=task.id, + sender_name=current_user.name, + ) + await db.commit() + await notify_review_decision( + task_id=task.id, + creator_user_id=creator_obj.user_id, + review_type="script", + reviewer_type=reviewer_type, + action=request.action, + comment=request.comment, + ) + await notify_task_updated( + task_id=task.id, + user_ids=[creator_obj.user_id], + data={"action": f"script_{request.action}", "stage": task.stage.value}, + ) + except Exception: + pass + return _task_to_response(task) @@ -628,6 +720,41 @@ async def review_video( # 重新加载关联 task = await get_task_by_id(db, task.id) + # 创建消息 + SSE 通知达人视频审核结果 + try: + result = await db.execute( + select(Creator).where(Creator.id == task.creator_id) + ) + creator_obj = result.scalar_one_or_none() + if creator_obj: + reviewer_type = "agency" if current_user.role == UserRole.AGENCY else "brand" + action_text = {"pass": "通过", "reject": "驳回", "force_pass": "强制通过"}.get(request.action, request.action) + await create_message( + db=db, + user_id=creator_obj.user_id, + type=request.action, + title=f"视频审核{action_text}", + content=f"您的任务「{task.name}」视频已被{action_text}" + (f",评语:{request.comment}" if request.comment else ""), + related_task_id=task.id, + sender_name=current_user.name, + ) + await db.commit() + await notify_review_decision( + task_id=task.id, + creator_user_id=creator_obj.user_id, + review_type="video", + reviewer_type=reviewer_type, + action=request.action, + comment=request.comment, + ) + await notify_task_updated( + task_id=task.id, + user_ids=[creator_obj.user_id], + data={"action": f"video_{request.action}", "stage": task.stage.value}, + ) + except Exception: + pass + return _task_to_response(task) @@ -676,6 +803,21 @@ async def submit_task_appeal( # 重新加载关联 task = await get_task_by_id(db, task.id) + # SSE 通知代理商有新申诉 + try: + result = await db.execute( + select(Agency).where(Agency.id == task.agency_id) + ) + agency_obj = result.scalar_one_or_none() + if agency_obj: + await notify_task_updated( + task_id=task.id, + user_ids=[agency_obj.user_id], + data={"action": "appeal_submitted", "stage": task.stage.value}, + ) + except Exception: + pass + return _task_to_response(task) diff --git a/backend/app/main.py b/backend/app/main.py index 8a5495b..6aa590b 100644 --- a/backend/app/main.py +++ b/backend/app/main.py @@ -5,7 +5,7 @@ from starlette.middleware.base import BaseHTTPMiddleware from app.config import settings from app.logging_config import setup_logging from app.middleware.rate_limit import RateLimitMiddleware -from app.api import health, auth, upload, scripts, videos, tasks, rules, ai_config, sse, projects, briefs, organizations, dashboard, export +from app.api import health, auth, upload, scripts, videos, tasks, rules, ai_config, sse, projects, briefs, organizations, dashboard, export, profile, messages # Initialize logging logger = setup_logging() @@ -72,6 +72,8 @@ app.include_router(briefs.router, prefix="/api/v1") app.include_router(organizations.router, prefix="/api/v1") app.include_router(dashboard.router, prefix="/api/v1") app.include_router(export.router, prefix="/api/v1") +app.include_router(profile.router, prefix="/api/v1") +app.include_router(messages.router, prefix="/api/v1") @app.on_event("startup") diff --git a/backend/app/models/__init__.py b/backend/app/models/__init__.py index c1b510b..840ac88 100644 --- a/backend/app/models/__init__.py +++ b/backend/app/models/__init__.py @@ -12,6 +12,7 @@ from app.models.ai_config import AIConfig from app.models.review import ReviewTask, Platform from app.models.rule import ForbiddenWord, WhitelistItem, Competitor from app.models.audit_log import AuditLog +from app.models.message import Message # 保留 Tenant 兼容旧代码,但新代码应使用 Brand from app.models.tenant import Tenant @@ -45,6 +46,8 @@ __all__ = [ "Competitor", # 审计日志 "AuditLog", + # 消息 + "Message", # 兼容 "Tenant", ] diff --git a/backend/app/models/message.py b/backend/app/models/message.py new file mode 100644 index 0000000..5be1408 --- /dev/null +++ b/backend/app/models/message.py @@ -0,0 +1,45 @@ +""" +消息/通知模型 +""" +from typing import Optional +from sqlalchemy import String, Boolean, Text, ForeignKey, Index +from sqlalchemy.orm import Mapped, mapped_column + +from app.models.base import Base, TimestampMixin + + +class Message(Base, TimestampMixin): + """消息表""" + __tablename__ = "messages" + + id: Mapped[str] = mapped_column(String(64), primary_key=True) + + # 接收者 + user_id: Mapped[str] = mapped_column( + String(64), + ForeignKey("users.id", ondelete="CASCADE"), + nullable=False, + ) + + # 消息类型: invite, new_task, pass, reject, appeal, system 等 + type: Mapped[str] = mapped_column(String(50), nullable=False) + + # 消息内容 + title: Mapped[str] = mapped_column(String(255), nullable=False) + content: Mapped[str] = mapped_column(Text, nullable=False) + + # 已读状态 + is_read: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False) + + # 关联信息(可选) + related_task_id: Mapped[Optional[str]] = mapped_column(String(64), nullable=True) + related_project_id: Mapped[Optional[str]] = mapped_column(String(64), nullable=True) + sender_name: Mapped[Optional[str]] = mapped_column(String(100), nullable=True) + + __table_args__ = ( + Index("idx_messages_user_id", "user_id"), + Index("idx_messages_user_read", "user_id", "is_read"), + ) + + def __repr__(self) -> str: + return f"" diff --git a/backend/app/schemas/message.py b/backend/app/schemas/message.py new file mode 100644 index 0000000..4992df7 --- /dev/null +++ b/backend/app/schemas/message.py @@ -0,0 +1,29 @@ +""" +消息相关 Schema +""" +from typing import Optional, List +from datetime import datetime +from pydantic import BaseModel + + +class MessageResponse(BaseModel): + id: str + type: str + title: str + content: str + is_read: bool + related_task_id: Optional[str] = None + related_project_id: Optional[str] = None + sender_name: Optional[str] = None + created_at: Optional[datetime] = None + + +class MessageListResponse(BaseModel): + items: List[MessageResponse] + total: int + page: int + page_size: int + + +class UnreadCountResponse(BaseModel): + count: int diff --git a/backend/app/schemas/profile.py b/backend/app/schemas/profile.py new file mode 100644 index 0000000..26f9558 --- /dev/null +++ b/backend/app/schemas/profile.py @@ -0,0 +1,77 @@ +""" +用户资料相关 Schema +""" +from typing import Optional +from datetime import datetime +from pydantic import BaseModel, Field + + +# ===== 角色附加信息 ===== + +class BrandProfile(BaseModel): + id: str + name: str + logo: Optional[str] = None + description: Optional[str] = None + contact_name: Optional[str] = None + contact_phone: Optional[str] = None + contact_email: Optional[str] = None + + +class AgencyProfile(BaseModel): + id: str + name: str + logo: Optional[str] = None + description: Optional[str] = None + contact_name: Optional[str] = None + contact_phone: Optional[str] = None + contact_email: Optional[str] = None + + +class CreatorProfile(BaseModel): + id: str + name: str + avatar: Optional[str] = None + bio: Optional[str] = None + douyin_account: Optional[str] = None + xiaohongshu_account: Optional[str] = None + bilibili_account: Optional[str] = None + + +# ===== 响应 ===== + +class ProfileResponse(BaseModel): + id: str + email: Optional[str] = None + phone: Optional[str] = None + name: str + avatar: Optional[str] = None + role: str + is_verified: bool = False + created_at: Optional[datetime] = None + brand: Optional[BrandProfile] = None + agency: Optional[AgencyProfile] = None + creator: Optional[CreatorProfile] = None + + +# ===== 请求 ===== + +class ProfileUpdateRequest(BaseModel): + name: Optional[str] = Field(None, max_length=100) + avatar: Optional[str] = Field(None, max_length=2048) + phone: Optional[str] = Field(None, max_length=20) + # 品牌方/代理商字段 + description: Optional[str] = None + contact_name: Optional[str] = Field(None, max_length=100) + contact_phone: Optional[str] = Field(None, max_length=20) + contact_email: Optional[str] = Field(None, max_length=255) + # 达人字段 + bio: Optional[str] = None + douyin_account: Optional[str] = Field(None, max_length=100) + xiaohongshu_account: Optional[str] = Field(None, max_length=100) + bilibili_account: Optional[str] = Field(None, max_length=100) + + +class ChangePasswordRequest(BaseModel): + old_password: str = Field(..., min_length=6) + new_password: str = Field(..., min_length=6) diff --git a/backend/app/services/message_service.py b/backend/app/services/message_service.py new file mode 100644 index 0000000..fb5d795 --- /dev/null +++ b/backend/app/services/message_service.py @@ -0,0 +1,116 @@ +""" +消息服务 +""" +import secrets +from typing import Optional, Tuple, List + +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy import select, func, update + +from app.models.message import Message + + +def _generate_message_id() -> str: + """生成消息 ID""" + random_part = secrets.randbelow(900000) + 100000 + return f"MSG{random_part}" + + +async def create_message( + db: AsyncSession, + user_id: str, + type: str, + title: str, + content: str, + related_task_id: Optional[str] = None, + related_project_id: Optional[str] = None, + sender_name: Optional[str] = None, +) -> Message: + """创建消息""" + message = Message( + id=_generate_message_id(), + user_id=user_id, + type=type, + title=title, + content=content, + is_read=False, + related_task_id=related_task_id, + related_project_id=related_project_id, + sender_name=sender_name, + ) + db.add(message) + await db.flush() + return message + + +async def list_messages( + db: AsyncSession, + user_id: str, + page: int = 1, + page_size: int = 20, + is_read: Optional[bool] = None, + type: Optional[str] = None, +) -> Tuple[List[Message], int]: + """查询消息列表""" + query = select(Message).where(Message.user_id == user_id) + count_query = select(func.count()).select_from(Message).where(Message.user_id == user_id) + + if is_read is not None: + query = query.where(Message.is_read == is_read) + count_query = count_query.where(Message.is_read == is_read) + + if type is not None: + query = query.where(Message.type == type) + count_query = count_query.where(Message.type == type) + + # 总数 + total_result = await db.execute(count_query) + total = total_result.scalar() or 0 + + # 分页 + query = query.order_by(Message.created_at.desc()) + query = query.offset((page - 1) * page_size).limit(page_size) + + result = await db.execute(query) + messages = list(result.scalars().all()) + + return messages, total + + +async def get_unread_count(db: AsyncSession, user_id: str) -> int: + """获取未读消息数""" + result = await db.execute( + select(func.count()).select_from(Message).where( + Message.user_id == user_id, + Message.is_read == False, + ) + ) + return result.scalar() or 0 + + +async def mark_as_read(db: AsyncSession, message_id: str, user_id: str) -> bool: + """标记单条消息已读""" + result = await db.execute( + select(Message).where( + Message.id == message_id, + Message.user_id == user_id, + ) + ) + message = result.scalar_one_or_none() + if not message: + return False + + message.is_read = True + await db.flush() + return True + + +async def mark_all_as_read(db: AsyncSession, user_id: str) -> int: + """标记所有消息已读,返回更新数量""" + result = await db.execute( + update(Message) + .where(Message.user_id == user_id, Message.is_read == False) + .values(is_read=True) + ) + await db.flush() + return result.rowcount diff --git a/backend/pyproject.toml b/backend/pyproject.toml index e69440d..8a0abe0 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -20,6 +20,7 @@ dependencies = [ "cryptography>=42.0.0", "openai>=1.12.0", "cachetools>=5.3.0", + "sse-starlette>=2.0.0", ] [project.optional-dependencies]