主要更新: - 更新代理商端文档,明确项目由品牌方分配流程 - 新增Brief配置详情页(已配置)设计稿 - 完善工作台紧急待办中品牌新任务功能 - 整理Pencil设计文件中代理商端页面顺序 - 新增后端FastAPI框架及核心API - 新增前端Next.js页面和组件库 - 添加.gitignore排除构建和缓存文件 Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
183 lines
4.6 KiB
Python
183 lines
4.6 KiB
Python
"""
|
||
AI 服务工厂
|
||
根据租户配置创建和管理 AI 客户端
|
||
"""
|
||
from typing import Optional
|
||
from cachetools import TTLCache
|
||
from sqlalchemy import select
|
||
from sqlalchemy.ext.asyncio import AsyncSession
|
||
|
||
from app.models.ai_config import AIConfig
|
||
from app.services.ai_client import OpenAICompatibleClient
|
||
from app.utils.crypto import decrypt_api_key
|
||
|
||
|
||
class AIServiceFactory:
|
||
"""
|
||
AI 服务工厂
|
||
|
||
根据租户的 AI 配置创建对应的 AI 客户端
|
||
使用 TTL 缓存避免频繁创建客户端
|
||
"""
|
||
|
||
# 客户端缓存,TTL 10 分钟
|
||
_cache: TTLCache = TTLCache(maxsize=100, ttl=600)
|
||
|
||
@classmethod
|
||
async def get_client(
|
||
cls,
|
||
tenant_id: str,
|
||
db: AsyncSession,
|
||
) -> Optional[OpenAICompatibleClient]:
|
||
"""
|
||
获取租户的 AI 客户端
|
||
|
||
Args:
|
||
tenant_id: 租户 ID
|
||
db: 数据库会话
|
||
|
||
Returns:
|
||
AI 客户端实例,未配置返回 None
|
||
"""
|
||
# 检查缓存
|
||
cache_key = f"ai_client:{tenant_id}"
|
||
if cache_key in cls._cache:
|
||
return cls._cache[cache_key]
|
||
|
||
# 从数据库获取配置
|
||
result = await db.execute(
|
||
select(AIConfig).where(
|
||
AIConfig.tenant_id == tenant_id,
|
||
AIConfig.is_configured == True,
|
||
)
|
||
)
|
||
config = result.scalar_one_or_none()
|
||
|
||
if not config:
|
||
return None
|
||
|
||
# 解密 API Key
|
||
api_key = decrypt_api_key(config.api_key_encrypted)
|
||
|
||
# 创建客户端
|
||
client = OpenAICompatibleClient(
|
||
base_url=config.base_url,
|
||
api_key=api_key,
|
||
provider=config.provider,
|
||
)
|
||
|
||
# 缓存客户端
|
||
cls._cache[cache_key] = client
|
||
|
||
return client
|
||
|
||
@classmethod
|
||
def invalidate_cache(cls, tenant_id: str) -> None:
|
||
"""
|
||
使缓存失效
|
||
|
||
当租户更新 AI 配置时调用
|
||
"""
|
||
cache_key = f"ai_client:{tenant_id}"
|
||
if cache_key in cls._cache:
|
||
del cls._cache[cache_key]
|
||
|
||
@classmethod
|
||
def clear_cache(cls) -> None:
|
||
"""清空所有缓存"""
|
||
cls._cache.clear()
|
||
|
||
@classmethod
|
||
async def get_config(
|
||
cls,
|
||
tenant_id: str,
|
||
db: AsyncSession,
|
||
) -> Optional[AIConfig]:
|
||
"""
|
||
获取租户的 AI 配置
|
||
|
||
Args:
|
||
tenant_id: 租户 ID
|
||
db: 数据库会话
|
||
|
||
Returns:
|
||
AI 配置模型,未配置返回 None
|
||
"""
|
||
result = await db.execute(
|
||
select(AIConfig).where(AIConfig.tenant_id == tenant_id)
|
||
)
|
||
return result.scalar_one_or_none()
|
||
|
||
@classmethod
|
||
async def create_or_update_config(
|
||
cls,
|
||
tenant_id: str,
|
||
provider: str,
|
||
base_url: str,
|
||
api_key_encrypted: str,
|
||
models: dict,
|
||
temperature: float,
|
||
max_tokens: int,
|
||
db: AsyncSession,
|
||
) -> AIConfig:
|
||
"""
|
||
创建或更新 AI 配置
|
||
|
||
Args:
|
||
tenant_id: 租户 ID
|
||
provider: 提供商
|
||
base_url: API 地址
|
||
api_key_encrypted: 加密的 API Key
|
||
models: 模型配置
|
||
temperature: 温度参数
|
||
max_tokens: 最大 token 数
|
||
db: 数据库会话
|
||
|
||
Returns:
|
||
更新后的配置
|
||
"""
|
||
# 查找现有配置
|
||
result = await db.execute(
|
||
select(AIConfig).where(AIConfig.tenant_id == tenant_id)
|
||
)
|
||
config = result.scalar_one_or_none()
|
||
|
||
if config:
|
||
# 更新现有配置
|
||
config.provider = provider
|
||
config.base_url = base_url
|
||
config.api_key_encrypted = api_key_encrypted
|
||
config.models = models
|
||
config.temperature = temperature
|
||
config.max_tokens = max_tokens
|
||
config.is_configured = True
|
||
else:
|
||
# 创建新配置
|
||
config = AIConfig(
|
||
tenant_id=tenant_id,
|
||
provider=provider,
|
||
base_url=base_url,
|
||
api_key_encrypted=api_key_encrypted,
|
||
models=models,
|
||
temperature=temperature,
|
||
max_tokens=max_tokens,
|
||
is_configured=True,
|
||
)
|
||
db.add(config)
|
||
|
||
await db.flush()
|
||
|
||
# 使缓存失效
|
||
cls.invalidate_cache(tenant_id)
|
||
|
||
return config
|
||
|
||
|
||
# 便捷函数
|
||
async def get_ai_client_for_tenant(
|
||
tenant_id: str,
|
||
db: AsyncSession,
|
||
) -> Optional[OpenAICompatibleClient]:
|
||
"""获取租户的 AI 客户端"""
|
||
return await AIServiceFactory.get_client(tenant_id, db)
|