| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572 |
- """
- NanoBanana Tool - 图像特征提取与图像生成
- 该工具可以提取图片中的特征,也可以根据描述生成图片。
- 支持通过 OpenRouter 调用多模态模型,提取结构化的图像特征并保存为 JSON,
- 或基于输入图像生成新的图像。
- """
- import base64
- import json
- import mimetypes
- import os
- import re
- from pathlib import Path
- from typing import Optional, Dict, Any, List, Tuple
- import httpx
- from dotenv import load_dotenv
- from agent.tools import tool, ToolResult
- OPENROUTER_BASE_URL = "https://openrouter.ai/api/v1"
- DEFAULT_TIMEOUT = 120.0
- DEFAULT_EXTRACTION_PROMPT = (
- "请从这张图像中提取跨场景相对稳定、可复用的视觉不变特征。"
- "输出严格 JSON,字段包含:identity_features、pose_features、appearance_features、"
- "material_features、style_features、uncertainty、notes。"
- "每个字段给出简洁要点,避免臆测。"
- )
- DEFAULT_IMAGE_PROMPT = (
- "基于输入图像生成一张保留主体身份与关键视觉特征的新图像。"
- "保持人物核心特征一致,同时提升清晰度与可用性。"
- )
- DEFAULT_IMAGE_MODEL_CANDIDATES = [
- "google/gemini-2.5-flash-image",
- # "google/gemini-3-pro-image-preview",
- # "black-forest-labs/flux.2-flex",
- # "black-forest-labs/flux.2-pro",
- ]
- def _resolve_api_key() -> Optional[str]:
- """优先读取环境变量,缺失时尝试从 .env 加载。"""
- api_key = os.getenv("OPENROUTER_API_KEY") or os.getenv("OPEN_ROUTER_API_KEY")
- if api_key:
- return api_key
- load_dotenv()
- return os.getenv("OPENROUTER_API_KEY") or os.getenv("OPEN_ROUTER_API_KEY")
- def _image_to_data_url(image_path: Path) -> str:
- """将图片文件编码为 data URL。"""
- mime_type = mimetypes.guess_type(str(image_path))[0] or "application/octet-stream"
- raw = image_path.read_bytes()
- b64 = base64.b64encode(raw).decode("utf-8")
- return f"data:{mime_type};base64,{b64}"
- def _safe_json_parse(content: str) -> Dict[str, Any]:
- """尽量从模型文本中提取 JSON。"""
- try:
- return json.loads(content)
- except json.JSONDecodeError:
- start = content.find("{")
- end = content.rfind("}")
- if start != -1 and end != -1 and end > start:
- candidate = content[start:end + 1]
- return json.loads(candidate)
- raise
- def _extract_data_url_images(message: Dict[str, Any]) -> List[Tuple[str, str]]:
- """
- 从 OpenRouter 响应消息中提取 data URL 图片。
- Returns:
- List[(mime_type, base64_data)]
- """
- extracted: List[Tuple[str, str]] = []
- # 官方文档中的主要位置:message.images[]
- for img in message.get("images", []) or []:
- if not isinstance(img, dict):
- continue
- if img.get("type") != "image_url":
- continue
- data_url = ((img.get("image_url") or {}).get("url") or "").strip()
- if not data_url.startswith("data:"):
- continue
- m = re.match(r"^data:([^;]+);base64,(.+)$", data_url, flags=re.DOTALL)
- if not m:
- continue
- extracted.append((m.group(1), m.group(2)))
- # 兼容某些模型可能把 image_url 放在 content 数组中
- content = message.get("content")
- if isinstance(content, list):
- for part in content:
- if not isinstance(part, dict):
- continue
- if part.get("type") != "image_url":
- continue
- data_url = ((part.get("image_url") or {}).get("url") or "").strip()
- if not data_url.startswith("data:"):
- continue
- m = re.match(r"^data:([^;]+);base64,(.+)$", data_url, flags=re.DOTALL)
- if not m:
- continue
- extracted.append((m.group(1), m.group(2)))
- return extracted
- def _extract_image_refs(choice: Dict[str, Any], message: Dict[str, Any]) -> List[Dict[str, str]]:
- """
- 尝试从不同响应格式中提取图片引用。
- 返回格式:
- - {"kind": "data_url", "value": "data:image/png;base64,..."}
- - {"kind": "base64", "value": "...", "mime_type": "image/png"}
- - {"kind": "url", "value": "https://..."}
- """
- refs: List[Dict[str, str]] = []
- # 1) 标准 message.images
- for img in message.get("images", []) or []:
- if not isinstance(img, dict):
- continue
- # image_url 结构
- data_url = ((img.get("image_url") or {}).get("url") or "").strip()
- if data_url.startswith("data:"):
- refs.append({"kind": "data_url", "value": data_url})
- continue
- if data_url.startswith("http"):
- refs.append({"kind": "url", "value": data_url})
- continue
- # 兼容 base64 字段
- b64 = (img.get("b64_json") or img.get("base64") or "").strip()
- if b64:
- refs.append({"kind": "base64", "value": b64, "mime_type": img.get("mime_type", "image/png")})
- # 2) 某些格式可能在 choice.images
- for img in choice.get("images", []) or []:
- if not isinstance(img, dict):
- continue
- data_url = ((img.get("image_url") or {}).get("url") or "").strip()
- if data_url.startswith("data:"):
- refs.append({"kind": "data_url", "value": data_url})
- continue
- if data_url.startswith("http"):
- refs.append({"kind": "url", "value": data_url})
- continue
- b64 = (img.get("b64_json") or img.get("base64") or "").strip()
- if b64:
- refs.append({"kind": "base64", "value": b64, "mime_type": img.get("mime_type", "image/png")})
- # 3) content 数组里的 image_url
- content = message.get("content")
- if isinstance(content, list):
- for part in content:
- if not isinstance(part, dict):
- continue
- if part.get("type") != "image_url":
- continue
- url = ((part.get("image_url") or {}).get("url") or "").strip()
- if url.startswith("data:"):
- refs.append({"kind": "data_url", "value": url})
- elif url.startswith("http"):
- refs.append({"kind": "url", "value": url})
- # 4) 极端兼容:文本中可能出现 data:image 或 http 图片 URL
- if isinstance(content, str):
- # data URL
- for m in re.finditer(r"(data:image\/[a-zA-Z0-9.+-]+;base64,[A-Za-z0-9+/=]+)", content):
- refs.append({"kind": "data_url", "value": m.group(1)})
- # http(s) 图片链接
- for m in re.finditer(r"(https?://\S+\.(?:png|jpg|jpeg|webp))", content, flags=re.IGNORECASE):
- refs.append({"kind": "url", "value": m.group(1)})
- return refs
- def _mime_to_ext(mime_type: str) -> str:
- """MIME 类型映射到扩展名。"""
- mapping = {
- "image/png": ".png",
- "image/jpeg": ".jpg",
- "image/webp": ".webp",
- }
- return mapping.get(mime_type.lower(), ".png")
- def _normalize_model_id(model_id: str) -> str:
- """
- 规范化常见误写模型 ID,减少无效重试。
- """
- if not model_id:
- return model_id
- m = model_id.strip()
- # 常见误写:gemini/gemini-xxx -> google/gemini-xxx
- if m.startswith("gemini/"):
- m = "google/" + m.split("/", 1)[1]
- # 常见顺序误写:preview-image -> image
- if "gemini-2.5-flash-preview-image" in m:
- m = m.replace("gemini-2.5-flash-preview-image", "gemini-2.5-flash-image")
- # 兼容旧 ID 到当前可用 ID
- if "gemini-2.5-flash-image-preview" in m:
- m = m.replace("gemini-2.5-flash-image-preview", "gemini-2.5-flash-image")
- return m
- @tool(description="可以提取图片中的特征,也可以根据描述生成图片")
- async def nanobanana(
- image_path: str = "",
- image_paths: Optional[List[str]] = None,
- output_file: Optional[str] = None,
- prompt: Optional[str] = None,
- model: Optional[str] = None,
- max_tokens: int = 1200,
- generate_image: bool = False,
- image_output_path: Optional[str] = None,
- ) -> ToolResult:
- """
- 可以提取图片中的特征,也可以根据描述生成图片。
- Args:
- image_path: 输入图片路径(单图模式,可选)
- image_paths: 输入图片路径列表(多图整体模式,可选)
- output_file: 输出 JSON 文件路径(可选,用于特征提取模式)
- prompt: 自定义提取指令或生成描述(可选)
- model: OpenRouter 模型名(可选,默认读取 NANOBANANA_MODEL 或使用 Gemini 视觉模型)
- max_tokens: 最大输出 token
- generate_image: 是否生成图片(False=提取特征,True=生成图片)
- image_output_path: 生成图片保存路径(generate_image=True 时可选)
- Returns:
- ToolResult: 包含结构化特征和输出文件路径,或生成的图片路径
- """
- raw_paths: List[str] = []
- if image_paths:
- raw_paths.extend(image_paths)
- if image_path:
- raw_paths.append(image_path)
- if not raw_paths:
- return ToolResult(
- title="NanoBanana 提取失败",
- output="",
- error="未提供输入图片,请传入 image_path 或 image_paths",
- )
- # 去重并检查路径
- unique_raw: List[str] = []
- seen = set()
- for p in raw_paths:
- if p and p not in seen:
- unique_raw.append(p)
- seen.add(p)
- input_paths: List[Path] = [Path(p) for p in unique_raw]
- invalid = [str(p) for p in input_paths if (not p.exists() or not p.is_file())]
- if invalid:
- return ToolResult(
- title="NanoBanana 提取失败",
- output="",
- error=f"以下图片不存在或不可读: {invalid}",
- )
- api_key = _resolve_api_key()
- if not api_key:
- return ToolResult(
- title="NanoBanana 提取失败",
- output="",
- error="未找到 OpenRouter API Key,请设置 OPENROUTER_API_KEY 或 OPEN_ROUTER_API_KEY",
- )
- if generate_image:
- user_prompt = prompt or DEFAULT_IMAGE_PROMPT
- else:
- chosen_model = model or os.getenv("NANOBANANA_MODEL") or "google/gemini-2.5-flash"
- user_prompt = prompt or DEFAULT_EXTRACTION_PROMPT
- try:
- image_data_urls = [_image_to_data_url(p) for p in input_paths]
- except Exception as e:
- return ToolResult(
- title="NanoBanana 提取失败",
- output="",
- error=f"图片编码失败: {e}",
- )
- user_content: List[Dict[str, Any]] = [{"type": "text", "text": user_prompt}]
- for u in image_data_urls:
- user_content.append({"type": "image_url", "image_url": {"url": u}})
- payload: Dict[str, Any] = {
- "messages": [
- {
- "role": "system",
- "content": (
- "你是视觉助手。"
- "当任务为特征提取时输出 JSON 对象,不要输出 markdown。"
- "当任务为图像生成时请返回图像。"
- ),
- },
- {
- "role": "user",
- "content": user_content,
- },
- ],
- "temperature": 0.2,
- "max_tokens": max_tokens,
- }
- if generate_image:
- payload["modalities"] = ["image", "text"]
- headers = {
- "Authorization": f"Bearer {api_key}",
- "Content-Type": "application/json",
- "HTTP-Referer": "https://local-agent",
- "X-Title": "Agent NanoBanana Tool",
- }
- endpoint = f"{OPENROUTER_BASE_URL}/chat/completions"
- # 图像生成模式:自动尝试多个可用模型,减少 404/invalid model 影响
- if generate_image:
- candidates: List[str] = []
- if model:
- candidates.append(_normalize_model_id(model))
- if env_model := os.getenv("NANOBANANA_IMAGE_MODEL"):
- candidates.append(_normalize_model_id(env_model))
- candidates.extend([_normalize_model_id(x) for x in DEFAULT_IMAGE_MODEL_CANDIDATES])
- # 去重并保持顺序
- dedup: List[str] = []
- seen = set()
- for m in candidates:
- if m and m not in seen:
- dedup.append(m)
- seen.add(m)
- candidates = dedup
- else:
- candidates = [chosen_model]
- data: Optional[Dict[str, Any]] = None
- used_model: Optional[str] = None
- errors: List[Dict[str, Any]] = []
- for cand in candidates:
- modality_attempts: List[Optional[List[str]]] = [None]
- if generate_image:
- modality_attempts = [["image", "text"], ["image"], None]
- for mods in modality_attempts:
- trial_payload = dict(payload)
- trial_payload["model"] = cand
- if mods is None:
- trial_payload.pop("modalities", None)
- else:
- trial_payload["modalities"] = mods
- try:
- async with httpx.AsyncClient(timeout=DEFAULT_TIMEOUT) as client:
- resp = await client.post(endpoint, json=trial_payload, headers=headers)
- resp.raise_for_status()
- data = resp.json()
- used_model = cand
- break
- except httpx.HTTPStatusError as e:
- errors.append({
- "model": cand,
- "modalities": mods,
- "status_code": e.response.status_code,
- "body": e.response.text[:600],
- })
- continue
- except Exception as e:
- errors.append({
- "model": cand,
- "modalities": mods,
- "status_code": None,
- "body": str(e)[:600],
- })
- continue
- if data is not None:
- break
- if data is None:
- title = "NanoBanana 生成失败" if generate_image else "NanoBanana 提取失败"
- return ToolResult(
- title=title,
- output=json.dumps({"attempted_models": candidates, "errors": errors}, ensure_ascii=False, indent=2),
- long_term_memory="All candidate models failed for this request",
- metadata={"attempted_models": candidates, "errors": errors},
- )
- chosen_model = used_model or candidates[0]
- choices = data.get("choices") or []
- message = choices[0].get("message", {}) if choices else {}
- # 图像生成分支
- if generate_image:
- refs = _extract_image_refs(choices[0] if choices else {}, message)
- if not refs:
- content = message.get("content")
- preview = ""
- if isinstance(content, str):
- preview = content[:500]
- elif isinstance(content, list):
- preview = json.dumps(content[:3], ensure_ascii=False)[:500]
- return ToolResult(
- title="NanoBanana 生成失败",
- output=json.dumps(data, ensure_ascii=False, indent=2),
- error="模型未返回可解析图片(未在 message.images/choice.images/content 中发现图片)",
- metadata={
- "model": chosen_model,
- "choice_keys": list((choices[0] if choices else {}).keys()),
- "message_keys": list(message.keys()) if isinstance(message, dict) else [],
- "content_preview": preview,
- },
- )
- output_paths: List[str] = []
- if image_output_path:
- base_path = Path(image_output_path)
- else:
- if len(input_paths) > 1:
- base_path = input_paths[0].parent / "set_generated.png"
- else:
- base_path = input_paths[0].parent / f"{input_paths[0].stem}_generated.png"
- base_path.parent.mkdir(parents=True, exist_ok=True)
- for idx, ref in enumerate(refs):
- kind = ref.get("kind", "")
- mime_type = "image/png"
- raw_bytes: Optional[bytes] = None
- if kind == "data_url":
- m = re.match(r"^data:([^;]+);base64,(.+)$", ref.get("value", ""), flags=re.DOTALL)
- if not m:
- continue
- mime_type = m.group(1)
- raw_bytes = base64.b64decode(m.group(2))
- elif kind == "base64":
- mime_type = ref.get("mime_type", "image/png")
- raw_bytes = base64.b64decode(ref.get("value", ""))
- elif kind == "url":
- url = ref.get("value", "")
- try:
- with httpx.Client(timeout=DEFAULT_TIMEOUT) as client:
- r = client.get(url)
- r.raise_for_status()
- raw_bytes = r.content
- mime_type = r.headers.get("content-type", "image/png").split(";")[0]
- except Exception:
- continue
- else:
- continue
- if not raw_bytes:
- continue
- ext = _mime_to_ext(mime_type)
- if len(refs) == 1:
- target = base_path
- if target.suffix.lower() not in [".png", ".jpg", ".jpeg", ".webp"]:
- target = target.with_suffix(ext)
- else:
- stem = base_path.stem
- target = base_path.with_name(f"{stem}_{idx+1}{ext}")
- try:
- target.write_bytes(raw_bytes)
- output_paths.append(str(target))
- except Exception as e:
- return ToolResult(
- title="NanoBanana 生成失败",
- output="",
- error=f"写入生成图片失败: {e}",
- metadata={"model": chosen_model},
- )
- if not output_paths:
- return ToolResult(
- title="NanoBanana 生成失败",
- output=json.dumps(data, ensure_ascii=False, indent=2),
- error="检测到图片引用但写入失败(可能是无效 base64 或 URL 不可访问)",
- metadata={"model": chosen_model, "ref_count": len(refs)},
- )
- usage = data.get("usage", {})
- prompt_tokens = usage.get("prompt_tokens") or usage.get("input_tokens", 0)
- completion_tokens = usage.get("completion_tokens") or usage.get("output_tokens", 0)
- summary = {
- "model": chosen_model,
- "input_images": [str(p) for p in input_paths],
- "input_count": len(input_paths),
- "generated_images": output_paths,
- "prompt_tokens": prompt_tokens,
- "completion_tokens": completion_tokens,
- }
- return ToolResult(
- title="NanoBanana 图片生成完成",
- output=json.dumps({"summary": summary}, ensure_ascii=False, indent=2),
- long_term_memory=f"Generated {len(output_paths)} image(s) from {len(input_paths)} input image(s) using {chosen_model}",
- attachments=output_paths,
- metadata=summary,
- )
- content = message.get("content") or ""
- if not content:
- return ToolResult(
- title="NanoBanana 提取失败",
- output=json.dumps(data, ensure_ascii=False, indent=2),
- error="模型未返回内容",
- )
- try:
- parsed = _safe_json_parse(content)
- except Exception as e:
- return ToolResult(
- title="NanoBanana 提取失败",
- output=content,
- error=f"模型返回非 JSON 内容,解析失败: {e}",
- metadata={"model": chosen_model},
- )
- if output_file:
- out_path = Path(output_file)
- else:
- if len(input_paths) > 1:
- out_path = input_paths[0].parent / "set_invariant_features.json"
- else:
- out_path = input_paths[0].parent / f"{input_paths[0].stem}_invariant_features.json"
- out_path.parent.mkdir(parents=True, exist_ok=True)
- out_path.write_text(json.dumps(parsed, ensure_ascii=False, indent=2), encoding="utf-8")
- usage = data.get("usage", {})
- prompt_tokens = usage.get("prompt_tokens") or usage.get("input_tokens", 0)
- completion_tokens = usage.get("completion_tokens") or usage.get("output_tokens", 0)
- summary = {
- "model": chosen_model,
- "input_images": [str(p) for p in input_paths],
- "input_count": len(input_paths),
- "output_file": str(out_path),
- "prompt_tokens": prompt_tokens,
- "completion_tokens": completion_tokens,
- }
- return ToolResult(
- title="NanoBanana 不变特征提取完成",
- output=json.dumps(
- {
- "summary": summary,
- "features": parsed,
- },
- ensure_ascii=False,
- indent=2,
- ),
- long_term_memory=f"Extracted invariant features from {len(input_paths)} input image(s) using {chosen_model}",
- attachments=[str(out_path)],
- metadata=summary,
- )
|