nanobanana.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572
  1. """
  2. NanoBanana Tool - 图像特征提取与图像生成
  3. 该工具可以提取图片中的特征,也可以根据描述生成图片。
  4. 支持通过 OpenRouter 调用多模态模型,提取结构化的图像特征并保存为 JSON,
  5. 或基于输入图像生成新的图像。
  6. """
  7. import base64
  8. import json
  9. import mimetypes
  10. import os
  11. import re
  12. from pathlib import Path
  13. from typing import Optional, Dict, Any, List, Tuple
  14. import httpx
  15. from dotenv import load_dotenv
  16. from agent.tools import tool, ToolResult
  17. OPENROUTER_BASE_URL = "https://openrouter.ai/api/v1"
  18. DEFAULT_TIMEOUT = 120.0
  19. DEFAULT_EXTRACTION_PROMPT = (
  20. "请从这张图像中提取跨场景相对稳定、可复用的视觉不变特征。"
  21. "输出严格 JSON,字段包含:identity_features、pose_features、appearance_features、"
  22. "material_features、style_features、uncertainty、notes。"
  23. "每个字段给出简洁要点,避免臆测。"
  24. )
  25. DEFAULT_IMAGE_PROMPT = (
  26. "基于输入图像生成一张保留主体身份与关键视觉特征的新图像。"
  27. "保持人物核心特征一致,同时提升清晰度与可用性。"
  28. )
  29. DEFAULT_IMAGE_MODEL_CANDIDATES = [
  30. "google/gemini-2.5-flash-image",
  31. # "google/gemini-3-pro-image-preview",
  32. # "black-forest-labs/flux.2-flex",
  33. # "black-forest-labs/flux.2-pro",
  34. ]
  35. def _resolve_api_key() -> Optional[str]:
  36. """优先读取环境变量,缺失时尝试从 .env 加载。"""
  37. api_key = os.getenv("OPENROUTER_API_KEY") or os.getenv("OPEN_ROUTER_API_KEY")
  38. if api_key:
  39. return api_key
  40. load_dotenv()
  41. return os.getenv("OPENROUTER_API_KEY") or os.getenv("OPEN_ROUTER_API_KEY")
  42. def _image_to_data_url(image_path: Path) -> str:
  43. """将图片文件编码为 data URL。"""
  44. mime_type = mimetypes.guess_type(str(image_path))[0] or "application/octet-stream"
  45. raw = image_path.read_bytes()
  46. b64 = base64.b64encode(raw).decode("utf-8")
  47. return f"data:{mime_type};base64,{b64}"
  48. def _safe_json_parse(content: str) -> Dict[str, Any]:
  49. """尽量从模型文本中提取 JSON。"""
  50. try:
  51. return json.loads(content)
  52. except json.JSONDecodeError:
  53. start = content.find("{")
  54. end = content.rfind("}")
  55. if start != -1 and end != -1 and end > start:
  56. candidate = content[start:end + 1]
  57. return json.loads(candidate)
  58. raise
  59. def _extract_data_url_images(message: Dict[str, Any]) -> List[Tuple[str, str]]:
  60. """
  61. 从 OpenRouter 响应消息中提取 data URL 图片。
  62. Returns:
  63. List[(mime_type, base64_data)]
  64. """
  65. extracted: List[Tuple[str, str]] = []
  66. # 官方文档中的主要位置:message.images[]
  67. for img in message.get("images", []) or []:
  68. if not isinstance(img, dict):
  69. continue
  70. if img.get("type") != "image_url":
  71. continue
  72. data_url = ((img.get("image_url") or {}).get("url") or "").strip()
  73. if not data_url.startswith("data:"):
  74. continue
  75. m = re.match(r"^data:([^;]+);base64,(.+)$", data_url, flags=re.DOTALL)
  76. if not m:
  77. continue
  78. extracted.append((m.group(1), m.group(2)))
  79. # 兼容某些模型可能把 image_url 放在 content 数组中
  80. content = message.get("content")
  81. if isinstance(content, list):
  82. for part in content:
  83. if not isinstance(part, dict):
  84. continue
  85. if part.get("type") != "image_url":
  86. continue
  87. data_url = ((part.get("image_url") or {}).get("url") or "").strip()
  88. if not data_url.startswith("data:"):
  89. continue
  90. m = re.match(r"^data:([^;]+);base64,(.+)$", data_url, flags=re.DOTALL)
  91. if not m:
  92. continue
  93. extracted.append((m.group(1), m.group(2)))
  94. return extracted
  95. def _extract_image_refs(choice: Dict[str, Any], message: Dict[str, Any]) -> List[Dict[str, str]]:
  96. """
  97. 尝试从不同响应格式中提取图片引用。
  98. 返回格式:
  99. - {"kind": "data_url", "value": "data:image/png;base64,..."}
  100. - {"kind": "base64", "value": "...", "mime_type": "image/png"}
  101. - {"kind": "url", "value": "https://..."}
  102. """
  103. refs: List[Dict[str, str]] = []
  104. # 1) 标准 message.images
  105. for img in message.get("images", []) or []:
  106. if not isinstance(img, dict):
  107. continue
  108. # image_url 结构
  109. data_url = ((img.get("image_url") or {}).get("url") or "").strip()
  110. if data_url.startswith("data:"):
  111. refs.append({"kind": "data_url", "value": data_url})
  112. continue
  113. if data_url.startswith("http"):
  114. refs.append({"kind": "url", "value": data_url})
  115. continue
  116. # 兼容 base64 字段
  117. b64 = (img.get("b64_json") or img.get("base64") or "").strip()
  118. if b64:
  119. refs.append({"kind": "base64", "value": b64, "mime_type": img.get("mime_type", "image/png")})
  120. # 2) 某些格式可能在 choice.images
  121. for img in choice.get("images", []) or []:
  122. if not isinstance(img, dict):
  123. continue
  124. data_url = ((img.get("image_url") or {}).get("url") or "").strip()
  125. if data_url.startswith("data:"):
  126. refs.append({"kind": "data_url", "value": data_url})
  127. continue
  128. if data_url.startswith("http"):
  129. refs.append({"kind": "url", "value": data_url})
  130. continue
  131. b64 = (img.get("b64_json") or img.get("base64") or "").strip()
  132. if b64:
  133. refs.append({"kind": "base64", "value": b64, "mime_type": img.get("mime_type", "image/png")})
  134. # 3) content 数组里的 image_url
  135. content = message.get("content")
  136. if isinstance(content, list):
  137. for part in content:
  138. if not isinstance(part, dict):
  139. continue
  140. if part.get("type") != "image_url":
  141. continue
  142. url = ((part.get("image_url") or {}).get("url") or "").strip()
  143. if url.startswith("data:"):
  144. refs.append({"kind": "data_url", "value": url})
  145. elif url.startswith("http"):
  146. refs.append({"kind": "url", "value": url})
  147. # 4) 极端兼容:文本中可能出现 data:image 或 http 图片 URL
  148. if isinstance(content, str):
  149. # data URL
  150. for m in re.finditer(r"(data:image\/[a-zA-Z0-9.+-]+;base64,[A-Za-z0-9+/=]+)", content):
  151. refs.append({"kind": "data_url", "value": m.group(1)})
  152. # http(s) 图片链接
  153. for m in re.finditer(r"(https?://\S+\.(?:png|jpg|jpeg|webp))", content, flags=re.IGNORECASE):
  154. refs.append({"kind": "url", "value": m.group(1)})
  155. return refs
  156. def _mime_to_ext(mime_type: str) -> str:
  157. """MIME 类型映射到扩展名。"""
  158. mapping = {
  159. "image/png": ".png",
  160. "image/jpeg": ".jpg",
  161. "image/webp": ".webp",
  162. }
  163. return mapping.get(mime_type.lower(), ".png")
  164. def _normalize_model_id(model_id: str) -> str:
  165. """
  166. 规范化常见误写模型 ID,减少无效重试。
  167. """
  168. if not model_id:
  169. return model_id
  170. m = model_id.strip()
  171. # 常见误写:gemini/gemini-xxx -> google/gemini-xxx
  172. if m.startswith("gemini/"):
  173. m = "google/" + m.split("/", 1)[1]
  174. # 常见顺序误写:preview-image -> image
  175. if "gemini-2.5-flash-preview-image" in m:
  176. m = m.replace("gemini-2.5-flash-preview-image", "gemini-2.5-flash-image")
  177. # 兼容旧 ID 到当前可用 ID
  178. if "gemini-2.5-flash-image-preview" in m:
  179. m = m.replace("gemini-2.5-flash-image-preview", "gemini-2.5-flash-image")
  180. return m
  181. @tool(description="可以提取图片中的特征,也可以根据描述生成图片")
  182. async def nanobanana(
  183. image_path: str = "",
  184. image_paths: Optional[List[str]] = None,
  185. output_file: Optional[str] = None,
  186. prompt: Optional[str] = None,
  187. model: Optional[str] = None,
  188. max_tokens: int = 1200,
  189. generate_image: bool = False,
  190. image_output_path: Optional[str] = None,
  191. ) -> ToolResult:
  192. """
  193. 可以提取图片中的特征,也可以根据描述生成图片。
  194. Args:
  195. image_path: 输入图片路径(单图模式,可选)
  196. image_paths: 输入图片路径列表(多图整体模式,可选)
  197. output_file: 输出 JSON 文件路径(可选,用于特征提取模式)
  198. prompt: 自定义提取指令或生成描述(可选)
  199. model: OpenRouter 模型名(可选,默认读取 NANOBANANA_MODEL 或使用 Gemini 视觉模型)
  200. max_tokens: 最大输出 token
  201. generate_image: 是否生成图片(False=提取特征,True=生成图片)
  202. image_output_path: 生成图片保存路径(generate_image=True 时可选)
  203. Returns:
  204. ToolResult: 包含结构化特征和输出文件路径,或生成的图片路径
  205. """
  206. raw_paths: List[str] = []
  207. if image_paths:
  208. raw_paths.extend(image_paths)
  209. if image_path:
  210. raw_paths.append(image_path)
  211. if not raw_paths:
  212. return ToolResult(
  213. title="NanoBanana 提取失败",
  214. output="",
  215. error="未提供输入图片,请传入 image_path 或 image_paths",
  216. )
  217. # 去重并检查路径
  218. unique_raw: List[str] = []
  219. seen = set()
  220. for p in raw_paths:
  221. if p and p not in seen:
  222. unique_raw.append(p)
  223. seen.add(p)
  224. input_paths: List[Path] = [Path(p) for p in unique_raw]
  225. invalid = [str(p) for p in input_paths if (not p.exists() or not p.is_file())]
  226. if invalid:
  227. return ToolResult(
  228. title="NanoBanana 提取失败",
  229. output="",
  230. error=f"以下图片不存在或不可读: {invalid}",
  231. )
  232. api_key = _resolve_api_key()
  233. if not api_key:
  234. return ToolResult(
  235. title="NanoBanana 提取失败",
  236. output="",
  237. error="未找到 OpenRouter API Key,请设置 OPENROUTER_API_KEY 或 OPEN_ROUTER_API_KEY",
  238. )
  239. if generate_image:
  240. user_prompt = prompt or DEFAULT_IMAGE_PROMPT
  241. else:
  242. chosen_model = model or os.getenv("NANOBANANA_MODEL") or "google/gemini-2.5-flash"
  243. user_prompt = prompt or DEFAULT_EXTRACTION_PROMPT
  244. try:
  245. image_data_urls = [_image_to_data_url(p) for p in input_paths]
  246. except Exception as e:
  247. return ToolResult(
  248. title="NanoBanana 提取失败",
  249. output="",
  250. error=f"图片编码失败: {e}",
  251. )
  252. user_content: List[Dict[str, Any]] = [{"type": "text", "text": user_prompt}]
  253. for u in image_data_urls:
  254. user_content.append({"type": "image_url", "image_url": {"url": u}})
  255. payload: Dict[str, Any] = {
  256. "messages": [
  257. {
  258. "role": "system",
  259. "content": (
  260. "你是视觉助手。"
  261. "当任务为特征提取时输出 JSON 对象,不要输出 markdown。"
  262. "当任务为图像生成时请返回图像。"
  263. ),
  264. },
  265. {
  266. "role": "user",
  267. "content": user_content,
  268. },
  269. ],
  270. "temperature": 0.2,
  271. "max_tokens": max_tokens,
  272. }
  273. if generate_image:
  274. payload["modalities"] = ["image", "text"]
  275. headers = {
  276. "Authorization": f"Bearer {api_key}",
  277. "Content-Type": "application/json",
  278. "HTTP-Referer": "https://local-agent",
  279. "X-Title": "Agent NanoBanana Tool",
  280. }
  281. endpoint = f"{OPENROUTER_BASE_URL}/chat/completions"
  282. # 图像生成模式:自动尝试多个可用模型,减少 404/invalid model 影响
  283. if generate_image:
  284. candidates: List[str] = []
  285. if model:
  286. candidates.append(_normalize_model_id(model))
  287. if env_model := os.getenv("NANOBANANA_IMAGE_MODEL"):
  288. candidates.append(_normalize_model_id(env_model))
  289. candidates.extend([_normalize_model_id(x) for x in DEFAULT_IMAGE_MODEL_CANDIDATES])
  290. # 去重并保持顺序
  291. dedup: List[str] = []
  292. seen = set()
  293. for m in candidates:
  294. if m and m not in seen:
  295. dedup.append(m)
  296. seen.add(m)
  297. candidates = dedup
  298. else:
  299. candidates = [chosen_model]
  300. data: Optional[Dict[str, Any]] = None
  301. used_model: Optional[str] = None
  302. errors: List[Dict[str, Any]] = []
  303. for cand in candidates:
  304. modality_attempts: List[Optional[List[str]]] = [None]
  305. if generate_image:
  306. modality_attempts = [["image", "text"], ["image"], None]
  307. for mods in modality_attempts:
  308. trial_payload = dict(payload)
  309. trial_payload["model"] = cand
  310. if mods is None:
  311. trial_payload.pop("modalities", None)
  312. else:
  313. trial_payload["modalities"] = mods
  314. try:
  315. async with httpx.AsyncClient(timeout=DEFAULT_TIMEOUT) as client:
  316. resp = await client.post(endpoint, json=trial_payload, headers=headers)
  317. resp.raise_for_status()
  318. data = resp.json()
  319. used_model = cand
  320. break
  321. except httpx.HTTPStatusError as e:
  322. errors.append({
  323. "model": cand,
  324. "modalities": mods,
  325. "status_code": e.response.status_code,
  326. "body": e.response.text[:600],
  327. })
  328. continue
  329. except Exception as e:
  330. errors.append({
  331. "model": cand,
  332. "modalities": mods,
  333. "status_code": None,
  334. "body": str(e)[:600],
  335. })
  336. continue
  337. if data is not None:
  338. break
  339. if data is None:
  340. title = "NanoBanana 生成失败" if generate_image else "NanoBanana 提取失败"
  341. return ToolResult(
  342. title=title,
  343. output=json.dumps({"attempted_models": candidates, "errors": errors}, ensure_ascii=False, indent=2),
  344. long_term_memory="All candidate models failed for this request",
  345. metadata={"attempted_models": candidates, "errors": errors},
  346. )
  347. chosen_model = used_model or candidates[0]
  348. choices = data.get("choices") or []
  349. message = choices[0].get("message", {}) if choices else {}
  350. # 图像生成分支
  351. if generate_image:
  352. refs = _extract_image_refs(choices[0] if choices else {}, message)
  353. if not refs:
  354. content = message.get("content")
  355. preview = ""
  356. if isinstance(content, str):
  357. preview = content[:500]
  358. elif isinstance(content, list):
  359. preview = json.dumps(content[:3], ensure_ascii=False)[:500]
  360. return ToolResult(
  361. title="NanoBanana 生成失败",
  362. output=json.dumps(data, ensure_ascii=False, indent=2),
  363. error="模型未返回可解析图片(未在 message.images/choice.images/content 中发现图片)",
  364. metadata={
  365. "model": chosen_model,
  366. "choice_keys": list((choices[0] if choices else {}).keys()),
  367. "message_keys": list(message.keys()) if isinstance(message, dict) else [],
  368. "content_preview": preview,
  369. },
  370. )
  371. output_paths: List[str] = []
  372. if image_output_path:
  373. base_path = Path(image_output_path)
  374. else:
  375. if len(input_paths) > 1:
  376. base_path = input_paths[0].parent / "set_generated.png"
  377. else:
  378. base_path = input_paths[0].parent / f"{input_paths[0].stem}_generated.png"
  379. base_path.parent.mkdir(parents=True, exist_ok=True)
  380. for idx, ref in enumerate(refs):
  381. kind = ref.get("kind", "")
  382. mime_type = "image/png"
  383. raw_bytes: Optional[bytes] = None
  384. if kind == "data_url":
  385. m = re.match(r"^data:([^;]+);base64,(.+)$", ref.get("value", ""), flags=re.DOTALL)
  386. if not m:
  387. continue
  388. mime_type = m.group(1)
  389. raw_bytes = base64.b64decode(m.group(2))
  390. elif kind == "base64":
  391. mime_type = ref.get("mime_type", "image/png")
  392. raw_bytes = base64.b64decode(ref.get("value", ""))
  393. elif kind == "url":
  394. url = ref.get("value", "")
  395. try:
  396. with httpx.Client(timeout=DEFAULT_TIMEOUT) as client:
  397. r = client.get(url)
  398. r.raise_for_status()
  399. raw_bytes = r.content
  400. mime_type = r.headers.get("content-type", "image/png").split(";")[0]
  401. except Exception:
  402. continue
  403. else:
  404. continue
  405. if not raw_bytes:
  406. continue
  407. ext = _mime_to_ext(mime_type)
  408. if len(refs) == 1:
  409. target = base_path
  410. if target.suffix.lower() not in [".png", ".jpg", ".jpeg", ".webp"]:
  411. target = target.with_suffix(ext)
  412. else:
  413. stem = base_path.stem
  414. target = base_path.with_name(f"{stem}_{idx+1}{ext}")
  415. try:
  416. target.write_bytes(raw_bytes)
  417. output_paths.append(str(target))
  418. except Exception as e:
  419. return ToolResult(
  420. title="NanoBanana 生成失败",
  421. output="",
  422. error=f"写入生成图片失败: {e}",
  423. metadata={"model": chosen_model},
  424. )
  425. if not output_paths:
  426. return ToolResult(
  427. title="NanoBanana 生成失败",
  428. output=json.dumps(data, ensure_ascii=False, indent=2),
  429. error="检测到图片引用但写入失败(可能是无效 base64 或 URL 不可访问)",
  430. metadata={"model": chosen_model, "ref_count": len(refs)},
  431. )
  432. usage = data.get("usage", {})
  433. prompt_tokens = usage.get("prompt_tokens") or usage.get("input_tokens", 0)
  434. completion_tokens = usage.get("completion_tokens") or usage.get("output_tokens", 0)
  435. summary = {
  436. "model": chosen_model,
  437. "input_images": [str(p) for p in input_paths],
  438. "input_count": len(input_paths),
  439. "generated_images": output_paths,
  440. "prompt_tokens": prompt_tokens,
  441. "completion_tokens": completion_tokens,
  442. }
  443. return ToolResult(
  444. title="NanoBanana 图片生成完成",
  445. output=json.dumps({"summary": summary}, ensure_ascii=False, indent=2),
  446. long_term_memory=f"Generated {len(output_paths)} image(s) from {len(input_paths)} input image(s) using {chosen_model}",
  447. attachments=output_paths,
  448. metadata=summary,
  449. )
  450. content = message.get("content") or ""
  451. if not content:
  452. return ToolResult(
  453. title="NanoBanana 提取失败",
  454. output=json.dumps(data, ensure_ascii=False, indent=2),
  455. error="模型未返回内容",
  456. )
  457. try:
  458. parsed = _safe_json_parse(content)
  459. except Exception as e:
  460. return ToolResult(
  461. title="NanoBanana 提取失败",
  462. output=content,
  463. error=f"模型返回非 JSON 内容,解析失败: {e}",
  464. metadata={"model": chosen_model},
  465. )
  466. if output_file:
  467. out_path = Path(output_file)
  468. else:
  469. if len(input_paths) > 1:
  470. out_path = input_paths[0].parent / "set_invariant_features.json"
  471. else:
  472. out_path = input_paths[0].parent / f"{input_paths[0].stem}_invariant_features.json"
  473. out_path.parent.mkdir(parents=True, exist_ok=True)
  474. out_path.write_text(json.dumps(parsed, ensure_ascii=False, indent=2), encoding="utf-8")
  475. usage = data.get("usage", {})
  476. prompt_tokens = usage.get("prompt_tokens") or usage.get("input_tokens", 0)
  477. completion_tokens = usage.get("completion_tokens") or usage.get("output_tokens", 0)
  478. summary = {
  479. "model": chosen_model,
  480. "input_images": [str(p) for p in input_paths],
  481. "input_count": len(input_paths),
  482. "output_file": str(out_path),
  483. "prompt_tokens": prompt_tokens,
  484. "completion_tokens": completion_tokens,
  485. }
  486. return ToolResult(
  487. title="NanoBanana 不变特征提取完成",
  488. output=json.dumps(
  489. {
  490. "summary": summary,
  491. "features": parsed,
  492. },
  493. ensure_ascii=False,
  494. indent=2,
  495. ),
  496. long_term_memory=f"Extracted invariant features from {len(input_paths)} input image(s) using {chosen_model}",
  497. attachments=[str(out_path)],
  498. metadata=summary,
  499. )