""" OpenAI 兼容 AI 客户端 支持多种 AI 提供商的统一接口 """ import asyncio import time from typing import Optional from dataclasses import dataclass import httpx from openai import AsyncOpenAI from app.schemas.ai_config import AIProvider, ModelCapability @dataclass class AIResponse: """AI 响应""" content: str model: str usage: dict finish_reason: str @dataclass class ConnectionTestResult: """连接测试结果""" success: bool latency_ms: int error: Optional[str] = None class OpenAICompatibleClient: """ OpenAI 兼容 API 客户端 支持: - OpenAI - Azure OpenAI - Anthropic (通过 OpenAI 兼容层) - DeepSeek - Qwen (通义千问) - Doubao (豆包) - 各种中转服务 (OneAPI, OpenRouter) """ def __init__( self, base_url: str, api_key: str, provider: str = "openai", timeout: float = 180.0, ): self.base_url = base_url.rstrip("/") self.api_key = api_key self.provider = provider self.timeout = timeout # 创建 OpenAI 客户端 self.client = AsyncOpenAI( base_url=self.base_url, api_key=self.api_key, timeout=timeout, ) async def chat_completion( self, messages: list[dict], model: str, temperature: float = 0.7, max_tokens: int = 2000, **kwargs, ) -> AIResponse: """ 聊天补全 Args: messages: 消息列表 [{"role": "user", "content": "..."}] model: 模型名称 temperature: 温度参数 max_tokens: 最大 token 数 Returns: AIResponse 包含生成的内容 """ response = await self.client.chat.completions.create( model=model, messages=messages, temperature=temperature, max_tokens=max_tokens, **kwargs, ) choice = response.choices[0] return AIResponse( content=choice.message.content or "", model=response.model, usage={ "prompt_tokens": response.usage.prompt_tokens if response.usage else 0, "completion_tokens": response.usage.completion_tokens if response.usage else 0, "total_tokens": response.usage.total_tokens if response.usage else 0, }, finish_reason=choice.finish_reason or "stop", ) async def vision_analysis( self, image_urls: list[str], prompt: str, model: str, temperature: float = 0.3, max_tokens: int = 2000, ) -> AIResponse: """ 视觉分析(图像理解) Args: image_urls: 图像 URL 列表 prompt: 分析提示 model: 视觉模型名称 Returns: AIResponse 包含分析结果 """ # 构建多模态消息 content = [{"type": "text", "text": prompt}] for url in image_urls: content.append({ "type": "image_url", "image_url": {"url": url}, }) messages = [{"role": "user", "content": content}] return await self.chat_completion( messages=messages, model=model, temperature=temperature, max_tokens=max_tokens, ) async def audio_transcription( self, audio_url: str, model: str = "whisper-1", language: str = "zh", ) -> AIResponse: """ 音频转写 (ASR) Args: audio_url: 音频文件 URL model: 转写模型 language: 语言代码 Returns: AIResponse 包含转写文本 """ # 下载音频文件 async with httpx.AsyncClient() as http_client: response = await http_client.get(audio_url, timeout=30) response.raise_for_status() audio_data = response.content # 调用 Whisper API transcription = await self.client.audio.transcriptions.create( model=model, file=("audio.mp3", audio_data, "audio/mpeg"), language=language, ) return AIResponse( content=transcription.text, model=model, usage={"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}, finish_reason="stop", ) async def test_connection( self, model: str, capability: ModelCapability = ModelCapability.TEXT, ) -> ConnectionTestResult: """ 测试模型连接 Args: model: 模型名称 capability: 模型能力类型 Returns: ConnectionTestResult 包含测试结果 """ start_time = time.time() try: if capability == ModelCapability.AUDIO: # 音频模型无法简单测试,只验证 API 可达 async with httpx.AsyncClient() as http_client: response = await http_client.get( f"{self.base_url}/models", headers={"Authorization": f"Bearer {self.api_key}"}, timeout=10, ) response.raise_for_status() latency_ms = int((time.time() - start_time) * 1000) return ConnectionTestResult(success=True, latency_ms=latency_ms) elif capability == ModelCapability.VISION: # 视觉模型测试:发送简单的文本请求 response = await self.chat_completion( messages=[{"role": "user", "content": "Hi"}], model=model, max_tokens=5, ) else: # 文本模型测试 response = await self.chat_completion( messages=[{"role": "user", "content": "Hi"}], model=model, max_tokens=5, ) latency_ms = int((time.time() - start_time) * 1000) return ConnectionTestResult(success=True, latency_ms=latency_ms) except Exception as e: latency_ms = int((time.time() - start_time) * 1000) return ConnectionTestResult( success=False, latency_ms=latency_ms, error=str(e), ) async def list_models(self) -> dict[str, list[dict]]: """ 获取可用模型列表 Returns: 按能力分类的模型列表 {"text": [...], "vision": [...], "audio": [...]} """ try: models = await self.client.models.list() # 已知模型能力映射 known_capabilities = { # OpenAI "gpt-4o": ["text", "vision"], "gpt-4o-mini": ["text", "vision"], "gpt-4-turbo": ["text", "vision"], "gpt-4": ["text"], "gpt-3.5-turbo": ["text"], "whisper-1": ["audio"], # Claude (通过兼容层) "claude-3-opus": ["text", "vision"], "claude-3-sonnet": ["text", "vision"], "claude-3-haiku": ["text", "vision"], # DeepSeek "deepseek-chat": ["text"], "deepseek-coder": ["text"], # Qwen "qwen-turbo": ["text"], "qwen-plus": ["text"], "qwen-max": ["text"], "qwen-vl-plus": ["vision"], "qwen-vl-max": ["vision"], # Doubao "doubao-pro": ["text"], "doubao-lite": ["text"], } result: dict[str, list[dict]] = { "text": [], "vision": [], "audio": [], } for model in models.data: model_id = model.id capabilities = known_capabilities.get(model_id, ["text"]) for cap in capabilities: if cap in result: result[cap].append({ "id": model_id, "name": model_id.replace("-", " ").title(), }) return result except Exception: # 如果无法获取模型列表,返回预设列表 return { "text": [ {"id": "gpt-4o", "name": "GPT-4o"}, {"id": "gpt-4o-mini", "name": "GPT-4o Mini"}, {"id": "deepseek-chat", "name": "DeepSeek Chat"}, ], "vision": [ {"id": "gpt-4o", "name": "GPT-4o"}, {"id": "qwen-vl-max", "name": "Qwen VL Max"}, ], "audio": [ {"id": "whisper-1", "name": "Whisper"}, ], } async def close(self): """关闭客户端""" try: await self.client.close() except Exception: # 关闭失败不应影响主流程 pass # 便捷函数 async def create_ai_client( base_url: str, api_key: str, provider: str = "openai", ) -> OpenAICompatibleClient: """创建 AI 客户端""" return OpenAICompatibleClient( base_url=base_url, api_key=api_key, provider=provider, )