gemini.py 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317
  1. """
  2. Gemini Provider (HTTP API)
  3. 使用 httpx 直接调用 Gemini REST API,避免 google-generativeai SDK 的兼容性问题
  4. 参考:Resonote/llm/providers/gemini.py
  5. """
  6. import os
  7. import json
  8. import httpx
  9. from typing import List, Dict, Any, Optional
  10. def _convert_messages_to_gemini(messages: List[Dict]) -> tuple[List[Dict], Optional[str]]:
  11. """
  12. 将 OpenAI 格式消息转换为 Gemini 格式
  13. Returns:
  14. (gemini_contents, system_instruction)
  15. """
  16. contents = []
  17. system_instruction = None
  18. tool_parts_buffer = []
  19. def flush_tool_buffer():
  20. """合并连续的 tool 消息为单个 user 消息"""
  21. if tool_parts_buffer:
  22. contents.append({
  23. "role": "user",
  24. "parts": tool_parts_buffer.copy()
  25. })
  26. tool_parts_buffer.clear()
  27. for msg in messages:
  28. role = msg.get("role")
  29. # System 消息 -> system_instruction
  30. if role == "system":
  31. system_instruction = msg.get("content", "")
  32. continue
  33. # Tool 消息 -> functionResponse
  34. if role == "tool":
  35. tool_name = msg.get("name")
  36. content_text = msg.get("content", "")
  37. if not tool_name:
  38. print(f"[WARNING] Tool message missing 'name' field, skipping")
  39. continue
  40. # 尝试解析为 JSON
  41. try:
  42. parsed = json.loads(content_text) if content_text else {}
  43. if isinstance(parsed, list):
  44. response_data = {"result": parsed}
  45. else:
  46. response_data = parsed
  47. except (json.JSONDecodeError, ValueError):
  48. response_data = {"result": content_text}
  49. # 添加到 buffer
  50. tool_parts_buffer.append({
  51. "functionResponse": {
  52. "name": tool_name,
  53. "response": response_data
  54. }
  55. })
  56. continue
  57. # 非 tool 消息:先 flush buffer
  58. flush_tool_buffer()
  59. content_text = msg.get("content", "")
  60. tool_calls = msg.get("tool_calls")
  61. # Assistant 消息 + tool_calls
  62. if role == "assistant" and tool_calls:
  63. parts = []
  64. if content_text and content_text.strip():
  65. parts.append({"text": content_text})
  66. # 转换 tool_calls 为 functionCall
  67. for tc in tool_calls:
  68. func = tc.get("function", {})
  69. func_name = func.get("name", "")
  70. func_args_str = func.get("arguments", "{}")
  71. try:
  72. func_args = json.loads(func_args_str) if isinstance(func_args_str, str) else func_args_str
  73. except json.JSONDecodeError:
  74. func_args = {}
  75. parts.append({
  76. "functionCall": {
  77. "name": func_name,
  78. "args": func_args
  79. }
  80. })
  81. if parts:
  82. contents.append({
  83. "role": "model",
  84. "parts": parts
  85. })
  86. continue
  87. # 跳过空消息
  88. if not content_text or not content_text.strip():
  89. continue
  90. # 普通消息
  91. gemini_role = "model" if role == "assistant" else "user"
  92. contents.append({
  93. "role": gemini_role,
  94. "parts": [{"text": content_text}]
  95. })
  96. # Flush 剩余的 tool messages
  97. flush_tool_buffer()
  98. # 合并连续的 user 消息(Gemini 要求严格交替)
  99. merged_contents = []
  100. i = 0
  101. while i < len(contents):
  102. current = contents[i]
  103. if current["role"] == "user":
  104. merged_parts = current["parts"].copy()
  105. j = i + 1
  106. while j < len(contents) and contents[j]["role"] == "user":
  107. merged_parts.extend(contents[j]["parts"])
  108. j += 1
  109. merged_contents.append({
  110. "role": "user",
  111. "parts": merged_parts
  112. })
  113. i = j
  114. else:
  115. merged_contents.append(current)
  116. i += 1
  117. return merged_contents, system_instruction
  118. def _convert_tools_to_gemini(tools: List[Dict]) -> List[Dict]:
  119. """
  120. 将 OpenAI 工具格式转换为 Gemini REST API 格式
  121. OpenAI: [{"type": "function", "function": {"name": "...", "parameters": {...}}}]
  122. Gemini API: [{"functionDeclarations": [{"name": "...", "parameters": {...}}]}]
  123. """
  124. if not tools:
  125. return []
  126. function_declarations = []
  127. for tool in tools:
  128. if tool.get("type") == "function":
  129. func = tool.get("function", {})
  130. # 清理不支持的字段
  131. parameters = func.get("parameters", {})
  132. if "properties" in parameters:
  133. cleaned_properties = {}
  134. for prop_name, prop_def in parameters["properties"].items():
  135. # 移除 default 字段
  136. cleaned_prop = {k: v for k, v in prop_def.items() if k != "default"}
  137. cleaned_properties[prop_name] = cleaned_prop
  138. # Gemini API 需要完整的 schema
  139. cleaned_parameters = {
  140. "type": "object",
  141. "properties": cleaned_properties
  142. }
  143. if "required" in parameters:
  144. cleaned_parameters["required"] = parameters["required"]
  145. parameters = cleaned_parameters
  146. function_declarations.append({
  147. "name": func.get("name"),
  148. "description": func.get("description", ""),
  149. "parameters": parameters
  150. })
  151. return [{"functionDeclarations": function_declarations}] if function_declarations else []
  152. def create_gemini_llm_call(
  153. base_url: Optional[str] = None,
  154. api_key: Optional[str] = None
  155. ):
  156. """
  157. 创建 Gemini LLM 调用函数(HTTP API)
  158. Args:
  159. base_url: Gemini API base URL(默认使用 Google 官方)
  160. api_key: API key(默认从环境变量读取)
  161. Returns:
  162. async 函数
  163. """
  164. base_url = base_url or "https://generativelanguage.googleapis.com/v1beta"
  165. api_key = api_key or os.getenv("GEMINI_API_KEY")
  166. if not api_key:
  167. raise ValueError("GEMINI_API_KEY not found")
  168. # 创建 HTTP 客户端
  169. client = httpx.AsyncClient(
  170. headers={"x-goog-api-key": api_key},
  171. timeout=httpx.Timeout(120.0, connect=10.0)
  172. )
  173. async def gemini_llm_call(
  174. messages: List[Dict[str, Any]],
  175. model: str = "gemini-2.5-pro",
  176. tools: Optional[List[Dict]] = None,
  177. **kwargs
  178. ) -> Dict[str, Any]:
  179. """
  180. 调用 Gemini REST API
  181. Args:
  182. messages: OpenAI 格式消息
  183. model: 模型名称
  184. tools: OpenAI 格式工具列表
  185. **kwargs: 其他参数
  186. Returns:
  187. {
  188. "content": str,
  189. "tool_calls": List[Dict] | None,
  190. "prompt_tokens": int,
  191. "completion_tokens": int,
  192. "cost": float
  193. }
  194. """
  195. # 转换消息
  196. contents, system_instruction = _convert_messages_to_gemini(messages)
  197. print(f"\n[Gemini HTTP] Converted {len(contents)} messages: {[c['role'] for c in contents]}")
  198. # 构建请求
  199. endpoint = f"{base_url}/models/{model}:generateContent"
  200. payload = {"contents": contents}
  201. # 添加 system instruction
  202. if system_instruction:
  203. payload["systemInstruction"] = {"parts": [{"text": system_instruction}]}
  204. # 添加工具
  205. if tools:
  206. gemini_tools = _convert_tools_to_gemini(tools)
  207. if gemini_tools:
  208. payload["tools"] = gemini_tools
  209. # 调用 API
  210. try:
  211. response = await client.post(endpoint, json=payload)
  212. response.raise_for_status()
  213. gemini_resp = response.json()
  214. except httpx.HTTPStatusError as e:
  215. error_body = e.response.text
  216. print(f"[Gemini HTTP] Error {e.response.status_code}: {error_body}")
  217. raise
  218. except Exception as e:
  219. print(f"[Gemini HTTP] Request failed: {e}")
  220. raise
  221. # 解析响应
  222. content = ""
  223. tool_calls = None
  224. candidates = gemini_resp.get("candidates", [])
  225. if candidates:
  226. parts = candidates[0].get("content", {}).get("parts", [])
  227. # 提取文本
  228. for part in parts:
  229. if "text" in part:
  230. content += part.get("text", "")
  231. # 提取 functionCall
  232. for i, part in enumerate(parts):
  233. if "functionCall" in part:
  234. if tool_calls is None:
  235. tool_calls = []
  236. fc = part["functionCall"]
  237. name = fc.get("name", "")
  238. args = fc.get("args", {})
  239. tool_calls.append({
  240. "id": f"call_{i}",
  241. "type": "function",
  242. "function": {
  243. "name": name,
  244. "arguments": json.dumps(args, ensure_ascii=False)
  245. }
  246. })
  247. # 提取 usage
  248. usage_meta = gemini_resp.get("usageMetadata", {})
  249. prompt_tokens = usage_meta.get("promptTokenCount", 0)
  250. completion_tokens = usage_meta.get("candidatesTokenCount", 0)
  251. return {
  252. "content": content,
  253. "tool_calls": tool_calls,
  254. "prompt_tokens": prompt_tokens,
  255. "completion_tokens": completion_tokens,
  256. "cost": 0.0
  257. }
  258. return gemini_llm_call