| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317 |
- """
- Gemini Provider (HTTP API)
- 使用 httpx 直接调用 Gemini REST API,避免 google-generativeai SDK 的兼容性问题
- 参考:Resonote/llm/providers/gemini.py
- """
- import os
- import json
- import httpx
- from typing import List, Dict, Any, Optional
- def _convert_messages_to_gemini(messages: List[Dict]) -> tuple[List[Dict], Optional[str]]:
- """
- 将 OpenAI 格式消息转换为 Gemini 格式
- Returns:
- (gemini_contents, system_instruction)
- """
- contents = []
- system_instruction = None
- tool_parts_buffer = []
- def flush_tool_buffer():
- """合并连续的 tool 消息为单个 user 消息"""
- if tool_parts_buffer:
- contents.append({
- "role": "user",
- "parts": tool_parts_buffer.copy()
- })
- tool_parts_buffer.clear()
- for msg in messages:
- role = msg.get("role")
- # System 消息 -> system_instruction
- if role == "system":
- system_instruction = msg.get("content", "")
- continue
- # Tool 消息 -> functionResponse
- if role == "tool":
- tool_name = msg.get("name")
- content_text = msg.get("content", "")
- if not tool_name:
- print(f"[WARNING] Tool message missing 'name' field, skipping")
- continue
- # 尝试解析为 JSON
- try:
- parsed = json.loads(content_text) if content_text else {}
- if isinstance(parsed, list):
- response_data = {"result": parsed}
- else:
- response_data = parsed
- except (json.JSONDecodeError, ValueError):
- response_data = {"result": content_text}
- # 添加到 buffer
- tool_parts_buffer.append({
- "functionResponse": {
- "name": tool_name,
- "response": response_data
- }
- })
- continue
- # 非 tool 消息:先 flush buffer
- flush_tool_buffer()
- content_text = msg.get("content", "")
- tool_calls = msg.get("tool_calls")
- # Assistant 消息 + tool_calls
- if role == "assistant" and tool_calls:
- parts = []
- if content_text and content_text.strip():
- parts.append({"text": content_text})
- # 转换 tool_calls 为 functionCall
- for tc in tool_calls:
- func = tc.get("function", {})
- func_name = func.get("name", "")
- func_args_str = func.get("arguments", "{}")
- try:
- func_args = json.loads(func_args_str) if isinstance(func_args_str, str) else func_args_str
- except json.JSONDecodeError:
- func_args = {}
- parts.append({
- "functionCall": {
- "name": func_name,
- "args": func_args
- }
- })
- if parts:
- contents.append({
- "role": "model",
- "parts": parts
- })
- continue
- # 跳过空消息
- if not content_text or not content_text.strip():
- continue
- # 普通消息
- gemini_role = "model" if role == "assistant" else "user"
- contents.append({
- "role": gemini_role,
- "parts": [{"text": content_text}]
- })
- # Flush 剩余的 tool messages
- flush_tool_buffer()
- # 合并连续的 user 消息(Gemini 要求严格交替)
- merged_contents = []
- i = 0
- while i < len(contents):
- current = contents[i]
- if current["role"] == "user":
- merged_parts = current["parts"].copy()
- j = i + 1
- while j < len(contents) and contents[j]["role"] == "user":
- merged_parts.extend(contents[j]["parts"])
- j += 1
- merged_contents.append({
- "role": "user",
- "parts": merged_parts
- })
- i = j
- else:
- merged_contents.append(current)
- i += 1
- return merged_contents, system_instruction
- def _convert_tools_to_gemini(tools: List[Dict]) -> List[Dict]:
- """
- 将 OpenAI 工具格式转换为 Gemini REST API 格式
- OpenAI: [{"type": "function", "function": {"name": "...", "parameters": {...}}}]
- Gemini API: [{"functionDeclarations": [{"name": "...", "parameters": {...}}]}]
- """
- if not tools:
- return []
- function_declarations = []
- for tool in tools:
- if tool.get("type") == "function":
- func = tool.get("function", {})
- # 清理不支持的字段
- parameters = func.get("parameters", {})
- if "properties" in parameters:
- cleaned_properties = {}
- for prop_name, prop_def in parameters["properties"].items():
- # 移除 default 字段
- cleaned_prop = {k: v for k, v in prop_def.items() if k != "default"}
- cleaned_properties[prop_name] = cleaned_prop
- # Gemini API 需要完整的 schema
- cleaned_parameters = {
- "type": "object",
- "properties": cleaned_properties
- }
- if "required" in parameters:
- cleaned_parameters["required"] = parameters["required"]
- parameters = cleaned_parameters
- function_declarations.append({
- "name": func.get("name"),
- "description": func.get("description", ""),
- "parameters": parameters
- })
- return [{"functionDeclarations": function_declarations}] if function_declarations else []
- def create_gemini_llm_call(
- base_url: Optional[str] = None,
- api_key: Optional[str] = None
- ):
- """
- 创建 Gemini LLM 调用函数(HTTP API)
- Args:
- base_url: Gemini API base URL(默认使用 Google 官方)
- api_key: API key(默认从环境变量读取)
- Returns:
- async 函数
- """
- base_url = base_url or "https://generativelanguage.googleapis.com/v1beta"
- api_key = api_key or os.getenv("GEMINI_API_KEY")
- if not api_key:
- raise ValueError("GEMINI_API_KEY not found")
- # 创建 HTTP 客户端
- client = httpx.AsyncClient(
- headers={"x-goog-api-key": api_key},
- timeout=httpx.Timeout(120.0, connect=10.0)
- )
- async def gemini_llm_call(
- messages: List[Dict[str, Any]],
- model: str = "gemini-2.5-pro",
- tools: Optional[List[Dict]] = None,
- **kwargs
- ) -> Dict[str, Any]:
- """
- 调用 Gemini REST API
- Args:
- messages: OpenAI 格式消息
- model: 模型名称
- tools: OpenAI 格式工具列表
- **kwargs: 其他参数
- Returns:
- {
- "content": str,
- "tool_calls": List[Dict] | None,
- "prompt_tokens": int,
- "completion_tokens": int,
- "cost": float
- }
- """
- # 转换消息
- contents, system_instruction = _convert_messages_to_gemini(messages)
- print(f"\n[Gemini HTTP] Converted {len(contents)} messages: {[c['role'] for c in contents]}")
- # 构建请求
- endpoint = f"{base_url}/models/{model}:generateContent"
- payload = {"contents": contents}
- # 添加 system instruction
- if system_instruction:
- payload["systemInstruction"] = {"parts": [{"text": system_instruction}]}
- # 添加工具
- if tools:
- gemini_tools = _convert_tools_to_gemini(tools)
- if gemini_tools:
- payload["tools"] = gemini_tools
- # 调用 API
- try:
- response = await client.post(endpoint, json=payload)
- response.raise_for_status()
- gemini_resp = response.json()
- except httpx.HTTPStatusError as e:
- error_body = e.response.text
- print(f"[Gemini HTTP] Error {e.response.status_code}: {error_body}")
- raise
- except Exception as e:
- print(f"[Gemini HTTP] Request failed: {e}")
- raise
- # 解析响应
- content = ""
- tool_calls = None
- candidates = gemini_resp.get("candidates", [])
- if candidates:
- parts = candidates[0].get("content", {}).get("parts", [])
- # 提取文本
- for part in parts:
- if "text" in part:
- content += part.get("text", "")
- # 提取 functionCall
- for i, part in enumerate(parts):
- if "functionCall" in part:
- if tool_calls is None:
- tool_calls = []
- fc = part["functionCall"]
- name = fc.get("name", "")
- args = fc.get("args", {})
- tool_calls.append({
- "id": f"call_{i}",
- "type": "function",
- "function": {
- "name": name,
- "arguments": json.dumps(args, ensure_ascii=False)
- }
- })
- # 提取 usage
- usage_meta = gemini_resp.get("usageMetadata", {})
- prompt_tokens = usage_meta.get("promptTokenCount", 0)
- completion_tokens = usage_meta.get("candidatesTokenCount", 0)
- return {
- "content": content,
- "tool_calls": tool_calls,
- "prompt_tokens": prompt_tokens,
- "completion_tokens": completion_tokens,
- "cost": 0.0
- }
- return gemini_llm_call
|