|
|
@@ -0,0 +1,317 @@
|
|
|
+"""
|
|
|
+Gemini Provider (HTTP API)
|
|
|
+
|
|
|
+使用 httpx 直接调用 Gemini REST API,避免 google-generativeai SDK 的兼容性问题
|
|
|
+
|
|
|
+参考:Resonote/llm/providers/gemini.py
|
|
|
+"""
|
|
|
+
|
|
|
+import os
|
|
|
+import json
|
|
|
+import httpx
|
|
|
+from typing import List, Dict, Any, Optional
|
|
|
+
|
|
|
+
|
|
|
+def _convert_messages_to_gemini(messages: List[Dict]) -> tuple[List[Dict], Optional[str]]:
|
|
|
+ """
|
|
|
+ 将 OpenAI 格式消息转换为 Gemini 格式
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ (gemini_contents, system_instruction)
|
|
|
+ """
|
|
|
+ contents = []
|
|
|
+ system_instruction = None
|
|
|
+ tool_parts_buffer = []
|
|
|
+
|
|
|
+ def flush_tool_buffer():
|
|
|
+ """合并连续的 tool 消息为单个 user 消息"""
|
|
|
+ if tool_parts_buffer:
|
|
|
+ contents.append({
|
|
|
+ "role": "user",
|
|
|
+ "parts": tool_parts_buffer.copy()
|
|
|
+ })
|
|
|
+ tool_parts_buffer.clear()
|
|
|
+
|
|
|
+ for msg in messages:
|
|
|
+ role = msg.get("role")
|
|
|
+
|
|
|
+ # System 消息 -> system_instruction
|
|
|
+ if role == "system":
|
|
|
+ system_instruction = msg.get("content", "")
|
|
|
+ continue
|
|
|
+
|
|
|
+ # Tool 消息 -> functionResponse
|
|
|
+ if role == "tool":
|
|
|
+ tool_name = msg.get("name")
|
|
|
+ content_text = msg.get("content", "")
|
|
|
+
|
|
|
+ if not tool_name:
|
|
|
+ print(f"[WARNING] Tool message missing 'name' field, skipping")
|
|
|
+ continue
|
|
|
+
|
|
|
+ # 尝试解析为 JSON
|
|
|
+ try:
|
|
|
+ parsed = json.loads(content_text) if content_text else {}
|
|
|
+ if isinstance(parsed, list):
|
|
|
+ response_data = {"result": parsed}
|
|
|
+ else:
|
|
|
+ response_data = parsed
|
|
|
+ except (json.JSONDecodeError, ValueError):
|
|
|
+ response_data = {"result": content_text}
|
|
|
+
|
|
|
+ # 添加到 buffer
|
|
|
+ tool_parts_buffer.append({
|
|
|
+ "functionResponse": {
|
|
|
+ "name": tool_name,
|
|
|
+ "response": response_data
|
|
|
+ }
|
|
|
+ })
|
|
|
+ continue
|
|
|
+
|
|
|
+ # 非 tool 消息:先 flush buffer
|
|
|
+ flush_tool_buffer()
|
|
|
+
|
|
|
+ content_text = msg.get("content", "")
|
|
|
+ tool_calls = msg.get("tool_calls")
|
|
|
+
|
|
|
+ # Assistant 消息 + tool_calls
|
|
|
+ if role == "assistant" and tool_calls:
|
|
|
+ parts = []
|
|
|
+ if content_text and content_text.strip():
|
|
|
+ parts.append({"text": content_text})
|
|
|
+
|
|
|
+ # 转换 tool_calls 为 functionCall
|
|
|
+ for tc in tool_calls:
|
|
|
+ func = tc.get("function", {})
|
|
|
+ func_name = func.get("name", "")
|
|
|
+ func_args_str = func.get("arguments", "{}")
|
|
|
+ try:
|
|
|
+ func_args = json.loads(func_args_str) if isinstance(func_args_str, str) else func_args_str
|
|
|
+ except json.JSONDecodeError:
|
|
|
+ func_args = {}
|
|
|
+
|
|
|
+ parts.append({
|
|
|
+ "functionCall": {
|
|
|
+ "name": func_name,
|
|
|
+ "args": func_args
|
|
|
+ }
|
|
|
+ })
|
|
|
+
|
|
|
+ if parts:
|
|
|
+ contents.append({
|
|
|
+ "role": "model",
|
|
|
+ "parts": parts
|
|
|
+ })
|
|
|
+ continue
|
|
|
+
|
|
|
+ # 跳过空消息
|
|
|
+ if not content_text or not content_text.strip():
|
|
|
+ continue
|
|
|
+
|
|
|
+ # 普通消息
|
|
|
+ gemini_role = "model" if role == "assistant" else "user"
|
|
|
+ contents.append({
|
|
|
+ "role": gemini_role,
|
|
|
+ "parts": [{"text": content_text}]
|
|
|
+ })
|
|
|
+
|
|
|
+ # Flush 剩余的 tool messages
|
|
|
+ flush_tool_buffer()
|
|
|
+
|
|
|
+ # 合并连续的 user 消息(Gemini 要求严格交替)
|
|
|
+ merged_contents = []
|
|
|
+ i = 0
|
|
|
+ while i < len(contents):
|
|
|
+ current = contents[i]
|
|
|
+
|
|
|
+ if current["role"] == "user":
|
|
|
+ merged_parts = current["parts"].copy()
|
|
|
+ j = i + 1
|
|
|
+ while j < len(contents) and contents[j]["role"] == "user":
|
|
|
+ merged_parts.extend(contents[j]["parts"])
|
|
|
+ j += 1
|
|
|
+
|
|
|
+ merged_contents.append({
|
|
|
+ "role": "user",
|
|
|
+ "parts": merged_parts
|
|
|
+ })
|
|
|
+ i = j
|
|
|
+ else:
|
|
|
+ merged_contents.append(current)
|
|
|
+ i += 1
|
|
|
+
|
|
|
+ return merged_contents, system_instruction
|
|
|
+
|
|
|
+
|
|
|
+def _convert_tools_to_gemini(tools: List[Dict]) -> List[Dict]:
|
|
|
+ """
|
|
|
+ 将 OpenAI 工具格式转换为 Gemini REST API 格式
|
|
|
+
|
|
|
+ OpenAI: [{"type": "function", "function": {"name": "...", "parameters": {...}}}]
|
|
|
+ Gemini API: [{"functionDeclarations": [{"name": "...", "parameters": {...}}]}]
|
|
|
+ """
|
|
|
+ if not tools:
|
|
|
+ return []
|
|
|
+
|
|
|
+ function_declarations = []
|
|
|
+ for tool in tools:
|
|
|
+ if tool.get("type") == "function":
|
|
|
+ func = tool.get("function", {})
|
|
|
+
|
|
|
+ # 清理不支持的字段
|
|
|
+ parameters = func.get("parameters", {})
|
|
|
+ if "properties" in parameters:
|
|
|
+ cleaned_properties = {}
|
|
|
+ for prop_name, prop_def in parameters["properties"].items():
|
|
|
+ # 移除 default 字段
|
|
|
+ cleaned_prop = {k: v for k, v in prop_def.items() if k != "default"}
|
|
|
+ cleaned_properties[prop_name] = cleaned_prop
|
|
|
+
|
|
|
+ # Gemini API 需要完整的 schema
|
|
|
+ cleaned_parameters = {
|
|
|
+ "type": "object",
|
|
|
+ "properties": cleaned_properties
|
|
|
+ }
|
|
|
+ if "required" in parameters:
|
|
|
+ cleaned_parameters["required"] = parameters["required"]
|
|
|
+
|
|
|
+ parameters = cleaned_parameters
|
|
|
+
|
|
|
+ function_declarations.append({
|
|
|
+ "name": func.get("name"),
|
|
|
+ "description": func.get("description", ""),
|
|
|
+ "parameters": parameters
|
|
|
+ })
|
|
|
+
|
|
|
+ return [{"functionDeclarations": function_declarations}] if function_declarations else []
|
|
|
+
|
|
|
+
|
|
|
+def create_gemini_llm_call(
|
|
|
+ base_url: Optional[str] = None,
|
|
|
+ api_key: Optional[str] = None
|
|
|
+):
|
|
|
+ """
|
|
|
+ 创建 Gemini LLM 调用函数(HTTP API)
|
|
|
+
|
|
|
+ Args:
|
|
|
+ base_url: Gemini API base URL(默认使用 Google 官方)
|
|
|
+ api_key: API key(默认从环境变量读取)
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ async 函数
|
|
|
+ """
|
|
|
+ base_url = base_url or "https://generativelanguage.googleapis.com/v1beta"
|
|
|
+ api_key = api_key or os.getenv("GEMINI_API_KEY")
|
|
|
+
|
|
|
+ if not api_key:
|
|
|
+ raise ValueError("GEMINI_API_KEY not found")
|
|
|
+
|
|
|
+ # 创建 HTTP 客户端
|
|
|
+ client = httpx.AsyncClient(
|
|
|
+ headers={"x-goog-api-key": api_key},
|
|
|
+ timeout=httpx.Timeout(120.0, connect=10.0)
|
|
|
+ )
|
|
|
+
|
|
|
+ async def gemini_llm_call(
|
|
|
+ messages: List[Dict[str, Any]],
|
|
|
+ model: str = "gemini-2.5-pro",
|
|
|
+ tools: Optional[List[Dict]] = None,
|
|
|
+ **kwargs
|
|
|
+ ) -> Dict[str, Any]:
|
|
|
+ """
|
|
|
+ 调用 Gemini REST API
|
|
|
+
|
|
|
+ Args:
|
|
|
+ messages: OpenAI 格式消息
|
|
|
+ model: 模型名称
|
|
|
+ tools: OpenAI 格式工具列表
|
|
|
+ **kwargs: 其他参数
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ {
|
|
|
+ "content": str,
|
|
|
+ "tool_calls": List[Dict] | None,
|
|
|
+ "prompt_tokens": int,
|
|
|
+ "completion_tokens": int,
|
|
|
+ "cost": float
|
|
|
+ }
|
|
|
+ """
|
|
|
+ # 转换消息
|
|
|
+ contents, system_instruction = _convert_messages_to_gemini(messages)
|
|
|
+
|
|
|
+ print(f"\n[Gemini HTTP] Converted {len(contents)} messages: {[c['role'] for c in contents]}")
|
|
|
+
|
|
|
+ # 构建请求
|
|
|
+ endpoint = f"{base_url}/models/{model}:generateContent"
|
|
|
+ payload = {"contents": contents}
|
|
|
+
|
|
|
+ # 添加 system instruction
|
|
|
+ if system_instruction:
|
|
|
+ payload["systemInstruction"] = {"parts": [{"text": system_instruction}]}
|
|
|
+
|
|
|
+ # 添加工具
|
|
|
+ if tools:
|
|
|
+ gemini_tools = _convert_tools_to_gemini(tools)
|
|
|
+ if gemini_tools:
|
|
|
+ payload["tools"] = gemini_tools
|
|
|
+
|
|
|
+ # 调用 API
|
|
|
+ try:
|
|
|
+ response = await client.post(endpoint, json=payload)
|
|
|
+ response.raise_for_status()
|
|
|
+ gemini_resp = response.json()
|
|
|
+
|
|
|
+ except httpx.HTTPStatusError as e:
|
|
|
+ error_body = e.response.text
|
|
|
+ print(f"[Gemini HTTP] Error {e.response.status_code}: {error_body}")
|
|
|
+ raise
|
|
|
+ except Exception as e:
|
|
|
+ print(f"[Gemini HTTP] Request failed: {e}")
|
|
|
+ raise
|
|
|
+
|
|
|
+ # 解析响应
|
|
|
+ content = ""
|
|
|
+ tool_calls = None
|
|
|
+
|
|
|
+ candidates = gemini_resp.get("candidates", [])
|
|
|
+ if candidates:
|
|
|
+ parts = candidates[0].get("content", {}).get("parts", [])
|
|
|
+
|
|
|
+ # 提取文本
|
|
|
+ for part in parts:
|
|
|
+ if "text" in part:
|
|
|
+ content += part.get("text", "")
|
|
|
+
|
|
|
+ # 提取 functionCall
|
|
|
+ for i, part in enumerate(parts):
|
|
|
+ if "functionCall" in part:
|
|
|
+ if tool_calls is None:
|
|
|
+ tool_calls = []
|
|
|
+
|
|
|
+ fc = part["functionCall"]
|
|
|
+ name = fc.get("name", "")
|
|
|
+ args = fc.get("args", {})
|
|
|
+
|
|
|
+ tool_calls.append({
|
|
|
+ "id": f"call_{i}",
|
|
|
+ "type": "function",
|
|
|
+ "function": {
|
|
|
+ "name": name,
|
|
|
+ "arguments": json.dumps(args, ensure_ascii=False)
|
|
|
+ }
|
|
|
+ })
|
|
|
+
|
|
|
+ # 提取 usage
|
|
|
+ usage_meta = gemini_resp.get("usageMetadata", {})
|
|
|
+ prompt_tokens = usage_meta.get("promptTokenCount", 0)
|
|
|
+ completion_tokens = usage_meta.get("candidatesTokenCount", 0)
|
|
|
+
|
|
|
+ return {
|
|
|
+ "content": content,
|
|
|
+ "tool_calls": tool_calls,
|
|
|
+ "prompt_tokens": prompt_tokens,
|
|
|
+ "completion_tokens": completion_tokens,
|
|
|
+ "cost": 0.0
|
|
|
+ }
|
|
|
+
|
|
|
+ return gemini_llm_call
|