| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277 |
- #!/usr/bin/env python3
- # -*- coding: utf-8 -*-
- """
- OpenRouter API 客户端
- 支持文本和多模态(图片)任务
- """
- import os
- import json
- import requests
- import logging
- import time
- from typing import List, Dict, Any, Optional
- logger = logging.getLogger(__name__)
- class OpenRouterClient:
- """OpenRouter API客户端"""
- def __init__(
- self,
- api_key: Optional[str] = None,
- model: str = "google/gemini-2.5-flash",
- max_tokens: int = 8192,
- temperature: float = 0.3,
- retry_delay: int = 3
- ):
- """
- 初始化客户端
- Args:
- api_key: API密钥,默认从环境变量读取
- model: 模型名称
- max_tokens: 最大token数
- temperature: 温度参数
- retry_delay: 默认重试延迟(秒)
- """
- self.api_key = api_key or os.getenv("OPENROUTER_API_KEY")
- if not self.api_key:
- raise ValueError("OPENROUTER_API_KEY not found in environment variables")
- self.base_url = "https://openrouter.ai/api/v1"
- self.model = model
- self.max_tokens = max_tokens
- self.temperature = temperature
- self.retry_delay = retry_delay
- logger.info(f"OpenRouter客户端已初始化: model={model}, max_tokens={max_tokens}, retry_delay={retry_delay}s")
- def chat(
- self,
- prompt: str,
- images: Optional[List[str]] = None,
- system_prompt: Optional[str] = None,
- max_retries: int = 3,
- retry_delay: Optional[int] = None
- ) -> Dict[str, Any]:
- """
- 调用LLM进行对话
- Args:
- prompt: 用户提示词
- images: 图片URL列表(可选,用于多模态任务)
- system_prompt: 系统提示词(可选)
- max_retries: 最大重试次数
- retry_delay: 重试延迟(秒),None则使用实例默认值
- Returns:
- LLM响应
- """
- # 使用实例默认retry_delay(如果未指定)
- if retry_delay is None:
- retry_delay = self.retry_delay
- # 构建消息
- messages = []
- # 添加系统提示词
- if system_prompt:
- messages.append({
- "role": "system",
- "content": system_prompt
- })
- # 构建用户消息
- if images:
- # 多模态消息
- content = [{"type": "text", "text": prompt}]
- for img_url in images:
- content.append({
- "type": "image_url",
- "image_url": {"url": img_url}
- })
- messages.append({
- "role": "user",
- "content": content
- })
- else:
- # 纯文本消息
- messages.append({
- "role": "user",
- "content": prompt
- })
- # 构建请求
- payload = {
- "model": self.model,
- "messages": messages,
- "max_tokens": self.max_tokens,
- "temperature": self.temperature
- }
- headers = {
- "Authorization": f"Bearer {self.api_key}",
- "Content-Type": "application/json"
- }
- # 重试循环
- last_exception = None
- for attempt in range(1, max_retries + 1):
- try:
- if attempt > 1:
- logger.info(f" 重试第 {attempt - 1}/{max_retries - 1} 次")
- time.sleep(retry_delay)
- response = requests.post(
- f"{self.base_url}/chat/completions",
- json=payload,
- headers=headers,
- timeout=60
- )
- response.raise_for_status()
- result = response.json()
- # 提取响应内容
- if "choices" in result and len(result["choices"]) > 0:
- content = result["choices"][0]["message"]["content"]
- # 尝试解析JSON
- try:
- # 如果响应是JSON格式,解析它
- if content.strip().startswith('{'):
- parsed = json.loads(content)
- return {
- "success": True,
- "content": content,
- "parsed": parsed,
- "raw_response": result
- }
- except json.JSONDecodeError:
- pass
- return {
- "success": True,
- "content": content,
- "raw_response": result
- }
- else:
- raise Exception(f"Invalid API response: {result}")
- except requests.exceptions.RequestException as e:
- last_exception = e
- logger.error(f" API调用失败 (第{attempt}次尝试): {e}")
- if attempt >= max_retries:
- logger.error(f" 已达最大重试次数 {max_retries}")
- # 所有重试都失败
- return {
- "success": False,
- "error": str(last_exception),
- "content": None
- }
- def chat_json(
- self,
- prompt: str,
- images: Optional[List[str]] = None,
- system_prompt: Optional[str] = None,
- max_retries: int = 3
- ) -> Optional[Dict[str, Any]]:
- """
- 调用LLM并期望返回JSON格式
- Args:
- prompt: 用户提示词(应包含返回JSON的指示)
- images: 图片URL列表
- system_prompt: 系统提示词
- max_retries: 最大重试次数
- Returns:
- 解析后的JSON对象,失败返回None
- """
- result = self.chat(
- prompt=prompt,
- images=images,
- system_prompt=system_prompt,
- max_retries=max_retries
- )
- if not result["success"]:
- logger.error(f"LLM调用失败: {result.get('error')}")
- return None
- # 如果已经解析了JSON
- if "parsed" in result:
- return result["parsed"]
- # 尝试从content中解析JSON
- content = result["content"]
- # 尝试提取JSON(可能包含在markdown代码块中)
- if "```json" in content:
- # 提取代码块中的JSON
- start = content.find("```json") + 7
- end = content.find("```", start)
- json_str = content[start:end].strip()
- elif "```" in content:
- # 普通代码块
- start = content.find("```") + 3
- end = content.find("```", start)
- json_str = content[start:end].strip()
- else:
- # 直接尝试解析
- json_str = content.strip()
- try:
- return json.loads(json_str)
- except json.JSONDecodeError as e:
- logger.error(f"JSON解析失败: {e}")
- logger.error(f"原始内容: {content[:500]}")
- return None
- def test_client():
- """测试客户端"""
- # 需要设置环境变量 OPENROUTER_API_KEY
- client = OpenRouterClient()
- # 测试文本任务
- print("\n=== 测试文本任务 ===")
- result = client.chat_json(
- prompt="""
- 评估搜索词"猫咪 宠物"能否找到包含"拟人"相关元素的内容。
- 返回JSON格式:
- {
- "score": 0.0-1.0,
- "reasoning": "评估理由"
- }
- """
- )
- print(json.dumps(result, ensure_ascii=False, indent=2))
- # 测试多模态任务
- print("\n=== 测试多模态任务 ===")
- result = client.chat_json(
- prompt="""
- 这张图片中是否包含与"拟人"相关的元素?
- 返回JSON格式:
- {
- "has_element": true/false,
- "elements": ["元素1", "元素2"],
- "reasoning": "理由"
- }
- """,
- images=["http://example.com/cat.jpg"] # 示例图片
- )
- print(json.dumps(result, ensure_ascii=False, indent=2))
- if __name__ == "__main__":
- logging.basicConfig(level=logging.INFO)
- test_client()
|