test_router_api.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313
  1. """测试 Router 核心接口
  2. 用法:
  3. uv run python tests/test_router_api.py # 只跑 health check
  4. uv run python tests/test_router_api.py --search # 搜索工具列表
  5. uv run python tests/test_router_api.py --search image # 关键词搜索
  6. uv run python tests/test_router_api.py --status # 工具运行状态
  7. uv run python tests/test_router_api.py --select image_stitcher # 调用指定工具
  8. uv run python tests/test_router_api.py --stitch # 测试图片拼接
  9. uv run python tests/test_router_api.py --create # 默认任务
  10. uv run python tests/test_router_api.py --create image_stitcher # 指定任务文件
  11. uv run python tests/test_router_api.py --launch-env # 创建 RunComfy 启动环境工具
  12. uv run python tests/test_router_api.py --run-only # 创建 RunComfy 任务执行工具
  13. uv run python tests/test_router_api.py --stop-env # 创建 RunComfy 环境销毁工具
  14. """
  15. import argparse
  16. import base64
  17. import json
  18. import sys
  19. import time
  20. from pathlib import Path
  21. import httpx
  22. BASE_URL = "http://127.0.0.1:8001"
  23. TASKS_DIR = Path(__file__).parent / "tasks"
  24. TEST_IMAGES_DIR = TASKS_DIR / "stitcher_images"
  25. OUTPUT_DIR = Path(__file__).parent / "output"
  26. def check_connection():
  27. try:
  28. httpx.get(f"{BASE_URL}/health", timeout=3)
  29. except httpx.ConnectError:
  30. print(f"ERROR: Cannot connect to {BASE_URL}")
  31. print("Please start the service first:")
  32. print(" uv run python -m tool_agent")
  33. sys.exit(1)
  34. def test_health():
  35. print("=== Health Check ===")
  36. resp = httpx.get(f"{BASE_URL}/health")
  37. print(f" Status : {resp.status_code}")
  38. print(f" Body : {json.dumps(resp.json(), ensure_ascii=False, indent=4)}")
  39. assert resp.status_code == 200
  40. print(" [PASS]")
  41. def test_search_tools(keyword: str = None):
  42. print(f"=== Search Tools{f' (keyword={keyword!r})' if keyword else ''} ===")
  43. payload = {"keyword": keyword} if keyword else {}
  44. resp = httpx.post(f"{BASE_URL}/search_tools", json=payload)
  45. print(f" Status : {resp.status_code}")
  46. if resp.status_code != 200:
  47. print(f" Body : {resp.text}")
  48. print(" [FAIL]")
  49. return
  50. data = resp.json()
  51. print(f" Total : {data['total']}")
  52. for t in data["tools"]:
  53. print(f"\n [{t['tool_id']}]")
  54. print(f" name : {t['name']}")
  55. print(f" category : {t.get('category', '')}")
  56. print(f" state : {t['state']}")
  57. print(f" runtime : {t.get('runtime_type', '')} host_dir={t.get('host_dir', '')}")
  58. print(f" endpoint : {t.get('http_method', '')} {t.get('endpoint_path', '')} port={t.get('port')}")
  59. print(f" stream_support: {t.get('stream_support', False)}")
  60. print(f" description : {t.get('description', '')}")
  61. print(f" params ({len(t.get('params', []))}):")
  62. for p in t.get("params", []):
  63. req_mark = "*" if p["required"] else " "
  64. default_str = f" default={p['default']}" if p.get("default") is not None else ""
  65. enum_str = f" enum={p['enum']}" if p.get("enum") else ""
  66. print(f" {req_mark} {p['name']:<25} {p['type']:<12} {p.get('description', '')}{default_str}{enum_str}")
  67. if t.get("output_schema"):
  68. out_props = t["output_schema"].get("properties", {})
  69. print(f" output ({len(out_props)}):")
  70. for oname, odef in out_props.items():
  71. print(f" {oname:<25} {odef.get('type', ''):<12} {odef.get('description', '')}")
  72. print("\n [PASS]")
  73. def test_tools_status():
  74. print("=== Tools Status ===")
  75. resp = httpx.get(f"{BASE_URL}/tools/status")
  76. print(f" Status : {resp.status_code}")
  77. data = resp.json()
  78. print(f" Total : {len(data['tools'])}")
  79. for t in data["tools"]:
  80. print(f" - {t['tool_id']}")
  81. print(f" state : {t['state']}")
  82. print(f" port : {t.get('port')}")
  83. print(f" pid : {t.get('pid')}")
  84. print(f" sources: {[s['type'] for s in t.get('sources', [])]}")
  85. if t.get("last_error"):
  86. print(f" error : {t['last_error']}")
  87. print(" [PASS]")
  88. def test_select_tool(tool_id: str):
  89. print(f"=== Select Tool (tool_id={tool_id!r}) ===")
  90. resp = httpx.post(f"{BASE_URL}/select_tool", json={
  91. "tool_id": tool_id,
  92. "params": {}
  93. }, timeout=30)
  94. print(f" Status : {resp.status_code}")
  95. data = resp.json()
  96. print(f" Result :")
  97. print(f" status: {data.get('status')}")
  98. if data.get("error"):
  99. print(f" error : {data['error']}")
  100. else:
  101. result_str = json.dumps(data.get("result"), ensure_ascii=False, indent=6)
  102. print(f" result: {result_str[:500]}")
  103. print(" [PASS]")
  104. def test_stitch_images():
  105. print("=== Test Image Stitcher ===")
  106. if not TEST_IMAGES_DIR.exists():
  107. print(f" ERROR: Test images directory not found: {TEST_IMAGES_DIR}")
  108. print(" [SKIP]")
  109. return
  110. image_files = sorted(TEST_IMAGES_DIR.glob("*.png"))
  111. if len(image_files) < 2:
  112. print(f" ERROR: Need at least 2 images, found {len(image_files)}")
  113. print(" [SKIP]")
  114. return
  115. print(f" Images : {len(image_files)} found")
  116. images_b64 = []
  117. for img_path in image_files[:6]:
  118. with open(img_path, "rb") as f:
  119. images_b64.append(base64.b64encode(f.read()).decode())
  120. print(f" - {img_path.name}")
  121. print(f" Calling image_stitcher (grid, 2 columns)...")
  122. try:
  123. resp = httpx.post(f"{BASE_URL}/select_tool", json={
  124. "tool_id": "image_stitcher",
  125. "params": {
  126. "images": images_b64,
  127. "direction": "grid",
  128. "columns": 2,
  129. "spacing": 10,
  130. "background_color": "#FFFFFF",
  131. }
  132. }, timeout=60)
  133. print(f" Status : {resp.status_code}")
  134. data = resp.json()
  135. if data["status"] == "success":
  136. result = data["result"]
  137. print(f" Result :")
  138. print(f" width : {result['width']}")
  139. print(f" height: {result['height']}")
  140. OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
  141. output_path = OUTPUT_DIR / "stitched_result.png"
  142. with open(output_path, "wb") as f:
  143. f.write(base64.b64decode(result["image"]))
  144. print(f" saved : {output_path}")
  145. print(" [PASS]")
  146. else:
  147. print(f" ERROR : {data.get('error', 'unknown')}")
  148. print(" [FAIL]")
  149. except httpx.TimeoutException:
  150. print(" ERROR : Request timeout")
  151. print(" [FAIL]")
  152. except Exception as e:
  153. print(f" ERROR : {e}")
  154. print(" [FAIL]")
  155. def load_task_spec(task_name: str) -> dict:
  156. task_file = TASKS_DIR / f"{task_name}.json"
  157. if not task_file.exists():
  158. print(f" ERROR: Task file not found: {task_file}")
  159. print(" Available tasks:")
  160. if TASKS_DIR.exists():
  161. for f in TASKS_DIR.glob("*.json"):
  162. print(f" - {f.stem}")
  163. sys.exit(1)
  164. with open(task_file, "r", encoding="utf-8") as f:
  165. return json.load(f)
  166. def test_create_tool(task_name: str = None):
  167. print(f"=== Create Tool{f' (task={task_name!r})' if task_name else ''} ===")
  168. if task_name:
  169. task_data = load_task_spec(task_name)
  170. print(f" File : tests/tasks/{task_name}.json")
  171. print(f" Description: {task_data['description'][:80]}")
  172. else:
  173. task_data = {"description": "创建一个简单的文本计数工具,输入文本,返回字数和字符数"}
  174. print(f" Description: {task_data['description']}")
  175. resp = httpx.post(f"{BASE_URL}/create_tool", json=task_data)
  176. data = resp.json()
  177. task_id = data["task_id"]
  178. print(f" Task ID : {task_id}")
  179. print(f" Status : {data['status']}")
  180. assert data["status"] == "pending"
  181. print(" [SUBMITTED]")
  182. print(f"\n Polling task {task_id} (timeout 10min)...")
  183. for i in range(120):
  184. time.sleep(5)
  185. resp = httpx.get(f"{BASE_URL}/tasks/{task_id}", timeout=30)
  186. task = resp.json()
  187. status = task["status"]
  188. if i % 6 == 0:
  189. print(f" [{i*5}s] status={status}")
  190. if status == "completed":
  191. print(f"\n Completed!")
  192. print(f" Result : {str(task.get('result', ''))[:300]}")
  193. resp2 = httpx.post(f"{BASE_URL}/search_tools", json={})
  194. tools = resp2.json()["tools"]
  195. print(f" Registered : {[t['tool_id'] for t in tools]}")
  196. print(" [PASS]")
  197. return
  198. if status == "failed":
  199. print(f"\n Failed!")
  200. print(f" Error : {task.get('error', 'unknown')}")
  201. print(" [FAIL]")
  202. return
  203. print(f"\n Timeout after 600s")
  204. print(" [TIMEOUT]")
  205. def main():
  206. parser = argparse.ArgumentParser(description="Router API Test")
  207. parser.add_argument("--search", nargs="?", const="", metavar="KEYWORD",
  208. help="search tools, optional keyword")
  209. parser.add_argument("--status", action="store_true",
  210. help="show tools status")
  211. parser.add_argument("--select", metavar="TOOL_ID",
  212. help="call a tool by tool_id")
  213. parser.add_argument("--stitch", action="store_true",
  214. help="test image stitcher with sample images")
  215. parser.add_argument("--create", nargs="?", const="", metavar="TASK_NAME",
  216. help="create tool, optional task file name")
  217. parser.add_argument("--launch-env", action="store_true",
  218. help="create RunComfy launch env tool (runcomfy_launch_env)")
  219. parser.add_argument("--run-only", action="store_true",
  220. help="create RunComfy run only tool (runcomfy_run_only)")
  221. parser.add_argument("--stop-env", action="store_true",
  222. help="create RunComfy stop env tool (runcomfy_stop_env)")
  223. args = parser.parse_args()
  224. print(f"Target: {BASE_URL}\n")
  225. check_connection()
  226. # 始终跑 health check
  227. test_health()
  228. ran_any = False
  229. if args.search is not None:
  230. print()
  231. test_search_tools(args.search or None)
  232. ran_any = True
  233. if args.status:
  234. print()
  235. test_tools_status()
  236. ran_any = True
  237. if args.select:
  238. print()
  239. test_select_tool(args.select)
  240. ran_any = True
  241. if args.stitch:
  242. print()
  243. test_stitch_images()
  244. ran_any = True
  245. if args.create is not None:
  246. print()
  247. test_create_tool(args.create or None)
  248. ran_any = True
  249. if args.launch_env:
  250. print()
  251. test_create_tool("runcomfy_launch_env")
  252. ran_any = True
  253. if args.run_only:
  254. print()
  255. test_create_tool("runcomfy_run_only")
  256. ran_any = True
  257. if args.stop_env:
  258. print()
  259. test_create_tool("runcomfy_stop_env")
  260. ran_any = True
  261. if not ran_any:
  262. print()
  263. print("No test specified. Available options:")
  264. parser.print_help()
  265. print("\n=== DONE ===")
  266. if __name__ == "__main__":
  267. main()