""" AI 服务工厂 根据租户配置创建和管理 AI 客户端 """ from typing import Optional from cachetools import TTLCache from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from app.models.ai_config import AIConfig from app.services.ai_client import OpenAICompatibleClient from app.utils.crypto import decrypt_api_key class AIServiceFactory: """ AI 服务工厂 根据租户的 AI 配置创建对应的 AI 客户端 使用 TTL 缓存避免频繁创建客户端 """ # 客户端缓存,TTL 10 分钟 _cache: TTLCache = TTLCache(maxsize=100, ttl=600) @classmethod async def get_client( cls, tenant_id: str, db: AsyncSession, ) -> Optional[OpenAICompatibleClient]: """ 获取租户的 AI 客户端 Args: tenant_id: 租户 ID db: 数据库会话 Returns: AI 客户端实例,未配置返回 None """ # 检查缓存 cache_key = f"ai_client:{tenant_id}" if cache_key in cls._cache: return cls._cache[cache_key] # 从数据库获取配置 result = await db.execute( select(AIConfig).where( AIConfig.tenant_id == tenant_id, AIConfig.is_configured == True, ) ) config = result.scalar_one_or_none() if not config: return None # 解密 API Key api_key = decrypt_api_key(config.api_key_encrypted) # 创建客户端 client = OpenAICompatibleClient( base_url=config.base_url, api_key=api_key, provider=config.provider, ) # 缓存客户端 cls._cache[cache_key] = client return client @classmethod def invalidate_cache(cls, tenant_id: str) -> None: """ 使缓存失效 当租户更新 AI 配置时调用 """ cache_key = f"ai_client:{tenant_id}" if cache_key in cls._cache: del cls._cache[cache_key] @classmethod def clear_cache(cls) -> None: """清空所有缓存""" cls._cache.clear() @classmethod async def get_config( cls, tenant_id: str, db: AsyncSession, ) -> Optional[AIConfig]: """ 获取租户的 AI 配置 Args: tenant_id: 租户 ID db: 数据库会话 Returns: AI 配置模型,未配置返回 None """ result = await db.execute( select(AIConfig).where(AIConfig.tenant_id == tenant_id) ) return result.scalar_one_or_none() @classmethod async def create_or_update_config( cls, tenant_id: str, provider: str, base_url: str, api_key_encrypted: str, models: dict, temperature: float, max_tokens: int, db: AsyncSession, ) -> AIConfig: """ 创建或更新 AI 配置 Args: tenant_id: 租户 ID provider: 提供商 base_url: API 地址 api_key_encrypted: 加密的 API Key models: 模型配置 temperature: 温度参数 max_tokens: 最大 token 数 db: 数据库会话 Returns: 更新后的配置 """ # 查找现有配置 result = await db.execute( select(AIConfig).where(AIConfig.tenant_id == tenant_id) ) config = result.scalar_one_or_none() if config: # 更新现有配置 config.provider = provider config.base_url = base_url config.api_key_encrypted = api_key_encrypted config.models = models config.temperature = temperature config.max_tokens = max_tokens config.is_configured = True else: # 创建新配置 config = AIConfig( tenant_id=tenant_id, provider=provider, base_url=base_url, api_key_encrypted=api_key_encrypted, models=models, temperature=temperature, max_tokens=max_tokens, is_configured=True, ) db.add(config) await db.flush() # 使缓存失效 cls.invalidate_cache(tenant_id) return config # 便捷函数 async def get_ai_client_for_tenant( tenant_id: str, db: AsyncSession, ) -> Optional[OpenAICompatibleClient]: """获取租户的 AI 客户端""" return await AIServiceFactory.get_client(tenant_id, db)