elksmmx 1 неделя назад
Родитель
Сommit
0910d419af
2 измененных файлов с 573 добавлено и 0 удалено
  1. 2 0
      agent/tools/builtin/__init__.py
  2. 571 0
      agent/tools/builtin/nanobanana.py

+ 2 - 0
agent/tools/builtin/__init__.py

@@ -16,6 +16,7 @@ from agent.tools.builtin.bash import bash_command
 from agent.tools.builtin.skill import skill, list_skills
 from agent.tools.builtin.subagent import agent, evaluate
 from agent.tools.builtin.search import search_posts, get_search_suggestions
+from agent.tools.builtin.nanobanana import nanobanana_extract_features
 from agent.tools.builtin.sandbox import (sandbox_create_environment, sandbox_run_shell,
                                          sandbox_rebuild_with_ports,sandbox_destroy_environment)
 
@@ -39,6 +40,7 @@ __all__ = [
     "evaluate",
     "search_posts",
     "get_search_suggestions",
+    "nanobanana_extract_features",
     "sandbox_create_environment",
     "sandbox_run_shell",
     "sandbox_rebuild_with_ports",

+ 571 - 0
agent/tools/builtin/nanobanana.py

@@ -0,0 +1,571 @@
+"""
+NanoBanana Tool - 通过 OpenRouter 提取图像不变特征
+
+该工具将单张图像发送给多模态模型(可配置为 NanoBanana 对应模型),
+返回结构化的不变特征,并将结果保存为 JSON 文件。
+"""
+
+import base64
+import json
+import mimetypes
+import os
+import re
+from pathlib import Path
+from typing import Optional, Dict, Any, List, Tuple
+
+import httpx
+from dotenv import load_dotenv
+
+from agent.tools import tool, ToolResult
+
+OPENROUTER_BASE_URL = "https://openrouter.ai/api/v1"
+DEFAULT_TIMEOUT = 120.0
+
+DEFAULT_EXTRACTION_PROMPT = (
+    "请从这张图像中提取跨场景相对稳定、可复用的视觉不变特征。"
+    "输出严格 JSON,字段包含:identity_features、pose_features、appearance_features、"
+    "material_features、style_features、uncertainty、notes。"
+    "每个字段给出简洁要点,避免臆测。"
+)
+
+DEFAULT_IMAGE_PROMPT = (
+    "基于输入图像生成一张保留主体身份与关键视觉特征的新图像。"
+    "保持人物核心特征一致,同时提升清晰度与可用性。"
+)
+
+DEFAULT_IMAGE_MODEL_CANDIDATES = [
+    "google/gemini-2.5-flash-image",
+    "google/gemini-3-pro-image-preview",
+    "black-forest-labs/flux.2-flex",
+    "black-forest-labs/flux.2-pro",
+]
+
+
+def _resolve_api_key() -> Optional[str]:
+    """优先读取环境变量,缺失时尝试从 .env 加载。"""
+    api_key = os.getenv("OPENROUTER_API_KEY") or os.getenv("OPEN_ROUTER_API_KEY")
+    if api_key:
+        return api_key
+
+    load_dotenv()
+    return os.getenv("OPENROUTER_API_KEY") or os.getenv("OPEN_ROUTER_API_KEY")
+
+
+def _image_to_data_url(image_path: Path) -> str:
+    """将图片文件编码为 data URL。"""
+    mime_type = mimetypes.guess_type(str(image_path))[0] or "application/octet-stream"
+    raw = image_path.read_bytes()
+    b64 = base64.b64encode(raw).decode("utf-8")
+    return f"data:{mime_type};base64,{b64}"
+
+
+def _safe_json_parse(content: str) -> Dict[str, Any]:
+    """尽量从模型文本中提取 JSON。"""
+    try:
+        return json.loads(content)
+    except json.JSONDecodeError:
+        start = content.find("{")
+        end = content.rfind("}")
+        if start != -1 and end != -1 and end > start:
+            candidate = content[start:end + 1]
+            return json.loads(candidate)
+        raise
+
+
+def _extract_data_url_images(message: Dict[str, Any]) -> List[Tuple[str, str]]:
+    """
+    从 OpenRouter 响应消息中提取 data URL 图片。
+
+    Returns:
+        List[(mime_type, base64_data)]
+    """
+    extracted: List[Tuple[str, str]] = []
+
+    # 官方文档中的主要位置:message.images[]
+    for img in message.get("images", []) or []:
+        if not isinstance(img, dict):
+            continue
+        if img.get("type") != "image_url":
+            continue
+        data_url = ((img.get("image_url") or {}).get("url") or "").strip()
+        if not data_url.startswith("data:"):
+            continue
+        m = re.match(r"^data:([^;]+);base64,(.+)$", data_url, flags=re.DOTALL)
+        if not m:
+            continue
+        extracted.append((m.group(1), m.group(2)))
+
+    # 兼容某些模型可能把 image_url 放在 content 数组中
+    content = message.get("content")
+    if isinstance(content, list):
+        for part in content:
+            if not isinstance(part, dict):
+                continue
+            if part.get("type") != "image_url":
+                continue
+            data_url = ((part.get("image_url") or {}).get("url") or "").strip()
+            if not data_url.startswith("data:"):
+                continue
+            m = re.match(r"^data:([^;]+);base64,(.+)$", data_url, flags=re.DOTALL)
+            if not m:
+                continue
+            extracted.append((m.group(1), m.group(2)))
+
+    return extracted
+
+
+def _extract_image_refs(choice: Dict[str, Any], message: Dict[str, Any]) -> List[Dict[str, str]]:
+    """
+    尝试从不同响应格式中提取图片引用。
+
+    返回格式:
+    - {"kind": "data_url", "value": "data:image/png;base64,..."}
+    - {"kind": "base64", "value": "...", "mime_type": "image/png"}
+    - {"kind": "url", "value": "https://..."}
+    """
+    refs: List[Dict[str, str]] = []
+
+    # 1) 标准 message.images
+    for img in message.get("images", []) or []:
+        if not isinstance(img, dict):
+            continue
+        # image_url 结构
+        data_url = ((img.get("image_url") or {}).get("url") or "").strip()
+        if data_url.startswith("data:"):
+            refs.append({"kind": "data_url", "value": data_url})
+            continue
+        if data_url.startswith("http"):
+            refs.append({"kind": "url", "value": data_url})
+            continue
+
+        # 兼容 base64 字段
+        b64 = (img.get("b64_json") or img.get("base64") or "").strip()
+        if b64:
+            refs.append({"kind": "base64", "value": b64, "mime_type": img.get("mime_type", "image/png")})
+
+    # 2) 某些格式可能在 choice.images
+    for img in choice.get("images", []) or []:
+        if not isinstance(img, dict):
+            continue
+        data_url = ((img.get("image_url") or {}).get("url") or "").strip()
+        if data_url.startswith("data:"):
+            refs.append({"kind": "data_url", "value": data_url})
+            continue
+        if data_url.startswith("http"):
+            refs.append({"kind": "url", "value": data_url})
+            continue
+        b64 = (img.get("b64_json") or img.get("base64") or "").strip()
+        if b64:
+            refs.append({"kind": "base64", "value": b64, "mime_type": img.get("mime_type", "image/png")})
+
+    # 3) content 数组里的 image_url
+    content = message.get("content")
+    if isinstance(content, list):
+        for part in content:
+            if not isinstance(part, dict):
+                continue
+            if part.get("type") != "image_url":
+                continue
+            url = ((part.get("image_url") or {}).get("url") or "").strip()
+            if url.startswith("data:"):
+                refs.append({"kind": "data_url", "value": url})
+            elif url.startswith("http"):
+                refs.append({"kind": "url", "value": url})
+
+    # 4) 极端兼容:文本中可能出现 data:image 或 http 图片 URL
+    if isinstance(content, str):
+        # data URL
+        for m in re.finditer(r"(data:image\/[a-zA-Z0-9.+-]+;base64,[A-Za-z0-9+/=]+)", content):
+            refs.append({"kind": "data_url", "value": m.group(1)})
+        # http(s) 图片链接
+        for m in re.finditer(r"(https?://\S+\.(?:png|jpg|jpeg|webp))", content, flags=re.IGNORECASE):
+            refs.append({"kind": "url", "value": m.group(1)})
+
+    return refs
+
+
+def _mime_to_ext(mime_type: str) -> str:
+    """MIME 类型映射到扩展名。"""
+    mapping = {
+        "image/png": ".png",
+        "image/jpeg": ".jpg",
+        "image/webp": ".webp",
+    }
+    return mapping.get(mime_type.lower(), ".png")
+
+
+def _normalize_model_id(model_id: str) -> str:
+    """
+    规范化常见误写模型 ID,减少无效重试。
+    """
+    if not model_id:
+        return model_id
+    m = model_id.strip()
+    # 常见误写:gemini/gemini-xxx -> google/gemini-xxx
+    if m.startswith("gemini/"):
+        m = "google/" + m.split("/", 1)[1]
+    # 常见顺序误写:preview-image -> image
+    if "gemini-2.5-flash-preview-image" in m:
+        m = m.replace("gemini-2.5-flash-preview-image", "gemini-2.5-flash-image")
+    # 兼容旧 ID 到当前可用 ID
+    if "gemini-2.5-flash-image-preview" in m:
+        m = m.replace("gemini-2.5-flash-image-preview", "gemini-2.5-flash-image")
+    return m
+
+
+@tool(description="使用 OpenRouter 多模态模型(NanoBanana 兼容)提取图像不变特征")
+async def nanobanana_extract_features(
+    image_path: str = "",
+    image_paths: Optional[List[str]] = None,
+    output_file: Optional[str] = None,
+    prompt: Optional[str] = None,
+    model: Optional[str] = None,
+    max_tokens: int = 1200,
+    generate_image: bool = False,
+    image_output_path: Optional[str] = None,
+) -> ToolResult:
+    """
+    提取图像中的不变特征并保存为 JSON。
+
+    Args:
+        image_path: 输入图片路径(单图模式,可选)
+        image_paths: 输入图片路径列表(多图整体模式,可选)
+        output_file: 输出 JSON 文件路径(可选)
+        prompt: 自定义提取指令(可选)
+        model: OpenRouter 模型名(可选,默认读取 NANOBANANA_MODEL 或使用 Gemini 视觉模型)
+        max_tokens: 最大输出 token
+        generate_image: 是否生成图片(False=提取JSON,True=生成图片)
+        image_output_path: 生成图片保存路径(generate_image=True 时可选)
+
+    Returns:
+        ToolResult: 包含结构化特征和输出文件路径
+    """
+    raw_paths: List[str] = []
+    if image_paths:
+        raw_paths.extend(image_paths)
+    if image_path:
+        raw_paths.append(image_path)
+    if not raw_paths:
+        return ToolResult(
+            title="NanoBanana 提取失败",
+            output="",
+            error="未提供输入图片,请传入 image_path 或 image_paths",
+        )
+
+    # 去重并检查路径
+    unique_raw: List[str] = []
+    seen = set()
+    for p in raw_paths:
+        if p and p not in seen:
+            unique_raw.append(p)
+            seen.add(p)
+
+    input_paths: List[Path] = [Path(p) for p in unique_raw]
+    invalid = [str(p) for p in input_paths if (not p.exists() or not p.is_file())]
+    if invalid:
+        return ToolResult(
+            title="NanoBanana 提取失败",
+            output="",
+            error=f"以下图片不存在或不可读: {invalid}",
+        )
+
+    api_key = _resolve_api_key()
+    if not api_key:
+        return ToolResult(
+            title="NanoBanana 提取失败",
+            output="",
+            error="未找到 OpenRouter API Key,请设置 OPENROUTER_API_KEY 或 OPEN_ROUTER_API_KEY",
+        )
+
+    if generate_image:
+        user_prompt = prompt or DEFAULT_IMAGE_PROMPT
+    else:
+        chosen_model = model or os.getenv("NANOBANANA_MODEL") or "google/gemini-2.5-flash"
+        user_prompt = prompt or DEFAULT_EXTRACTION_PROMPT
+
+    try:
+        image_data_urls = [_image_to_data_url(p) for p in input_paths]
+    except Exception as e:
+        return ToolResult(
+            title="NanoBanana 提取失败",
+            output="",
+            error=f"图片编码失败: {e}",
+        )
+
+    user_content: List[Dict[str, Any]] = [{"type": "text", "text": user_prompt}]
+    for u in image_data_urls:
+        user_content.append({"type": "image_url", "image_url": {"url": u}})
+
+    payload: Dict[str, Any] = {
+        "messages": [
+            {
+                "role": "system",
+                "content": (
+                    "你是视觉助手。"
+                    "当任务为特征提取时输出 JSON 对象,不要输出 markdown。"
+                    "当任务为图像生成时请返回图像。"
+                ),
+            },
+            {
+                "role": "user",
+                "content": user_content,
+            },
+        ],
+        "temperature": 0.2,
+        "max_tokens": max_tokens,
+    }
+    if generate_image:
+        payload["modalities"] = ["image", "text"]
+
+    headers = {
+        "Authorization": f"Bearer {api_key}",
+        "Content-Type": "application/json",
+        "HTTP-Referer": "https://local-agent",
+        "X-Title": "Agent NanoBanana Tool",
+    }
+
+    endpoint = f"{OPENROUTER_BASE_URL}/chat/completions"
+
+    # 图像生成模式:自动尝试多个可用模型,减少 404/invalid model 影响
+    if generate_image:
+        candidates: List[str] = []
+        if model:
+            candidates.append(_normalize_model_id(model))
+        if env_model := os.getenv("NANOBANANA_IMAGE_MODEL"):
+            candidates.append(_normalize_model_id(env_model))
+        candidates.extend([_normalize_model_id(x) for x in DEFAULT_IMAGE_MODEL_CANDIDATES])
+        # 去重并保持顺序
+        dedup: List[str] = []
+        seen = set()
+        for m in candidates:
+            if m and m not in seen:
+                dedup.append(m)
+                seen.add(m)
+        candidates = dedup
+    else:
+        candidates = [chosen_model]
+
+    data: Optional[Dict[str, Any]] = None
+    used_model: Optional[str] = None
+    errors: List[Dict[str, Any]] = []
+
+    for cand in candidates:
+        modality_attempts: List[Optional[List[str]]] = [None]
+        if generate_image:
+            modality_attempts = [["image", "text"], ["image"], None]
+
+        for mods in modality_attempts:
+            trial_payload = dict(payload)
+            trial_payload["model"] = cand
+
+            if mods is None:
+                trial_payload.pop("modalities", None)
+            else:
+                trial_payload["modalities"] = mods
+
+            try:
+                async with httpx.AsyncClient(timeout=DEFAULT_TIMEOUT) as client:
+                    resp = await client.post(endpoint, json=trial_payload, headers=headers)
+                    resp.raise_for_status()
+                    data = resp.json()
+                    used_model = cand
+                    break
+            except httpx.HTTPStatusError as e:
+                errors.append({
+                    "model": cand,
+                    "modalities": mods,
+                    "status_code": e.response.status_code,
+                    "body": e.response.text[:600],
+                })
+                continue
+            except Exception as e:
+                errors.append({
+                    "model": cand,
+                    "modalities": mods,
+                    "status_code": None,
+                    "body": str(e)[:600],
+                })
+                continue
+
+        if data is not None:
+            break
+
+    if data is None:
+        title = "NanoBanana 生成失败" if generate_image else "NanoBanana 提取失败"
+        return ToolResult(
+            title=title,
+            output=json.dumps({"attempted_models": candidates, "errors": errors}, ensure_ascii=False, indent=2),
+            long_term_memory="All candidate models failed for this request",
+            metadata={"attempted_models": candidates, "errors": errors},
+        )
+
+    chosen_model = used_model or candidates[0]
+
+    choices = data.get("choices") or []
+    message = choices[0].get("message", {}) if choices else {}
+
+    # 图像生成分支
+    if generate_image:
+        refs = _extract_image_refs(choices[0] if choices else {}, message)
+        if not refs:
+            content = message.get("content")
+            preview = ""
+            if isinstance(content, str):
+                preview = content[:500]
+            elif isinstance(content, list):
+                preview = json.dumps(content[:3], ensure_ascii=False)[:500]
+
+            return ToolResult(
+                title="NanoBanana 生成失败",
+                output=json.dumps(data, ensure_ascii=False, indent=2),
+                error="模型未返回可解析图片(未在 message.images/choice.images/content 中发现图片)",
+                metadata={
+                    "model": chosen_model,
+                    "choice_keys": list((choices[0] if choices else {}).keys()),
+                    "message_keys": list(message.keys()) if isinstance(message, dict) else [],
+                    "content_preview": preview,
+                },
+            )
+
+        output_paths: List[str] = []
+        if image_output_path:
+            base_path = Path(image_output_path)
+        else:
+            if len(input_paths) > 1:
+                base_path = input_paths[0].parent / "set_generated.png"
+            else:
+                base_path = input_paths[0].parent / f"{input_paths[0].stem}_generated.png"
+        base_path.parent.mkdir(parents=True, exist_ok=True)
+
+        for idx, ref in enumerate(refs):
+            kind = ref.get("kind", "")
+            mime_type = "image/png"
+            raw_bytes: Optional[bytes] = None
+
+            if kind == "data_url":
+                m = re.match(r"^data:([^;]+);base64,(.+)$", ref.get("value", ""), flags=re.DOTALL)
+                if not m:
+                    continue
+                mime_type = m.group(1)
+                raw_bytes = base64.b64decode(m.group(2))
+            elif kind == "base64":
+                mime_type = ref.get("mime_type", "image/png")
+                raw_bytes = base64.b64decode(ref.get("value", ""))
+            elif kind == "url":
+                url = ref.get("value", "")
+                try:
+                    with httpx.Client(timeout=DEFAULT_TIMEOUT) as client:
+                        r = client.get(url)
+                        r.raise_for_status()
+                        raw_bytes = r.content
+                        mime_type = r.headers.get("content-type", "image/png").split(";")[0]
+                except Exception:
+                    continue
+            else:
+                continue
+
+            if not raw_bytes:
+                continue
+
+            ext = _mime_to_ext(mime_type)
+            if len(refs) == 1:
+                target = base_path
+                if target.suffix.lower() not in [".png", ".jpg", ".jpeg", ".webp"]:
+                    target = target.with_suffix(ext)
+            else:
+                stem = base_path.stem
+                target = base_path.with_name(f"{stem}_{idx+1}{ext}")
+            try:
+                target.write_bytes(raw_bytes)
+                output_paths.append(str(target))
+            except Exception as e:
+                return ToolResult(
+                    title="NanoBanana 生成失败",
+                    output="",
+                    error=f"写入生成图片失败: {e}",
+                    metadata={"model": chosen_model},
+                )
+
+        if not output_paths:
+            return ToolResult(
+                title="NanoBanana 生成失败",
+                output=json.dumps(data, ensure_ascii=False, indent=2),
+                error="检测到图片引用但写入失败(可能是无效 base64 或 URL 不可访问)",
+                metadata={"model": chosen_model, "ref_count": len(refs)},
+            )
+
+        usage = data.get("usage", {})
+        prompt_tokens = usage.get("prompt_tokens") or usage.get("input_tokens", 0)
+        completion_tokens = usage.get("completion_tokens") or usage.get("output_tokens", 0)
+        summary = {
+            "model": chosen_model,
+            "input_images": [str(p) for p in input_paths],
+            "input_count": len(input_paths),
+            "generated_images": output_paths,
+            "prompt_tokens": prompt_tokens,
+            "completion_tokens": completion_tokens,
+        }
+        return ToolResult(
+            title="NanoBanana 图片生成完成",
+            output=json.dumps({"summary": summary}, ensure_ascii=False, indent=2),
+            long_term_memory=f"Generated {len(output_paths)} image(s) from {len(input_paths)} input image(s) using {chosen_model}",
+            attachments=output_paths,
+            metadata=summary,
+        )
+
+    content = message.get("content") or ""
+    if not content:
+        return ToolResult(
+            title="NanoBanana 提取失败",
+            output=json.dumps(data, ensure_ascii=False, indent=2),
+            error="模型未返回内容",
+        )
+
+    try:
+        parsed = _safe_json_parse(content)
+    except Exception as e:
+        return ToolResult(
+            title="NanoBanana 提取失败",
+            output=content,
+            error=f"模型返回非 JSON 内容,解析失败: {e}",
+            metadata={"model": chosen_model},
+        )
+
+    if output_file:
+        out_path = Path(output_file)
+    else:
+        if len(input_paths) > 1:
+            out_path = input_paths[0].parent / "set_invariant_features.json"
+        else:
+            out_path = input_paths[0].parent / f"{input_paths[0].stem}_invariant_features.json"
+
+    out_path.parent.mkdir(parents=True, exist_ok=True)
+    out_path.write_text(json.dumps(parsed, ensure_ascii=False, indent=2), encoding="utf-8")
+
+    usage = data.get("usage", {})
+    prompt_tokens = usage.get("prompt_tokens") or usage.get("input_tokens", 0)
+    completion_tokens = usage.get("completion_tokens") or usage.get("output_tokens", 0)
+
+    summary = {
+        "model": chosen_model,
+        "input_images": [str(p) for p in input_paths],
+        "input_count": len(input_paths),
+        "output_file": str(out_path),
+        "prompt_tokens": prompt_tokens,
+        "completion_tokens": completion_tokens,
+    }
+
+    return ToolResult(
+        title="NanoBanana 不变特征提取完成",
+        output=json.dumps(
+            {
+                "summary": summary,
+                "features": parsed,
+            },
+            ensure_ascii=False,
+            indent=2,
+        ),
+        long_term_memory=f"Extracted invariant features from {len(input_paths)} input image(s) using {chosen_model}",
+        attachments=[str(out_path)],
+        metadata=summary,
+    )