gemini.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464
  1. """
  2. Gemini Provider (HTTP API)
  3. 使用 httpx 直接调用 Gemini REST API,避免 google-generativeai SDK 的兼容性问题
  4. 参考:Resonote/llm/providers/gemini.py
  5. """
  6. import os
  7. import json
  8. import sys
  9. import httpx
  10. from typing import List, Dict, Any, Optional
  11. from .usage import TokenUsage
  12. from .pricing import calculate_cost
  13. def _dump_llm_request(endpoint: str, payload: Dict[str, Any], model: str):
  14. """
  15. Dump完整的LLM请求用于调试(需要设置 AGENT_DEBUG=1)
  16. 特别处理:
  17. - 图片base64数据:只显示前50字符 + 长度信息
  18. - Tools schema:完整显示
  19. - 输出到stderr,避免污染正常输出
  20. """
  21. if not os.getenv("AGENT_DEBUG"):
  22. return
  23. def truncate_images(obj):
  24. """递归处理对象,truncate图片base64数据"""
  25. if isinstance(obj, dict):
  26. result = {}
  27. for key, value in obj.items():
  28. # 处理 inline_data 中的 base64 图片
  29. if key == "inline_data" and isinstance(value, dict):
  30. mime_type = value.get("mime_type", "unknown")
  31. data = value.get("data", "")
  32. data_size_kb = len(data) / 1024 if data else 0
  33. result[key] = {
  34. "mime_type": mime_type,
  35. "data": f"<BASE64_IMAGE: {data_size_kb:.1f}KB, preview: {data[:50]}...>"
  36. }
  37. else:
  38. result[key] = truncate_images(value)
  39. return result
  40. elif isinstance(obj, list):
  41. return [truncate_images(item) for item in obj]
  42. else:
  43. return obj
  44. # 构造完整的调试信息
  45. debug_info = {
  46. "endpoint": endpoint,
  47. "model": model,
  48. "payload": truncate_images(payload)
  49. }
  50. # 输出到stderr
  51. print("\n" + "="*80, file=sys.stderr)
  52. print("[AGENT_DEBUG] LLM Request Dump", file=sys.stderr)
  53. print("="*80, file=sys.stderr)
  54. print(json.dumps(debug_info, indent=2, ensure_ascii=False), file=sys.stderr)
  55. print("="*80 + "\n", file=sys.stderr)
  56. def _convert_messages_to_gemini(messages: List[Dict]) -> tuple[List[Dict], Optional[str]]:
  57. """
  58. 将 OpenAI 格式消息转换为 Gemini 格式
  59. Returns:
  60. (gemini_contents, system_instruction)
  61. """
  62. contents = []
  63. system_instruction = None
  64. tool_parts_buffer = []
  65. def flush_tool_buffer():
  66. """合并连续的 tool 消息为单个 user 消息"""
  67. if tool_parts_buffer:
  68. contents.append({
  69. "role": "user",
  70. "parts": tool_parts_buffer.copy()
  71. })
  72. tool_parts_buffer.clear()
  73. for msg in messages:
  74. role = msg.get("role")
  75. # System 消息 -> system_instruction
  76. if role == "system":
  77. system_instruction = msg.get("content", "")
  78. continue
  79. # Tool 消息 -> functionResponse
  80. if role == "tool":
  81. tool_name = msg.get("name")
  82. content_text = msg.get("content", "")
  83. if not tool_name:
  84. print(f"[WARNING] Tool message missing 'name' field, skipping")
  85. continue
  86. # 尝试解析为 JSON
  87. try:
  88. parsed = json.loads(content_text) if content_text else {}
  89. if isinstance(parsed, list):
  90. response_data = {"result": parsed}
  91. else:
  92. response_data = parsed
  93. except (json.JSONDecodeError, ValueError):
  94. response_data = {"result": content_text}
  95. # 添加到 buffer
  96. tool_parts_buffer.append({
  97. "functionResponse": {
  98. "name": tool_name,
  99. "response": response_data
  100. }
  101. })
  102. continue
  103. # 非 tool 消息:先 flush buffer
  104. flush_tool_buffer()
  105. content = msg.get("content", "")
  106. tool_calls = msg.get("tool_calls")
  107. raw_gemini_parts = msg.get("raw_gemini_parts")
  108. # Assistant 消息 + tool_calls
  109. if role == "assistant" and tool_calls:
  110. if raw_gemini_parts:
  111. contents.append({
  112. "role": "model",
  113. "parts": raw_gemini_parts
  114. })
  115. continue
  116. parts = []
  117. if content and (isinstance(content, str) and content.strip()):
  118. parts.append({"text": content})
  119. # 转换 tool_calls 为 functionCall
  120. for tc in tool_calls:
  121. func = tc.get("function", {})
  122. func_name = func.get("name", "")
  123. func_args_str = func.get("arguments", "{}")
  124. try:
  125. func_args = json.loads(func_args_str) if isinstance(func_args_str, str) else func_args_str
  126. except json.JSONDecodeError:
  127. func_args = {}
  128. parts.append({
  129. "functionCall": {
  130. "name": func_name,
  131. "args": func_args
  132. }
  133. })
  134. if parts:
  135. contents.append({
  136. "role": "model",
  137. "parts": parts
  138. })
  139. continue
  140. # 处理多模态消息(content 为数组)
  141. if isinstance(content, list):
  142. parts = []
  143. for item in content:
  144. item_type = item.get("type")
  145. # 文本部分
  146. if item_type == "text":
  147. text = item.get("text", "")
  148. if text.strip():
  149. parts.append({"text": text})
  150. # 图片部分(OpenAI format -> Gemini format)
  151. elif item_type == "image_url":
  152. image_url = item.get("image_url", {})
  153. url = image_url.get("url", "")
  154. # 处理 data URL (data:image/png;base64,...)
  155. if url.startswith("data:"):
  156. # 解析 MIME type 和 base64 数据
  157. # 格式:data:image/png;base64,<base64_data>
  158. try:
  159. header, base64_data = url.split(",", 1)
  160. mime_type = header.split(";")[0].replace("data:", "")
  161. parts.append({
  162. "inline_data": {
  163. "mime_type": mime_type,
  164. "data": base64_data
  165. }
  166. })
  167. except Exception as e:
  168. print(f"[WARNING] Failed to parse image data URL: {e}")
  169. if parts:
  170. gemini_role = "model" if role == "assistant" else "user"
  171. contents.append({
  172. "role": gemini_role,
  173. "parts": parts
  174. })
  175. continue
  176. # 普通文本消息(content 为字符串)
  177. if isinstance(content, str):
  178. # 跳过空消息
  179. if not content.strip():
  180. continue
  181. gemini_role = "model" if role == "assistant" else "user"
  182. contents.append({
  183. "role": gemini_role,
  184. "parts": [{"text": content}]
  185. })
  186. continue
  187. # Flush 剩余的 tool messages
  188. flush_tool_buffer()
  189. # 合并连续的 user 消息(Gemini 要求严格交替)
  190. merged_contents = []
  191. i = 0
  192. while i < len(contents):
  193. current = contents[i]
  194. if current["role"] == "user":
  195. merged_parts = current["parts"].copy()
  196. j = i + 1
  197. while j < len(contents) and contents[j]["role"] == "user":
  198. merged_parts.extend(contents[j]["parts"])
  199. j += 1
  200. merged_contents.append({
  201. "role": "user",
  202. "parts": merged_parts
  203. })
  204. i = j
  205. else:
  206. merged_contents.append(current)
  207. i += 1
  208. return merged_contents, system_instruction
  209. def _convert_tools_to_gemini(tools: List[Dict]) -> List[Dict]:
  210. """
  211. 将 OpenAI 工具格式转换为 Gemini REST API 格式
  212. OpenAI: [{"type": "function", "function": {"name": "...", "parameters": {...}}}]
  213. Gemini API: [{"functionDeclarations": [{"name": "...", "parameters": {...}}]}]
  214. """
  215. if not tools:
  216. return []
  217. function_declarations = []
  218. for tool in tools:
  219. if tool.get("type") == "function":
  220. func = tool.get("function", {})
  221. # 清理不支持的字段
  222. parameters = func.get("parameters", {})
  223. if "properties" in parameters:
  224. cleaned_properties = {}
  225. for prop_name, prop_def in parameters["properties"].items():
  226. # 移除 default 字段
  227. cleaned_prop = {k: v for k, v in prop_def.items() if k != "default"}
  228. cleaned_properties[prop_name] = cleaned_prop
  229. # Gemini API 需要完整的 schema
  230. cleaned_parameters = {
  231. "type": "object",
  232. "properties": cleaned_properties
  233. }
  234. if "required" in parameters:
  235. cleaned_parameters["required"] = parameters["required"]
  236. parameters = cleaned_parameters
  237. function_declarations.append({
  238. "name": func.get("name"),
  239. "description": func.get("description", ""),
  240. "parameters": parameters
  241. })
  242. return [{"functionDeclarations": function_declarations}] if function_declarations else []
  243. def create_gemini_llm_call(
  244. base_url: Optional[str] = None,
  245. api_key: Optional[str] = None
  246. ):
  247. """
  248. 创建 Gemini LLM 调用函数(HTTP API)
  249. Args:
  250. base_url: Gemini API base URL(默认使用 Google 官方)
  251. api_key: API key(默认从环境变量读取)
  252. Returns:
  253. async 函数
  254. """
  255. base_url = base_url or "https://generativelanguage.googleapis.com/v1beta"
  256. api_key = api_key or os.getenv("GEMINI_API_KEY")
  257. if not api_key:
  258. raise ValueError("GEMINI_API_KEY not found")
  259. # 创建 HTTP 客户端
  260. client = httpx.AsyncClient(
  261. headers={"x-goog-api-key": api_key},
  262. timeout=httpx.Timeout(120.0, connect=10.0)
  263. )
  264. async def gemini_llm_call(
  265. messages: List[Dict[str, Any]],
  266. model: str = "gemini-2.5-pro",
  267. tools: Optional[List[Dict]] = None,
  268. **kwargs
  269. ) -> Dict[str, Any]:
  270. """
  271. 调用 Gemini REST API
  272. Args:
  273. messages: OpenAI 格式消息
  274. model: 模型名称
  275. tools: OpenAI 格式工具列表
  276. **kwargs: 其他参数
  277. Returns:
  278. {
  279. "content": str,
  280. "tool_calls": List[Dict] | None,
  281. "prompt_tokens": int,
  282. "completion_tokens": int,
  283. "finish_reason": str,
  284. "cost": float
  285. }
  286. """
  287. # 转换消息
  288. contents, system_instruction = _convert_messages_to_gemini(messages)
  289. print(f"\n[Gemini HTTP] Converted {len(contents)} messages: {[c['role'] for c in contents]}")
  290. # 构建请求
  291. endpoint = f"{base_url}/models/{model}:generateContent"
  292. payload = {"contents": contents}
  293. # 添加 system instruction
  294. if system_instruction:
  295. payload["systemInstruction"] = {"parts": [{"text": system_instruction}]}
  296. # 添加工具
  297. if tools:
  298. gemini_tools = _convert_tools_to_gemini(tools)
  299. if gemini_tools:
  300. payload["tools"] = gemini_tools
  301. # Debug: dump完整请求(需要设置 AGENT_DEBUG=1)
  302. _dump_llm_request(endpoint, payload, model)
  303. # 调用 API
  304. try:
  305. response = await client.post(endpoint, json=payload)
  306. response.raise_for_status()
  307. gemini_resp = response.json()
  308. except httpx.HTTPStatusError as e:
  309. error_body = e.response.text
  310. print(f"[Gemini HTTP] Error {e.response.status_code}: {error_body}")
  311. raise
  312. except Exception as e:
  313. print(f"[Gemini HTTP] Request failed: {e}")
  314. raise
  315. # Debug: 输出原始响应(如果启用)
  316. if os.getenv("AGENT_DEBUG"):
  317. print("\n[AGENT_DEBUG] Gemini Response:", file=sys.stderr)
  318. print(json.dumps(gemini_resp, ensure_ascii=False, indent=2)[:2000], file=sys.stderr)
  319. print("\n", file=sys.stderr)
  320. # 解析响应
  321. content = ""
  322. tool_calls = None
  323. finish_reason = "stop" # 默认值
  324. candidates = gemini_resp.get("candidates", [])
  325. if candidates:
  326. candidate = candidates[0]
  327. # 提取 finish_reason(Gemini -> OpenAI 格式映射)
  328. gemini_finish_reason = candidate.get("finishReason", "STOP")
  329. if gemini_finish_reason == "STOP":
  330. finish_reason = "stop"
  331. elif gemini_finish_reason == "MAX_TOKENS":
  332. finish_reason = "length"
  333. elif gemini_finish_reason in ("SAFETY", "RECITATION"):
  334. finish_reason = "content_filter"
  335. elif gemini_finish_reason == "MALFORMED_FUNCTION_CALL":
  336. finish_reason = "stop" # 映射为 stop,但在 content 中包含错误信息
  337. else:
  338. finish_reason = gemini_finish_reason.lower() # 保持原值,转小写
  339. # 检查是否有错误
  340. if gemini_finish_reason == "MALFORMED_FUNCTION_CALL":
  341. # Gemini 返回了格式错误的函数调用
  342. # 提取 finishMessage 中的内容作为 content
  343. finish_message = candidate.get("finishMessage", "")
  344. print(f"[Gemini HTTP] Warning: MALFORMED_FUNCTION_CALL\n{finish_message}")
  345. content = f"[模型尝试调用工具但格式错误]\n\n{finish_message}"
  346. else:
  347. # 正常解析
  348. parts = candidate.get("content", {}).get("parts", [])
  349. # 提取文本
  350. for part in parts:
  351. if "text" in part:
  352. content += part.get("text", "")
  353. # 提取 functionCall
  354. for i, part in enumerate(parts):
  355. if "functionCall" in part:
  356. if tool_calls is None:
  357. tool_calls = []
  358. fc = part["functionCall"]
  359. name = fc.get("name", "")
  360. args = fc.get("args", {})
  361. tool_calls.append({
  362. "id": f"call_{i}",
  363. "type": "function",
  364. "function": {
  365. "name": name,
  366. "arguments": json.dumps(args, ensure_ascii=False)
  367. }
  368. })
  369. # 提取 usage(完整版)
  370. usage_meta = gemini_resp.get("usageMetadata", {})
  371. usage = TokenUsage.from_gemini(usage_meta)
  372. # 计算费用
  373. cost = calculate_cost(model, usage)
  374. return {
  375. "content": content,
  376. "tool_calls": tool_calls,
  377. "raw_gemini_parts": candidate.get("content", {}).get("parts", []) if candidates else [],
  378. "prompt_tokens": usage.input_tokens,
  379. "completion_tokens": usage.output_tokens,
  380. "reasoning_tokens": usage.reasoning_tokens,
  381. "cached_content_tokens": usage.cached_content_tokens,
  382. "finish_reason": finish_reason,
  383. "cost": cost,
  384. "usage": usage, # 完整的 TokenUsage 对象
  385. }
  386. return gemini_llm_call