Your Name e4959d584f feat: 完善代理商端业务逻辑与前后端框架
主要更新:
- 更新代理商端文档,明确项目由品牌方分配流程
- 新增Brief配置详情页(已配置)设计稿
- 完善工作台紧急待办中品牌新任务功能
- 整理Pencil设计文件中代理商端页面顺序
- 新增后端FastAPI框架及核心API
- 新增前端Next.js页面和组件库
- 添加.gitignore排除构建和缓存文件

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-05 19:27:31 +08:00

315 lines
9.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
AI 服务配置 API
品牌方管理 AI 提供商配置、模型选择、连通性测试
"""
import asyncio
from datetime import datetime, timezone
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, Header, status
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.database import get_db
from app.models.ai_config import AIConfig
from app.models.tenant import Tenant
from app.schemas.ai_config import (
AIProvider,
AIConfigUpdate,
AIConfigResponse,
AIModelsConfig,
AIParametersConfig,
GetModelsRequest,
TestConnectionRequest,
ModelsListResponse,
ConnectionTestResponse,
ModelTestResult,
ModelInfo,
ModelCapability,
mask_api_key,
)
from app.services.ai_client import OpenAICompatibleClient
from app.services.ai_service import AIServiceFactory
from app.utils.crypto import encrypt_api_key, decrypt_api_key
router = APIRouter(prefix="/ai-config", tags=["ai-config"])
async def _ensure_tenant_exists(tenant_id: str, db: AsyncSession) -> Tenant:
"""确保租户存在,不存在则自动创建"""
result = await db.execute(
select(Tenant).where(Tenant.id == tenant_id)
)
tenant = result.scalar_one_or_none()
if not tenant:
tenant = Tenant(id=tenant_id, name=f"租户-{tenant_id}")
db.add(tenant)
await db.flush()
return tenant
@router.get("", response_model=AIConfigResponse)
async def get_ai_config(
x_tenant_id: str = Header(..., alias="X-Tenant-ID"),
db: AsyncSession = Depends(get_db),
) -> AIConfigResponse:
"""
获取当前 AI 配置
- 未配置返回 404
- 已配置返回配置信息API Key 脱敏)
"""
result = await db.execute(
select(AIConfig).where(
AIConfig.tenant_id == x_tenant_id,
AIConfig.is_configured == True,
)
)
config = result.scalar_one_or_none()
if not config:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="AI 服务未配置,请先完成配置",
)
# 解密 API Key 用于脱敏显示
api_key = decrypt_api_key(config.api_key_encrypted)
return AIConfigResponse(
provider=config.provider,
base_url=config.base_url,
api_key_masked=mask_api_key(api_key),
models=AIModelsConfig(**config.models),
parameters=AIParametersConfig(
temperature=config.temperature,
max_tokens=config.max_tokens,
),
available_models=config.available_models or {},
is_configured=config.is_configured,
last_test_at=config.last_test_at.isoformat() if config.last_test_at else None,
last_test_result=config.last_test_result,
)
@router.put("", response_model=AIConfigResponse)
async def update_ai_config(
request: AIConfigUpdate,
x_tenant_id: str = Header(..., alias="X-Tenant-ID"),
db: AsyncSession = Depends(get_db),
) -> AIConfigResponse:
"""
更新 AI 配置
- 保存提供商、连接信息、模型配置
- API Key 加密存储
"""
# 确保租户存在
await _ensure_tenant_exists(x_tenant_id, db)
# 加密 API Key
api_key_encrypted = encrypt_api_key(request.api_key)
# 创建或更新配置
config = await AIServiceFactory.create_or_update_config(
tenant_id=x_tenant_id,
provider=request.provider.value,
base_url=request.base_url,
api_key_encrypted=api_key_encrypted,
models=request.models.model_dump(),
temperature=request.parameters.temperature,
max_tokens=request.parameters.max_tokens,
db=db,
)
return AIConfigResponse(
provider=config.provider,
base_url=config.base_url,
api_key_masked=mask_api_key(request.api_key),
models=AIModelsConfig(**config.models),
parameters=AIParametersConfig(
temperature=config.temperature,
max_tokens=config.max_tokens,
),
available_models=config.available_models or {},
is_configured=True,
last_test_at=config.last_test_at.isoformat() if config.last_test_at else None,
last_test_result=config.last_test_result,
)
@router.post("/models", response_model=ModelsListResponse)
async def get_available_models(
request: GetModelsRequest,
x_tenant_id: str = Header(..., alias="X-Tenant-ID"),
db: AsyncSession = Depends(get_db),
) -> ModelsListResponse:
"""
获取可用模型列表
- 调用提供商 API 获取模型列表
- 按能力分类text/vision/audio
"""
try:
client = OpenAICompatibleClient(
base_url=request.base_url,
api_key=request.api_key,
provider=request.provider.value,
)
models_dict = await client.list_models()
await client.close()
# 转换为 ModelInfo 对象
models = {
k: [ModelInfo(**m) for m in v]
for k, v in models_dict.items()
}
# 更新配置中的可用模型缓存
result = await db.execute(
select(AIConfig).where(AIConfig.tenant_id == x_tenant_id)
)
config = result.scalar_one_or_none()
if config:
config.available_models = models_dict
await db.flush()
return ModelsListResponse(
success=True,
models=models,
)
except Exception as e:
raise HTTPException(
status_code=status.HTTP_502_BAD_GATEWAY,
detail=f"获取模型列表失败: {str(e)}",
)
@router.post("/test", response_model=ConnectionTestResponse)
async def test_connection(
request: TestConnectionRequest,
x_tenant_id: str = Header(..., alias="X-Tenant-ID"),
db: AsyncSession = Depends(get_db),
) -> ConnectionTestResponse:
"""
测试 AI 服务连接
- 并行测试三个模型
- 返回每个模型的测试结果
"""
client = None
models = request.models.model_dump()
try:
client = OpenAICompatibleClient(
base_url=request.base_url,
api_key=request.api_key,
provider=request.provider.value,
)
# 定义模型能力映射
capability_map = {
"text": ModelCapability.TEXT,
"vision": ModelCapability.VISION,
"audio": ModelCapability.AUDIO,
}
async def test_single(model_type: str, model_id: str) -> tuple[str, ModelTestResult]:
capability = capability_map.get(model_type, ModelCapability.TEXT)
result = await client.test_connection(model_id, capability)
return model_type, ModelTestResult(
success=result.success,
latency_ms=result.latency_ms,
error=result.error,
model=model_id,
)
# 并行测试所有模型
tasks = [
test_single(model_type, model_id)
for model_type, model_id in models.items()
]
results_list = await asyncio.gather(*tasks)
results = {model_type: result for model_type, result in results_list}
# 计算测试结果
all_success = all(r.success for r in results.values())
failed_count = sum(1 for r in results.values() if not r.success)
if all_success:
message = "所有模型连接成功"
else:
message = f"{failed_count} 个模型连接失败,请检查模型名称或 API 权限"
response = ConnectionTestResponse(
success=all_success,
results=results,
message=message,
)
except Exception as exc:
# 确保接口返回 200并返回失败详情
results = {
model_type: ModelTestResult(
success=False,
latency_ms=0,
error=str(exc),
model=model_id,
)
for model_type, model_id in models.items()
}
response = ConnectionTestResponse(
success=False,
results=results,
message=f"连接测试失败: {str(exc)}",
)
finally:
if client is not None:
try:
await client.close()
except Exception:
pass
# 保存测试结果到数据库
db_result = await db.execute(
select(AIConfig).where(AIConfig.tenant_id == x_tenant_id)
)
config = db_result.scalar_one_or_none()
if config:
config.last_test_at = datetime.now(timezone.utc)
config.last_test_result = {
k: v.model_dump() for k, v in response.results.items()
}
await db.flush()
return response
# ==================== 供其他模块调用 ====================
async def get_ai_config_for_tenant(
tenant_id: str,
db: AsyncSession,
) -> Optional[dict]:
"""获取租户的 AI 配置(供审核服务调用)"""
result = await db.execute(
select(AIConfig).where(
AIConfig.tenant_id == tenant_id,
AIConfig.is_configured == True,
)
)
config = result.scalar_one_or_none()
if not config:
return None
return {
"tenant_id": config.tenant_id,
"provider": config.provider,
"base_url": config.base_url,
"api_key": decrypt_api_key(config.api_key_encrypted),
"models": config.models,
"temperature": config.temperature,
"max_tokens": config.max_tokens,
}