| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432 |
- """
- Gemini Provider (HTTP API)
- 使用 httpx 直接调用 Gemini REST API,避免 google-generativeai SDK 的兼容性问题
- 参考:Resonote/llm/providers/gemini.py
- """
- import os
- import json
- import sys
- import httpx
- from typing import List, Dict, Any, Optional
- def _dump_llm_request(endpoint: str, payload: Dict[str, Any], model: str):
- """
- Dump完整的LLM请求用于调试(需要设置 AGENT_DEBUG=1)
- 特别处理:
- - 图片base64数据:只显示前50字符 + 长度信息
- - Tools schema:完整显示
- - 输出到stderr,避免污染正常输出
- """
- if not os.getenv("AGENT_DEBUG"):
- return
- def truncate_images(obj):
- """递归处理对象,truncate图片base64数据"""
- if isinstance(obj, dict):
- result = {}
- for key, value in obj.items():
- # 处理 inline_data 中的 base64 图片
- if key == "inline_data" and isinstance(value, dict):
- mime_type = value.get("mime_type", "unknown")
- data = value.get("data", "")
- data_size_kb = len(data) / 1024 if data else 0
- result[key] = {
- "mime_type": mime_type,
- "data": f"<BASE64_IMAGE: {data_size_kb:.1f}KB, preview: {data[:50]}...>"
- }
- else:
- result[key] = truncate_images(value)
- return result
- elif isinstance(obj, list):
- return [truncate_images(item) for item in obj]
- else:
- return obj
- # 构造完整的调试信息
- debug_info = {
- "endpoint": endpoint,
- "model": model,
- "payload": truncate_images(payload)
- }
- # 输出到stderr
- print("\n" + "="*80, file=sys.stderr)
- print("[AGENT_DEBUG] LLM Request Dump", file=sys.stderr)
- print("="*80, file=sys.stderr)
- print(json.dumps(debug_info, indent=2, ensure_ascii=False), file=sys.stderr)
- print("="*80 + "\n", file=sys.stderr)
- def _convert_messages_to_gemini(messages: List[Dict]) -> tuple[List[Dict], Optional[str]]:
- """
- 将 OpenAI 格式消息转换为 Gemini 格式
- Returns:
- (gemini_contents, system_instruction)
- """
- contents = []
- system_instruction = None
- tool_parts_buffer = []
- def flush_tool_buffer():
- """合并连续的 tool 消息为单个 user 消息"""
- if tool_parts_buffer:
- contents.append({
- "role": "user",
- "parts": tool_parts_buffer.copy()
- })
- tool_parts_buffer.clear()
- for msg in messages:
- role = msg.get("role")
- # System 消息 -> system_instruction
- if role == "system":
- system_instruction = msg.get("content", "")
- continue
- # Tool 消息 -> functionResponse
- if role == "tool":
- tool_name = msg.get("name")
- content_text = msg.get("content", "")
- if not tool_name:
- print(f"[WARNING] Tool message missing 'name' field, skipping")
- continue
- # 尝试解析为 JSON
- try:
- parsed = json.loads(content_text) if content_text else {}
- if isinstance(parsed, list):
- response_data = {"result": parsed}
- else:
- response_data = parsed
- except (json.JSONDecodeError, ValueError):
- response_data = {"result": content_text}
- # 添加到 buffer
- tool_parts_buffer.append({
- "functionResponse": {
- "name": tool_name,
- "response": response_data
- }
- })
- continue
- # 非 tool 消息:先 flush buffer
- flush_tool_buffer()
- content = msg.get("content", "")
- tool_calls = msg.get("tool_calls")
- # Assistant 消息 + tool_calls
- if role == "assistant" and tool_calls:
- parts = []
- if content and (isinstance(content, str) and content.strip()):
- parts.append({"text": content})
- # 转换 tool_calls 为 functionCall
- for tc in tool_calls:
- func = tc.get("function", {})
- func_name = func.get("name", "")
- func_args_str = func.get("arguments", "{}")
- try:
- func_args = json.loads(func_args_str) if isinstance(func_args_str, str) else func_args_str
- except json.JSONDecodeError:
- func_args = {}
- parts.append({
- "functionCall": {
- "name": func_name,
- "args": func_args
- }
- })
- if parts:
- contents.append({
- "role": "model",
- "parts": parts
- })
- continue
- # 处理多模态消息(content 为数组)
- if isinstance(content, list):
- parts = []
- for item in content:
- item_type = item.get("type")
- # 文本部分
- if item_type == "text":
- text = item.get("text", "")
- if text.strip():
- parts.append({"text": text})
- # 图片部分(OpenAI format -> Gemini format)
- elif item_type == "image_url":
- image_url = item.get("image_url", {})
- url = image_url.get("url", "")
- # 处理 data URL (data:image/png;base64,...)
- if url.startswith("data:"):
- # 解析 MIME type 和 base64 数据
- # 格式:data:image/png;base64,<base64_data>
- try:
- header, base64_data = url.split(",", 1)
- mime_type = header.split(";")[0].replace("data:", "")
- parts.append({
- "inline_data": {
- "mime_type": mime_type,
- "data": base64_data
- }
- })
- except Exception as e:
- print(f"[WARNING] Failed to parse image data URL: {e}")
- if parts:
- gemini_role = "model" if role == "assistant" else "user"
- contents.append({
- "role": gemini_role,
- "parts": parts
- })
- continue
- # 普通文本消息(content 为字符串)
- if isinstance(content, str):
- # 跳过空消息
- if not content.strip():
- continue
- gemini_role = "model" if role == "assistant" else "user"
- contents.append({
- "role": gemini_role,
- "parts": [{"text": content}]
- })
- continue
- # Flush 剩余的 tool messages
- flush_tool_buffer()
- # 合并连续的 user 消息(Gemini 要求严格交替)
- merged_contents = []
- i = 0
- while i < len(contents):
- current = contents[i]
- if current["role"] == "user":
- merged_parts = current["parts"].copy()
- j = i + 1
- while j < len(contents) and contents[j]["role"] == "user":
- merged_parts.extend(contents[j]["parts"])
- j += 1
- merged_contents.append({
- "role": "user",
- "parts": merged_parts
- })
- i = j
- else:
- merged_contents.append(current)
- i += 1
- return merged_contents, system_instruction
- def _convert_tools_to_gemini(tools: List[Dict]) -> List[Dict]:
- """
- 将 OpenAI 工具格式转换为 Gemini REST API 格式
- OpenAI: [{"type": "function", "function": {"name": "...", "parameters": {...}}}]
- Gemini API: [{"functionDeclarations": [{"name": "...", "parameters": {...}}]}]
- """
- if not tools:
- return []
- function_declarations = []
- for tool in tools:
- if tool.get("type") == "function":
- func = tool.get("function", {})
- # 清理不支持的字段
- parameters = func.get("parameters", {})
- if "properties" in parameters:
- cleaned_properties = {}
- for prop_name, prop_def in parameters["properties"].items():
- # 移除 default 字段
- cleaned_prop = {k: v for k, v in prop_def.items() if k != "default"}
- cleaned_properties[prop_name] = cleaned_prop
- # Gemini API 需要完整的 schema
- cleaned_parameters = {
- "type": "object",
- "properties": cleaned_properties
- }
- if "required" in parameters:
- cleaned_parameters["required"] = parameters["required"]
- parameters = cleaned_parameters
- function_declarations.append({
- "name": func.get("name"),
- "description": func.get("description", ""),
- "parameters": parameters
- })
- return [{"functionDeclarations": function_declarations}] if function_declarations else []
- def create_gemini_llm_call(
- base_url: Optional[str] = None,
- api_key: Optional[str] = None
- ):
- """
- 创建 Gemini LLM 调用函数(HTTP API)
- Args:
- base_url: Gemini API base URL(默认使用 Google 官方)
- api_key: API key(默认从环境变量读取)
- Returns:
- async 函数
- """
- base_url = base_url or "https://generativelanguage.googleapis.com/v1beta"
- api_key = api_key or os.getenv("GEMINI_API_KEY")
- if not api_key:
- raise ValueError("GEMINI_API_KEY not found")
- # 创建 HTTP 客户端
- client = httpx.AsyncClient(
- headers={"x-goog-api-key": api_key},
- timeout=httpx.Timeout(120.0, connect=10.0)
- )
- async def gemini_llm_call(
- messages: List[Dict[str, Any]],
- model: str = "gemini-2.5-pro",
- tools: Optional[List[Dict]] = None,
- **kwargs
- ) -> Dict[str, Any]:
- """
- 调用 Gemini REST API
- Args:
- messages: OpenAI 格式消息
- model: 模型名称
- tools: OpenAI 格式工具列表
- **kwargs: 其他参数
- Returns:
- {
- "content": str,
- "tool_calls": List[Dict] | None,
- "prompt_tokens": int,
- "completion_tokens": int,
- "cost": float
- }
- """
- # 转换消息
- contents, system_instruction = _convert_messages_to_gemini(messages)
- print(f"\n[Gemini HTTP] Converted {len(contents)} messages: {[c['role'] for c in contents]}")
- # 构建请求
- endpoint = f"{base_url}/models/{model}:generateContent"
- payload = {"contents": contents}
- # 添加 system instruction
- if system_instruction:
- payload["systemInstruction"] = {"parts": [{"text": system_instruction}]}
- # 添加工具
- if tools:
- gemini_tools = _convert_tools_to_gemini(tools)
- if gemini_tools:
- payload["tools"] = gemini_tools
- # Debug: dump完整请求(需要设置 AGENT_DEBUG=1)
- _dump_llm_request(endpoint, payload, model)
- # 调用 API
- try:
- response = await client.post(endpoint, json=payload)
- response.raise_for_status()
- gemini_resp = response.json()
- except httpx.HTTPStatusError as e:
- error_body = e.response.text
- print(f"[Gemini HTTP] Error {e.response.status_code}: {error_body}")
- raise
- except Exception as e:
- print(f"[Gemini HTTP] Request failed: {e}")
- raise
- # Debug: 输出原始响应(如果启用)
- if os.getenv("AGENT_DEBUG"):
- print("\n[AGENT_DEBUG] Gemini Response:", file=sys.stderr)
- print(json.dumps(gemini_resp, ensure_ascii=False, indent=2)[:2000], file=sys.stderr)
- print("\n", file=sys.stderr)
- # 解析响应
- content = ""
- tool_calls = None
- candidates = gemini_resp.get("candidates", [])
- if candidates:
- candidate = candidates[0]
- # 检查是否有错误
- finish_reason = candidate.get("finishReason")
- if finish_reason == "MALFORMED_FUNCTION_CALL":
- # Gemini 返回了格式错误的函数调用
- # 提取 finishMessage 中的内容作为 content
- finish_message = candidate.get("finishMessage", "")
- print(f"[Gemini HTTP] Warning: MALFORMED_FUNCTION_CALL\n{finish_message}")
- content = f"[模型尝试调用工具但格式错误]\n\n{finish_message}"
- else:
- # 正常解析
- parts = candidate.get("content", {}).get("parts", [])
- # 提取文本
- for part in parts:
- if "text" in part:
- content += part.get("text", "")
- # 提取 functionCall
- for i, part in enumerate(parts):
- if "functionCall" in part:
- if tool_calls is None:
- tool_calls = []
- fc = part["functionCall"]
- name = fc.get("name", "")
- args = fc.get("args", {})
- tool_calls.append({
- "id": f"call_{i}",
- "type": "function",
- "function": {
- "name": name,
- "arguments": json.dumps(args, ensure_ascii=False)
- }
- })
- # 提取 usage
- usage_meta = gemini_resp.get("usageMetadata", {})
- prompt_tokens = usage_meta.get("promptTokenCount", 0)
- completion_tokens = usage_meta.get("candidatesTokenCount", 0)
- return {
- "content": content,
- "tool_calls": tool_calls,
- "prompt_tokens": prompt_tokens,
- "completion_tokens": completion_tokens,
- "cost": 0.0
- }
- return gemini_llm_call
|