瀏覽代碼

新增nano_banana测试

kevin.yang 4 天之前
父節點
當前提交
d00f8573b6
共有 3 個文件被更改,包括 196 次插入1 次删除
  1. 二進制
      tests/output/nano_banana_result.png
  2. 190 1
      tests/test_router_api.py
  3. 6 0
      tools/local/nano_banana/tests/server.log

二進制
tests/output/nano_banana_result.png


+ 190 - 1
tests/test_router_api.py

@@ -7,6 +7,7 @@
     uv run python tests/test_router_api.py --status                      # 工具运行状态
     uv run python tests/test_router_api.py --select image_stitcher       # POST /run_tool 调用工具
     uv run python tests/test_router_api.py --stitch                      # 测试图片拼接
+    uv run python tests/test_router_api.py --nano_banana                 # 测试 nano_banana
     uv run python tests/test_router_api.py --create                      # 默认任务
     uv run python tests/test_router_api.py --create image_stitcher       # 指定任务文件
     uv run python tests/test_router_api.py --launch-env                  # 创建 RunComfy 启动环境工具
@@ -17,6 +18,8 @@
 import argparse
 import base64
 import json
+import os
+import re
 import sys
 import time
 from pathlib import Path
@@ -24,7 +27,7 @@ from typing import Any
 
 import httpx
 
-BASE_URL = "http://127.0.0.1:8001"
+BASE_URL = os.environ.get("TOOL_AGENT_ROUTER_URL", "http://127.0.0.1:8001")
 TASKS_DIR = Path(__file__).parent / "tasks"
 TEST_IMAGES_DIR = TASKS_DIR / "stitcher_images"
 OUTPUT_DIR = Path(__file__).parent / "output"
@@ -195,6 +198,185 @@ def test_stitch_images():
         print("  [FAIL]")
 
 
