主要更新: - 更新代理商端文档,明确项目由品牌方分配流程 - 新增Brief配置详情页(已配置)设计稿 - 完善工作台紧急待办中品牌新任务功能 - 整理Pencil设计文件中代理商端页面顺序 - 新增后端FastAPI框架及核心API - 新增前端Next.js页面和组件库 - 添加.gitignore排除构建和缓存文件 Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
315 lines
9.3 KiB
Python
315 lines
9.3 KiB
Python
"""
|
||
AI 服务配置 API
|
||
品牌方管理 AI 提供商配置、模型选择、连通性测试
|
||
"""
|
||
import asyncio
|
||
from datetime import datetime, timezone
|
||
from typing import Optional
|
||
from fastapi import APIRouter, Depends, HTTPException, Header, status
|
||
from sqlalchemy import select
|
||
from sqlalchemy.ext.asyncio import AsyncSession
|
||
|
||
from app.database import get_db
|
||
from app.models.ai_config import AIConfig
|
||
from app.models.tenant import Tenant
|
||
from app.schemas.ai_config import (
|
||
AIProvider,
|
||
AIConfigUpdate,
|
||
AIConfigResponse,
|
||
AIModelsConfig,
|
||
AIParametersConfig,
|
||
GetModelsRequest,
|
||
TestConnectionRequest,
|
||
ModelsListResponse,
|
||
ConnectionTestResponse,
|
||
ModelTestResult,
|
||
ModelInfo,
|
||
ModelCapability,
|
||
mask_api_key,
|
||
)
|
||
from app.services.ai_client import OpenAICompatibleClient
|
||
from app.services.ai_service import AIServiceFactory
|
||
from app.utils.crypto import encrypt_api_key, decrypt_api_key
|
||
|
||
router = APIRouter(prefix="/ai-config", tags=["ai-config"])
|
||
|
||
|
||
async def _ensure_tenant_exists(tenant_id: str, db: AsyncSession) -> Tenant:
|
||
"""确保租户存在,不存在则自动创建"""
|
||
result = await db.execute(
|
||
select(Tenant).where(Tenant.id == tenant_id)
|
||
)
|
||
tenant = result.scalar_one_or_none()
|
||
|
||
if not tenant:
|
||
tenant = Tenant(id=tenant_id, name=f"租户-{tenant_id}")
|
||
db.add(tenant)
|
||
await db.flush()
|
||
|
||
return tenant
|
||
|
||
|
||
@router.get("", response_model=AIConfigResponse)
|
||
async def get_ai_config(
|
||
x_tenant_id: str = Header(..., alias="X-Tenant-ID"),
|
||
db: AsyncSession = Depends(get_db),
|
||
) -> AIConfigResponse:
|
||
"""
|
||
获取当前 AI 配置
|
||
|
||
- 未配置返回 404
|
||
- 已配置返回配置信息(API Key 脱敏)
|
||
"""
|
||
result = await db.execute(
|
||
select(AIConfig).where(
|
||
AIConfig.tenant_id == x_tenant_id,
|
||
AIConfig.is_configured == True,
|
||
)
|
||
)
|
||
config = result.scalar_one_or_none()
|
||
|
||
if not config:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_404_NOT_FOUND,
|
||
detail="AI 服务未配置,请先完成配置",
|
||
)
|
||
|
||
# 解密 API Key 用于脱敏显示
|
||
api_key = decrypt_api_key(config.api_key_encrypted)
|
||
|
||
return AIConfigResponse(
|
||
provider=config.provider,
|
||
base_url=config.base_url,
|
||
api_key_masked=mask_api_key(api_key),
|
||
models=AIModelsConfig(**config.models),
|
||
parameters=AIParametersConfig(
|
||
temperature=config.temperature,
|
||
max_tokens=config.max_tokens,
|
||
),
|
||
available_models=config.available_models or {},
|
||
is_configured=config.is_configured,
|
||
last_test_at=config.last_test_at.isoformat() if config.last_test_at else None,
|
||
last_test_result=config.last_test_result,
|
||
)
|
||
|
||
|
||
@router.put("", response_model=AIConfigResponse)
|
||
async def update_ai_config(
|
||
request: AIConfigUpdate,
|
||
x_tenant_id: str = Header(..., alias="X-Tenant-ID"),
|
||
db: AsyncSession = Depends(get_db),
|
||
) -> AIConfigResponse:
|
||
"""
|
||
更新 AI 配置
|
||
|
||
- 保存提供商、连接信息、模型配置
|
||
- API Key 加密存储
|
||
"""
|
||
# 确保租户存在
|
||
await _ensure_tenant_exists(x_tenant_id, db)
|
||
|
||
# 加密 API Key
|
||
api_key_encrypted = encrypt_api_key(request.api_key)
|
||
|
||
# 创建或更新配置
|
||
config = await AIServiceFactory.create_or_update_config(
|
||
tenant_id=x_tenant_id,
|
||
provider=request.provider.value,
|
||
base_url=request.base_url,
|
||
api_key_encrypted=api_key_encrypted,
|
||
models=request.models.model_dump(),
|
||
temperature=request.parameters.temperature,
|
||
max_tokens=request.parameters.max_tokens,
|
||
db=db,
|
||
)
|
||
|
||
return AIConfigResponse(
|
||
provider=config.provider,
|
||
base_url=config.base_url,
|
||
api_key_masked=mask_api_key(request.api_key),
|
||
models=AIModelsConfig(**config.models),
|
||
parameters=AIParametersConfig(
|
||
temperature=config.temperature,
|
||
max_tokens=config.max_tokens,
|
||
),
|
||
available_models=config.available_models or {},
|
||
is_configured=True,
|
||
last_test_at=config.last_test_at.isoformat() if config.last_test_at else None,
|
||
last_test_result=config.last_test_result,
|
||
)
|
||
|
||
|
||
@router.post("/models", response_model=ModelsListResponse)
|
||
async def get_available_models(
|
||
request: GetModelsRequest,
|
||
x_tenant_id: str = Header(..., alias="X-Tenant-ID"),
|
||
db: AsyncSession = Depends(get_db),
|
||
) -> ModelsListResponse:
|
||
"""
|
||
获取可用模型列表
|
||
|
||
- 调用提供商 API 获取模型列表
|
||
- 按能力分类(text/vision/audio)
|
||
"""
|
||
try:
|
||
client = OpenAICompatibleClient(
|
||
base_url=request.base_url,
|
||
api_key=request.api_key,
|
||
provider=request.provider.value,
|
||
)
|
||
|
||
models_dict = await client.list_models()
|
||
await client.close()
|
||
|
||
# 转换为 ModelInfo 对象
|
||
models = {
|
||
k: [ModelInfo(**m) for m in v]
|
||
for k, v in models_dict.items()
|
||
}
|
||
|
||
# 更新配置中的可用模型缓存
|
||
result = await db.execute(
|
||
select(AIConfig).where(AIConfig.tenant_id == x_tenant_id)
|
||
)
|
||
config = result.scalar_one_or_none()
|
||
if config:
|
||
config.available_models = models_dict
|
||
await db.flush()
|
||
|
||
return ModelsListResponse(
|
||
success=True,
|
||
models=models,
|
||
)
|
||
except Exception as e:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_502_BAD_GATEWAY,
|
||
detail=f"获取模型列表失败: {str(e)}",
|
||
)
|
||
|
||
|
||
@router.post("/test", response_model=ConnectionTestResponse)
|
||
async def test_connection(
|
||
request: TestConnectionRequest,
|
||
x_tenant_id: str = Header(..., alias="X-Tenant-ID"),
|
||
db: AsyncSession = Depends(get_db),
|
||
) -> ConnectionTestResponse:
|
||
"""
|
||
测试 AI 服务连接
|
||
|
||
- 并行测试三个模型
|
||
- 返回每个模型的测试结果
|
||
"""
|
||
client = None
|
||
models = request.models.model_dump()
|
||
try:
|
||
client = OpenAICompatibleClient(
|
||
base_url=request.base_url,
|
||
api_key=request.api_key,
|
||
provider=request.provider.value,
|
||
)
|
||
|
||
# 定义模型能力映射
|
||
capability_map = {
|
||
"text": ModelCapability.TEXT,
|
||
"vision": ModelCapability.VISION,
|
||
"audio": ModelCapability.AUDIO,
|
||
}
|
||
|
||
async def test_single(model_type: str, model_id: str) -> tuple[str, ModelTestResult]:
|
||
capability = capability_map.get(model_type, ModelCapability.TEXT)
|
||
result = await client.test_connection(model_id, capability)
|
||
return model_type, ModelTestResult(
|
||
success=result.success,
|
||
latency_ms=result.latency_ms,
|
||
error=result.error,
|
||
model=model_id,
|
||
)
|
||
|
||
# 并行测试所有模型
|
||
tasks = [
|
||
test_single(model_type, model_id)
|
||
for model_type, model_id in models.items()
|
||
]
|
||
results_list = await asyncio.gather(*tasks)
|
||
results = {model_type: result for model_type, result in results_list}
|
||
|
||
# 计算测试结果
|
||
all_success = all(r.success for r in results.values())
|
||
failed_count = sum(1 for r in results.values() if not r.success)
|
||
|
||
if all_success:
|
||
message = "所有模型连接成功"
|
||
else:
|
||
message = f"{failed_count} 个模型连接失败,请检查模型名称或 API 权限"
|
||
|
||
response = ConnectionTestResponse(
|
||
success=all_success,
|
||
results=results,
|
||
message=message,
|
||
)
|
||
except Exception as exc:
|
||
# 确保接口返回 200,并返回失败详情
|
||
results = {
|
||
model_type: ModelTestResult(
|
||
success=False,
|
||
latency_ms=0,
|
||
error=str(exc),
|
||
model=model_id,
|
||
)
|
||
for model_type, model_id in models.items()
|
||
}
|
||
response = ConnectionTestResponse(
|
||
success=False,
|
||
results=results,
|
||
message=f"连接测试失败: {str(exc)}",
|
||
)
|
||
finally:
|
||
if client is not None:
|
||
try:
|
||
await client.close()
|
||
except Exception:
|
||
pass
|
||
|
||
# 保存测试结果到数据库
|
||
db_result = await db.execute(
|
||
select(AIConfig).where(AIConfig.tenant_id == x_tenant_id)
|
||
)
|
||
config = db_result.scalar_one_or_none()
|
||
if config:
|
||
config.last_test_at = datetime.now(timezone.utc)
|
||
config.last_test_result = {
|
||
k: v.model_dump() for k, v in response.results.items()
|
||
}
|
||
await db.flush()
|
||
|
||
return response
|
||
|
||
|
||
# ==================== 供其他模块调用 ====================
|
||
|
||
async def get_ai_config_for_tenant(
|
||
tenant_id: str,
|
||
db: AsyncSession,
|
||
) -> Optional[dict]:
|
||
"""获取租户的 AI 配置(供审核服务调用)"""
|
||
result = await db.execute(
|
||
select(AIConfig).where(
|
||
AIConfig.tenant_id == tenant_id,
|
||
AIConfig.is_configured == True,
|
||
)
|
||
)
|
||
config = result.scalar_one_or_none()
|
||
|
||
if not config:
|
||
return None
|
||
|
||
return {
|
||
"tenant_id": config.tenant_id,
|
||
"provider": config.provider,
|
||
"base_url": config.base_url,
|
||
"api_key": decrypt_api_key(config.api_key_encrypted),
|
||
"models": config.models,
|
||
"temperature": config.temperature,
|
||
"max_tokens": config.max_tokens,
|
||
}
|