test_create_runcomfy_atomic.py 2.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
  1. """批量触发生成三个 RunComfy 原子化工具
  2. 用法:
  3. uv run python tests/test_create_runcomfy_atomic.py
  4. """
  5. import sys
  6. import time
  7. from pathlib import Path
  8. import httpx
  9. BASE_URL = "http://127.0.0.1:8001"
  10. TASKS_DIR = Path(__file__).parent / "tasks"
  11. # 我们刚才写的三个原子化任务书
  12. TASKS = [
  13. "runcomfy_launch_env",
  14. "runcomfy_run_only",
  15. "runcomfy_stop_env"
  16. ]
  17. def check_connection():
  18. try:
  19. httpx.get(f"{BASE_URL}/health", timeout=3)
  20. except httpx.ConnectError:
  21. print(f"ERROR: Cannot connect to {BASE_URL}")
  22. print("Please start the service first:")
  23. print(" uv run python -m tool_agent")
  24. sys.exit(1)
  25. def submit_task(task_name: str) -> str:
  26. task_file = TASKS_DIR / f"{task_name}.json"
  27. if not task_file.exists():
  28. print(f"ERROR: Task file not found: {task_file}")
  29. sys.exit(1)
  30. import json
  31. with open(task_file, "r", encoding="utf-8") as f:
  32. task_data = json.load(f)
  33. print(f"\n[{task_name}] Submitting...")
  34. resp = httpx.post(f"{BASE_URL}/create_tool", json=task_data, timeout=30)
  35. resp.raise_for_status()
  36. data = resp.json()
  37. task_id = data["task_id"]
  38. print(f"[{task_name}] Task ID: {task_id}")
  39. return task_id
  40. def poll_tasks(task_ids: dict[str, str], timeout: int = 900):
  41. print("\n=== Polling Tasks ===")
  42. pending = set(task_ids.values())
  43. interval = 10
  44. steps = timeout // interval
  45. for i in range(steps):
  46. if not pending:
  47. print("\nAll tasks finished!")
  48. break
  49. time.sleep(interval)
  50. elapsed = (i + 1) * interval
  51. for task_name, task_id in list(task_ids.items()):
  52. if task_id not in pending:
  53. continue
  54. resp = httpx.get(f"{BASE_URL}/tasks/{task_id}", timeout=30)
  55. status = resp.json()["status"]
  56. if status in ("completed", "failed"):
  57. print(f"\n[{task_name}] Finished with status: {status}")
  58. pending.remove(task_id)
  59. elif elapsed % 30 == 0:
  60. print(f"[{elapsed}s] {task_name}: {status}")
  61. if pending:
  62. print(f"\nTimeout! Still pending: {pending}")
  63. def main():
  64. check_connection()
  65. # 1. 批量提交
  66. task_ids = {}
  67. for task_name in TASKS:
  68. task_id = submit_task(task_name)
  69. task_ids[task_name] = task_id
  70. # 2. 并行轮询
  71. poll_tasks(task_ids)
  72. if __name__ == "__main__":
  73. main()