+def _nano_has_image(data: dict[str, Any]) -> bool:
+    if data.get("images"):
+        return True
+    img = data.get("image")
+    if isinstance(img, str) and len(img) > 100:
+        return True
+    if data.get("image_base64"):
+        return True
+    cands = data.get("candidates")
+    if isinstance(cands, list) and cands:
+        parts = cands[0].get("content", {}).get("parts", [])
+        for p in parts:
+            if isinstance(p, dict) and (p.get("inlineData") or p.get("inline_data")):
+                return True
+    return False
+
+
+_NANO_DATA_URL_RE = re.compile(r"^data:([^;]+);base64,(.+)$", re.I | re.S)
+
+
+def _nano_mime_to_ext(mime: str) -> str:
+    base = mime.lower().split(";")[0].strip()
+    if base == "image/png":
+        return "png"
+    if base in ("image/jpeg", "image/jpg"):
+        return "jpg"
+    if base == "image/webp":
+        return "webp"
+    return "png"
+
+
+def _nano_collect_image_bytes(result: dict[str, Any]) -> list[tuple[bytes, str]]:
+    """从 nano_banana 常见返回结构解析出 (raw_bytes, ext) 列表。"""
+    out: list[tuple[bytes, str]] = []
+    imgs = result.get("images")
+    if isinstance(imgs, list):
+        for item in imgs:
+            if not isinstance(item, str) or not item.strip():
+                continue
+            s = item.strip()
+            m = _NANO_DATA_URL_RE.match(s)
+            if m:
+                mime, b64 = m.group(1), m.group(2)
+                try:
+                    out.append((base64.b64decode(b64), _nano_mime_to_ext(mime)))
+                except Exception:
+                    continue
+            else:
+                try:
+                    out.append((base64.b64decode(s), "png"))
+                except Exception:
+                    continue
+    img_one = result.get("image")
+    if not out and isinstance(img_one, str) and len(img_one) > 100:
+        try:
+            out.append((base64.b64decode(img_one), "png"))
+        except Exception:
+            pass
+    b64_field = result.get("image_base64")
+    if not out and isinstance(b64_field, str) and b64_field.strip():
+        try:
+            out.append((base64.b64decode(b64_field.strip()), "png"))
+        except Exception:
+            pass
+    cands = result.get("candidates")
+    if not out and isinstance(cands, list) and cands:
+        cand0 = cands[0]
+        if isinstance(cand0, dict):
+            for p in cand0.get("content", {}).get("parts", []) or []:
+                if not isinstance(p, dict):
+                    continue
+                inline = p.get("inlineData") or p.get("inline_data")
+                if not isinstance(inline, dict):
+                    continue
+                b64 = inline.get("data")
+                if not b64:
+                    continue
+                mime = str(
+                    inline.get("mimeType") or inline.get("mime_type") or "image/png"
+                )
+                try:
+                    out.append((base64.b64decode(b64), _nano_mime_to_ext(mime)))
+                except Exception:
+                    continue
+                break
+    return out
+
+
+def _nano_save_images_default(result: dict[str, Any]) -> list[Path]:
+    """默认写入 tests/output/nano_banana_result[_{n}].{ext},返回已写入路径。"""
+    blobs = _nano_collect_image_bytes(result)
+    if not blobs:
+        return []
+    OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
+    paths: list[Path] = []
+    if len(blobs) == 1:
+        ext = blobs[0][1]
+        p = OUTPUT_DIR / f"nano_banana_result.{ext}"
+        p.write_bytes(blobs[0][0])
+        paths.append(p)
+    else:
+        for i, (raw, ext) in enumerate(blobs):
+            p = OUTPUT_DIR / f"nano_banana_result_{i}.{ext}"
+            p.write_bytes(raw)
+            paths.append(p)
+    return paths
+
+
+def test_nano_banana():
+    """POST /run_tool → nano_banana;依赖 tools/local/nano_banana/.env 中 GEMINI_API_KEY。"""
+    print("=== Test nano_banana (Gemini 图模) ===")
+    print("  需: tools/local/nano_banana/.env → GEMINI_API_KEY")
+    print("  可选环境变量: NANO_BANANA_TEST_PROMPT, NANO_BANANA_MODEL")
+    print("  通过时默认保存图片到 tests/output/nano_banana_result*.png(或多张时带序号)")
+
+    tid = os.environ.get("NANO_BANANA_TOOL_ID", "nano_banana")
+    try:
+        tr = httpx.get(f"{BASE_URL}/tools", timeout=30)
+        tr.raise_for_status()
+        ids = {t["tool_id"] for t in tr.json().get("tools", [])}
+        if tid not in ids:
+            print(f"  ERROR : 注册表中无 {tid!r},请先检查 data/registry.json")
+            print("  [FAIL]")
+            return
+        print(f"  tool_id: {tid} (已注册)")
+    except Exception as e:
+        print(f"  ERROR : GET /tools 失败: {e}")
+        print("  [FAIL]")
+        return
+
+    prompt = os.environ.get(
+        "NANO_BANANA_TEST_PROMPT",
+        "A minimal flat yellow banana icon on white background, no text",
+    )
+    params: dict[str, Any] = {"prompt": prompt}
+    model = os.environ.get("NANO_BANANA_MODEL", "").strip()
+    if model:
+        params["model"] = model
+        print(f"  model: {model}")
+    else:
+        print("  model: (使用工具默认 / GEMINI_IMAGE_MODEL)")
+
+    print(f"  calling {tid} ...")
+    try:
+        ok, err, result = _run_tool(tid, params, timeout=180.0)
+        if not ok:
+            print(f"  ERROR : {err}")
+            print("  [FAIL]")
+            return
+        if not isinstance(result, dict):
+            print(f"  ERROR : 非 dict 结果: {type(result)}")
+            print("  [FAIL]")
+            return
+        if _nano_has_image(result):
+            n = len(result["images"]) if isinstance(result.get("images"), list) else 0
+            print(f"  Result : 含图片字段 (images 条数≈{n})")
+            if result.get("model"):
+                print(f"    model: {result['model']}")
+            saved = _nano_save_images_default(result)
+            if saved:
+                for sp in saved:
+                    print(f"    saved : {sp}")
+            else:
+                print(
+                    "    WARN : 未能从响应解析出图片字节(字段存在但无法 base64 解码)"
+                )
+            print("  [PASS]")
+            return
+        print(f"  ERROR : 未识别到图片字段,keys={list(result.keys())}")
+        print(f"  截断: {str(result)[:400]}...")
+        print("  [FAIL]")
+    except httpx.TimeoutException:
+        print("  ERROR : Request timeout")
+        print("  [FAIL]")
+    except Exception as e:
+        print(f"  ERROR : {e}")
+        print("  [FAIL]")
+
+
 def load_task_spec(task_name: str) -> dict:
     task_file = TASKS_DIR / f"{task_name}.json"
     if not task_file.exists():
@@ -265,6 +447,8 @@ def main():
                         help="call a tool by tool_id")
     parser.add_argument("--stitch", action="store_true",
                         help="test image stitcher with sample images")
+    parser.add_argument("--nano_banana", action="store_true",
+                        help="test nano_banana (Gemini); need GEMINI_API_KEY in tools/local/nano_banana/.env")
     parser.add_argument("--create", nargs="?", const="", metavar="TASK_NAME",
                         help="create tool, optional task file name")
     parser.add_argument("--launch-env", action="store_true",
@@ -303,6 +487,11 @@ def main():
         test_stitch_images()
         ran_any = True
 
+    if args.nano_banana:
+        print()
+        test_nano_banana()
+        ran_any = True
+
     if args.create is not None:
         print()
         test_create_tool(args.create or None)

+ 6 - 0
tools/local/nano_banana/tests/server.log

@@ -0,0 +1,6 @@
+INFO:     Started server process [44982]
+INFO:     Waiting for application startup.
+INFO:     Application startup complete.
+INFO:     Uvicorn running on http://0.0.0.0:57891 (Press CTRL+C to quit)
+INFO:     127.0.0.1:57895 - "POST /generate HTTP/1.1" 200 OK
+INFO:     127.0.0.1:58843 - "POST /generate HTTP/1.1" 200 OK