openrouter_client.py 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. OpenRouter API 客户端
  5. 支持文本和多模态(图片)任务
  6. """
  7. import os
  8. import json
  9. import requests
  10. import logging
  11. import time
  12. from typing import List, Dict, Any, Optional
  13. logger = logging.getLogger(__name__)
  14. class OpenRouterClient:
  15. """OpenRouter API客户端"""
  16. def __init__(
  17. self,
  18. api_key: Optional[str] = None,
  19. model: str = "google/gemini-2.5-flash",
  20. max_tokens: int = 8192,
  21. temperature: float = 0.3,
  22. retry_delay: int = 3
  23. ):
  24. """
  25. 初始化客户端
  26. Args:
  27. api_key: API密钥,默认从环境变量读取
  28. model: 模型名称
  29. max_tokens: 最大token数
  30. temperature: 温度参数
  31. retry_delay: 默认重试延迟(秒)
  32. """
  33. self.api_key = api_key or os.getenv("OPENROUTER_API_KEY")
  34. if not self.api_key:
  35. raise ValueError("OPENROUTER_API_KEY not found in environment variables")
  36. self.base_url = "https://openrouter.ai/api/v1"
  37. self.model = model
  38. self.max_tokens = max_tokens
  39. self.temperature = temperature
  40. self.retry_delay = retry_delay
  41. logger.info(f"OpenRouter客户端已初始化: model={model}, max_tokens={max_tokens}, retry_delay={retry_delay}s")
  42. def chat(
  43. self,
  44. prompt: str,
  45. images: Optional[List[str]] = None,
  46. system_prompt: Optional[str] = None,
  47. max_retries: int = 3,
  48. retry_delay: Optional[int] = None
  49. ) -> Dict[str, Any]:
  50. """
  51. 调用LLM进行对话
  52. Args:
  53. prompt: 用户提示词
  54. images: 图片URL列表(可选,用于多模态任务)
  55. system_prompt: 系统提示词(可选)
  56. max_retries: 最大重试次数
  57. retry_delay: 重试延迟(秒),None则使用实例默认值
  58. Returns:
  59. LLM响应
  60. """
  61. # 使用实例默认retry_delay(如果未指定)
  62. if retry_delay is None:
  63. retry_delay = self.retry_delay
  64. # 构建消息
  65. messages = []
  66. # 添加系统提示词
  67. if system_prompt:
  68. messages.append({
  69. "role": "system",
  70. "content": system_prompt
  71. })
  72. # 构建用户消息
  73. if images:
  74. # 多模态消息
  75. content = [{"type": "text", "text": prompt}]
  76. for img_url in images:
  77. content.append({
  78. "type": "image_url",
  79. "image_url": {"url": img_url}
  80. })
  81. messages.append({
  82. "role": "user",
  83. "content": content
  84. })
  85. else:
  86. # 纯文本消息
  87. messages.append({
  88. "role": "user",
  89. "content": prompt
  90. })
  91. # 构建请求
  92. payload = {
  93. "model": self.model,
  94. "messages": messages,
  95. "max_tokens": self.max_tokens,
  96. "temperature": self.temperature
  97. }
  98. headers = {
  99. "Authorization": f"Bearer {self.api_key}",
  100. "Content-Type": "application/json"
  101. }
  102. # 重试循环
  103. last_exception = None
  104. for attempt in range(1, max_retries + 1):
  105. try:
  106. if attempt > 1:
  107. logger.info(f" 重试第 {attempt - 1}/{max_retries - 1} 次")
  108. time.sleep(retry_delay)
  109. response = requests.post(
  110. f"{self.base_url}/chat/completions",
  111. json=payload,
  112. headers=headers,
  113. timeout=60
  114. )
  115. response.raise_for_status()
  116. result = response.json()
  117. # 提取响应内容
  118. if "choices" in result and len(result["choices"]) > 0:
  119. content = result["choices"][0]["message"]["content"]
  120. # 尝试解析JSON
  121. try:
  122. # 如果响应是JSON格式,解析它
  123. if content.strip().startswith('{'):
  124. parsed = json.loads(content)
  125. return {
  126. "success": True,
  127. "content": content,
  128. "parsed": parsed,
  129. "raw_response": result
  130. }
  131. except json.JSONDecodeError:
  132. pass
  133. return {
  134. "success": True,
  135. "content": content,
  136. "raw_response": result
  137. }
  138. else:
  139. raise Exception(f"Invalid API response: {result}")
  140. except requests.exceptions.RequestException as e:
  141. last_exception = e
  142. logger.error(f" API调用失败 (第{attempt}次尝试): {e}")
  143. if attempt >= max_retries:
  144. logger.error(f" 已达最大重试次数 {max_retries}")
  145. # 所有重试都失败
  146. return {
  147. "success": False,
  148. "error": str(last_exception),
  149. "content": None
  150. }
  151. def chat_json(
  152. self,
  153. prompt: str,
  154. images: Optional[List[str]] = None,
  155. system_prompt: Optional[str] = None,
  156. max_retries: int = 3
  157. ) -> Optional[Dict[str, Any]]:
  158. """
  159. 调用LLM并期望返回JSON格式
  160. Args:
  161. prompt: 用户提示词(应包含返回JSON的指示)
  162. images: 图片URL列表
  163. system_prompt: 系统提示词
  164. max_retries: 最大重试次数
  165. Returns:
  166. 解析后的JSON对象,失败返回None
  167. """
  168. result = self.chat(
  169. prompt=prompt,
  170. images=images,
  171. system_prompt=system_prompt,
  172. max_retries=max_retries
  173. )
  174. if not result["success"]:
  175. logger.error(f"LLM调用失败: {result.get('error')}")
  176. return None
  177. # 如果已经解析了JSON
  178. if "parsed" in result:
  179. return result["parsed"]
  180. # 尝试从content中解析JSON
  181. content = result["content"]
  182. # 尝试提取JSON(可能包含在markdown代码块中)
  183. if "```json" in content:
  184. # 提取代码块中的JSON
  185. start = content.find("```json") + 7
  186. end = content.find("```", start)
  187. json_str = content[start:end].strip()
  188. elif "```" in content:
  189. # 普通代码块
  190. start = content.find("```") + 3
  191. end = content.find("```", start)
  192. json_str = content[start:end].strip()
  193. else:
  194. # 直接尝试解析
  195. json_str = content.strip()
  196. try:
  197. return json.loads(json_str)
  198. except json.JSONDecodeError as e:
  199. logger.error(f"JSON解析失败: {e}")
  200. logger.error(f"原始内容: {content[:500]}")
  201. return None
  202. def test_client():
  203. """测试客户端"""
  204. # 需要设置环境变量 OPENROUTER_API_KEY
  205. client = OpenRouterClient()
  206. # 测试文本任务
  207. print("\n=== 测试文本任务 ===")
  208. result = client.chat_json(
  209. prompt="""
  210. 评估搜索词"猫咪 宠物"能否找到包含"拟人"相关元素的内容。
  211. 返回JSON格式:
  212. {
  213. "score": 0.0-1.0,
  214. "reasoning": "评估理由"
  215. }
  216. """
  217. )
  218. print(json.dumps(result, ensure_ascii=False, indent=2))
  219. # 测试多模态任务
  220. print("\n=== 测试多模态任务 ===")
  221. result = client.chat_json(
  222. prompt="""
  223. 这张图片中是否包含与"拟人"相关的元素?
  224. 返回JSON格式:
  225. {
  226. "has_element": true/false,
  227. "elements": ["元素1", "元素2"],
  228. "reasoning": "理由"
  229. }
  230. """,
  231. images=["http://example.com/cat.jpg"] # 示例图片
  232. )
  233. print(json.dumps(result, ensure_ascii=False, indent=2))
  234. if __name__ == "__main__":
  235. logging.basicConfig(level=logging.INFO)
  236. test_client()