test_router_api.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524
  1. """测试 Router 核心接口
  2. 用法:
  3. uv run python tests/test_router_api.py # 只跑 health check
  4. uv run python tests/test_router_api.py --search # 搜索工具列表
  5. uv run python tests/test_router_api.py --search image # 关键词搜索
  6. uv run python tests/test_router_api.py --status # 工具运行状态
  7. uv run python tests/test_router_api.py --select image_stitcher # POST /run_tool 调用工具
  8. uv run python tests/test_router_api.py --stitch # 测试图片拼接
  9. uv run python tests/test_router_api.py --nano_banana # 测试 nano_banana
  10. uv run python tests/test_router_api.py --create # 默认任务
  11. uv run python tests/test_router_api.py --create image_stitcher # 指定任务文件
  12. uv run python tests/test_router_api.py --launch-env # 创建 RunComfy 启动环境工具
  13. uv run python tests/test_router_api.py --run-only # 创建 RunComfy 任务执行工具
  14. uv run python tests/test_router_api.py --stop-env # 创建 RunComfy 环境销毁工具
  15. """
  16. import argparse
  17. import base64
  18. import json
  19. import os
  20. import re
  21. import sys
  22. import time
  23. from pathlib import Path
  24. from typing import Any
  25. import httpx
  26. BASE_URL = os.environ.get("TOOL_AGENT_ROUTER_URL", "http://127.0.0.1:8001")
  27. TASKS_DIR = Path(__file__).parent / "tasks"
  28. TEST_IMAGES_DIR = TASKS_DIR / "stitcher_images"
  29. OUTPUT_DIR = Path(__file__).parent / "output"
  30. def check_connection():
  31. try:
  32. httpx.get(f"{BASE_URL}/health", timeout=3)
  33. except httpx.ConnectError:
  34. print(f"ERROR: Cannot connect to {BASE_URL}")
  35. print("Please start the service first:")
  36. print(" uv run python -m tool_agent")
  37. sys.exit(1)
  38. def test_health():
  39. print("=== Health Check ===")
  40. resp = httpx.get(f"{BASE_URL}/health")
  41. print(f" Status : {resp.status_code}")
  42. print(f" Body : {json.dumps(resp.json(), ensure_ascii=False, indent=4)}")
  43. assert resp.status_code == 200
  44. print(" [PASS]")
  45. def test_search_tools(keyword: str = None):
  46. print(f"=== Search Tools{f' (keyword={keyword!r})' if keyword else ''} ===")
  47. payload = {"keyword": keyword} if keyword else {}
  48. resp = httpx.post(f"{BASE_URL}/search_tools", json=payload)
  49. print(f" Status : {resp.status_code}")
  50. if resp.status_code != 200:
  51. print(f" Body : {resp.text}")
  52. print(" [FAIL]")
  53. return
  54. data = resp.json()
  55. print(f" Total : {data['total']}")
  56. for t in data["tools"]:
  57. print(f"\n [{t['tool_id']}]")
  58. print(f" name : {t['name']}")
  59. print(f" category : {t.get('category', '')}")
  60. print(f" state : {t['state']}")
  61. print(f" runtime : {t.get('runtime_type', '')} host_dir={t.get('host_dir', '')}")
  62. print(f" endpoint : {t.get('http_method', '')} {t.get('endpoint_path', '')} port={t.get('port')}")
  63. print(f" stream_support: {t.get('stream_support', False)}")
  64. print(f" description : {t.get('description', '')}")
  65. print(f" params ({len(t.get('params', []))}):")
  66. for p in t.get("params", []):
  67. req_mark = "*" if p["required"] else " "
  68. default_str = f" default={p['default']}" if p.get("default") is not None else ""
  69. enum_str = f" enum={p['enum']}" if p.get("enum") else ""
  70. print(f" {req_mark} {p['name']:<25} {p['type']:<12} {p.get('description', '')}{default_str}{enum_str}")
  71. if t.get("output_schema"):
  72. out_props = t["output_schema"].get("properties", {})
  73. print(f" output ({len(out_props)}):")
  74. for oname, odef in out_props.items():
  75. print(f" {oname:<25} {odef.get('type', ''):<12} {odef.get('description', '')}")
  76. print("\n [PASS]")
  77. def test_tools_status():
  78. print("=== Tools Status ===")
  79. resp = httpx.get(f"{BASE_URL}/tools/status")
  80. print(f" Status : {resp.status_code}")
  81. data = resp.json()
  82. print(f" Total : {len(data['tools'])}")
  83. for t in data["tools"]:
  84. print(f" - {t['tool_id']}")
  85. print(f" state : {t['state']}")
  86. print(f" port : {t.get('port')}")
  87. print(f" pid : {t.get('pid')}")
  88. print(f" sources: {[s['type'] for s in t.get('sources', [])]}")
  89. if t.get("last_error"):
  90. print(f" error : {t['last_error']}")
  91. print(" [PASS]")
  92. def _run_tool(
  93. tool_id: str, params: dict[str, Any], timeout: float = 120.0
  94. ) -> tuple[bool, str | None, Any]:
  95. """POST /run_tool。成功返回 (True, None, result);失败 (False, message, None)。"""
  96. resp = httpx.post(
  97. f"{BASE_URL}/run_tool",
  98. json={"tool_id": tool_id, "params": params},
  99. timeout=timeout,
  100. )
  101. print(f" Status : {resp.status_code}")
  102. if resp.status_code != 200:
  103. return False, f"HTTP {resp.status_code}: {resp.text[:300]}", None
  104. try:
  105. data = resp.json()
  106. except Exception as e:
  107. return False, f"Invalid JSON: {e}", None
  108. if data.get("status") != "success":
  109. return False, data.get("error") or str(data), None
  110. result = data.get("result")
  111. if isinstance(result, dict) and result.get("status") == "error":
  112. return False, str(result.get("error", result)), None
  113. return True, None, result
  114. def test_select_tool(tool_id: str):
  115. print(f"=== Run Tool (tool_id={tool_id!r}) ===")
  116. ok, err, result = _run_tool(tool_id, {}, timeout=30)
  117. print(f" Result :")
  118. if not ok:
  119. print(f" error : {err}")
  120. print(" [FAIL]")
  121. return
  122. result_str = json.dumps(result, ensure_ascii=False, indent=6)
  123. print(f" body: {result_str[:500]}")
  124. print(" [PASS]")
  125. def test_stitch_images():
  126. print("=== Test Image Stitcher ===")
  127. if not TEST_IMAGES_DIR.exists():
  128. print(f" ERROR: Test images directory not found: {TEST_IMAGES_DIR}")
  129. print(" [SKIP]")
  130. return
  131. image_files = sorted(TEST_IMAGES_DIR.glob("*.png"))
  132. if len(image_files) < 2:
  133. print(f" ERROR: Need at least 2 images, found {len(image_files)}")
  134. print(" [SKIP]")
  135. return
  136. print(f" Images : {len(image_files)} found")
  137. images_b64 = []
  138. for img_path in image_files[:6]:
  139. with open(img_path, "rb") as f:
  140. images_b64.append(base64.b64encode(f.read()).decode())
  141. print(f" - {img_path.name}")
  142. print(f" Calling image_stitcher (grid, 2 columns)...")
  143. try:
  144. ok, err, result = _run_tool(
  145. "image_stitcher",
  146. {
  147. "images": images_b64,
  148. "direction": "grid",
  149. "columns": 2,
  150. "spacing": 10,
  151. "background_color": "#FFFFFF",
  152. },
  153. timeout=120.0,
  154. )
  155. if not ok:
  156. print(f" ERROR : {err}")
  157. print(" [FAIL]")
  158. return
  159. if not isinstance(result, dict) or "image" not in result:
  160. print(f" ERROR : 缺少 image 字段: {result!r}")
  161. print(" [FAIL]")
  162. return
  163. print(f" Result :")
  164. print(f" width : {result.get('width')}")
  165. print(f" height: {result.get('height')}")
  166. OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
  167. output_path = OUTPUT_DIR / "stitched_result.png"
  168. with open(output_path, "wb") as f:
  169. f.write(base64.b64decode(result["image"]))
  170. print(f" saved : {output_path}")
  171. print(" [PASS]")
  172. except httpx.TimeoutException:
  173. print(" ERROR : Request timeout")
  174. print(" [FAIL]")
  175. except Exception as e:
  176. print(f" ERROR : {e}")
  177. print(" [FAIL]")
  178. def _nano_has_image(data: dict[str, Any]) -> bool:
  179. if data.get("images"):
  180. return True
  181. img = data.get("image")
  182. if isinstance(img, str) and len(img) > 100:
  183. return True
  184. if data.get("image_base64"):
  185. return True
  186. cands = data.get("candidates")
  187. if isinstance(cands, list) and cands:
  188. parts = cands[0].get("content", {}).get("parts", [])
  189. for p in parts:
  190. if isinstance(p, dict) and (p.get("inlineData") or p.get("inline_data")):
  191. return True
  192. return False
  193. _NANO_DATA_URL_RE = re.compile(r"^data:([^;]+);base64,(.+)$", re.I | re.S)
  194. def _nano_mime_to_ext(mime: str) -> str:
  195. base = mime.lower().split(";")[0].strip()
  196. if base == "image/png":
  197. return "png"
  198. if base in ("image/jpeg", "image/jpg"):
  199. return "jpg"
  200. if base == "image/webp":
  201. return "webp"
  202. return "png"
  203. def _nano_collect_image_bytes(result: dict[str, Any]) -> list[tuple[bytes, str]]:
  204. """从 nano_banana 常见返回结构解析出 (raw_bytes, ext) 列表。"""
  205. out: list[tuple[bytes, str]] = []
  206. imgs = result.get("images")
  207. if isinstance(imgs, list):
  208. for item in imgs:
  209. if not isinstance(item, str) or not item.strip():
  210. continue
  211. s = item.strip()
  212. m = _NANO_DATA_URL_RE.match(s)
  213. if m:
  214. mime, b64 = m.group(1), m.group(2)
  215. try:
  216. out.append((base64.b64decode(b64), _nano_mime_to_ext(mime)))
  217. except Exception:
  218. continue
  219. else:
  220. try:
  221. out.append((base64.b64decode(s), "png"))
  222. except Exception:
  223. continue
  224. img_one = result.get("image")
  225. if not out and isinstance(img_one, str) and len(img_one) > 100:
  226. try:
  227. out.append((base64.b64decode(img_one), "png"))
  228. except Exception:
  229. pass
  230. b64_field = result.get("image_base64")
  231. if not out and isinstance(b64_field, str) and b64_field.strip():
  232. try:
  233. out.append((base64.b64decode(b64_field.strip()), "png"))
  234. except Exception:
  235. pass
  236. cands = result.get("candidates")
  237. if not out and isinstance(cands, list) and cands:
  238. cand0 = cands[0]
  239. if isinstance(cand0, dict):
  240. for p in cand0.get("content", {}).get("parts", []) or []:
  241. if not isinstance(p, dict):
  242. continue
  243. inline = p.get("inlineData") or p.get("inline_data")
  244. if not isinstance(inline, dict):
  245. continue
  246. b64 = inline.get("data")
  247. if not b64:
  248. continue
  249. mime = str(
  250. inline.get("mimeType") or inline.get("mime_type") or "image/png"
  251. )
  252. try:
  253. out.append((base64.b64decode(b64), _nano_mime_to_ext(mime)))
  254. except Exception:
  255. continue
  256. break
  257. return out
  258. def _nano_save_images_default(result: dict[str, Any]) -> list[Path]:
  259. """默认写入 tests/output/nano_banana_result[_{n}].{ext},返回已写入路径。"""
  260. blobs = _nano_collect_image_bytes(result)
  261. if not blobs:
  262. return []
  263. OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
  264. paths: list[Path] = []
  265. if len(blobs) == 1:
  266. ext = blobs[0][1]
  267. p = OUTPUT_DIR / f"nano_banana_result.{ext}"
  268. p.write_bytes(blobs[0][0])
  269. paths.append(p)
  270. else:
  271. for i, (raw, ext) in enumerate(blobs):
  272. p = OUTPUT_DIR / f"nano_banana_result_{i}.{ext}"
  273. p.write_bytes(raw)
  274. paths.append(p)
  275. return paths
  276. def test_nano_banana():
  277. """POST /run_tool → nano_banana;依赖 tools/local/nano_banana/.env 中 GEMINI_API_KEY。"""
  278. print("=== Test nano_banana (Gemini 图模) ===")
  279. print(" 需: tools/local/nano_banana/.env → GEMINI_API_KEY")
  280. print(" 可选环境变量: NANO_BANANA_TEST_PROMPT, NANO_BANANA_MODEL")
  281. print(" 通过时默认保存图片到 tests/output/nano_banana_result*.png(或多张时带序号)")
  282. tid = os.environ.get("NANO_BANANA_TOOL_ID", "nano_banana")
  283. try:
  284. tr = httpx.get(f"{BASE_URL}/tools", timeout=30)
  285. tr.raise_for_status()
  286. ids = {t["tool_id"] for t in tr.json().get("tools", [])}
  287. if tid not in ids:
  288. print(f" ERROR : 注册表中无 {tid!r},请先检查 data/registry.json")
  289. print(" [FAIL]")
  290. return
  291. print(f" tool_id: {tid} (已注册)")
  292. except Exception as e:
  293. print(f" ERROR : GET /tools 失败: {e}")
  294. print(" [FAIL]")
  295. return
  296. prompt = os.environ.get(
  297. "NANO_BANANA_TEST_PROMPT",
  298. "A minimal flat yellow banana icon on white background, no text",
  299. )
  300. params: dict[str, Any] = {"prompt": prompt}
  301. model = os.environ.get("NANO_BANANA_MODEL", "").strip()
  302. if model:
  303. params["model"] = model
  304. print(f" model: {model}")
  305. else:
  306. print(" model: (使用工具默认 / GEMINI_IMAGE_MODEL)")
  307. print(f" calling {tid} ...")
  308. try:
  309. ok, err, result = _run_tool(tid, params, timeout=180.0)
  310. if not ok:
  311. print(f" ERROR : {err}")
  312. print(" [FAIL]")
  313. return
  314. if not isinstance(result, dict):
  315. print(f" ERROR : 非 dict 结果: {type(result)}")
  316. print(" [FAIL]")
  317. return
  318. if _nano_has_image(result):
  319. n = len(result["images"]) if isinstance(result.get("images"), list) else 0
  320. print(f" Result : 含图片字段 (images 条数≈{n})")
  321. if result.get("model"):
  322. print(f" model: {result['model']}")
  323. saved = _nano_save_images_default(result)
  324. if saved:
  325. for sp in saved:
  326. print(f" saved : {sp}")
  327. else:
  328. print(
  329. " WARN : 未能从响应解析出图片字节(字段存在但无法 base64 解码)"
  330. )
  331. print(" [PASS]")
  332. return
  333. print(f" ERROR : 未识别到图片字段,keys={list(result.keys())}")
  334. print(f" 截断: {str(result)[:400]}...")
  335. print(" [FAIL]")
  336. except httpx.TimeoutException:
  337. print(" ERROR : Request timeout")
  338. print(" [FAIL]")
  339. except Exception as e:
  340. print(f" ERROR : {e}")
  341. print(" [FAIL]")
  342. def load_task_spec(task_name: str) -> dict:
  343. task_file = TASKS_DIR / f"{task_name}.json"
  344. if not task_file.exists():
  345. print(f" ERROR: Task file not found: {task_file}")
  346. print(" Available tasks:")
  347. if TASKS_DIR.exists():
  348. for f in TASKS_DIR.glob("*.json"):
  349. print(f" - {f.stem}")
  350. sys.exit(1)
  351. with open(task_file, "r", encoding="utf-8") as f:
  352. return json.load(f)
  353. def test_create_tool(task_name: str = None):
  354. print(f"=== Create Tool{f' (task={task_name!r})' if task_name else ''} ===")
  355. if task_name:
  356. task_data = load_task_spec(task_name)
  357. print(f" File : tests/tasks/{task_name}.json")
  358. print(f" Description: {task_data['description'][:80]}")
  359. else:
  360. task_data = {"description": "创建一个简单的文本计数工具,输入文本,返回字数和字符数"}
  361. print(f" Description: {task_data['description']}")
  362. resp = httpx.post(f"{BASE_URL}/create_tool", json=task_data)
  363. data = resp.json()
  364. task_id = data["task_id"]
  365. print(f" Task ID : {task_id}")
  366. print(f" Status : {data['status']}")
  367. assert data["status"] == "pending"
  368. print(" [SUBMITTED]")
  369. print(f"\n Polling task {task_id} (timeout 10min)...")
  370. for i in range(120):
  371. time.sleep(5)
  372. resp = httpx.get(f"{BASE_URL}/tasks/{task_id}", timeout=30)
  373. task = resp.json()
  374. status = task["status"]
  375. if i % 6 == 0:
  376. print(f" [{i*5}s] status={status}")
  377. if status == "completed":
  378. print(f"\n Completed!")
  379. print(f" Result : {str(task.get('result', ''))[:300]}")
  380. resp2 = httpx.post(f"{BASE_URL}/search_tools", json={})
  381. tools = resp2.json()["tools"]
  382. print(f" Registered : {[t['tool_id'] for t in tools]}")
  383. print(" [PASS]")
  384. return
  385. if status == "failed":
  386. print(f"\n Failed!")
  387. print(f" Error : {task.get('error', 'unknown')}")
  388. print(" [FAIL]")
  389. return
  390. print(f"\n Timeout after 600s")
  391. print(" [TIMEOUT]")
  392. def main():
  393. parser = argparse.ArgumentParser(description="Router API Test")
  394. parser.add_argument("--search", nargs="?", const="", metavar="KEYWORD",
  395. help="search tools, optional keyword")
  396. parser.add_argument("--status", action="store_true",
  397. help="show tools status")
  398. parser.add_argument("--select", metavar="TOOL_ID",
  399. help="call a tool by tool_id")
  400. parser.add_argument("--stitch", action="store_true",
  401. help="test image stitcher with sample images")
  402. parser.add_argument("--nano_banana", action="store_true",
  403. help="test nano_banana (Gemini); need GEMINI_API_KEY in tools/local/nano_banana/.env")
  404. parser.add_argument("--create", nargs="?", const="", metavar="TASK_NAME",
  405. help="create tool, optional task file name")
  406. parser.add_argument("--launch-env", action="store_true",
  407. help="create RunComfy launch env tool (runcomfy_launch_env)")
  408. parser.add_argument("--run-only", action="store_true",
  409. help="create RunComfy run only tool (runcomfy_run_only)")
  410. parser.add_argument("--stop-env", action="store_true",
  411. help="create RunComfy stop env tool (runcomfy_stop_env)")
  412. args = parser.parse_args()
  413. print(f"Target: {BASE_URL}\n")
  414. check_connection()
  415. # 始终跑 health check
  416. test_health()
  417. ran_any = False
  418. if args.search is not None:
  419. print()
  420. test_search_tools(args.search or None)
  421. ran_any = True
  422. if args.status:
  423. print()
  424. test_tools_status()
  425. ran_any = True
  426. if args.select:
  427. print()
  428. test_select_tool(args.select)
  429. ran_any = True
  430. if args.stitch:
  431. print()
  432. test_stitch_images()
  433. ran_any = True
  434. if args.nano_banana:
  435. print()
  436. test_nano_banana()
  437. ran_any = True
  438. if args.create is not None:
  439. print()
  440. test_create_tool(args.create or None)
  441. ran_any = True
  442. if args.launch_env:
  443. print()
  444. test_create_tool("runcomfy_launch_env")
  445. ran_any = True
  446. if args.run_only:
  447. print()
  448. test_create_tool("runcomfy_run_only")
  449. ran_any = True
  450. if args.stop_env:
  451. print()
  452. test_create_tool("runcomfy_stop_env")
  453. ran_any = True
  454. if not ran_any:
  455. print()
  456. print("No test specified. Available options:")
  457. parser.print_help()
  458. print("\n=== DONE ===")
  459. if __name__ == "__main__":
  460. main()