""" AI 服务配置 API 品牌方管理 AI 提供商配置、模型选择、连通性测试 """ import asyncio from datetime import datetime, timezone from typing import Optional from fastapi import APIRouter, Depends, HTTPException, Header, status from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from app.database import get_db from app.models.ai_config import AIConfig from app.models.tenant import Tenant from app.schemas.ai_config import ( AIProvider, AIConfigUpdate, AIConfigResponse, AIModelsConfig, AIParametersConfig, GetModelsRequest, TestConnectionRequest, ModelsListResponse, ConnectionTestResponse, ModelTestResult, ModelInfo, ModelCapability, mask_api_key, ) from app.services.ai_client import OpenAICompatibleClient from app.services.ai_service import AIServiceFactory from app.utils.crypto import encrypt_api_key, decrypt_api_key router = APIRouter(prefix="/ai-config", tags=["ai-config"]) async def _ensure_tenant_exists(tenant_id: str, db: AsyncSession) -> Tenant: """确保租户存在,不存在则自动创建""" result = await db.execute( select(Tenant).where(Tenant.id == tenant_id) ) tenant = result.scalar_one_or_none() if not tenant: tenant = Tenant(id=tenant_id, name=f"租户-{tenant_id}") db.add(tenant) await db.flush() return tenant @router.get("", response_model=AIConfigResponse) async def get_ai_config( x_tenant_id: str = Header(..., alias="X-Tenant-ID"), db: AsyncSession = Depends(get_db), ) -> AIConfigResponse: """ 获取当前 AI 配置 - 未配置返回 404 - 已配置返回配置信息(API Key 脱敏) """ result = await db.execute( select(AIConfig).where( AIConfig.tenant_id == x_tenant_id, AIConfig.is_configured == True, ) ) config = result.scalar_one_or_none() if not config: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="AI 服务未配置,请先完成配置", ) # 解密 API Key 用于脱敏显示 api_key = decrypt_api_key(config.api_key_encrypted) return AIConfigResponse( provider=config.provider, base_url=config.base_url, api_key_masked=mask_api_key(api_key), models=AIModelsConfig(**config.models), parameters=AIParametersConfig( temperature=config.temperature, max_tokens=config.max_tokens, ), available_models=config.available_models or {}, is_configured=config.is_configured, last_test_at=config.last_test_at.isoformat() if config.last_test_at else None, last_test_result=config.last_test_result, ) @router.put("", response_model=AIConfigResponse) async def update_ai_config( request: AIConfigUpdate, x_tenant_id: str = Header(..., alias="X-Tenant-ID"), db: AsyncSession = Depends(get_db), ) -> AIConfigResponse: """ 更新 AI 配置 - 保存提供商、连接信息、模型配置 - API Key 加密存储 """ # 确保租户存在 await _ensure_tenant_exists(x_tenant_id, db) # 加密 API Key api_key_encrypted = encrypt_api_key(request.api_key) # 创建或更新配置 config = await AIServiceFactory.create_or_update_config( tenant_id=x_tenant_id, provider=request.provider.value, base_url=request.base_url, api_key_encrypted=api_key_encrypted, models=request.models.model_dump(), temperature=request.parameters.temperature, max_tokens=request.parameters.max_tokens, db=db, ) return AIConfigResponse( provider=config.provider, base_url=config.base_url, api_key_masked=mask_api_key(request.api_key), models=AIModelsConfig(**config.models), parameters=AIParametersConfig( temperature=config.temperature, max_tokens=config.max_tokens, ), available_models=config.available_models or {}, is_configured=True, last_test_at=config.last_test_at.isoformat() if config.last_test_at else None, last_test_result=config.last_test_result, ) @router.post("/models", response_model=ModelsListResponse) async def get_available_models( request: GetModelsRequest, x_tenant_id: str = Header(..., alias="X-Tenant-ID"), db: AsyncSession = Depends(get_db), ) -> ModelsListResponse: """ 获取可用模型列表 - 调用提供商 API 获取模型列表 - 按能力分类(text/vision/audio) """ try: client = OpenAICompatibleClient( base_url=request.base_url, api_key=request.api_key, provider=request.provider.value, ) models_dict = await client.list_models() await client.close() # 转换为 ModelInfo 对象 models = { k: [ModelInfo(**m) for m in v] for k, v in models_dict.items() } # 更新配置中的可用模型缓存 result = await db.execute( select(AIConfig).where(AIConfig.tenant_id == x_tenant_id) ) config = result.scalar_one_or_none() if config: config.available_models = models_dict await db.flush() return ModelsListResponse( success=True, models=models, ) except Exception as e: raise HTTPException( status_code=status.HTTP_502_BAD_GATEWAY, detail=f"获取模型列表失败: {str(e)}", ) @router.post("/test", response_model=ConnectionTestResponse) async def test_connection( request: TestConnectionRequest, x_tenant_id: str = Header(..., alias="X-Tenant-ID"), db: AsyncSession = Depends(get_db), ) -> ConnectionTestResponse: """ 测试 AI 服务连接 - 并行测试三个模型 - 返回每个模型的测试结果 """ client = None models = request.models.model_dump() try: client = OpenAICompatibleClient( base_url=request.base_url, api_key=request.api_key, provider=request.provider.value, ) # 定义模型能力映射 capability_map = { "text": ModelCapability.TEXT, "vision": ModelCapability.VISION, "audio": ModelCapability.AUDIO, } async def test_single(model_type: str, model_id: str) -> tuple[str, ModelTestResult]: capability = capability_map.get(model_type, ModelCapability.TEXT) result = await client.test_connection(model_id, capability) return model_type, ModelTestResult( success=result.success, latency_ms=result.latency_ms, error=result.error, model=model_id, ) # 并行测试所有模型 tasks = [ test_single(model_type, model_id) for model_type, model_id in models.items() ] results_list = await asyncio.gather(*tasks) results = {model_type: result for model_type, result in results_list} # 计算测试结果 all_success = all(r.success for r in results.values()) failed_count = sum(1 for r in results.values() if not r.success) if all_success: message = "所有模型连接成功" else: message = f"{failed_count} 个模型连接失败,请检查模型名称或 API 权限" response = ConnectionTestResponse( success=all_success, results=results, message=message, ) except Exception as exc: # 确保接口返回 200,并返回失败详情 results = { model_type: ModelTestResult( success=False, latency_ms=0, error=str(exc), model=model_id, ) for model_type, model_id in models.items() } response = ConnectionTestResponse( success=False, results=results, message=f"连接测试失败: {str(exc)}", ) finally: if client is not None: try: await client.close() except Exception: pass # 保存测试结果到数据库 db_result = await db.execute( select(AIConfig).where(AIConfig.tenant_id == x_tenant_id) ) config = db_result.scalar_one_or_none() if config: config.last_test_at = datetime.now(timezone.utc) config.last_test_result = { k: v.model_dump() for k, v in response.results.items() } await db.flush() return response # ==================== 供其他模块调用 ==================== async def get_ai_config_for_tenant( tenant_id: str, db: AsyncSession, ) -> Optional[dict]: """获取租户的 AI 配置(供审核服务调用)""" result = await db.execute( select(AIConfig).where( AIConfig.tenant_id == tenant_id, AIConfig.is_configured == True, ) ) config = result.scalar_one_or_none() if not config: return None return { "tenant_id": config.tenant_id, "provider": config.provider, "base_url": config.base_url, "api_key": decrypt_api_key(config.api_key_encrypted), "models": config.models, "temperature": config.temperature, "max_tokens": config.max_tokens, }