gemini.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455
  1. """
  2. Gemini Provider (HTTP API)
  3. 使用 httpx 直接调用 Gemini REST API,避免 google-generativeai SDK 的兼容性问题
  4. 参考:Resonote/llm/providers/gemini.py
  5. """
  6. import os
  7. import json
  8. import sys
  9. import httpx
  10. from typing import List, Dict, Any, Optional
  11. from .usage import TokenUsage
  12. from .pricing import calculate_cost
  13. def _dump_llm_request(endpoint: str, payload: Dict[str, Any], model: str):
  14. """
  15. Dump完整的LLM请求用于调试(需要设置 AGENT_DEBUG=1)
  16. 特别处理:
  17. - 图片base64数据:只显示前50字符 + 长度信息
  18. - Tools schema:完整显示
  19. - 输出到stderr,避免污染正常输出
  20. """
  21. if not os.getenv("AGENT_DEBUG"):
  22. return
  23. def truncate_images(obj):
  24. """递归处理对象,truncate图片base64数据"""
  25. if isinstance(obj, dict):
  26. result = {}
  27. for key, value in obj.items():
  28. # 处理 inline_data 中的 base64 图片
  29. if key == "inline_data" and isinstance(value, dict):
  30. mime_type = value.get("mime_type", "unknown")
  31. data = value.get("data", "")
  32. data_size_kb = len(data) / 1024 if data else 0
  33. result[key] = {
  34. "mime_type": mime_type,
  35. "data": f"<BASE64_IMAGE: {data_size_kb:.1f}KB, preview: {data[:50]}...>"
  36. }
  37. else:
  38. result[key] = truncate_images(value)
  39. return result
  40. elif isinstance(obj, list):
  41. return [truncate_images(item) for item in obj]
  42. else:
  43. return obj
  44. # 构造完整的调试信息
  45. debug_info = {
  46. "endpoint": endpoint,
  47. "model": model,
  48. "payload": truncate_images(payload)
  49. }
  50. # 输出到stderr
  51. print("\n" + "="*80, file=sys.stderr)
  52. print("[AGENT_DEBUG] LLM Request Dump", file=sys.stderr)
  53. print("="*80, file=sys.stderr)
  54. print(json.dumps(debug_info, indent=2, ensure_ascii=False), file=sys.stderr)
  55. print("="*80 + "\n", file=sys.stderr)
  56. def _convert_messages_to_gemini(messages: List[Dict]) -> tuple[List[Dict], Optional[str]]:
  57. """
  58. 将 OpenAI 格式消息转换为 Gemini 格式
  59. Returns:
  60. (gemini_contents, system_instruction)
  61. """
  62. contents = []
  63. system_instruction = None
  64. tool_parts_buffer = []
  65. def flush_tool_buffer():
  66. """合并连续的 tool 消息为单个 user 消息"""
  67. if tool_parts_buffer:
  68. contents.append({
  69. "role": "user",
  70. "parts": tool_parts_buffer.copy()
  71. })
  72. tool_parts_buffer.clear()
  73. for msg in messages:
  74. role = msg.get("role")
  75. # System 消息 -> system_instruction
  76. if role == "system":
  77. system_instruction = msg.get("content", "")
  78. continue
  79. # Tool 消息 -> functionResponse
  80. if role == "tool":
  81. tool_name = msg.get("name")
  82. content_text = msg.get("content", "")
  83. if not tool_name:
  84. print(f"[WARNING] Tool message missing 'name' field, skipping")
  85. continue
  86. # 尝试解析为 JSON
  87. try:
  88. parsed = json.loads(content_text) if content_text else {}
  89. if isinstance(parsed, list):
  90. response_data = {"result": parsed}
  91. else:
  92. response_data = parsed
  93. except (json.JSONDecodeError, ValueError):
  94. response_data = {"result": content_text}
  95. # 添加到 buffer
  96. tool_parts_buffer.append({
  97. "functionResponse": {
  98. "name": tool_name,
  99. "response": response_data
  100. }
  101. })
  102. continue
  103. # 非 tool 消息:先 flush buffer
  104. flush_tool_buffer()
  105. content = msg.get("content", "")
  106. tool_calls = msg.get("tool_calls")
  107. # Assistant 消息 + tool_calls
  108. if role == "assistant" and tool_calls:
  109. parts = []
  110. if content and (isinstance(content, str) and content.strip()):
  111. parts.append({"text": content})
  112. # 转换 tool_calls 为 functionCall
  113. for tc in tool_calls:
  114. func = tc.get("function", {})
  115. func_name = func.get("name", "")
  116. func_args_str = func.get("arguments", "{}")
  117. try:
  118. func_args = json.loads(func_args_str) if isinstance(func_args_str, str) else func_args_str
  119. except json.JSONDecodeError:
  120. func_args = {}
  121. parts.append({
  122. "functionCall": {
  123. "name": func_name,
  124. "args": func_args
  125. }
  126. })
  127. if parts:
  128. contents.append({
  129. "role": "model",
  130. "parts": parts
  131. })
  132. continue
  133. # 处理多模态消息(content 为数组)
  134. if isinstance(content, list):
  135. parts = []
  136. for item in content:
  137. item_type = item.get("type")
  138. # 文本部分
  139. if item_type == "text":
  140. text = item.get("text", "")
  141. if text.strip():
  142. parts.append({"text": text})
  143. # 图片部分(OpenAI format -> Gemini format)
  144. elif item_type == "image_url":
  145. image_url = item.get("image_url", {})
  146. url = image_url.get("url", "")
  147. # 处理 data URL (data:image/png;base64,...)
  148. if url.startswith("data:"):
  149. # 解析 MIME type 和 base64 数据
  150. # 格式:data:image/png;base64,<base64_data>
  151. try:
  152. header, base64_data = url.split(",", 1)
  153. mime_type = header.split(";")[0].replace("data:", "")
  154. parts.append({
  155. "inline_data": {
  156. "mime_type": mime_type,
  157. "data": base64_data
  158. }
  159. })
  160. except Exception as e:
  161. print(f"[WARNING] Failed to parse image data URL: {e}")
  162. if parts:
  163. gemini_role = "model" if role == "assistant" else "user"
  164. contents.append({
  165. "role": gemini_role,
  166. "parts": parts
  167. })
  168. continue
  169. # 普通文本消息(content 为字符串)
  170. if isinstance(content, str):
  171. # 跳过空消息
  172. if not content.strip():
  173. continue
  174. gemini_role = "model" if role == "assistant" else "user"
  175. contents.append({
  176. "role": gemini_role,
  177. "parts": [{"text": content}]
  178. })
  179. continue
  180. # Flush 剩余的 tool messages
  181. flush_tool_buffer()
  182. # 合并连续的 user 消息(Gemini 要求严格交替)
  183. merged_contents = []
  184. i = 0
  185. while i < len(contents):
  186. current = contents[i]
  187. if current["role"] == "user":
  188. merged_parts = current["parts"].copy()
  189. j = i + 1
  190. while j < len(contents) and contents[j]["role"] == "user":
  191. merged_parts.extend(contents[j]["parts"])
  192. j += 1
  193. merged_contents.append({
  194. "role": "user",
  195. "parts": merged_parts
  196. })
  197. i = j
  198. else:
  199. merged_contents.append(current)
  200. i += 1
  201. return merged_contents, system_instruction
  202. def _convert_tools_to_gemini(tools: List[Dict]) -> List[Dict]:
  203. """
  204. 将 OpenAI 工具格式转换为 Gemini REST API 格式
  205. OpenAI: [{"type": "function", "function": {"name": "...", "parameters": {...}}}]
  206. Gemini API: [{"functionDeclarations": [{"name": "...", "parameters": {...}}]}]
  207. """
  208. if not tools:
  209. return []
  210. function_declarations = []
  211. for tool in tools:
  212. if tool.get("type") == "function":
  213. func = tool.get("function", {})
  214. # 清理不支持的字段
  215. parameters = func.get("parameters", {})
  216. if "properties" in parameters:
  217. cleaned_properties = {}
  218. for prop_name, prop_def in parameters["properties"].items():
  219. # 移除 default 字段
  220. cleaned_prop = {k: v for k, v in prop_def.items() if k != "default"}
  221. cleaned_properties[prop_name] = cleaned_prop
  222. # Gemini API 需要完整的 schema
  223. cleaned_parameters = {
  224. "type": "object",
  225. "properties": cleaned_properties
  226. }
  227. if "required" in parameters:
  228. cleaned_parameters["required"] = parameters["required"]
  229. parameters = cleaned_parameters
  230. function_declarations.append({
  231. "name": func.get("name"),
  232. "description": func.get("description", ""),
  233. "parameters": parameters
  234. })
  235. return [{"functionDeclarations": function_declarations}] if function_declarations else []
  236. def create_gemini_llm_call(
  237. base_url: Optional[str] = None,
  238. api_key: Optional[str] = None
  239. ):
  240. """
  241. 创建 Gemini LLM 调用函数(HTTP API)
  242. Args:
  243. base_url: Gemini API base URL(默认使用 Google 官方)
  244. api_key: API key(默认从环境变量读取)
  245. Returns:
  246. async 函数
  247. """
  248. base_url = base_url or "https://generativelanguage.googleapis.com/v1beta"
  249. api_key = api_key or os.getenv("GEMINI_API_KEY")
  250. if not api_key:
  251. raise ValueError("GEMINI_API_KEY not found")
  252. # 创建 HTTP 客户端
  253. client = httpx.AsyncClient(
  254. headers={"x-goog-api-key": api_key},
  255. timeout=httpx.Timeout(120.0, connect=10.0)
  256. )
  257. async def gemini_llm_call(
  258. messages: List[Dict[str, Any]],
  259. model: str = "gemini-2.5-pro",
  260. tools: Optional[List[Dict]] = None,
  261. **kwargs
  262. ) -> Dict[str, Any]:
  263. """
  264. 调用 Gemini REST API
  265. Args:
  266. messages: OpenAI 格式消息
  267. model: 模型名称
  268. tools: OpenAI 格式工具列表
  269. **kwargs: 其他参数
  270. Returns:
  271. {
  272. "content": str,
  273. "tool_calls": List[Dict] | None,
  274. "prompt_tokens": int,
  275. "completion_tokens": int,
  276. "finish_reason": str,
  277. "cost": float
  278. }
  279. """
  280. # 转换消息
  281. contents, system_instruction = _convert_messages_to_gemini(messages)
  282. print(f"\n[Gemini HTTP] Converted {len(contents)} messages: {[c['role'] for c in contents]}")
  283. # 构建请求
  284. endpoint = f"{base_url}/models/{model}:generateContent"
  285. payload = {"contents": contents}
  286. # 添加 system instruction
  287. if system_instruction:
  288. payload["systemInstruction"] = {"parts": [{"text": system_instruction}]}
  289. # 添加工具
  290. if tools:
  291. gemini_tools = _convert_tools_to_gemini(tools)
  292. if gemini_tools:
  293. payload["tools"] = gemini_tools
  294. # Debug: dump完整请求(需要设置 AGENT_DEBUG=1)
  295. _dump_llm_request(endpoint, payload, model)
  296. # 调用 API
  297. try:
  298. response = await client.post(endpoint, json=payload)
  299. response.raise_for_status()
  300. gemini_resp = response.json()
  301. except httpx.HTTPStatusError as e:
  302. error_body = e.response.text
  303. print(f"[Gemini HTTP] Error {e.response.status_code}: {error_body}")
  304. raise
  305. except Exception as e:
  306. print(f"[Gemini HTTP] Request failed: {e}")
  307. raise
  308. # Debug: 输出原始响应(如果启用)
  309. if os.getenv("AGENT_DEBUG"):
  310. print("\n[AGENT_DEBUG] Gemini Response:", file=sys.stderr)
  311. print(json.dumps(gemini_resp, ensure_ascii=False, indent=2)[:2000], file=sys.stderr)
  312. print("\n", file=sys.stderr)
  313. # 解析响应
  314. content = ""
  315. tool_calls = None
  316. finish_reason = "stop" # 默认值
  317. candidates = gemini_resp.get("candidates", [])
  318. if candidates:
  319. candidate = candidates[0]
  320. # 提取 finish_reason(Gemini -> OpenAI 格式映射)
  321. gemini_finish_reason = candidate.get("finishReason", "STOP")
  322. if gemini_finish_reason == "STOP":
  323. finish_reason = "stop"
  324. elif gemini_finish_reason == "MAX_TOKENS":
  325. finish_reason = "length"
  326. elif gemini_finish_reason in ("SAFETY", "RECITATION"):
  327. finish_reason = "content_filter"
  328. elif gemini_finish_reason == "MALFORMED_FUNCTION_CALL":
  329. finish_reason = "stop" # 映射为 stop,但在 content 中包含错误信息
  330. else:
  331. finish_reason = gemini_finish_reason.lower() # 保持原值,转小写
  332. # 检查是否有错误
  333. if gemini_finish_reason == "MALFORMED_FUNCTION_CALL":
  334. # Gemini 返回了格式错误的函数调用
  335. # 提取 finishMessage 中的内容作为 content
  336. finish_message = candidate.get("finishMessage", "")
  337. print(f"[Gemini HTTP] Warning: MALFORMED_FUNCTION_CALL\n{finish_message}")
  338. content = f"[模型尝试调用工具但格式错误]\n\n{finish_message}"
  339. else:
  340. # 正常解析
  341. parts = candidate.get("content", {}).get("parts", [])
  342. # 提取文本
  343. for part in parts:
  344. if "text" in part:
  345. content += part.get("text", "")
  346. # 提取 functionCall
  347. for i, part in enumerate(parts):
  348. if "functionCall" in part:
  349. if tool_calls is None:
  350. tool_calls = []
  351. fc = part["functionCall"]
  352. name = fc.get("name", "")
  353. args = fc.get("args", {})
  354. tool_calls.append({
  355. "id": f"call_{i}",
  356. "type": "function",
  357. "function": {
  358. "name": name,
  359. "arguments": json.dumps(args, ensure_ascii=False)
  360. }
  361. })
  362. # 提取 usage(完整版)
  363. usage_meta = gemini_resp.get("usageMetadata", {})
  364. usage = TokenUsage.from_gemini(usage_meta)
  365. # 计算费用
  366. cost = calculate_cost(model, usage)
  367. return {
  368. "content": content,
  369. "tool_calls": tool_calls,
  370. "prompt_tokens": usage.input_tokens,
  371. "completion_tokens": usage.output_tokens,
  372. "reasoning_tokens": usage.reasoning_tokens,
  373. "cached_content_tokens": usage.cached_content_tokens,
  374. "finish_reason": finish_reason,
  375. "cost": cost,
  376. "usage": usage, # 完整的 TokenUsage 对象
  377. }
  378. return gemini_llm_call