gemini.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447
  1. """
  2. Gemini Provider (HTTP API)
  3. 使用 httpx 直接调用 Gemini REST API,避免 google-generativeai SDK 的兼容性问题
  4. 参考:Resonote/llm/providers/gemini.py
  5. """
  6. import os
  7. import json
  8. import sys
  9. import httpx
  10. from typing import List, Dict, Any, Optional
  11. def _dump_llm_request(endpoint: str, payload: Dict[str, Any], model: str):
  12. """
  13. Dump完整的LLM请求用于调试(需要设置 AGENT_DEBUG=1)
  14. 特别处理:
  15. - 图片base64数据:只显示前50字符 + 长度信息
  16. - Tools schema:完整显示
  17. - 输出到stderr,避免污染正常输出
  18. """
  19. if not os.getenv("AGENT_DEBUG"):
  20. return
  21. def truncate_images(obj):
  22. """递归处理对象,truncate图片base64数据"""
  23. if isinstance(obj, dict):
  24. result = {}
  25. for key, value in obj.items():
  26. # 处理 inline_data 中的 base64 图片
  27. if key == "inline_data" and isinstance(value, dict):
  28. mime_type = value.get("mime_type", "unknown")
  29. data = value.get("data", "")
  30. data_size_kb = len(data) / 1024 if data else 0
  31. result[key] = {
  32. "mime_type": mime_type,
  33. "data": f"<BASE64_IMAGE: {data_size_kb:.1f}KB, preview: {data[:50]}...>"
  34. }
  35. else:
  36. result[key] = truncate_images(value)
  37. return result
  38. elif isinstance(obj, list):
  39. return [truncate_images(item) for item in obj]
  40. else:
  41. return obj
  42. # 构造完整的调试信息
  43. debug_info = {
  44. "endpoint": endpoint,
  45. "model": model,
  46. "payload": truncate_images(payload)
  47. }
  48. # 输出到stderr
  49. print("\n" + "="*80, file=sys.stderr)
  50. print("[AGENT_DEBUG] LLM Request Dump", file=sys.stderr)
  51. print("="*80, file=sys.stderr)
  52. print(json.dumps(debug_info, indent=2, ensure_ascii=False), file=sys.stderr)
  53. print("="*80 + "\n", file=sys.stderr)
  54. def _convert_messages_to_gemini(messages: List[Dict]) -> tuple[List[Dict], Optional[str]]:
  55. """
  56. 将 OpenAI 格式消息转换为 Gemini 格式
  57. Returns:
  58. (gemini_contents, system_instruction)
  59. """
  60. contents = []
  61. system_instruction = None
  62. tool_parts_buffer = []
  63. def flush_tool_buffer():
  64. """合并连续的 tool 消息为单个 user 消息"""
  65. if tool_parts_buffer:
  66. contents.append({
  67. "role": "user",
  68. "parts": tool_parts_buffer.copy()
  69. })
  70. tool_parts_buffer.clear()
  71. for msg in messages:
  72. role = msg.get("role")
  73. # System 消息 -> system_instruction
  74. if role == "system":
  75. system_instruction = msg.get("content", "")
  76. continue
  77. # Tool 消息 -> functionResponse
  78. if role == "tool":
  79. tool_name = msg.get("name")
  80. content_text = msg.get("content", "")
  81. if not tool_name:
  82. print(f"[WARNING] Tool message missing 'name' field, skipping")
  83. continue
  84. # 尝试解析为 JSON
  85. try:
  86. parsed = json.loads(content_text) if content_text else {}
  87. if isinstance(parsed, list):
  88. response_data = {"result": parsed}
  89. else:
  90. response_data = parsed
  91. except (json.JSONDecodeError, ValueError):
  92. response_data = {"result": content_text}
  93. # 添加到 buffer
  94. tool_parts_buffer.append({
  95. "functionResponse": {
  96. "name": tool_name,
  97. "response": response_data
  98. }
  99. })
  100. continue
  101. # 非 tool 消息:先 flush buffer
  102. flush_tool_buffer()
  103. content = msg.get("content", "")
  104. tool_calls = msg.get("tool_calls")
  105. # Assistant 消息 + tool_calls
  106. if role == "assistant" and tool_calls:
  107. parts = []
  108. if content and (isinstance(content, str) and content.strip()):
  109. parts.append({"text": content})
  110. # 转换 tool_calls 为 functionCall
  111. for tc in tool_calls:
  112. func = tc.get("function", {})
  113. func_name = func.get("name", "")
  114. func_args_str = func.get("arguments", "{}")
  115. try:
  116. func_args = json.loads(func_args_str) if isinstance(func_args_str, str) else func_args_str
  117. except json.JSONDecodeError:
  118. func_args = {}
  119. parts.append({
  120. "functionCall": {
  121. "name": func_name,
  122. "args": func_args
  123. }
  124. })
  125. if parts:
  126. contents.append({
  127. "role": "model",
  128. "parts": parts
  129. })
  130. continue
  131. # 处理多模态消息(content 为数组)
  132. if isinstance(content, list):
  133. parts = []
  134. for item in content:
  135. item_type = item.get("type")
  136. # 文本部分
  137. if item_type == "text":
  138. text = item.get("text", "")
  139. if text.strip():
  140. parts.append({"text": text})
  141. # 图片部分(OpenAI format -> Gemini format)
  142. elif item_type == "image_url":
  143. image_url = item.get("image_url", {})
  144. url = image_url.get("url", "")
  145. # 处理 data URL (data:image/png;base64,...)
  146. if url.startswith("data:"):
  147. # 解析 MIME type 和 base64 数据
  148. # 格式:data:image/png;base64,<base64_data>
  149. try:
  150. header, base64_data = url.split(",", 1)
  151. mime_type = header.split(";")[0].replace("data:", "")
  152. parts.append({
  153. "inline_data": {
  154. "mime_type": mime_type,
  155. "data": base64_data
  156. }
  157. })
  158. except Exception as e:
  159. print(f"[WARNING] Failed to parse image data URL: {e}")
  160. if parts:
  161. gemini_role = "model" if role == "assistant" else "user"
  162. contents.append({
  163. "role": gemini_role,
  164. "parts": parts
  165. })
  166. continue
  167. # 普通文本消息(content 为字符串)
  168. if isinstance(content, str):
  169. # 跳过空消息
  170. if not content.strip():
  171. continue
  172. gemini_role = "model" if role == "assistant" else "user"
  173. contents.append({
  174. "role": gemini_role,
  175. "parts": [{"text": content}]
  176. })
  177. continue
  178. # Flush 剩余的 tool messages
  179. flush_tool_buffer()
  180. # 合并连续的 user 消息(Gemini 要求严格交替)
  181. merged_contents = []
  182. i = 0
  183. while i < len(contents):
  184. current = contents[i]
  185. if current["role"] == "user":
  186. merged_parts = current["parts"].copy()
  187. j = i + 1
  188. while j < len(contents) and contents[j]["role"] == "user":
  189. merged_parts.extend(contents[j]["parts"])
  190. j += 1
  191. merged_contents.append({
  192. "role": "user",
  193. "parts": merged_parts
  194. })
  195. i = j
  196. else:
  197. merged_contents.append(current)
  198. i += 1
  199. return merged_contents, system_instruction
  200. def _convert_tools_to_gemini(tools: List[Dict]) -> List[Dict]:
  201. """
  202. 将 OpenAI 工具格式转换为 Gemini REST API 格式
  203. OpenAI: [{"type": "function", "function": {"name": "...", "parameters": {...}}}]
  204. Gemini API: [{"functionDeclarations": [{"name": "...", "parameters": {...}}]}]
  205. """
  206. if not tools:
  207. return []
  208. function_declarations = []
  209. for tool in tools:
  210. if tool.get("type") == "function":
  211. func = tool.get("function", {})
  212. # 清理不支持的字段
  213. parameters = func.get("parameters", {})
  214. if "properties" in parameters:
  215. cleaned_properties = {}
  216. for prop_name, prop_def in parameters["properties"].items():
  217. # 移除 default 字段
  218. cleaned_prop = {k: v for k, v in prop_def.items() if k != "default"}
  219. cleaned_properties[prop_name] = cleaned_prop
  220. # Gemini API 需要完整的 schema
  221. cleaned_parameters = {
  222. "type": "object",
  223. "properties": cleaned_properties
  224. }
  225. if "required" in parameters:
  226. cleaned_parameters["required"] = parameters["required"]
  227. parameters = cleaned_parameters
  228. function_declarations.append({
  229. "name": func.get("name"),
  230. "description": func.get("description", ""),
  231. "parameters": parameters
  232. })
  233. return [{"functionDeclarations": function_declarations}] if function_declarations else []
  234. def create_gemini_llm_call(
  235. base_url: Optional[str] = None,
  236. api_key: Optional[str] = None
  237. ):
  238. """
  239. 创建 Gemini LLM 调用函数(HTTP API)
  240. Args:
  241. base_url: Gemini API base URL(默认使用 Google 官方)
  242. api_key: API key(默认从环境变量读取)
  243. Returns:
  244. async 函数
  245. """
  246. base_url = base_url or "https://generativelanguage.googleapis.com/v1beta"
  247. api_key = api_key or os.getenv("GEMINI_API_KEY")
  248. if not api_key:
  249. raise ValueError("GEMINI_API_KEY not found")
  250. # 创建 HTTP 客户端
  251. client = httpx.AsyncClient(
  252. headers={"x-goog-api-key": api_key},
  253. timeout=httpx.Timeout(120.0, connect=10.0)
  254. )
  255. async def gemini_llm_call(
  256. messages: List[Dict[str, Any]],
  257. model: str = "gemini-2.5-pro",
  258. tools: Optional[List[Dict]] = None,
  259. **kwargs
  260. ) -> Dict[str, Any]:
  261. """
  262. 调用 Gemini REST API
  263. Args:
  264. messages: OpenAI 格式消息
  265. model: 模型名称
  266. tools: OpenAI 格式工具列表
  267. **kwargs: 其他参数
  268. Returns:
  269. {
  270. "content": str,
  271. "tool_calls": List[Dict] | None,
  272. "prompt_tokens": int,
  273. "completion_tokens": int,
  274. "finish_reason": str,
  275. "cost": float
  276. }
  277. """
  278. # 转换消息
  279. contents, system_instruction = _convert_messages_to_gemini(messages)
  280. print(f"\n[Gemini HTTP] Converted {len(contents)} messages: {[c['role'] for c in contents]}")
  281. # 构建请求
  282. endpoint = f"{base_url}/models/{model}:generateContent"
  283. payload = {"contents": contents}
  284. # 添加 system instruction
  285. if system_instruction:
  286. payload["systemInstruction"] = {"parts": [{"text": system_instruction}]}
  287. # 添加工具
  288. if tools:
  289. gemini_tools = _convert_tools_to_gemini(tools)
  290. if gemini_tools:
  291. payload["tools"] = gemini_tools
  292. # Debug: dump完整请求(需要设置 AGENT_DEBUG=1)
  293. _dump_llm_request(endpoint, payload, model)
  294. # 调用 API
  295. try:
  296. response = await client.post(endpoint, json=payload)
  297. response.raise_for_status()
  298. gemini_resp = response.json()
  299. except httpx.HTTPStatusError as e:
  300. error_body = e.response.text
  301. print(f"[Gemini HTTP] Error {e.response.status_code}: {error_body}")
  302. raise
  303. except Exception as e:
  304. print(f"[Gemini HTTP] Request failed: {e}")
  305. raise
  306. # Debug: 输出原始响应(如果启用)
  307. if os.getenv("AGENT_DEBUG"):
  308. print("\n[AGENT_DEBUG] Gemini Response:", file=sys.stderr)
  309. print(json.dumps(gemini_resp, ensure_ascii=False, indent=2)[:2000], file=sys.stderr)
  310. print("\n", file=sys.stderr)
  311. # 解析响应
  312. content = ""
  313. tool_calls = None
  314. finish_reason = "stop" # 默认值
  315. candidates = gemini_resp.get("candidates", [])
  316. if candidates:
  317. candidate = candidates[0]
  318. # 提取 finish_reason(Gemini -> OpenAI 格式映射)
  319. gemini_finish_reason = candidate.get("finishReason", "STOP")
  320. if gemini_finish_reason == "STOP":
  321. finish_reason = "stop"
  322. elif gemini_finish_reason == "MAX_TOKENS":
  323. finish_reason = "length"
  324. elif gemini_finish_reason in ("SAFETY", "RECITATION"):
  325. finish_reason = "content_filter"
  326. elif gemini_finish_reason == "MALFORMED_FUNCTION_CALL":
  327. finish_reason = "stop" # 映射为 stop,但在 content 中包含错误信息
  328. else:
  329. finish_reason = gemini_finish_reason.lower() # 保持原值,转小写
  330. # 检查是否有错误
  331. if gemini_finish_reason == "MALFORMED_FUNCTION_CALL":
  332. # Gemini 返回了格式错误的函数调用
  333. # 提取 finishMessage 中的内容作为 content
  334. finish_message = candidate.get("finishMessage", "")
  335. print(f"[Gemini HTTP] Warning: MALFORMED_FUNCTION_CALL\n{finish_message}")
  336. content = f"[模型尝试调用工具但格式错误]\n\n{finish_message}"
  337. else:
  338. # 正常解析
  339. parts = candidate.get("content", {}).get("parts", [])
  340. # 提取文本
  341. for part in parts:
  342. if "text" in part:
  343. content += part.get("text", "")
  344. # 提取 functionCall
  345. for i, part in enumerate(parts):
  346. if "functionCall" in part:
  347. if tool_calls is None:
  348. tool_calls = []
  349. fc = part["functionCall"]
  350. name = fc.get("name", "")
  351. args = fc.get("args", {})
  352. tool_calls.append({
  353. "id": f"call_{i}",
  354. "type": "function",
  355. "function": {
  356. "name": name,
  357. "arguments": json.dumps(args, ensure_ascii=False)
  358. }
  359. })
  360. # 提取 usage
  361. usage_meta = gemini_resp.get("usageMetadata", {})
  362. prompt_tokens = usage_meta.get("promptTokenCount", 0)
  363. completion_tokens = usage_meta.get("candidatesTokenCount", 0)
  364. return {
  365. "content": content,
  366. "tool_calls": tool_calls,
  367. "prompt_tokens": prompt_tokens,
  368. "completion_tokens": completion_tokens,
  369. "finish_reason": finish_reason,
  370. "cost": 0.0
  371. }
  372. return gemini_llm_call