"""测试 Router 核心接口 用法: uv run python tests/test_router_api.py # 只跑 health check uv run python tests/test_router_api.py --search # 搜索工具列表 uv run python tests/test_router_api.py --search image # 关键词搜索 uv run python tests/test_router_api.py --status # 工具运行状态 uv run python tests/test_router_api.py --select image_stitcher # POST /run_tool 调用工具 uv run python tests/test_router_api.py --stitch # 测试图片拼接 uv run python tests/test_router_api.py --nano_banana # 测试 nano_banana uv run python tests/test_router_api.py --create # 默认任务 uv run python tests/test_router_api.py --create image_stitcher # 指定任务文件 uv run python tests/test_router_api.py --launch-env # 创建 RunComfy 启动环境工具 uv run python tests/test_router_api.py --run-only # 创建 RunComfy 任务执行工具 uv run python tests/test_router_api.py --stop-env # 创建 RunComfy 环境销毁工具 """ import argparse import base64 import json import os import re import sys import time from pathlib import Path from typing import Any import httpx BASE_URL = os.environ.get("TOOL_AGENT_ROUTER_URL", "http://127.0.0.1:8001") TASKS_DIR = Path(__file__).parent / "tasks" TEST_IMAGES_DIR = TASKS_DIR / "stitcher_images" OUTPUT_DIR = Path(__file__).parent / "output" def check_connection(): try: httpx.get(f"{BASE_URL}/health", timeout=3) except httpx.ConnectError: print(f"ERROR: Cannot connect to {BASE_URL}") print("Please start the service first:") print(" uv run python -m tool_agent") sys.exit(1) def test_health(): print("=== Health Check ===") resp = httpx.get(f"{BASE_URL}/health") print(f" Status : {resp.status_code}") print(f" Body : {json.dumps(resp.json(), ensure_ascii=False, indent=4)}") assert resp.status_code == 200 print(" [PASS]") def test_search_tools(keyword: str = None): print(f"=== Search Tools{f' (keyword={keyword!r})' if keyword else ''} ===") payload = {"keyword": keyword} if keyword else {} resp = httpx.post(f"{BASE_URL}/search_tools", json=payload) print(f" Status : {resp.status_code}") if resp.status_code != 200: print(f" Body : {resp.text}") print(" [FAIL]") return data = resp.json() print(f" Total : {data['total']}") for t in data["tools"]: print(f"\n [{t['tool_id']}]") print(f" name : {t['name']}") print(f" category : {t.get('category', '')}") print(f" state : {t['state']}") print(f" runtime : {t.get('runtime_type', '')} host_dir={t.get('host_dir', '')}") print(f" endpoint : {t.get('http_method', '')} {t.get('endpoint_path', '')} port={t.get('port')}") print(f" stream_support: {t.get('stream_support', False)}") print(f" description : {t.get('description', '')}") print(f" params ({len(t.get('params', []))}):") for p in t.get("params", []): req_mark = "*" if p["required"] else " " default_str = f" default={p['default']}" if p.get("default") is not None else "" enum_str = f" enum={p['enum']}" if p.get("enum") else "" print(f" {req_mark} {p['name']:<25} {p['type']:<12} {p.get('description', '')}{default_str}{enum_str}") if t.get("output_schema"): out_props = t["output_schema"].get("properties", {}) print(f" output ({len(out_props)}):") for oname, odef in out_props.items(): print(f" {oname:<25} {odef.get('type', ''):<12} {odef.get('description', '')}") print("\n [PASS]") def test_tools_status(): print("=== Tools Status ===") resp = httpx.get(f"{BASE_URL}/tools/status") print(f" Status : {resp.status_code}") data = resp.json() print(f" Total : {len(data['tools'])}") for t in data["tools"]: print(f" - {t['tool_id']}") print(f" state : {t['state']}") print(f" port : {t.get('port')}") print(f" pid : {t.get('pid')}") print(f" sources: {[s['type'] for s in t.get('sources', [])]}") if t.get("last_error"): print(f" error : {t['last_error']}") print(" [PASS]") def _run_tool( tool_id: str, params: dict[str, Any], timeout: float = 120.0 ) -> tuple[bool, str | None, Any]: """POST /run_tool。成功返回 (True, None, result);失败 (False, message, None)。""" resp = httpx.post( f"{BASE_URL}/run_tool", json={"tool_id": tool_id, "params": params}, timeout=timeout, ) print(f" Status : {resp.status_code}") if resp.status_code != 200: return False, f"HTTP {resp.status_code}: {resp.text[:300]}", None try: data = resp.json() except Exception as e: return False, f"Invalid JSON: {e}", None if data.get("status") != "success": return False, data.get("error") or str(data), None result = data.get("result") if isinstance(result, dict) and result.get("status") == "error": return False, str(result.get("error", result)), None return True, None, result def test_select_tool(tool_id: str): print(f"=== Run Tool (tool_id={tool_id!r}) ===") ok, err, result = _run_tool(tool_id, {}, timeout=30) print(f" Result :") if not ok: print(f" error : {err}") print(" [FAIL]") return result_str = json.dumps(result, ensure_ascii=False, indent=6) print(f" body: {result_str[:500]}") print(" [PASS]") def test_stitch_images(): print("=== Test Image Stitcher ===") if not TEST_IMAGES_DIR.exists(): print(f" ERROR: Test images directory not found: {TEST_IMAGES_DIR}") print(" [SKIP]") return image_files = sorted(TEST_IMAGES_DIR.glob("*.png")) if len(image_files) < 2: print(f" ERROR: Need at least 2 images, found {len(image_files)}") print(" [SKIP]") return print(f" Images : {len(image_files)} found") images_b64 = [] for img_path in image_files[:6]: with open(img_path, "rb") as f: images_b64.append(base64.b64encode(f.read()).decode()) print(f" - {img_path.name}") print(f" Calling image_stitcher (grid, 2 columns)...") try: ok, err, result = _run_tool( "image_stitcher", { "images": images_b64, "direction": "grid", "columns": 2, "spacing": 10, "background_color": "#FFFFFF", }, timeout=120.0, ) if not ok: print(f" ERROR : {err}") print(" [FAIL]") return if not isinstance(result, dict) or "image" not in result: print(f" ERROR : 缺少 image 字段: {result!r}") print(" [FAIL]") return print(f" Result :") print(f" width : {result.get('width')}") print(f" height: {result.get('height')}") OUTPUT_DIR.mkdir(parents=True, exist_ok=True) output_path = OUTPUT_DIR / "stitched_result.png" with open(output_path, "wb") as f: f.write(base64.b64decode(result["image"])) print(f" saved : {output_path}") print(" [PASS]") except httpx.TimeoutException: print(" ERROR : Request timeout") print(" [FAIL]") except Exception as e: print(f" ERROR : {e}") print(" [FAIL]") def _nano_has_image(data: dict[str, Any]) -> bool: if data.get("images"): return True img = data.get("image") if isinstance(img, str) and len(img) > 100: return True if data.get("image_base64"): return True cands = data.get("candidates") if isinstance(cands, list) and cands: parts = cands[0].get("content", {}).get("parts", []) for p in parts: if isinstance(p, dict) and (p.get("inlineData") or p.get("inline_data")): return True return False _NANO_DATA_URL_RE = re.compile(r"^data:([^;]+);base64,(.+)$", re.I | re.S) def _nano_mime_to_ext(mime: str) -> str: base = mime.lower().split(";")[0].strip() if base == "image/png": return "png" if base in ("image/jpeg", "image/jpg"): return "jpg" if base == "image/webp": return "webp" return "png" def _nano_collect_image_bytes(result: dict[str, Any]) -> list[tuple[bytes, str]]: """从 nano_banana 常见返回结构解析出 (raw_bytes, ext) 列表。""" out: list[tuple[bytes, str]] = [] imgs = result.get("images") if isinstance(imgs, list): for item in imgs: if not isinstance(item, str) or not item.strip(): continue s = item.strip() m = _NANO_DATA_URL_RE.match(s) if m: mime, b64 = m.group(1), m.group(2) try: out.append((base64.b64decode(b64), _nano_mime_to_ext(mime))) except Exception: continue else: try: out.append((base64.b64decode(s), "png")) except Exception: continue img_one = result.get("image") if not out and isinstance(img_one, str) and len(img_one) > 100: try: out.append((base64.b64decode(img_one), "png")) except Exception: pass b64_field = result.get("image_base64") if not out and isinstance(b64_field, str) and b64_field.strip(): try: out.append((base64.b64decode(b64_field.strip()), "png")) except Exception: pass cands = result.get("candidates") if not out and isinstance(cands, list) and cands: cand0 = cands[0] if isinstance(cand0, dict): for p in cand0.get("content", {}).get("parts", []) or []: if not isinstance(p, dict): continue inline = p.get("inlineData") or p.get("inline_data") if not isinstance(inline, dict): continue b64 = inline.get("data") if not b64: continue mime = str( inline.get("mimeType") or inline.get("mime_type") or "image/png" ) try: out.append((base64.b64decode(b64), _nano_mime_to_ext(mime))) except Exception: continue break return out def _nano_save_images_default(result: dict[str, Any]) -> list[Path]: """默认写入 tests/output/nano_banana_result[_{n}].{ext},返回已写入路径。""" blobs = _nano_collect_image_bytes(result) if not blobs: return [] OUTPUT_DIR.mkdir(parents=True, exist_ok=True) paths: list[Path] = [] if len(blobs) == 1: ext = blobs[0][1] p = OUTPUT_DIR / f"nano_banana_result.{ext}" p.write_bytes(blobs[0][0]) paths.append(p) else: for i, (raw, ext) in enumerate(blobs): p = OUTPUT_DIR / f"nano_banana_result_{i}.{ext}" p.write_bytes(raw) paths.append(p) return paths def test_nano_banana(): """POST /run_tool → nano_banana;依赖 tools/local/nano_banana/.env 中 GEMINI_API_KEY。""" print("=== Test nano_banana (Gemini 图模) ===") print(" 需: tools/local/nano_banana/.env → GEMINI_API_KEY") print(" 可选环境变量: NANO_BANANA_TEST_PROMPT, NANO_BANANA_MODEL") print(" 通过时默认保存图片到 tests/output/nano_banana_result*.png(或多张时带序号)") tid = os.environ.get("NANO_BANANA_TOOL_ID", "nano_banana") try: tr = httpx.get(f"{BASE_URL}/tools", timeout=30) tr.raise_for_status() ids = {t["tool_id"] for t in tr.json().get("tools", [])} if tid not in ids: print(f" ERROR : 注册表中无 {tid!r},请先检查 data/registry.json") print(" [FAIL]") return print(f" tool_id: {tid} (已注册)") except Exception as e: print(f" ERROR : GET /tools 失败: {e}") print(" [FAIL]") return prompt = os.environ.get( "NANO_BANANA_TEST_PROMPT", "A minimal flat yellow banana icon on white background, no text", ) params: dict[str, Any] = {"prompt": prompt} model = os.environ.get("NANO_BANANA_MODEL", "").strip() if model: params["model"] = model print(f" model: {model}") else: print(" model: (使用工具默认 / GEMINI_IMAGE_MODEL)") print(f" calling {tid} ...") try: ok, err, result = _run_tool(tid, params, timeout=180.0) if not ok: print(f" ERROR : {err}") print(" [FAIL]") return if not isinstance(result, dict): print(f" ERROR : 非 dict 结果: {type(result)}") print(" [FAIL]") return if _nano_has_image(result): n = len(result["images"]) if isinstance(result.get("images"), list) else 0 print(f" Result : 含图片字段 (images 条数≈{n})") if result.get("model"): print(f" model: {result['model']}") saved = _nano_save_images_default(result) if saved: for sp in saved: print(f" saved : {sp}") else: print( " WARN : 未能从响应解析出图片字节(字段存在但无法 base64 解码)" ) print(" [PASS]") return print(f" ERROR : 未识别到图片字段,keys={list(result.keys())}") print(f" 截断: {str(result)[:400]}...") print(" [FAIL]") except httpx.TimeoutException: print(" ERROR : Request timeout") print(" [FAIL]") except Exception as e: print(f" ERROR : {e}") print(" [FAIL]") def load_task_spec(task_name: str) -> dict: task_file = TASKS_DIR / f"{task_name}.json" if not task_file.exists(): print(f" ERROR: Task file not found: {task_file}") print(" Available tasks:") if TASKS_DIR.exists(): for f in TASKS_DIR.glob("*.json"): print(f" - {f.stem}") sys.exit(1) with open(task_file, "r", encoding="utf-8") as f: return json.load(f) def test_create_tool(task_name: str = None): print(f"=== Create Tool{f' (task={task_name!r})' if task_name else ''} ===") if task_name: task_data = load_task_spec(task_name) print(f" File : tests/tasks/{task_name}.json") print(f" Description: {task_data['description'][:80]}") else: task_data = {"description": "创建一个简单的文本计数工具,输入文本,返回字数和字符数"} print(f" Description: {task_data['description']}") resp = httpx.post(f"{BASE_URL}/create_tool", json=task_data) data = resp.json() task_id = data["task_id"] print(f" Task ID : {task_id}") print(f" Status : {data['status']}") assert data["status"] == "pending" print(" [SUBMITTED]") print(f"\n Polling task {task_id} (timeout 10min)...") for i in range(120): time.sleep(5) resp = httpx.get(f"{BASE_URL}/tasks/{task_id}", timeout=30) task = resp.json() status = task["status"] if i % 6 == 0: print(f" [{i*5}s] status={status}") if status == "completed": print(f"\n Completed!") print(f" Result : {str(task.get('result', ''))[:300]}") resp2 = httpx.post(f"{BASE_URL}/search_tools", json={}) tools = resp2.json()["tools"] print(f" Registered : {[t['tool_id'] for t in tools]}") print(" [PASS]") return if status == "failed": print(f"\n Failed!") print(f" Error : {task.get('error', 'unknown')}") print(" [FAIL]") return print(f"\n Timeout after 600s") print(" [TIMEOUT]") def main(): parser = argparse.ArgumentParser(description="Router API Test") parser.add_argument("--search", nargs="?", const="", metavar="KEYWORD", help="search tools, optional keyword") parser.add_argument("--status", action="store_true", help="show tools status") parser.add_argument("--select", metavar="TOOL_ID", help="call a tool by tool_id") parser.add_argument("--stitch", action="store_true", help="test image stitcher with sample images") parser.add_argument("--nano_banana", action="store_true", help="test nano_banana (Gemini); need GEMINI_API_KEY in tools/local/nano_banana/.env") parser.add_argument("--create", nargs="?", const="", metavar="TASK_NAME", help="create tool, optional task file name") parser.add_argument("--launch-env", action="store_true", help="create RunComfy launch env tool (runcomfy_launch_env)") parser.add_argument("--run-only", action="store_true", help="create RunComfy run only tool (runcomfy_run_only)") parser.add_argument("--stop-env", action="store_true", help="create RunComfy stop env tool (runcomfy_stop_env)") args = parser.parse_args() print(f"Target: {BASE_URL}\n") check_connection() # 始终跑 health check test_health() ran_any = False if args.search is not None: print() test_search_tools(args.search or None) ran_any = True if args.status: print() test_tools_status() ran_any = True if args.select: print() test_select_tool(args.select) ran_any = True if args.stitch: print() test_stitch_images() ran_any = True if args.nano_banana: print() test_nano_banana() ran_any = True if args.create is not None: print() test_create_tool(args.create or None) ran_any = True if args.launch_env: print() test_create_tool("runcomfy_launch_env") ran_any = True if args.run_only: print() test_create_tool("runcomfy_run_only") ran_any = True if args.stop_env: print() test_create_tool("runcomfy_stop_env") ran_any = True if not ran_any: print() print("No test specified. Available options:") parser.print_help() print("\n=== DONE ===") if __name__ == "__main__": main()