|
|
@@ -7,6 +7,7 @@
|
|
|
uv run python tests/test_router_api.py --status # 工具运行状态
|
|
|
uv run python tests/test_router_api.py --select image_stitcher # POST /run_tool 调用工具
|
|
|
uv run python tests/test_router_api.py --stitch # 测试图片拼接
|
|
|
+ uv run python tests/test_router_api.py --nano_banana # 测试 nano_banana
|
|
|
uv run python tests/test_router_api.py --create # 默认任务
|
|
|
uv run python tests/test_router_api.py --create image_stitcher # 指定任务文件
|
|
|
uv run python tests/test_router_api.py --launch-env # 创建 RunComfy 启动环境工具
|
|
|
@@ -17,6 +18,8 @@
|
|
|
import argparse
|
|
|
import base64
|
|
|
import json
|
|
|
+import os
|
|
|
+import re
|
|
|
import sys
|
|
|
import time
|
|
|
from pathlib import Path
|
|
|
@@ -24,7 +27,7 @@ from typing import Any
|
|
|
|
|
|
import httpx
|
|
|
|
|
|
-BASE_URL = "http://127.0.0.1:8001"
|
|
|
+BASE_URL = os.environ.get("TOOL_AGENT_ROUTER_URL", "http://127.0.0.1:8001")
|
|
|
TASKS_DIR = Path(__file__).parent / "tasks"
|
|
|
TEST_IMAGES_DIR = TASKS_DIR / "stitcher_images"
|
|
|
OUTPUT_DIR = Path(__file__).parent / "output"
|
|
|
@@ -195,6 +198,185 @@ def test_stitch_images():
|
|
|
print(" [FAIL]")
|
|
|
|
|
|
|
|
|
+def _nano_has_image(data: dict[str, Any]) -> bool:
|
|
|
+ if data.get("images"):
|
|
|
+ return True
|
|
|
+ img = data.get("image")
|
|
|
+ if isinstance(img, str) and len(img) > 100:
|
|
|
+ return True
|
|
|
+ if data.get("image_base64"):
|
|
|
+ return True
|
|
|
+ cands = data.get("candidates")
|
|
|
+ if isinstance(cands, list) and cands:
|
|
|
+ parts = cands[0].get("content", {}).get("parts", [])
|
|
|
+ for p in parts:
|
|
|
+ if isinstance(p, dict) and (p.get("inlineData") or p.get("inline_data")):
|
|
|
+ return True
|
|
|
+ return False
|
|
|
+
|
|
|
+
|
|
|
+_NANO_DATA_URL_RE = re.compile(r"^data:([^;]+);base64,(.+)$", re.I | re.S)
|
|
|
+
|
|
|
+
|
|
|
+def _nano_mime_to_ext(mime: str) -> str:
|
|
|
+ base = mime.lower().split(";")[0].strip()
|
|
|
+ if base == "image/png":
|
|
|
+ return "png"
|
|
|
+ if base in ("image/jpeg", "image/jpg"):
|
|
|
+ return "jpg"
|
|
|
+ if base == "image/webp":
|
|
|
+ return "webp"
|
|
|
+ return "png"
|
|
|
+
|
|
|
+
|
|
|
+def _nano_collect_image_bytes(result: dict[str, Any]) -> list[tuple[bytes, str]]:
|
|
|
+ """从 nano_banana 常见返回结构解析出 (raw_bytes, ext) 列表。"""
|
|
|
+ out: list[tuple[bytes, str]] = []
|
|
|
+ imgs = result.get("images")
|
|
|
+ if isinstance(imgs, list):
|
|
|
+ for item in imgs:
|
|
|
+ if not isinstance(item, str) or not item.strip():
|
|
|
+ continue
|
|
|
+ s = item.strip()
|
|
|
+ m = _NANO_DATA_URL_RE.match(s)
|
|
|
+ if m:
|
|
|
+ mime, b64 = m.group(1), m.group(2)
|
|
|
+ try:
|
|
|
+ out.append((base64.b64decode(b64), _nano_mime_to_ext(mime)))
|
|
|
+ except Exception:
|
|
|
+ continue
|
|
|
+ else:
|
|
|
+ try:
|
|
|
+ out.append((base64.b64decode(s), "png"))
|
|
|
+ except Exception:
|
|
|
+ continue
|
|
|
+ img_one = result.get("image")
|
|
|
+ if not out and isinstance(img_one, str) and len(img_one) > 100:
|
|
|
+ try:
|
|
|
+ out.append((base64.b64decode(img_one), "png"))
|
|
|
+ except Exception:
|
|
|
+ pass
|
|
|
+ b64_field = result.get("image_base64")
|
|
|
+ if not out and isinstance(b64_field, str) and b64_field.strip():
|
|
|
+ try:
|
|
|
+ out.append((base64.b64decode(b64_field.strip()), "png"))
|
|
|
+ except Exception:
|
|
|
+ pass
|
|
|
+ cands = result.get("candidates")
|
|
|
+ if not out and isinstance(cands, list) and cands:
|
|
|
+ cand0 = cands[0]
|
|
|
+ if isinstance(cand0, dict):
|
|
|
+ for p in cand0.get("content", {}).get("parts", []) or []:
|
|
|
+ if not isinstance(p, dict):
|
|
|
+ continue
|
|
|
+ inline = p.get("inlineData") or p.get("inline_data")
|
|
|
+ if not isinstance(inline, dict):
|
|
|
+ continue
|
|
|
+ b64 = inline.get("data")
|
|
|
+ if not b64:
|
|
|
+ continue
|
|
|
+ mime = str(
|
|
|
+ inline.get("mimeType") or inline.get("mime_type") or "image/png"
|
|
|
+ )
|
|
|
+ try:
|
|
|
+ out.append((base64.b64decode(b64), _nano_mime_to_ext(mime)))
|
|
|
+ except Exception:
|
|
|
+ continue
|
|
|
+ break
|
|
|
+ return out
|
|
|
+
|
|
|
+
|
|
|
+def _nano_save_images_default(result: dict[str, Any]) -> list[Path]:
|
|
|
+ """默认写入 tests/output/nano_banana_result[_{n}].{ext},返回已写入路径。"""
|
|
|
+ blobs = _nano_collect_image_bytes(result)
|
|
|
+ if not blobs:
|
|
|
+ return []
|
|
|
+ OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
|
|
+ paths: list[Path] = []
|
|
|
+ if len(blobs) == 1:
|
|
|
+ ext = blobs[0][1]
|
|
|
+ p = OUTPUT_DIR / f"nano_banana_result.{ext}"
|
|
|
+ p.write_bytes(blobs[0][0])
|
|
|
+ paths.append(p)
|
|
|
+ else:
|
|
|
+ for i, (raw, ext) in enumerate(blobs):
|
|
|
+ p = OUTPUT_DIR / f"nano_banana_result_{i}.{ext}"
|
|
|
+ p.write_bytes(raw)
|
|
|
+ paths.append(p)
|
|
|
+ return paths
|
|
|
+
|
|
|
+
|
|
|
+def test_nano_banana():
|
|
|
+ """POST /run_tool → nano_banana;依赖 tools/local/nano_banana/.env 中 GEMINI_API_KEY。"""
|
|
|
+ print("=== Test nano_banana (Gemini 图模) ===")
|
|
|
+ print(" 需: tools/local/nano_banana/.env → GEMINI_API_KEY")
|
|
|
+ print(" 可选环境变量: NANO_BANANA_TEST_PROMPT, NANO_BANANA_MODEL")
|
|
|
+ print(" 通过时默认保存图片到 tests/output/nano_banana_result*.png(或多张时带序号)")
|
|
|
+
|
|
|
+ tid = os.environ.get("NANO_BANANA_TOOL_ID", "nano_banana")
|
|
|
+ try:
|
|
|
+ tr = httpx.get(f"{BASE_URL}/tools", timeout=30)
|
|
|
+ tr.raise_for_status()
|
|
|
+ ids = {t["tool_id"] for t in tr.json().get("tools", [])}
|
|
|
+ if tid not in ids:
|
|
|
+ print(f" ERROR : 注册表中无 {tid!r},请先检查 data/registry.json")
|
|
|
+ print(" [FAIL]")
|
|
|
+ return
|
|
|
+ print(f" tool_id: {tid} (已注册)")
|
|
|
+ except Exception as e:
|
|
|
+ print(f" ERROR : GET /tools 失败: {e}")
|
|
|
+ print(" [FAIL]")
|
|
|
+ return
|
|
|
+
|
|
|
+ prompt = os.environ.get(
|
|
|
+ "NANO_BANANA_TEST_PROMPT",
|
|
|
+ "A minimal flat yellow banana icon on white background, no text",
|
|
|
+ )
|
|
|
+ params: dict[str, Any] = {"prompt": prompt}
|
|
|
+ model = os.environ.get("NANO_BANANA_MODEL", "").strip()
|
|
|
+ if model:
|
|
|
+ params["model"] = model
|
|
|
+ print(f" model: {model}")
|
|
|
+ else:
|
|
|
+ print(" model: (使用工具默认 / GEMINI_IMAGE_MODEL)")
|
|
|
+
|
|
|
+ print(f" calling {tid} ...")
|
|
|
+ try:
|
|
|
+ ok, err, result = _run_tool(tid, params, timeout=180.0)
|
|
|
+ if not ok:
|
|
|
+ print(f" ERROR : {err}")
|
|
|
+ print(" [FAIL]")
|
|
|
+ return
|
|
|
+ if not isinstance(result, dict):
|
|
|
+ print(f" ERROR : 非 dict 结果: {type(result)}")
|
|
|
+ print(" [FAIL]")
|
|
|
+ return
|
|
|
+ if _nano_has_image(result):
|
|
|
+ n = len(result["images"]) if isinstance(result.get("images"), list) else 0
|
|
|
+ print(f" Result : 含图片字段 (images 条数≈{n})")
|
|
|
+ if result.get("model"):
|
|
|
+ print(f" model: {result['model']}")
|
|
|
+ saved = _nano_save_images_default(result)
|
|
|
+ if saved:
|
|
|
+ for sp in saved:
|
|
|
+ print(f" saved : {sp}")
|
|
|
+ else:
|
|
|
+ print(
|
|
|
+ " WARN : 未能从响应解析出图片字节(字段存在但无法 base64 解码)"
|
|
|
+ )
|
|
|
+ print(" [PASS]")
|
|
|
+ return
|
|
|
+ print(f" ERROR : 未识别到图片字段,keys={list(result.keys())}")
|
|
|
+ print(f" 截断: {str(result)[:400]}...")
|
|
|
+ print(" [FAIL]")
|
|
|
+ except httpx.TimeoutException:
|
|
|
+ print(" ERROR : Request timeout")
|
|
|
+ print(" [FAIL]")
|
|
|
+ except Exception as e:
|
|
|
+ print(f" ERROR : {e}")
|
|
|
+ print(" [FAIL]")
|
|
|
+
|
|
|
+
|
|
|
def load_task_spec(task_name: str) -> dict:
|
|
|
task_file = TASKS_DIR / f"{task_name}.json"
|
|
|
if not task_file.exists():
|
|
|
@@ -265,6 +447,8 @@ def main():
|
|
|
help="call a tool by tool_id")
|
|
|
parser.add_argument("--stitch", action="store_true",
|
|
|
help="test image stitcher with sample images")
|
|
|
+ parser.add_argument("--nano_banana", action="store_true",
|
|
|
+ help="test nano_banana (Gemini); need GEMINI_API_KEY in tools/local/nano_banana/.env")
|
|
|
parser.add_argument("--create", nargs="?", const="", metavar="TASK_NAME",
|
|
|
help="create tool, optional task file name")
|
|
|
parser.add_argument("--launch-env", action="store_true",
|
|
|
@@ -303,6 +487,11 @@ def main():
|
|
|
test_stitch_images()
|
|
|
ran_any = True
|
|
|
|
|
|
+ if args.nano_banana:
|
|
|
+ print()
|
|
|
+ test_nano_banana()
|
|
|
+ ran_any = True
|
|
|
+
|
|
|
if args.create is not None:
|
|
|
print()
|
|
|
test_create_tool(args.create or None)
|