| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524 |
- """测试 Router 核心接口
- 用法:
- uv run python tests/test_router_api.py # 只跑 health check
- uv run python tests/test_router_api.py --search # 搜索工具列表
- uv run python tests/test_router_api.py --search image # 关键词搜索
- uv run python tests/test_router_api.py --status # 工具运行状态
- uv run python tests/test_router_api.py --select image_stitcher # POST /run_tool 调用工具
- uv run python tests/test_router_api.py --stitch # 测试图片拼接
- uv run python tests/test_router_api.py --nano_banana # 测试 nano_banana
- uv run python tests/test_router_api.py --create # 默认任务
- uv run python tests/test_router_api.py --create image_stitcher # 指定任务文件
- uv run python tests/test_router_api.py --launch-env # 创建 RunComfy 启动环境工具
- uv run python tests/test_router_api.py --run-only # 创建 RunComfy 任务执行工具
- uv run python tests/test_router_api.py --stop-env # 创建 RunComfy 环境销毁工具
- """
- import argparse
- import base64
- import json
- import os
- import re
- import sys
- import time
- from pathlib import Path
- from typing import Any
- import httpx
- BASE_URL = os.environ.get("TOOL_AGENT_ROUTER_URL", "http://127.0.0.1:8001")
- TASKS_DIR = Path(__file__).parent / "tasks"
- TEST_IMAGES_DIR = TASKS_DIR / "stitcher_images"
- OUTPUT_DIR = Path(__file__).parent / "output"
- def check_connection():
- try:
- httpx.get(f"{BASE_URL}/health", timeout=3)
- except httpx.ConnectError:
- print(f"ERROR: Cannot connect to {BASE_URL}")
- print("Please start the service first:")
- print(" uv run python -m tool_agent")
- sys.exit(1)
- def test_health():
- print("=== Health Check ===")
- resp = httpx.get(f"{BASE_URL}/health")
- print(f" Status : {resp.status_code}")
- print(f" Body : {json.dumps(resp.json(), ensure_ascii=False, indent=4)}")
- assert resp.status_code == 200
- print(" [PASS]")
- def test_search_tools(keyword: str = None):
- print(f"=== Search Tools{f' (keyword={keyword!r})' if keyword else ''} ===")
- payload = {"keyword": keyword} if keyword else {}
- resp = httpx.post(f"{BASE_URL}/search_tools", json=payload)
- print(f" Status : {resp.status_code}")
- if resp.status_code != 200:
- print(f" Body : {resp.text}")
- print(" [FAIL]")
- return
- data = resp.json()
- print(f" Total : {data['total']}")
- for t in data["tools"]:
- print(f"\n [{t['tool_id']}]")
- print(f" name : {t['name']}")
- print(f" category : {t.get('category', '')}")
- print(f" state : {t['state']}")
- print(f" runtime : {t.get('runtime_type', '')} host_dir={t.get('host_dir', '')}")
- print(f" endpoint : {t.get('http_method', '')} {t.get('endpoint_path', '')} port={t.get('port')}")
- print(f" stream_support: {t.get('stream_support', False)}")
- print(f" description : {t.get('description', '')}")
- print(f" params ({len(t.get('params', []))}):")
- for p in t.get("params", []):
- req_mark = "*" if p["required"] else " "
- default_str = f" default={p['default']}" if p.get("default") is not None else ""
- enum_str = f" enum={p['enum']}" if p.get("enum") else ""
- print(f" {req_mark} {p['name']:<25} {p['type']:<12} {p.get('description', '')}{default_str}{enum_str}")
- if t.get("output_schema"):
- out_props = t["output_schema"].get("properties", {})
- print(f" output ({len(out_props)}):")
- for oname, odef in out_props.items():
- print(f" {oname:<25} {odef.get('type', ''):<12} {odef.get('description', '')}")
- print("\n [PASS]")
- def test_tools_status():
- print("=== Tools Status ===")
- resp = httpx.get(f"{BASE_URL}/tools/status")
- print(f" Status : {resp.status_code}")
- data = resp.json()
- print(f" Total : {len(data['tools'])}")
- for t in data["tools"]:
- print(f" - {t['tool_id']}")
- print(f" state : {t['state']}")
- print(f" port : {t.get('port')}")
- print(f" pid : {t.get('pid')}")
- print(f" sources: {[s['type'] for s in t.get('sources', [])]}")
- if t.get("last_error"):
- print(f" error : {t['last_error']}")
- print(" [PASS]")
- def _run_tool(
- tool_id: str, params: dict[str, Any], timeout: float = 120.0
- ) -> tuple[bool, str | None, Any]:
- """POST /run_tool。成功返回 (True, None, result);失败 (False, message, None)。"""
- resp = httpx.post(
- f"{BASE_URL}/run_tool",
- json={"tool_id": tool_id, "params": params},
- timeout=timeout,
- )
- print(f" Status : {resp.status_code}")
- if resp.status_code != 200:
- return False, f"HTTP {resp.status_code}: {resp.text[:300]}", None
- try:
- data = resp.json()
- except Exception as e:
- return False, f"Invalid JSON: {e}", None
- if data.get("status") != "success":
- return False, data.get("error") or str(data), None
- result = data.get("result")
- if isinstance(result, dict) and result.get("status") == "error":
- return False, str(result.get("error", result)), None
- return True, None, result
- def test_select_tool(tool_id: str):
- print(f"=== Run Tool (tool_id={tool_id!r}) ===")
- ok, err, result = _run_tool(tool_id, {}, timeout=30)
- print(f" Result :")
- if not ok:
- print(f" error : {err}")
- print(" [FAIL]")
- return
- result_str = json.dumps(result, ensure_ascii=False, indent=6)
- print(f" body: {result_str[:500]}")
- print(" [PASS]")
- def test_stitch_images():
- print("=== Test Image Stitcher ===")
- if not TEST_IMAGES_DIR.exists():
- print(f" ERROR: Test images directory not found: {TEST_IMAGES_DIR}")
- print(" [SKIP]")
- return
- image_files = sorted(TEST_IMAGES_DIR.glob("*.png"))
- if len(image_files) < 2:
- print(f" ERROR: Need at least 2 images, found {len(image_files)}")
- print(" [SKIP]")
- return
- print(f" Images : {len(image_files)} found")
- images_b64 = []
- for img_path in image_files[:6]:
- with open(img_path, "rb") as f:
- images_b64.append(base64.b64encode(f.read()).decode())
- print(f" - {img_path.name}")
- print(f" Calling image_stitcher (grid, 2 columns)...")
- try:
- ok, err, result = _run_tool(
- "image_stitcher",
- {
- "images": images_b64,
- "direction": "grid",
- "columns": 2,
- "spacing": 10,
- "background_color": "#FFFFFF",
- },
- timeout=120.0,
- )
- if not ok:
- print(f" ERROR : {err}")
- print(" [FAIL]")
- return
- if not isinstance(result, dict) or "image" not in result:
- print(f" ERROR : 缺少 image 字段: {result!r}")
- print(" [FAIL]")
- return
- print(f" Result :")
- print(f" width : {result.get('width')}")
- print(f" height: {result.get('height')}")
- OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
- output_path = OUTPUT_DIR / "stitched_result.png"
- with open(output_path, "wb") as f:
- f.write(base64.b64decode(result["image"]))
- print(f" saved : {output_path}")
- print(" [PASS]")
- except httpx.TimeoutException:
- print(" ERROR : Request timeout")
- print(" [FAIL]")
- except Exception as e:
- print(f" ERROR : {e}")
- print(" [FAIL]")
- def _nano_has_image(data: dict[str, Any]) -> bool:
- if data.get("images"):
- return True
- img = data.get("image")
- if isinstance(img, str) and len(img) > 100:
- return True
- if data.get("image_base64"):
- return True
- cands = data.get("candidates")
- if isinstance(cands, list) and cands:
- parts = cands[0].get("content", {}).get("parts", [])
- for p in parts:
- if isinstance(p, dict) and (p.get("inlineData") or p.get("inline_data")):
- return True
- return False
- _NANO_DATA_URL_RE = re.compile(r"^data:([^;]+);base64,(.+)$", re.I | re.S)
- def _nano_mime_to_ext(mime: str) -> str:
- base = mime.lower().split(";")[0].strip()
- if base == "image/png":
- return "png"
- if base in ("image/jpeg", "image/jpg"):
- return "jpg"
- if base == "image/webp":
- return "webp"
- return "png"
- def _nano_collect_image_bytes(result: dict[str, Any]) -> list[tuple[bytes, str]]:
- """从 nano_banana 常见返回结构解析出 (raw_bytes, ext) 列表。"""
- out: list[tuple[bytes, str]] = []
- imgs = result.get("images")
- if isinstance(imgs, list):
- for item in imgs:
- if not isinstance(item, str) or not item.strip():
- continue
- s = item.strip()
- m = _NANO_DATA_URL_RE.match(s)
- if m:
- mime, b64 = m.group(1), m.group(2)
- try:
- out.append((base64.b64decode(b64), _nano_mime_to_ext(mime)))
- except Exception:
- continue
- else:
- try:
- out.append((base64.b64decode(s), "png"))
- except Exception:
- continue
- img_one = result.get("image")
- if not out and isinstance(img_one, str) and len(img_one) > 100:
- try:
- out.append((base64.b64decode(img_one), "png"))
- except Exception:
- pass
- b64_field = result.get("image_base64")
- if not out and isinstance(b64_field, str) and b64_field.strip():
- try:
- out.append((base64.b64decode(b64_field.strip()), "png"))
- except Exception:
- pass
- cands = result.get("candidates")
- if not out and isinstance(cands, list) and cands:
- cand0 = cands[0]
- if isinstance(cand0, dict):
- for p in cand0.get("content", {}).get("parts", []) or []:
- if not isinstance(p, dict):
- continue
- inline = p.get("inlineData") or p.get("inline_data")
- if not isinstance(inline, dict):
- continue
- b64 = inline.get("data")
- if not b64:
- continue
- mime = str(
- inline.get("mimeType") or inline.get("mime_type") or "image/png"
- )
- try:
- out.append((base64.b64decode(b64), _nano_mime_to_ext(mime)))
- except Exception:
- continue
- break
- return out
- def _nano_save_images_default(result: dict[str, Any]) -> list[Path]:
- """默认写入 tests/output/nano_banana_result[_{n}].{ext},返回已写入路径。"""
- blobs = _nano_collect_image_bytes(result)
- if not blobs:
- return []
- OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
- paths: list[Path] = []
- if len(blobs) == 1:
- ext = blobs[0][1]
- p = OUTPUT_DIR / f"nano_banana_result.{ext}"
- p.write_bytes(blobs[0][0])
- paths.append(p)
- else:
- for i, (raw, ext) in enumerate(blobs):
- p = OUTPUT_DIR / f"nano_banana_result_{i}.{ext}"
- p.write_bytes(raw)
- paths.append(p)
- return paths
- def test_nano_banana():
- """POST /run_tool → nano_banana;依赖 tools/local/nano_banana/.env 中 GEMINI_API_KEY。"""
- print("=== Test nano_banana (Gemini 图模) ===")
- print(" 需: tools/local/nano_banana/.env → GEMINI_API_KEY")
- print(" 可选环境变量: NANO_BANANA_TEST_PROMPT, NANO_BANANA_MODEL")
- print(" 通过时默认保存图片到 tests/output/nano_banana_result*.png(或多张时带序号)")
- tid = os.environ.get("NANO_BANANA_TOOL_ID", "nano_banana")
- try:
- tr = httpx.get(f"{BASE_URL}/tools", timeout=30)
- tr.raise_for_status()
- ids = {t["tool_id"] for t in tr.json().get("tools", [])}
- if tid not in ids:
- print(f" ERROR : 注册表中无 {tid!r},请先检查 data/registry.json")
- print(" [FAIL]")
- return
- print(f" tool_id: {tid} (已注册)")
- except Exception as e:
- print(f" ERROR : GET /tools 失败: {e}")
- print(" [FAIL]")
- return
- prompt = os.environ.get(
- "NANO_BANANA_TEST_PROMPT",
- "A minimal flat yellow banana icon on white background, no text",
- )
- params: dict[str, Any] = {"prompt": prompt}
- model = os.environ.get("NANO_BANANA_MODEL", "").strip()
- if model:
- params["model"] = model
- print(f" model: {model}")
- else:
- print(" model: (使用工具默认 / GEMINI_IMAGE_MODEL)")
- print(f" calling {tid} ...")
- try:
- ok, err, result = _run_tool(tid, params, timeout=180.0)
- if not ok:
- print(f" ERROR : {err}")
- print(" [FAIL]")
- return
- if not isinstance(result, dict):
- print(f" ERROR : 非 dict 结果: {type(result)}")
- print(" [FAIL]")
- return
- if _nano_has_image(result):
- n = len(result["images"]) if isinstance(result.get("images"), list) else 0
- print(f" Result : 含图片字段 (images 条数≈{n})")
- if result.get("model"):
- print(f" model: {result['model']}")
- saved = _nano_save_images_default(result)
- if saved:
- for sp in saved:
- print(f" saved : {sp}")
- else:
- print(
- " WARN : 未能从响应解析出图片字节(字段存在但无法 base64 解码)"
- )
- print(" [PASS]")
- return
- print(f" ERROR : 未识别到图片字段,keys={list(result.keys())}")
- print(f" 截断: {str(result)[:400]}...")
- print(" [FAIL]")
- except httpx.TimeoutException:
- print(" ERROR : Request timeout")
- print(" [FAIL]")
- except Exception as e:
- print(f" ERROR : {e}")
- print(" [FAIL]")
- def load_task_spec(task_name: str) -> dict:
- task_file = TASKS_DIR / f"{task_name}.json"
- if not task_file.exists():
- print(f" ERROR: Task file not found: {task_file}")
- print(" Available tasks:")
- if TASKS_DIR.exists():
- for f in TASKS_DIR.glob("*.json"):
- print(f" - {f.stem}")
- sys.exit(1)
- with open(task_file, "r", encoding="utf-8") as f:
- return json.load(f)
- def test_create_tool(task_name: str = None):
- print(f"=== Create Tool{f' (task={task_name!r})' if task_name else ''} ===")
- if task_name:
- task_data = load_task_spec(task_name)
- print(f" File : tests/tasks/{task_name}.json")
- print(f" Description: {task_data['description'][:80]}")
- else:
- task_data = {"description": "创建一个简单的文本计数工具,输入文本,返回字数和字符数"}
- print(f" Description: {task_data['description']}")
- resp = httpx.post(f"{BASE_URL}/create_tool", json=task_data)
- data = resp.json()
- task_id = data["task_id"]
- print(f" Task ID : {task_id}")
- print(f" Status : {data['status']}")
- assert data["status"] == "pending"
- print(" [SUBMITTED]")
- print(f"\n Polling task {task_id} (timeout 10min)...")
- for i in range(120):
- time.sleep(5)
- resp = httpx.get(f"{BASE_URL}/tasks/{task_id}", timeout=30)
- task = resp.json()
- status = task["status"]
- if i % 6 == 0:
- print(f" [{i*5}s] status={status}")
- if status == "completed":
- print(f"\n Completed!")
- print(f" Result : {str(task.get('result', ''))[:300]}")
- resp2 = httpx.post(f"{BASE_URL}/search_tools", json={})
- tools = resp2.json()["tools"]
- print(f" Registered : {[t['tool_id'] for t in tools]}")
- print(" [PASS]")
- return
- if status == "failed":
- print(f"\n Failed!")
- print(f" Error : {task.get('error', 'unknown')}")
- print(" [FAIL]")
- return
- print(f"\n Timeout after 600s")
- print(" [TIMEOUT]")
- def main():
- parser = argparse.ArgumentParser(description="Router API Test")
- parser.add_argument("--search", nargs="?", const="", metavar="KEYWORD",
- help="search tools, optional keyword")
- parser.add_argument("--status", action="store_true",
- help="show tools status")
- parser.add_argument("--select", metavar="TOOL_ID",
- help="call a tool by tool_id")
- parser.add_argument("--stitch", action="store_true",
- help="test image stitcher with sample images")
- parser.add_argument("--nano_banana", action="store_true",
- help="test nano_banana (Gemini); need GEMINI_API_KEY in tools/local/nano_banana/.env")
- parser.add_argument("--create", nargs="?", const="", metavar="TASK_NAME",
- help="create tool, optional task file name")
- parser.add_argument("--launch-env", action="store_true",
- help="create RunComfy launch env tool (runcomfy_launch_env)")
- parser.add_argument("--run-only", action="store_true",
- help="create RunComfy run only tool (runcomfy_run_only)")
- parser.add_argument("--stop-env", action="store_true",
- help="create RunComfy stop env tool (runcomfy_stop_env)")
- args = parser.parse_args()
- print(f"Target: {BASE_URL}\n")
- check_connection()
- # 始终跑 health check
- test_health()
- ran_any = False
- if args.search is not None:
- print()
- test_search_tools(args.search or None)
- ran_any = True
- if args.status:
- print()
- test_tools_status()
- ran_any = True
- if args.select:
- print()
- test_select_tool(args.select)
- ran_any = True
- if args.stitch:
- print()
- test_stitch_images()
- ran_any = True
- if args.nano_banana:
- print()
- test_nano_banana()
- ran_any = True
- if args.create is not None:
- print()
- test_create_tool(args.create or None)
- ran_any = True
- if args.launch_env:
- print()
- test_create_tool("runcomfy_launch_env")
- ran_any = True
- if args.run_only:
- print()
- test_create_tool("runcomfy_run_only")
- ran_any = True
- if args.stop_env:
- print()
- test_create_tool("runcomfy_stop_env")
- ran_any = True
- if not ran_any:
- print()
- print("No test specified. Available options:")
- parser.print_help()
- print("\n=== DONE ===")
- if __name__ == "__main__":
- main()
|