nanobanana.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488
  1. """
  2. NanoBanana Tool - 图像生成
  3. 通用图像生成工具,可以接受自然语言描述和/或图像输入,生成新的图像。
  4. 支持通过 OpenRouter 调用 Gemini 2.5 Flash Image 模型。
  5. """
  6. import base64
  7. import json
  8. import mimetypes
  9. import os
  10. import re
  11. from pathlib import Path
  12. from typing import Optional, Dict, Any, List, Tuple
  13. import httpx
  14. from dotenv import load_dotenv
  15. from agent.tools import tool, ToolResult
  16. OPENROUTER_BASE_URL = "https://openrouter.ai/api/v1"
  17. DEFAULT_TIMEOUT = 120.0
  18. DEFAULT_IMAGE_PROMPT = "根据输入生成图像。"
  19. DEFAULT_IMAGE_MODEL_CANDIDATES = [
  20. # "google/gemini-2.5-flash-image",
  21. "google/gemini-3-pro-image-preview"
  22. # "google/gemini-3.1-flash-image-preview"
  23. ]
  24. def _resolve_api_key() -> Optional[str]:
  25. """优先读取环境变量,缺失时尝试从 .env 加载。"""
  26. api_key = os.getenv("OPENROUTER_API_KEY") or os.getenv("OPEN_ROUTER_API_KEY")
  27. if api_key:
  28. return api_key
  29. load_dotenv()
  30. return os.getenv("OPENROUTER_API_KEY") or os.getenv("OPEN_ROUTER_API_KEY")
  31. def _image_to_data_url(image_path: Path) -> str:
  32. """将图片文件编码为 data URL。"""
  33. mime_type = mimetypes.guess_type(str(image_path))[0] or "application/octet-stream"
  34. raw = image_path.read_bytes()
  35. b64 = base64.b64encode(raw).decode("utf-8")
  36. return f"data:{mime_type};base64,{b64}"
  37. def _safe_json_parse(content: str) -> Dict[str, Any]:
  38. """尽量从模型文本中提取 JSON。"""
  39. try:
  40. return json.loads(content)
  41. except json.JSONDecodeError:
  42. start = content.find("{")
  43. end = content.rfind("}")
  44. if start != -1 and end != -1 and end > start:
  45. candidate = content[start:end + 1]
  46. return json.loads(candidate)
  47. raise
  48. def _extract_data_url_images(message: Dict[str, Any]) -> List[Tuple[str, str]]:
  49. """
  50. 从 OpenRouter 响应消息中提取 data URL 图片。
  51. Returns:
  52. List[(mime_type, base64_data)]
  53. """
  54. extracted: List[Tuple[str, str]] = []
  55. # 官方文档中的主要位置:message.images[]
  56. for img in message.get("images", []) or []:
  57. if not isinstance(img, dict):
  58. continue
  59. if img.get("type") != "image_url":
  60. continue
  61. data_url = ((img.get("image_url") or {}).get("url") or "").strip()
  62. if not data_url.startswith("data:"):
  63. continue
  64. m = re.match(r"^data:([^;]+);base64,(.+)$", data_url, flags=re.DOTALL)
  65. if not m:
  66. continue
  67. extracted.append((m.group(1), m.group(2)))
  68. # 兼容某些模型可能把 image_url 放在 content 数组中
  69. content = message.get("content")
  70. if isinstance(content, list):
  71. for part in content:
  72. if not isinstance(part, dict):
  73. continue
  74. if part.get("type") != "image_url":
  75. continue
  76. data_url = ((part.get("image_url") or {}).get("url") or "").strip()
  77. if not data_url.startswith("data:"):
  78. continue
  79. m = re.match(r"^data:([^;]+);base64,(.+)$", data_url, flags=re.DOTALL)
  80. if not m:
  81. continue
  82. extracted.append((m.group(1), m.group(2)))
  83. return extracted
  84. def _extract_image_refs(choice: Dict[str, Any], message: Dict[str, Any]) -> List[Dict[str, str]]:
  85. """
  86. 尝试从不同响应格式中提取图片引用。
  87. 返回格式:
  88. - {"kind": "data_url", "value": "data:image/png;base64,..."}
  89. - {"kind": "base64", "value": "...", "mime_type": "image/png"}
  90. - {"kind": "url", "value": "https://..."}
  91. """
  92. refs: List[Dict[str, str]] = []
  93. # 1) 标准 message.images
  94. for img in message.get("images", []) or []:
  95. if not isinstance(img, dict):
  96. continue
  97. # image_url 结构
  98. data_url = ((img.get("image_url") or {}).get("url") or "").strip()
  99. if data_url.startswith("data:"):
  100. refs.append({"kind": "data_url", "value": data_url})
  101. continue
  102. if data_url.startswith("http"):
  103. refs.append({"kind": "url", "value": data_url})
  104. continue
  105. # 兼容 base64 字段
  106. b64 = (img.get("b64_json") or img.get("base64") or "").strip()
  107. if b64:
  108. refs.append({"kind": "base64", "value": b64, "mime_type": img.get("mime_type", "image/png")})
  109. # 2) 某些格式可能在 choice.images
  110. for img in choice.get("images", []) or []:
  111. if not isinstance(img, dict):
  112. continue
  113. data_url = ((img.get("image_url") or {}).get("url") or "").strip()
  114. if data_url.startswith("data:"):
  115. refs.append({"kind": "data_url", "value": data_url})
  116. continue
  117. if data_url.startswith("http"):
  118. refs.append({"kind": "url", "value": data_url})
  119. continue
  120. b64 = (img.get("b64_json") or img.get("base64") or "").strip()
  121. if b64:
  122. refs.append({"kind": "base64", "value": b64, "mime_type": img.get("mime_type", "image/png")})
  123. # 3) content 数组里的 image_url
  124. content = message.get("content")
  125. if isinstance(content, list):
  126. for part in content:
  127. if not isinstance(part, dict):
  128. continue
  129. if part.get("type") != "image_url":
  130. continue
  131. url = ((part.get("image_url") or {}).get("url") or "").strip()
  132. if url.startswith("data:"):
  133. refs.append({"kind": "data_url", "value": url})
  134. elif url.startswith("http"):
  135. refs.append({"kind": "url", "value": url})
  136. # 4) 极端兼容:文本中可能出现 data:image 或 http 图片 URL
  137. if isinstance(content, str):
  138. # data URL
  139. for m in re.finditer(r"(data:image\/[a-zA-Z0-9.+-]+;base64,[A-Za-z0-9+/=]+)", content):
  140. refs.append({"kind": "data_url", "value": m.group(1)})
  141. # http(s) 图片链接
  142. for m in re.finditer(r"(https?://\S+\.(?:png|jpg|jpeg|webp))", content, flags=re.IGNORECASE):
  143. refs.append({"kind": "url", "value": m.group(1)})
  144. return refs
  145. def _mime_to_ext(mime_type: str) -> str:
  146. """MIME 类型映射到扩展名。"""
  147. mapping = {
  148. "image/png": ".png",
  149. "image/jpeg": ".jpg",
  150. "image/webp": ".webp",
  151. }
  152. return mapping.get(mime_type.lower(), ".png")
  153. def _normalize_model_id(model_id: str) -> str:
  154. """
  155. 规范化常见误写模型 ID,减少无效重试。
  156. """
  157. if not model_id:
  158. return model_id
  159. m = model_id.strip()
  160. # 常见误写:gemini/gemini-xxx -> google/gemini-xxx
  161. if m.startswith("gemini/"):
  162. m = "google/" + m.split("/", 1)[1]
  163. # 常见顺序误写:preview-image -> image
  164. if "gemini-2.5-flash-preview-image" in m:
  165. m = m.replace("gemini-2.5-flash-preview-image", "gemini-2.5-flash-image")
  166. # 兼容旧 ID 到当前可用 ID
  167. if "gemini-2.5-flash-image-preview" in m:
  168. m = m.replace("gemini-2.5-flash-image-preview", "gemini-2.5-flash-image")
  169. return m
  170. @tool(description="通用的图像生成工具,根据文本描述和/或参考图像生成图片。输出格式只有图片,不能输出文字。")
  171. async def nanobanana(
  172. image_path: str = "",
  173. image_paths: Optional[List[str]] = None,
  174. prompt: Optional[str] = None,
  175. model: Optional[str] = None,
  176. max_tokens: int = 1200,
  177. image_output_path: Optional[str] = None,
  178. ) -> ToolResult:
  179. """
  180. 通用图像生成工具,可以接受自然语言描述和/或图像输入,生成新的图像。
  181. Args:
  182. image_path: 输入图片路径(单图模式,可选)
  183. image_paths: 输入图片路径列表(多图模式,可选)
  184. prompt: 自定义生成描述(可选,默认使用通用prompt)
  185. model: OpenRouter 模型名(可选,默认使用 gemini-2.5-flash-image)
  186. max_tokens: 最大输出 token
  187. image_output_path: 生成图片保存路径(可选)
  188. Returns:
  189. ToolResult: 包含生成的图片路径
  190. """
  191. raw_paths: List[str] = []
  192. if image_paths:
  193. raw_paths.extend(image_paths)
  194. if image_path:
  195. raw_paths.append(image_path)
  196. # 图像输入是可选的,但如果提供了就需要验证
  197. input_paths: List[Path] = []
  198. if raw_paths:
  199. # 去重并检查路径
  200. unique_raw: List[str] = []
  201. seen = set()
  202. for p in raw_paths:
  203. if p and p not in seen:
  204. unique_raw.append(p)
  205. seen.add(p)
  206. input_paths = [Path(p) for p in unique_raw]
  207. invalid = [str(p) for p in input_paths if (not p.exists() or not p.is_file())]
  208. if invalid:
  209. return ToolResult(
  210. title="NanoBanana 生成失败",
  211. output="",
  212. error=f"以下图片不存在或不可读: {invalid}",
  213. )
  214. api_key = _resolve_api_key()
  215. if not api_key:
  216. return ToolResult(
  217. title="NanoBanana 生成失败",
  218. output="",
  219. error="未找到 OpenRouter API Key,请设置 OPENROUTER_API_KEY 或 OPEN_ROUTER_API_KEY",
  220. )
  221. user_prompt = prompt or DEFAULT_IMAGE_PROMPT
  222. # 编码图像(如果有)
  223. image_data_urls = []
  224. if input_paths:
  225. try:
  226. image_data_urls = [_image_to_data_url(p) for p in input_paths]
  227. except Exception as e:
  228. return ToolResult(
  229. title="NanoBanana 生成失败",
  230. output="",
  231. error=f"图片编码失败: {e}",
  232. )
  233. user_content: List[Dict[str, Any]] = [{"type": "text", "text": user_prompt}]
  234. for u in image_data_urls:
  235. user_content.append({"type": "image_url", "image_url": {"url": u}})
  236. payload: Dict[str, Any] = {
  237. "messages": [
  238. {
  239. "role": "system",
  240. "content": "你是图像生成助手。请根据用户的描述和/或输入图像生成新的图像。",
  241. },
  242. {
  243. "role": "user",
  244. "content": user_content,
  245. },
  246. ],
  247. "temperature": 0.2,
  248. "max_tokens": max_tokens,
  249. "modalities": ["image", "text"],
  250. }
  251. headers = {
  252. "Authorization": f"Bearer {api_key}",
  253. "Content-Type": "application/json",
  254. "HTTP-Referer": "https://local-agent",
  255. "X-Title": "Agent NanoBanana Tool",
  256. }
  257. endpoint = f"{OPENROUTER_BASE_URL}/chat/completions"
  258. # 自动尝试多个可用模型,减少 404/invalid model 影响
  259. candidates: List[str] = []
  260. if model:
  261. candidates.append(_normalize_model_id(model))
  262. if env_model := os.getenv("NANOBANANA_IMAGE_MODEL"):
  263. candidates.append(_normalize_model_id(env_model))
  264. candidates.extend([_normalize_model_id(x) for x in DEFAULT_IMAGE_MODEL_CANDIDATES])
  265. # 去重并保持顺序
  266. dedup: List[str] = []
  267. seen = set()
  268. for m in candidates:
  269. if m and m not in seen:
  270. dedup.append(m)
  271. seen.add(m)
  272. candidates = dedup
  273. data: Optional[Dict[str, Any]] = None
  274. used_model: Optional[str] = None
  275. errors: List[Dict[str, Any]] = []
  276. for cand in candidates:
  277. modality_attempts: List[Optional[List[str]]] = [["image", "text"], ["image"], None]
  278. for mods in modality_attempts:
  279. trial_payload = dict(payload)
  280. trial_payload["model"] = cand
  281. if mods is None:
  282. trial_payload.pop("modalities", None)
  283. else:
  284. trial_payload["modalities"] = mods
  285. try:
  286. async with httpx.AsyncClient(timeout=DEFAULT_TIMEOUT) as client:
  287. resp = await client.post(endpoint, json=trial_payload, headers=headers)
  288. resp.raise_for_status()
  289. data = resp.json()
  290. used_model = cand
  291. break
  292. except httpx.HTTPStatusError as e:
  293. errors.append({
  294. "model": cand,
  295. "modalities": mods,
  296. "status_code": e.response.status_code,
  297. "body": e.response.text[:600],
  298. })
  299. continue
  300. except Exception as e:
  301. errors.append({
  302. "model": cand,
  303. "modalities": mods,
  304. "status_code": None,
  305. "body": str(e)[:600],
  306. })
  307. continue
  308. if data is not None:
  309. break
  310. if data is None:
  311. return ToolResult(
  312. title="NanoBanana 生成失败",
  313. output=json.dumps({"attempted_models": candidates, "errors": errors}, ensure_ascii=False, indent=2),
  314. long_term_memory="All candidate models failed for this request",
  315. metadata={"attempted_models": candidates, "errors": errors},
  316. )
  317. chosen_model = used_model or candidates[0]
  318. choices = data.get("choices") or []
  319. message = choices[0].get("message", {}) if choices else {}
  320. # 提取生成的图像
  321. refs = _extract_image_refs(choices[0] if choices else {}, message)
  322. if not refs:
  323. content = message.get("content")
  324. preview = ""
  325. if isinstance(content, str):
  326. preview = content[:500]
  327. elif isinstance(content, list):
  328. preview = json.dumps(content[:3], ensure_ascii=False)[:500]
  329. return ToolResult(
  330. title="NanoBanana 生成失败",
  331. output=json.dumps(data, ensure_ascii=False, indent=2),
  332. error="模型未返回可解析图片(未在 message.images/choice.images/content 中发现图片)",
  333. metadata={
  334. "model": chosen_model,
  335. "choice_keys": list((choices[0] if choices else {}).keys()),
  336. "message_keys": list(message.keys()) if isinstance(message, dict) else [],
  337. "content_preview": preview,
  338. },
  339. )
  340. output_paths: List[str] = []
  341. if image_output_path:
  342. base_path = Path(image_output_path)
  343. else:
  344. if len(input_paths) > 1:
  345. base_path = input_paths[0].parent / "set_generated.png"
  346. else:
  347. base_path = input_paths[0].parent / f"{input_paths[0].stem}_generated.png"
  348. base_path.parent.mkdir(parents=True, exist_ok=True)
  349. for idx, ref in enumerate(refs):
  350. kind = ref.get("kind", "")
  351. mime_type = "image/png"
  352. raw_bytes: Optional[bytes] = None
  353. if kind == "data_url":
  354. m = re.match(r"^data:([^;]+);base64,(.+)$", ref.get("value", ""), flags=re.DOTALL)
  355. if not m:
  356. continue
  357. mime_type = m.group(1)
  358. raw_bytes = base64.b64decode(m.group(2))
  359. elif kind == "base64":
  360. mime_type = ref.get("mime_type", "image/png")
  361. raw_bytes = base64.b64decode(ref.get("value", ""))
  362. elif kind == "url":
  363. url = ref.get("value", "")
  364. try:
  365. with httpx.Client(timeout=DEFAULT_TIMEOUT) as client:
  366. r = client.get(url)
  367. r.raise_for_status()
  368. raw_bytes = r.content
  369. mime_type = r.headers.get("content-type", "image/png").split(";")[0]
  370. except Exception:
  371. continue
  372. else:
  373. continue
  374. if not raw_bytes:
  375. continue
  376. ext = _mime_to_ext(mime_type)
  377. if len(refs) == 1:
  378. target = base_path
  379. if target.suffix.lower() not in [".png", ".jpg", ".jpeg", ".webp"]:
  380. target = target.with_suffix(ext)
  381. else:
  382. stem = base_path.stem
  383. target = base_path.with_name(f"{stem}_{idx+1}{ext}")
  384. try:
  385. target.write_bytes(raw_bytes)
  386. output_paths.append(str(target))
  387. except Exception as e:
  388. return ToolResult(
  389. title="NanoBanana 生成失败",
  390. output="",
  391. error=f"写入生成图片失败: {e}",
  392. metadata={"model": chosen_model},
  393. )
  394. if not output_paths:
  395. return ToolResult(
  396. title="NanoBanana 生成失败",
  397. output=json.dumps(data, ensure_ascii=False, indent=2),
  398. error="检测到图片引用但写入失败(可能是无效 base64 或 URL 不可访问)",
  399. metadata={"model": chosen_model, "ref_count": len(refs)},
  400. )
  401. usage = data.get("usage", {})
  402. prompt_tokens = usage.get("prompt_tokens") or usage.get("input_tokens", 0)
  403. completion_tokens = usage.get("completion_tokens") or usage.get("output_tokens", 0)
  404. summary = {
  405. "model": chosen_model,
  406. "input_images": [str(p) for p in input_paths],
  407. "input_count": len(input_paths),
  408. "generated_images": output_paths,
  409. "prompt_tokens": prompt_tokens,
  410. "completion_tokens": completion_tokens,
  411. }
  412. return ToolResult(
  413. title="NanoBanana 图片生成完成",
  414. output=json.dumps({"summary": summary}, ensure_ascii=False, indent=2),
  415. long_term_memory=f"Generated {len(output_paths)} image(s) from {len(input_paths)} input image(s) using {chosen_model}",
  416. attachments=output_paths,
  417. metadata=summary,
  418. tool_usage={
  419. "model": chosen_model,
  420. "prompt_tokens": prompt_tokens,
  421. "completion_tokens": completion_tokens,
  422. }
  423. )