test_router_api.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335
  1. """测试 Router 核心接口
  2. 用法:
  3. uv run python tests/test_router_api.py # 只跑 health check
  4. uv run python tests/test_router_api.py --search # 搜索工具列表
  5. uv run python tests/test_router_api.py --search image # 关键词搜索
  6. uv run python tests/test_router_api.py --status # 工具运行状态
  7. uv run python tests/test_router_api.py --select image_stitcher # POST /run_tool 调用工具
  8. uv run python tests/test_router_api.py --stitch # 测试图片拼接
  9. uv run python tests/test_router_api.py --create # 默认任务
  10. uv run python tests/test_router_api.py --create image_stitcher # 指定任务文件
  11. uv run python tests/test_router_api.py --launch-env # 创建 RunComfy 启动环境工具
  12. uv run python tests/test_router_api.py --run-only # 创建 RunComfy 任务执行工具
  13. uv run python tests/test_router_api.py --stop-env # 创建 RunComfy 环境销毁工具
  14. """
  15. import argparse
  16. import base64
  17. import json
  18. import sys
  19. import time
  20. from pathlib import Path
  21. from typing import Any
  22. import httpx
  23. BASE_URL = "http://127.0.0.1:8001"
  24. TASKS_DIR = Path(__file__).parent / "tasks"
  25. TEST_IMAGES_DIR = TASKS_DIR / "stitcher_images"
  26. OUTPUT_DIR = Path(__file__).parent / "output"
  27. def check_connection():
  28. try:
  29. httpx.get(f"{BASE_URL}/health", timeout=3)
  30. except httpx.ConnectError:
  31. print(f"ERROR: Cannot connect to {BASE_URL}")
  32. print("Please start the service first:")
  33. print(" uv run python -m tool_agent")
  34. sys.exit(1)
  35. def test_health():
  36. print("=== Health Check ===")
  37. resp = httpx.get(f"{BASE_URL}/health")
  38. print(f" Status : {resp.status_code}")
  39. print(f" Body : {json.dumps(resp.json(), ensure_ascii=False, indent=4)}")
  40. assert resp.status_code == 200
  41. print(" [PASS]")
  42. def test_search_tools(keyword: str = None):
  43. print(f"=== Search Tools{f' (keyword={keyword!r})' if keyword else ''} ===")
  44. payload = {"keyword": keyword} if keyword else {}
  45. resp = httpx.post(f"{BASE_URL}/search_tools", json=payload)
  46. print(f" Status : {resp.status_code}")
  47. if resp.status_code != 200:
  48. print(f" Body : {resp.text}")
  49. print(" [FAIL]")
  50. return
  51. data = resp.json()
  52. print(f" Total : {data['total']}")
  53. for t in data["tools"]:
  54. print(f"\n [{t['tool_id']}]")
  55. print(f" name : {t['name']}")
  56. print(f" category : {t.get('category', '')}")
  57. print(f" state : {t['state']}")
  58. print(f" runtime : {t.get('runtime_type', '')} host_dir={t.get('host_dir', '')}")
  59. print(f" endpoint : {t.get('http_method', '')} {t.get('endpoint_path', '')} port={t.get('port')}")
  60. print(f" stream_support: {t.get('stream_support', False)}")
  61. print(f" description : {t.get('description', '')}")
  62. print(f" params ({len(t.get('params', []))}):")
  63. for p in t.get("params", []):
  64. req_mark = "*" if p["required"] else " "
  65. default_str = f" default={p['default']}" if p.get("default") is not None else ""
  66. enum_str = f" enum={p['enum']}" if p.get("enum") else ""
  67. print(f" {req_mark} {p['name']:<25} {p['type']:<12} {p.get('description', '')}{default_str}{enum_str}")
  68. if t.get("output_schema"):
  69. out_props = t["output_schema"].get("properties", {})
  70. print(f" output ({len(out_props)}):")
  71. for oname, odef in out_props.items():
  72. print(f" {oname:<25} {odef.get('type', ''):<12} {odef.get('description', '')}")
  73. print("\n [PASS]")
  74. def test_tools_status():
  75. print("=== Tools Status ===")
  76. resp = httpx.get(f"{BASE_URL}/tools/status")
  77. print(f" Status : {resp.status_code}")
  78. data = resp.json()
  79. print(f" Total : {len(data['tools'])}")
  80. for t in data["tools"]:
  81. print(f" - {t['tool_id']}")
  82. print(f" state : {t['state']}")
  83. print(f" port : {t.get('port')}")
  84. print(f" pid : {t.get('pid')}")
  85. print(f" sources: {[s['type'] for s in t.get('sources', [])]}")
  86. if t.get("last_error"):
  87. print(f" error : {t['last_error']}")
  88. print(" [PASS]")
  89. def _run_tool(
  90. tool_id: str, params: dict[str, Any], timeout: float = 120.0
  91. ) -> tuple[bool, str | None, Any]:
  92. """POST /run_tool。成功返回 (True, None, result);失败 (False, message, None)。"""
  93. resp = httpx.post(
  94. f"{BASE_URL}/run_tool",
  95. json={"tool_id": tool_id, "params": params},
  96. timeout=timeout,
  97. )
  98. print(f" Status : {resp.status_code}")
  99. if resp.status_code != 200:
  100. return False, f"HTTP {resp.status_code}: {resp.text[:300]}", None
  101. try:
  102. data = resp.json()
  103. except Exception as e:
  104. return False, f"Invalid JSON: {e}", None
  105. if data.get("status") != "success":
  106. return False, data.get("error") or str(data), None
  107. result = data.get("result")
  108. if isinstance(result, dict) and result.get("status") == "error":
  109. return False, str(result.get("error", result)), None
  110. return True, None, result
  111. def test_select_tool(tool_id: str):
  112. print(f"=== Run Tool (tool_id={tool_id!r}) ===")
  113. ok, err, result = _run_tool(tool_id, {}, timeout=30)
  114. print(f" Result :")
  115. if not ok:
  116. print(f" error : {err}")
  117. print(" [FAIL]")
  118. return
  119. result_str = json.dumps(result, ensure_ascii=False, indent=6)
  120. print(f" body: {result_str[:500]}")
  121. print(" [PASS]")
  122. def test_stitch_images():
  123. print("=== Test Image Stitcher ===")
  124. if not TEST_IMAGES_DIR.exists():
  125. print(f" ERROR: Test images directory not found: {TEST_IMAGES_DIR}")
  126. print(" [SKIP]")
  127. return
  128. image_files = sorted(TEST_IMAGES_DIR.glob("*.png"))
  129. if len(image_files) < 2:
  130. print(f" ERROR: Need at least 2 images, found {len(image_files)}")
  131. print(" [SKIP]")
  132. return
  133. print(f" Images : {len(image_files)} found")
  134. images_b64 = []
  135. for img_path in image_files[:6]:
  136. with open(img_path, "rb") as f:
  137. images_b64.append(base64.b64encode(f.read()).decode())
  138. print(f" - {img_path.name}")
  139. print(f" Calling image_stitcher (grid, 2 columns)...")
  140. try:
  141. ok, err, result = _run_tool(
  142. "image_stitcher",
  143. {
  144. "images": images_b64,
  145. "direction": "grid",
  146. "columns": 2,
  147. "spacing": 10,
  148. "background_color": "#FFFFFF",
  149. },
  150. timeout=120.0,
  151. )
  152. if not ok:
  153. print(f" ERROR : {err}")
  154. print(" [FAIL]")
  155. return
  156. if not isinstance(result, dict) or "image" not in result:
  157. print(f" ERROR : 缺少 image 字段: {result!r}")
  158. print(" [FAIL]")
  159. return
  160. print(f" Result :")
  161. print(f" width : {result.get('width')}")
  162. print(f" height: {result.get('height')}")
  163. OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
  164. output_path = OUTPUT_DIR / "stitched_result.png"
  165. with open(output_path, "wb") as f:
  166. f.write(base64.b64decode(result["image"]))
  167. print(f" saved : {output_path}")
  168. print(" [PASS]")
  169. except httpx.TimeoutException:
  170. print(" ERROR : Request timeout")
  171. print(" [FAIL]")
  172. except Exception as e:
  173. print(f" ERROR : {e}")
  174. print(" [FAIL]")
  175. def load_task_spec(task_name: str) -> dict:
  176. task_file = TASKS_DIR / f"{task_name}.json"
  177. if not task_file.exists():
  178. print(f" ERROR: Task file not found: {task_file}")
  179. print(" Available tasks:")
  180. if TASKS_DIR.exists():
  181. for f in TASKS_DIR.glob("*.json"):
  182. print(f" - {f.stem}")
  183. sys.exit(1)
  184. with open(task_file, "r", encoding="utf-8") as f:
  185. return json.load(f)
  186. def test_create_tool(task_name: str = None):
  187. print(f"=== Create Tool{f' (task={task_name!r})' if task_name else ''} ===")
  188. if task_name:
  189. task_data = load_task_spec(task_name)
  190. print(f" File : tests/tasks/{task_name}.json")
  191. print(f" Description: {task_data['description'][:80]}")
  192. else:
  193. task_data = {"description": "创建一个简单的文本计数工具,输入文本,返回字数和字符数"}
  194. print(f" Description: {task_data['description']}")
  195. resp = httpx.post(f"{BASE_URL}/create_tool", json=task_data)
  196. data = resp.json()
  197. task_id = data["task_id"]
  198. print(f" Task ID : {task_id}")
  199. print(f" Status : {data['status']}")
  200. assert data["status"] == "pending"
  201. print(" [SUBMITTED]")
  202. print(f"\n Polling task {task_id} (timeout 10min)...")
  203. for i in range(120):
  204. time.sleep(5)
  205. resp = httpx.get(f"{BASE_URL}/tasks/{task_id}", timeout=30)
  206. task = resp.json()
  207. status = task["status"]
  208. if i % 6 == 0:
  209. print(f" [{i*5}s] status={status}")
  210. if status == "completed":
  211. print(f"\n Completed!")
  212. print(f" Result : {str(task.get('result', ''))[:300]}")
  213. resp2 = httpx.post(f"{BASE_URL}/search_tools", json={})
  214. tools = resp2.json()["tools"]
  215. print(f" Registered : {[t['tool_id'] for t in tools]}")
  216. print(" [PASS]")
  217. return
  218. if status == "failed":
  219. print(f"\n Failed!")
  220. print(f" Error : {task.get('error', 'unknown')}")
  221. print(" [FAIL]")
  222. return
  223. print(f"\n Timeout after 600s")
  224. print(" [TIMEOUT]")
  225. def main():
  226. parser = argparse.ArgumentParser(description="Router API Test")
  227. parser.add_argument("--search", nargs="?", const="", metavar="KEYWORD",
  228. help="search tools, optional keyword")
  229. parser.add_argument("--status", action="store_true",
  230. help="show tools status")
  231. parser.add_argument("--select", metavar="TOOL_ID",
  232. help="call a tool by tool_id")
  233. parser.add_argument("--stitch", action="store_true",
  234. help="test image stitcher with sample images")
  235. parser.add_argument("--create", nargs="?", const="", metavar="TASK_NAME",
  236. help="create tool, optional task file name")
  237. parser.add_argument("--launch-env", action="store_true",
  238. help="create RunComfy launch env tool (runcomfy_launch_env)")
  239. parser.add_argument("--run-only", action="store_true",
  240. help="create RunComfy run only tool (runcomfy_run_only)")
  241. parser.add_argument("--stop-env", action="store_true",
  242. help="create RunComfy stop env tool (runcomfy_stop_env)")
  243. args = parser.parse_args()
  244. print(f"Target: {BASE_URL}\n")
  245. check_connection()
  246. # 始终跑 health check
  247. test_health()
  248. ran_any = False
  249. if args.search is not None:
  250. print()
  251. test_search_tools(args.search or None)
  252. ran_any = True
  253. if args.status:
  254. print()
  255. test_tools_status()
  256. ran_any = True
  257. if args.select:
  258. print()
  259. test_select_tool(args.select)
  260. ran_any = True
  261. if args.stitch:
  262. print()
  263. test_stitch_images()
  264. ran_any = True
  265. if args.create is not None:
  266. print()
  267. test_create_tool(args.create or None)
  268. ran_any = True
  269. if args.launch_env:
  270. print()
  271. test_create_tool("runcomfy_launch_env")
  272. ran_any = True
  273. if args.run_only:
  274. print()
  275. test_create_tool("runcomfy_run_only")
  276. ran_any = True
  277. if args.stop_env:
  278. print()
  279. test_create_tool("runcomfy_stop_env")
  280. ran_any = True
  281. if not ran_any:
  282. print()
  283. print("No test specified. Available options:")
  284. parser.print_help()
  285. print("\n=== DONE ===")
  286. if __name__ == "__main__":
  287. main()