gemini.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361
  1. """
  2. Gemini Provider (HTTP API)
  3. 使用 httpx 直接调用 Gemini REST API,避免 google-generativeai SDK 的兼容性问题
  4. 参考:Resonote/llm/providers/gemini.py
  5. """
  6. import os
  7. import json
  8. import httpx
  9. from typing import List, Dict, Any, Optional
  10. def _convert_messages_to_gemini(messages: List[Dict]) -> tuple[List[Dict], Optional[str]]:
  11. """
  12. 将 OpenAI 格式消息转换为 Gemini 格式
  13. Returns:
  14. (gemini_contents, system_instruction)
  15. """
  16. contents = []
  17. system_instruction = None
  18. tool_parts_buffer = []
  19. def flush_tool_buffer():
  20. """合并连续的 tool 消息为单个 user 消息"""
  21. if tool_parts_buffer:
  22. contents.append({
  23. "role": "user",
  24. "parts": tool_parts_buffer.copy()
  25. })
  26. tool_parts_buffer.clear()
  27. for msg in messages:
  28. role = msg.get("role")
  29. # System 消息 -> system_instruction
  30. if role == "system":
  31. system_instruction = msg.get("content", "")
  32. continue
  33. # Tool 消息 -> functionResponse
  34. if role == "tool":
  35. tool_name = msg.get("name")
  36. content_text = msg.get("content", "")
  37. if not tool_name:
  38. print(f"[WARNING] Tool message missing 'name' field, skipping")
  39. continue
  40. # 尝试解析为 JSON
  41. try:
  42. parsed = json.loads(content_text) if content_text else {}
  43. if isinstance(parsed, list):
  44. response_data = {"result": parsed}
  45. else:
  46. response_data = parsed
  47. except (json.JSONDecodeError, ValueError):
  48. response_data = {"result": content_text}
  49. # 添加到 buffer
  50. tool_parts_buffer.append({
  51. "functionResponse": {
  52. "name": tool_name,
  53. "response": response_data
  54. }
  55. })
  56. continue
  57. # 非 tool 消息:先 flush buffer
  58. flush_tool_buffer()
  59. content = msg.get("content", "")
  60. tool_calls = msg.get("tool_calls")
  61. # Assistant 消息 + tool_calls
  62. if role == "assistant" and tool_calls:
  63. parts = []
  64. if content and (isinstance(content, str) and content.strip()):
  65. parts.append({"text": content})
  66. # 转换 tool_calls 为 functionCall
  67. for tc in tool_calls:
  68. func = tc.get("function", {})
  69. func_name = func.get("name", "")
  70. func_args_str = func.get("arguments", "{}")
  71. try:
  72. func_args = json.loads(func_args_str) if isinstance(func_args_str, str) else func_args_str
  73. except json.JSONDecodeError:
  74. func_args = {}
  75. parts.append({
  76. "functionCall": {
  77. "name": func_name,
  78. "args": func_args
  79. }
  80. })
  81. if parts:
  82. contents.append({
  83. "role": "model",
  84. "parts": parts
  85. })
  86. continue
  87. # 处理多模态消息(content 为数组)
  88. if isinstance(content, list):
  89. parts = []
  90. for item in content:
  91. item_type = item.get("type")
  92. # 文本部分
  93. if item_type == "text":
  94. text = item.get("text", "")
  95. if text.strip():
  96. parts.append({"text": text})
  97. # 图片部分(OpenAI format -> Gemini format)
  98. elif item_type == "image_url":
  99. image_url = item.get("image_url", {})
  100. url = image_url.get("url", "")
  101. # 处理 data URL (data:image/png;base64,...)
  102. if url.startswith("data:"):
  103. # 解析 MIME type 和 base64 数据
  104. # 格式:data:image/png;base64,<base64_data>
  105. try:
  106. header, base64_data = url.split(",", 1)
  107. mime_type = header.split(";")[0].replace("data:", "")
  108. parts.append({
  109. "inline_data": {
  110. "mime_type": mime_type,
  111. "data": base64_data
  112. }
  113. })
  114. except Exception as e:
  115. print(f"[WARNING] Failed to parse image data URL: {e}")
  116. if parts:
  117. gemini_role = "model" if role == "assistant" else "user"
  118. contents.append({
  119. "role": gemini_role,
  120. "parts": parts
  121. })
  122. continue
  123. # 普通文本消息(content 为字符串)
  124. if isinstance(content, str):
  125. # 跳过空消息
  126. if not content.strip():
  127. continue
  128. gemini_role = "model" if role == "assistant" else "user"
  129. contents.append({
  130. "role": gemini_role,
  131. "parts": [{"text": content}]
  132. })
  133. continue
  134. # Flush 剩余的 tool messages
  135. flush_tool_buffer()
  136. # 合并连续的 user 消息(Gemini 要求严格交替)
  137. merged_contents = []
  138. i = 0
  139. while i < len(contents):
  140. current = contents[i]
  141. if current["role"] == "user":
  142. merged_parts = current["parts"].copy()
  143. j = i + 1
  144. while j < len(contents) and contents[j]["role"] == "user":
  145. merged_parts.extend(contents[j]["parts"])
  146. j += 1
  147. merged_contents.append({
  148. "role": "user",
  149. "parts": merged_parts
  150. })
  151. i = j
  152. else:
  153. merged_contents.append(current)
  154. i += 1
  155. return merged_contents, system_instruction
  156. def _convert_tools_to_gemini(tools: List[Dict]) -> List[Dict]:
  157. """
  158. 将 OpenAI 工具格式转换为 Gemini REST API 格式
  159. OpenAI: [{"type": "function", "function": {"name": "...", "parameters": {...}}}]
  160. Gemini API: [{"functionDeclarations": [{"name": "...", "parameters": {...}}]}]
  161. """
  162. if not tools:
  163. return []
  164. function_declarations = []
  165. for tool in tools:
  166. if tool.get("type") == "function":
  167. func = tool.get("function", {})
  168. # 清理不支持的字段
  169. parameters = func.get("parameters", {})
  170. if "properties" in parameters:
  171. cleaned_properties = {}
  172. for prop_name, prop_def in parameters["properties"].items():
  173. # 移除 default 字段
  174. cleaned_prop = {k: v for k, v in prop_def.items() if k != "default"}
  175. cleaned_properties[prop_name] = cleaned_prop
  176. # Gemini API 需要完整的 schema
  177. cleaned_parameters = {
  178. "type": "object",
  179. "properties": cleaned_properties
  180. }
  181. if "required" in parameters:
  182. cleaned_parameters["required"] = parameters["required"]
  183. parameters = cleaned_parameters
  184. function_declarations.append({
  185. "name": func.get("name"),
  186. "description": func.get("description", ""),
  187. "parameters": parameters
  188. })
  189. return [{"functionDeclarations": function_declarations}] if function_declarations else []
  190. def create_gemini_llm_call(
  191. base_url: Optional[str] = None,
  192. api_key: Optional[str] = None
  193. ):
  194. """
  195. 创建 Gemini LLM 调用函数(HTTP API)
  196. Args:
  197. base_url: Gemini API base URL(默认使用 Google 官方)
  198. api_key: API key(默认从环境变量读取)
  199. Returns:
  200. async 函数
  201. """
  202. base_url = base_url or "https://generativelanguage.googleapis.com/v1beta"
  203. api_key = api_key or os.getenv("GEMINI_API_KEY")
  204. if not api_key:
  205. raise ValueError("GEMINI_API_KEY not found")
  206. # 创建 HTTP 客户端
  207. client = httpx.AsyncClient(
  208. headers={"x-goog-api-key": api_key},
  209. timeout=httpx.Timeout(120.0, connect=10.0)
  210. )
  211. async def gemini_llm_call(
  212. messages: List[Dict[str, Any]],
  213. model: str = "gemini-2.5-pro",
  214. tools: Optional[List[Dict]] = None,
  215. **kwargs
  216. ) -> Dict[str, Any]:
  217. """
  218. 调用 Gemini REST API
  219. Args:
  220. messages: OpenAI 格式消息
  221. model: 模型名称
  222. tools: OpenAI 格式工具列表
  223. **kwargs: 其他参数
  224. Returns:
  225. {
  226. "content": str,
  227. "tool_calls": List[Dict] | None,
  228. "prompt_tokens": int,
  229. "completion_tokens": int,
  230. "cost": float
  231. }
  232. """
  233. # 转换消息
  234. contents, system_instruction = _convert_messages_to_gemini(messages)
  235. print(f"\n[Gemini HTTP] Converted {len(contents)} messages: {[c['role'] for c in contents]}")
  236. # 构建请求
  237. endpoint = f"{base_url}/models/{model}:generateContent"
  238. payload = {"contents": contents}
  239. # 添加 system instruction
  240. if system_instruction:
  241. payload["systemInstruction"] = {"parts": [{"text": system_instruction}]}
  242. # 添加工具
  243. if tools:
  244. gemini_tools = _convert_tools_to_gemini(tools)
  245. if gemini_tools:
  246. payload["tools"] = gemini_tools
  247. # 调用 API
  248. try:
  249. response = await client.post(endpoint, json=payload)
  250. response.raise_for_status()
  251. gemini_resp = response.json()
  252. except httpx.HTTPStatusError as e:
  253. error_body = e.response.text
  254. print(f"[Gemini HTTP] Error {e.response.status_code}: {error_body}")
  255. raise
  256. except Exception as e:
  257. print(f"[Gemini HTTP] Request failed: {e}")
  258. raise
  259. # 解析响应
  260. content = ""
  261. tool_calls = None
  262. candidates = gemini_resp.get("candidates", [])
  263. if candidates:
  264. parts = candidates[0].get("content", {}).get("parts", [])
  265. # 提取文本
  266. for part in parts:
  267. if "text" in part:
  268. content += part.get("text", "")
  269. # 提取 functionCall
  270. for i, part in enumerate(parts):
  271. if "functionCall" in part:
  272. if tool_calls is None:
  273. tool_calls = []
  274. fc = part["functionCall"]
  275. name = fc.get("name", "")
  276. args = fc.get("args", {})
  277. tool_calls.append({
  278. "id": f"call_{i}",
  279. "type": "function",
  280. "function": {
  281. "name": name,
  282. "arguments": json.dumps(args, ensure_ascii=False)
  283. }
  284. })
  285. # 提取 usage
  286. usage_meta = gemini_resp.get("usageMetadata", {})
  287. prompt_tokens = usage_meta.get("promptTokenCount", 0)
  288. completion_tokens = usage_meta.get("candidatesTokenCount", 0)
  289. return {
  290. "content": content,
  291. "tool_calls": tool_calls,
  292. "prompt_tokens": prompt_tokens,
  293. "completion_tokens": completion_tokens,
  294. "cost": 0.0
  295. }
  296. return gemini_llm_call