|
@@ -5,8 +5,9 @@
|
|
|
uv run python tests/test_router_api.py --search # 搜索工具列表
|
|
uv run python tests/test_router_api.py --search # 搜索工具列表
|
|
|
uv run python tests/test_router_api.py --search image # 关键词搜索
|
|
uv run python tests/test_router_api.py --search image # 关键词搜索
|
|
|
uv run python tests/test_router_api.py --status # 工具运行状态
|
|
uv run python tests/test_router_api.py --status # 工具运行状态
|
|
|
- uv run python tests/test_router_api.py --select image_stitcher # 调用指定工具
|
|
|
|
|
|
|
+ uv run python tests/test_router_api.py --select image_stitcher # POST /run_tool 调用工具
|
|
|
uv run python tests/test_router_api.py --stitch # 测试图片拼接
|
|
uv run python tests/test_router_api.py --stitch # 测试图片拼接
|
|
|
|
|
+ uv run python tests/test_router_api.py --nano_banana # 测试 nano_banana
|
|
|
uv run python tests/test_router_api.py --create # 默认任务
|
|
uv run python tests/test_router_api.py --create # 默认任务
|
|
|
uv run python tests/test_router_api.py --create image_stitcher # 指定任务文件
|
|
uv run python tests/test_router_api.py --create image_stitcher # 指定任务文件
|
|
|
uv run python tests/test_router_api.py --launch-env # 创建 RunComfy 启动环境工具
|
|
uv run python tests/test_router_api.py --launch-env # 创建 RunComfy 启动环境工具
|
|
@@ -17,13 +18,16 @@
|
|
|
import argparse
|
|
import argparse
|
|
|
import base64
|
|
import base64
|
|
|
import json
|
|
import json
|
|
|
|
|
+import os
|
|
|
|
|
+import re
|
|
|
import sys
|
|
import sys
|
|
|
import time
|
|
import time
|
|
|
from pathlib import Path
|
|
from pathlib import Path
|
|
|
|
|
+from typing import Any
|
|
|
|
|
|
|
|
import httpx
|
|
import httpx
|
|
|
|
|
|
|
|
-BASE_URL = "http://127.0.0.1:8001"
|
|
|
|
|
|
|
+BASE_URL = os.environ.get("TOOL_AGENT_ROUTER_URL", "http://127.0.0.1:8001")
|
|
|
TASKS_DIR = Path(__file__).parent / "tasks"
|
|
TASKS_DIR = Path(__file__).parent / "tasks"
|
|
|
TEST_IMAGES_DIR = TASKS_DIR / "stitcher_images"
|
|
TEST_IMAGES_DIR = TASKS_DIR / "stitcher_images"
|
|
|
OUTPUT_DIR = Path(__file__).parent / "output"
|
|
OUTPUT_DIR = Path(__file__).parent / "output"
|
|
@@ -99,21 +103,40 @@ def test_tools_status():
|
|
|
print(" [PASS]")
|
|
print(" [PASS]")
|
|
|
|
|
|
|
|
|
|
|
|
|
-def test_select_tool(tool_id: str):
|
|
|
|
|
- print(f"=== Select Tool (tool_id={tool_id!r}) ===")
|
|
|
|
|
- resp = httpx.post(f"{BASE_URL}/select_tool", json={
|
|
|
|
|
- "tool_id": tool_id,
|
|
|
|
|
- "params": {}
|
|
|
|
|
- }, timeout=30)
|
|
|
|
|
|
|
+def _run_tool(
|
|
|
|
|
+ tool_id: str, params: dict[str, Any], timeout: float = 120.0
|
|
|
|
|
+) -> tuple[bool, str | None, Any]:
|
|
|
|
|
+ """POST /run_tool。成功返回 (True, None, result);失败 (False, message, None)。"""
|
|
|
|
|
+ resp = httpx.post(
|
|
|
|
|
+ f"{BASE_URL}/run_tool",
|
|
|
|
|
+ json={"tool_id": tool_id, "params": params},
|
|
|
|
|
+ timeout=timeout,
|
|
|
|
|
+ )
|
|
|
print(f" Status : {resp.status_code}")
|
|
print(f" Status : {resp.status_code}")
|
|
|
- data = resp.json()
|
|
|
|
|
|
|
+ if resp.status_code != 200:
|
|
|
|
|
+ return False, f"HTTP {resp.status_code}: {resp.text[:300]}", None
|
|
|
|
|
+ try:
|
|
|
|
|
+ data = resp.json()
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ return False, f"Invalid JSON: {e}", None
|
|
|
|
|
+ if data.get("status") != "success":
|
|
|
|
|
+ return False, data.get("error") or str(data), None
|
|
|
|
|
+ result = data.get("result")
|
|
|
|
|
+ if isinstance(result, dict) and result.get("status") == "error":
|
|
|
|
|
+ return False, str(result.get("error", result)), None
|
|
|
|
|
+ return True, None, result
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def test_select_tool(tool_id: str):
|
|
|
|
|
+ print(f"=== Run Tool (tool_id={tool_id!r}) ===")
|
|
|
|
|
+ ok, err, result = _run_tool(tool_id, {}, timeout=30)
|
|
|
print(f" Result :")
|
|
print(f" Result :")
|
|
|
- print(f" status: {data.get('status')}")
|
|
|
|
|
- if data.get("error"):
|
|
|
|
|
- print(f" error : {data['error']}")
|
|
|
|
|
- else:
|
|
|
|
|
- result_str = json.dumps(data.get("result"), ensure_ascii=False, indent=6)
|
|
|
|
|
- print(f" result: {result_str[:500]}")
|
|
|
|
|
|
|
+ if not ok:
|
|
|
|
|
+ print(f" error : {err}")
|
|
|
|
|
+ print(" [FAIL]")
|
|
|
|
|
+ return
|
|
|
|
|
+ result_str = json.dumps(result, ensure_ascii=False, indent=6)
|
|
|
|
|
+ print(f" body: {result_str[:500]}")
|
|
|
print(" [PASS]")
|
|
print(" [PASS]")
|
|
|
|
|
|
|
|
|
|
|
|
@@ -139,32 +162,213 @@ def test_stitch_images():
|
|
|
|
|
|
|
|
print(f" Calling image_stitcher (grid, 2 columns)...")
|
|
print(f" Calling image_stitcher (grid, 2 columns)...")
|
|
|
try:
|
|
try:
|
|
|
- resp = httpx.post(f"{BASE_URL}/select_tool", json={
|
|
|
|
|
- "tool_id": "image_stitcher",
|
|
|
|
|
- "params": {
|
|
|
|
|
|
|
+ ok, err, result = _run_tool(
|
|
|
|
|
+ "image_stitcher",
|
|
|
|
|
+ {
|
|
|
"images": images_b64,
|
|
"images": images_b64,
|
|
|
"direction": "grid",
|
|
"direction": "grid",
|
|
|
"columns": 2,
|
|
"columns": 2,
|
|
|
"spacing": 10,
|
|
"spacing": 10,
|
|
|
"background_color": "#FFFFFF",
|
|
"background_color": "#FFFFFF",
|
|
|
- }
|
|
|
|
|
- }, timeout=60)
|
|
|
|
|
- print(f" Status : {resp.status_code}")
|
|
|
|
|
- data = resp.json()
|
|
|
|
|
- if data["status"] == "success":
|
|
|
|
|
- result = data["result"]
|
|
|
|
|
- print(f" Result :")
|
|
|
|
|
- print(f" width : {result['width']}")
|
|
|
|
|
- print(f" height: {result['height']}")
|
|
|
|
|
- OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
- output_path = OUTPUT_DIR / "stitched_result.png"
|
|
|
|
|
- with open(output_path, "wb") as f:
|
|
|
|
|
- f.write(base64.b64decode(result["image"]))
|
|
|
|
|
- print(f" saved : {output_path}")
|
|
|
|
|
- print(" [PASS]")
|
|
|
|
|
- else:
|
|
|
|
|
- print(f" ERROR : {data.get('error', 'unknown')}")
|
|
|
|
|
|
|
+ },
|
|
|
|
|
+ timeout=120.0,
|
|
|
|
|
+ )
|
|
|
|
|
+ if not ok:
|
|
|
|
|
+ print(f" ERROR : {err}")
|
|
|
|
|
+ print(" [FAIL]")
|
|
|
|
|
+ return
|
|
|
|
|
+ if not isinstance(result, dict) or "image" not in result:
|
|
|
|
|
+ print(f" ERROR : 缺少 image 字段: {result!r}")
|
|
|
print(" [FAIL]")
|
|
print(" [FAIL]")
|
|
|
|
|
+ return
|
|
|
|
|
+ print(f" Result :")
|
|
|
|
|
+ print(f" width : {result.get('width')}")
|
|
|
|
|
+ print(f" height: {result.get('height')}")
|
|
|
|
|
+ OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
+ output_path = OUTPUT_DIR / "stitched_result.png"
|
|
|
|
|
+ with open(output_path, "wb") as f:
|
|
|
|
|
+ f.write(base64.b64decode(result["image"]))
|
|
|
|
|
+ print(f" saved : {output_path}")
|
|
|
|
|
+ print(" [PASS]")
|
|
|
|
|
+ except httpx.TimeoutException:
|
|
|
|
|
+ print(" ERROR : Request timeout")
|
|
|
|
|
+ print(" [FAIL]")
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ print(f" ERROR : {e}")
|
|
|
|
|
+ print(" [FAIL]")
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def _nano_has_image(data: dict[str, Any]) -> bool:
|
|
|
|
|
+ if data.get("images"):
|
|
|
|
|
+ return True
|
|
|
|
|
+ img = data.get("image")
|
|
|
|
|
+ if isinstance(img, str) and len(img) > 100:
|
|
|
|
|
+ return True
|
|
|
|
|
+ if data.get("image_base64"):
|
|
|
|
|
+ return True
|
|
|
|
|
+ cands = data.get("candidates")
|
|
|
|
|
+ if isinstance(cands, list) and cands:
|
|
|
|
|
+ parts = cands[0].get("content", {}).get("parts", [])
|
|
|
|
|
+ for p in parts:
|
|
|
|
|
+ if isinstance(p, dict) and (p.get("inlineData") or p.get("inline_data")):
|
|
|
|
|
+ return True
|
|
|
|
|
+ return False
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+_NANO_DATA_URL_RE = re.compile(r"^data:([^;]+);base64,(.+)$", re.I | re.S)
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def _nano_mime_to_ext(mime: str) -> str:
|
|
|
|
|
+ base = mime.lower().split(";")[0].strip()
|
|
|
|
|
+ if base == "image/png":
|
|
|
|
|
+ return "png"
|
|
|
|
|
+ if base in ("image/jpeg", "image/jpg"):
|
|
|
|
|
+ return "jpg"
|
|
|
|
|
+ if base == "image/webp":
|
|
|
|
|
+ return "webp"
|
|
|
|
|
+ return "png"
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def _nano_collect_image_bytes(result: dict[str, Any]) -> list[tuple[bytes, str]]:
|
|
|
|
|
+ """从 nano_banana 常见返回结构解析出 (raw_bytes, ext) 列表。"""
|
|
|
|
|
+ out: list[tuple[bytes, str]] = []
|
|
|
|
|
+ imgs = result.get("images")
|
|
|
|
|
+ if isinstance(imgs, list):
|
|
|
|
|
+ for item in imgs:
|
|
|
|
|
+ if not isinstance(item, str) or not item.strip():
|
|
|
|
|
+ continue
|
|
|
|
|
+ s = item.strip()
|
|
|
|
|
+ m = _NANO_DATA_URL_RE.match(s)
|
|
|
|
|
+ if m:
|
|
|
|
|
+ mime, b64 = m.group(1), m.group(2)
|
|
|
|
|
+ try:
|
|
|
|
|
+ out.append((base64.b64decode(b64), _nano_mime_to_ext(mime)))
|
|
|
|
|
+ except Exception:
|
|
|
|
|
+ continue
|
|
|
|
|
+ else:
|
|
|
|
|
+ try:
|
|
|
|
|
+ out.append((base64.b64decode(s), "png"))
|
|
|
|
|
+ except Exception:
|
|
|
|
|
+ continue
|
|
|
|
|
+ img_one = result.get("image")
|
|
|
|
|
+ if not out and isinstance(img_one, str) and len(img_one) > 100:
|
|
|
|
|
+ try:
|
|
|
|
|
+ out.append((base64.b64decode(img_one), "png"))
|
|
|
|
|
+ except Exception:
|
|
|
|
|
+ pass
|
|
|
|
|
+ b64_field = result.get("image_base64")
|
|
|
|
|
+ if not out and isinstance(b64_field, str) and b64_field.strip():
|
|
|
|
|
+ try:
|
|
|
|
|
+ out.append((base64.b64decode(b64_field.strip()), "png"))
|
|
|
|
|
+ except Exception:
|
|
|
|
|
+ pass
|
|
|
|
|
+ cands = result.get("candidates")
|
|
|
|
|
+ if not out and isinstance(cands, list) and cands:
|
|
|
|
|
+ cand0 = cands[0]
|
|
|
|
|
+ if isinstance(cand0, dict):
|
|
|
|
|
+ for p in cand0.get("content", {}).get("parts", []) or []:
|
|
|
|
|
+ if not isinstance(p, dict):
|
|
|
|
|
+ continue
|
|
|
|
|
+ inline = p.get("inlineData") or p.get("inline_data")
|
|
|
|
|
+ if not isinstance(inline, dict):
|
|
|
|
|
+ continue
|
|
|
|
|
+ b64 = inline.get("data")
|
|
|
|
|
+ if not b64:
|
|
|
|
|
+ continue
|
|
|
|
|
+ mime = str(
|
|
|
|
|
+ inline.get("mimeType") or inline.get("mime_type") or "image/png"
|
|
|
|
|
+ )
|
|
|
|
|
+ try:
|
|
|
|
|
+ out.append((base64.b64decode(b64), _nano_mime_to_ext(mime)))
|
|
|
|
|
+ except Exception:
|
|
|
|
|
+ continue
|
|
|
|
|
+ break
|
|
|
|
|
+ return out
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def _nano_save_images_default(result: dict[str, Any]) -> list[Path]:
|
|
|
|
|
+ """默认写入 tests/output/nano_banana_result[_{n}].{ext},返回已写入路径。"""
|
|
|
|
|
+ blobs = _nano_collect_image_bytes(result)
|
|
|
|
|
+ if not blobs:
|
|
|
|
|
+ return []
|
|
|
|
|
+ OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
+ paths: list[Path] = []
|
|
|
|
|
+ if len(blobs) == 1:
|
|
|
|
|
+ ext = blobs[0][1]
|
|
|
|
|
+ p = OUTPUT_DIR / f"nano_banana_result.{ext}"
|
|
|
|
|
+ p.write_bytes(blobs[0][0])
|
|
|
|
|
+ paths.append(p)
|
|
|
|
|
+ else:
|
|
|
|
|
+ for i, (raw, ext) in enumerate(blobs):
|
|
|
|
|
+ p = OUTPUT_DIR / f"nano_banana_result_{i}.{ext}"
|
|
|
|
|
+ p.write_bytes(raw)
|
|
|
|
|
+ paths.append(p)
|
|
|
|
|
+ return paths
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def test_nano_banana():
|
|
|
|
|
+ """POST /run_tool → nano_banana;依赖 tools/local/nano_banana/.env 中 GEMINI_API_KEY。"""
|
|
|
|
|
+ print("=== Test nano_banana (Gemini 图模) ===")
|
|
|
|
|
+ print(" 需: tools/local/nano_banana/.env → GEMINI_API_KEY")
|
|
|
|
|
+ print(" 可选环境变量: NANO_BANANA_TEST_PROMPT, NANO_BANANA_MODEL")
|
|
|
|
|
+ print(" 通过时默认保存图片到 tests/output/nano_banana_result*.png(或多张时带序号)")
|
|
|
|
|
+
|
|
|
|
|
+ tid = os.environ.get("NANO_BANANA_TOOL_ID", "nano_banana")
|
|
|
|
|
+ try:
|
|
|
|
|
+ tr = httpx.get(f"{BASE_URL}/tools", timeout=30)
|
|
|
|
|
+ tr.raise_for_status()
|
|
|
|
|
+ ids = {t["tool_id"] for t in tr.json().get("tools", [])}
|
|
|
|
|
+ if tid not in ids:
|
|
|
|
|
+ print(f" ERROR : 注册表中无 {tid!r},请先检查 data/registry.json")
|
|
|
|
|
+ print(" [FAIL]")
|
|
|
|
|
+ return
|
|
|
|
|
+ print(f" tool_id: {tid} (已注册)")
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ print(f" ERROR : GET /tools 失败: {e}")
|
|
|
|
|
+ print(" [FAIL]")
|
|
|
|
|
+ return
|
|
|
|
|
+
|
|
|
|
|
+ prompt = os.environ.get(
|
|
|
|
|
+ "NANO_BANANA_TEST_PROMPT",
|
|
|
|
|
+ "A minimal flat yellow banana icon on white background, no text",
|
|
|
|
|
+ )
|
|
|
|
|
+ params: dict[str, Any] = {"prompt": prompt}
|
|
|
|
|
+ model = os.environ.get("NANO_BANANA_MODEL", "").strip()
|
|
|
|
|
+ if model:
|
|
|
|
|
+ params["model"] = model
|
|
|
|
|
+ print(f" model: {model}")
|
|
|
|
|
+ else:
|
|
|
|
|
+ print(" model: (使用工具默认 / GEMINI_IMAGE_MODEL)")
|
|
|
|
|
+
|
|
|
|
|
+ print(f" calling {tid} ...")
|
|
|
|
|
+ try:
|
|
|
|
|
+ ok, err, result = _run_tool(tid, params, timeout=180.0)
|
|
|
|
|
+ if not ok:
|
|
|
|
|
+ print(f" ERROR : {err}")
|
|
|
|
|
+ print(" [FAIL]")
|
|
|
|
|
+ return
|
|
|
|
|
+ if not isinstance(result, dict):
|
|
|
|
|
+ print(f" ERROR : 非 dict 结果: {type(result)}")
|
|
|
|
|
+ print(" [FAIL]")
|
|
|
|
|
+ return
|
|
|
|
|
+ if _nano_has_image(result):
|
|
|
|
|
+ n = len(result["images"]) if isinstance(result.get("images"), list) else 0
|
|
|
|
|
+ print(f" Result : 含图片字段 (images 条数≈{n})")
|
|
|
|
|
+ if result.get("model"):
|
|
|
|
|
+ print(f" model: {result['model']}")
|
|
|
|
|
+ saved = _nano_save_images_default(result)
|
|
|
|
|
+ if saved:
|
|
|
|
|
+ for sp in saved:
|
|
|
|
|
+ print(f" saved : {sp}")
|
|
|
|
|
+ else:
|
|
|
|
|
+ print(
|
|
|
|
|
+ " WARN : 未能从响应解析出图片字节(字段存在但无法 base64 解码)"
|
|
|
|
|
+ )
|
|
|
|
|
+ print(" [PASS]")
|
|
|
|
|
+ return
|
|
|
|
|
+ print(f" ERROR : 未识别到图片字段,keys={list(result.keys())}")
|
|
|
|
|
+ print(f" 截断: {str(result)[:400]}...")
|
|
|
|
|
+ print(" [FAIL]")
|
|
|
except httpx.TimeoutException:
|
|
except httpx.TimeoutException:
|
|
|
print(" ERROR : Request timeout")
|
|
print(" ERROR : Request timeout")
|
|
|
print(" [FAIL]")
|
|
print(" [FAIL]")
|
|
@@ -243,6 +447,8 @@ def main():
|
|
|
help="call a tool by tool_id")
|
|
help="call a tool by tool_id")
|
|
|
parser.add_argument("--stitch", action="store_true",
|
|
parser.add_argument("--stitch", action="store_true",
|
|
|
help="test image stitcher with sample images")
|
|
help="test image stitcher with sample images")
|
|
|
|
|
+ parser.add_argument("--nano_banana", action="store_true",
|
|
|
|
|
+ help="test nano_banana (Gemini); need GEMINI_API_KEY in tools/local/nano_banana/.env")
|
|
|
parser.add_argument("--create", nargs="?", const="", metavar="TASK_NAME",
|
|
parser.add_argument("--create", nargs="?", const="", metavar="TASK_NAME",
|
|
|
help="create tool, optional task file name")
|
|
help="create tool, optional task file name")
|
|
|
parser.add_argument("--launch-env", action="store_true",
|
|
parser.add_argument("--launch-env", action="store_true",
|
|
@@ -281,6 +487,11 @@ def main():
|
|
|
test_stitch_images()
|
|
test_stitch_images()
|
|
|
ran_any = True
|
|
ran_any = True
|
|
|
|
|
|
|
|
|
|
+ if args.nano_banana:
|
|
|
|
|
+ print()
|
|
|
|
|
+ test_nano_banana()
|
|
|
|
|
+ ran_any = True
|
|
|
|
|
+
|
|
|
if args.create is not None:
|
|
if args.create is not None:
|
|
|
print()
|
|
print()
|
|
|
test_create_tool(args.create or None)
|
|
test_create_tool(args.create or None)
|