test_e2e_proxy.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144
  1. import os
  2. import sys
  3. import json
  4. import time
  5. import requests
  6. import asyncio
  7. # 添加项目根目录到 Python 路径
  8. sys.path.insert(0, str(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))))
  9. from agent.tools.builtin.toolhub import _preprocess_params
  10. ROUTER_URL = "http://127.0.0.1:8001/run_tool"
  11. def call_tool(tool_id, params):
  12. print(f"\n[{tool_id}] Calling...")
  13. # 就像真正的 agent 流程一样,先预处理参数(将本地文件转化为 CDN 链接)
  14. processed_params = asyncio.run(_preprocess_params(params))
  15. resp = requests.post(ROUTER_URL, json={"tool_id": tool_id, "params": processed_params})
  16. resp.raise_for_status()
  17. res = resp.json()
  18. if res.get("status") == "error":
  19. raise Exception(f"Tool error: {res.get('error')}")
  20. print(f"[{tool_id}] Success")
  21. return res.get("result", {})
  22. def test_full_workflow():
  23. print("=== RunComfy Workflow E2E Test ===")
  24. # 0. Check existing machines
  25. print("\n0. Checking for existing ComfyUI Instances...")
  26. status_res = call_tool("runcomfy_check_status", {})
  27. servers = status_res.get("servers", [])
  28. server_id = None
  29. for s in servers:
  30. if s.get("current_status") == "Ready":
  31. server_id = s.get("server_id")
  32. print(f"Found Ready machine: {server_id}")
  33. break
  34. elif s.get("current_status") in ("Starting", "Creating") and not server_id:
  35. server_id = s.get("server_id")
  36. print(f"Found Starting machine: {server_id}. Note: You might need to wait for it to be Ready.")
  37. if not server_id:
  38. raise Exception("\nNo active machine found. Expecting an already running machine. Aborting.")
  39. print("\n1. Proceeding with existing machine...")
  40. try:
  41. print("\n2. Loading custom workflow from JSON...")
  42. import os
  43. workflow_path = os.path.join(os.path.dirname(__file__), "flux_depth_controlnet_workflow.json")
  44. with open(workflow_path, "r", encoding="utf-8") as f:
  45. workflow_api = json.load(f)
  46. print("\n3. Injecting correct prompt to fix ControlNet mismatch...")
  47. for node_id, node in workflow_api.items():
  48. if node.get("class_type") == "CLIPTextEncode":
  49. current_text = node["inputs"].get("text", "")
  50. if "landscape" in current_text:
  51. new_prompt = (
  52. "A back-view of a person standing in front of a wooden painting easel, "
  53. "looking out at a beautiful landscape with majestic mountains and a serene cyan lake. "
  54. "Highly detailed, masterpiece, realistic photography, cinematic lighting, 8k resolution"
  55. )
  56. node["inputs"]["text"] = new_prompt
  57. print(f" -> Successfully updated prompt in node {node_id}")
  58. print("\n4. Executing workflow with ControlNet inputs...")
  59. # 收集输入图片文件列表传递给 executor
  60. # 我们在这里填入包含本地路径的 url
  61. # call_tool 会使用 ToolHub 的 _preprocess_params 自动把这些 url 的本地路径替换为真正的 CDN 链接
  62. input_dir = os.path.join(os.path.dirname(__file__), "input")
  63. input_files = [
  64. {
  65. "filename": "depth_map.png",
  66. "type": "images",
  67. "url": os.path.join(input_dir, "depth_map.png")
  68. },
  69. {
  70. "filename": "background_bokeh_img2.png",
  71. "type": "images",
  72. "url": os.path.join(input_dir, "background_bokeh_img2.png")
  73. },
  74. {
  75. "filename": "character_ref_back.png",
  76. "type": "images",
  77. "url": os.path.join(input_dir, "character_ref_back.png")
  78. },
  79. {
  80. "filename": "easel_blank_canvas_img4.png",
  81. "type": "images",
  82. "url": os.path.join(input_dir, "easel_blank_canvas_img4.png")
  83. }
  84. ]
  85. exec_res = call_tool("runcomfy_workflow_executor", {
  86. "server_id": server_id,
  87. "workflow_api": workflow_api,
  88. "input_files": input_files
  89. })
  90. print("Execution finished! Prompt ID:", exec_res.get("prompt_id"))
  91. images = exec_res.get("images", [])
  92. print("Generated images count:", len(images))
  93. if images:
  94. # 拿到的是 base64 或者是 CDN url
  95. img_data = images[0]
  96. if isinstance(img_data, dict):
  97. print("First image info:", img_data.get("url") or (img_data.get("data", "")[:50] + "..."))
  98. else:
  99. print("First image info (raw):", img_data[:50] + "...")
  100. output_dir = os.path.join(os.path.dirname(__file__), "output")
  101. os.makedirs(output_dir, exist_ok=True)
  102. output_path = os.path.join(output_dir, "test_output.png")
  103. try:
  104. if img_data.startswith("http"):
  105. import requests
  106. img_resp = requests.get(img_data)
  107. img_resp.raise_for_status()
  108. with open(output_path, "wb") as fh:
  109. fh.write(img_resp.content)
  110. else:
  111. import base64
  112. with open(output_path, "wb") as fh:
  113. fh.write(base64.b64decode(img_data))
  114. print(f"🎉 成功!生成的图片已保存至: {output_path}")
  115. except Exception as e:
  116. print(f"Failed to save image: {e}")
  117. finally:
  118. # 4. Cleanup
  119. print("\n4. Keep machine alive (Skip cleanup so you can check runcomfy)")
  120. # 暂时跳过 stop 方便排查
  121. # call_tool("runcomfy_stop_env", {"server_id": server_id})
  122. if __name__ == "__main__":
  123. test_full_workflow()
  124. os._exit(0)