run_workflow.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340
  1. #!/usr/bin/env python3
  2. """RunComfy Server API + ComfyUI Backend API
  3. 流程:启动机器 → 上传 input/ 目录文件 → 提交 workflow → WebSocket 监听 → 下载结果 → 关机
  4. input/ 目录结构:
  5. input/
  6. ├── images/ → 上传到 ComfyUI input/(LoadImage 节点用)
  7. ├── loras/ → 上传到 ComfyUI models/loras/
  8. ├── checkpoints/ → 上传到 ComfyUI models/checkpoints/
  9. ├── vae/ → 上传到 ComfyUI models/vae/
  10. └── (其他文件) → 上传到 ComfyUI input/
  11. 用法:
  12. python run_workflow.py workflow_api.json
  13. python run_workflow.py workflow_api.json --input-dir ./input
  14. python run_workflow.py workflow_api.json --input-dir ./input --server-type large
  15. """
  16. import argparse
  17. import json
  18. import os
  19. import sys
  20. import time
  21. import urllib.parse
  22. import uuid
  23. from pathlib import Path
  24. import requests
  25. import websocket
  26. from dotenv import load_dotenv
  27. from check_workflow import analyze, check_files_exist
  28. load_dotenv(Path(__file__).parent.parent.parent / ".env")
  29. BASE_URL = "https://beta-api.runcomfy.net/prod/api"
  30. USER_ID = os.getenv("RUNCOMFY_USER_ID")
  31. API_TOKEN = os.getenv("API_TOKEN")
  32. HEADERS = {
  33. "Authorization": f"Bearer {API_TOKEN}",
  34. "Content-Type": "application/json",
  35. }
  36. DEFAULT_VERSION_ID = "90f77137-ba75-400d-870f-204c614ae8a3" # RunComfy/ComfyUI-NodesLoaded
  37. # input/ 子目录 → ComfyUI 上传类型和 subfolder 映射
  38. # type="input" 对应 ComfyUI 的 input 目录
  39. # type="model" 暂无官方支持,lora 等模型走 subfolder 区分
  40. SUBDIR_UPLOAD_MAP = {
  41. "images": {"type": "input", "subfolder": ""},
  42. "loras": {"type": "input", "subfolder": "loras"},
  43. "checkpoints": {"type": "input", "subfolder": "checkpoints"},
  44. "vae": {"type": "input", "subfolder": "vae"},
  45. "controlnet": {"type": "input", "subfolder": "controlnet"},
  46. "upscale": {"type": "input", "subfolder": "upscale_models"},
  47. }
  48. # ── 机器管理 ──────────────────────────────────────────────
  49. def launch_machine(version_id: str, server_type: str = "medium", duration: int = 3600) -> str:
  50. payload = {
  51. "workflow_version_id": version_id,
  52. "server_type": server_type,
  53. "estimated_duration": duration,
  54. }
  55. resp = requests.post(f"{BASE_URL}/users/{USER_ID}/servers", headers=HEADERS, json=payload)
  56. if not resp.ok:
  57. print(f" HTTP {resp.status_code}: {resp.text}")
  58. resp.raise_for_status()
  59. print(f" 响应: {resp.json()}")
  60. server_id = resp.json()["server_id"]
  61. print(f"机器已创建: {server_id} (type={server_type})")
  62. return server_id
  63. def wait_for_ready(server_id: str, timeout: int = 300) -> str:
  64. print("等待机器就绪...")
  65. start = time.time()
  66. while time.time() - start < timeout:
  67. resp = requests.get(f"{BASE_URL}/users/{USER_ID}/servers/{server_id}", headers=HEADERS)
  68. resp.raise_for_status()
  69. data = resp.json()
  70. status = data.get("current_status", "")
  71. print(f" 状态: {status}")
  72. if status == "Ready":
  73. url = data["main_service_url"].rstrip("/")
  74. print(f" 就绪: {url}")
  75. return url
  76. if status in ("Error", "Failed"):
  77. raise Exception(f"机器启动失败: {status}")
  78. time.sleep(5)
  79. raise TimeoutError(f"等待超时 ({timeout}s)")
  80. def stop_machine(server_id: str):
  81. resp = requests.delete(f"{BASE_URL}/users/{USER_ID}/servers/{server_id}", headers=HEADERS)
  82. resp.raise_for_status()
  83. print(f"机器已关闭: {server_id}")
  84. # ── 文件上传 ──────────────────────────────────────────────
  85. def upload_file(comfy_url: str, file_path: Path, file_type: str = "input", subfolder: str = "") -> str:
  86. """上传文件到 ComfyUI,返回服务器上的实际文件名"""
  87. with open(file_path, "rb") as f:
  88. files = [("image", (file_path.name, f, "application/octet-stream"))]
  89. data = {"overwrite": "true", "type": file_type, "subfolder": subfolder}
  90. resp = requests.post(f"{comfy_url}/upload/image", data=data, files=files)
  91. resp.raise_for_status()
  92. server_name = resp.json()["name"]
  93. subfolder_str = f" → {subfolder}/{server_name}" if subfolder else f" → {server_name}"
  94. print(f" 上传: {file_path.name}{subfolder_str}")
  95. return server_name
  96. def upload_input_dir(comfy_url: str, input_dir: Path) -> dict[str, str]:
  97. """
  98. 扫描 input_dir,按子目录上传文件,返回 {原文件名: 服务器文件名} 映射
  99. - input/images/ → type=input, subfolder=""
  100. - input/loras/ → type=input, subfolder="loras"
  101. - input/*.png → type=input, subfolder=""(根目录文件)
  102. """
  103. if not input_dir.exists():
  104. print(f" input 目录不存在: {input_dir}")
  105. return {}
  106. uploaded = {}
  107. IMAGE_EXTS = {".png", ".jpg", ".jpeg", ".webp", ".bmp"}
  108. VIDEO_EXTS = {".mp4", ".avi", ".mov", ".webm"}
  109. MODEL_EXTS = {".safetensors", ".ckpt", ".pt", ".pth", ".gguf"}
  110. ALL_EXTS = IMAGE_EXTS | VIDEO_EXTS | MODEL_EXTS
  111. # 根目录文件 → input
  112. for f in input_dir.iterdir():
  113. if f.is_file() and f.suffix.lower() in ALL_EXTS:
  114. server_name = upload_file(comfy_url, f, "input", "")
  115. uploaded[f.name] = server_name
  116. # 子目录文件
  117. for subdir in input_dir.iterdir():
  118. if not subdir.is_dir():
  119. continue
  120. mapping = SUBDIR_UPLOAD_MAP.get(subdir.name, {"type": "input", "subfolder": subdir.name})
  121. for f in subdir.iterdir():
  122. if f.is_file() and f.suffix.lower() in ALL_EXTS:
  123. server_name = upload_file(comfy_url, f, mapping["type"], mapping["subfolder"])
  124. uploaded[f.name] = server_name
  125. return uploaded
  126. # ── 提交 workflow ─────────────────────────────────────────
  127. def submit_prompt(comfy_url: str, workflow_api: dict, client_id: str) -> str:
  128. payload = {"prompt": workflow_api, "client_id": client_id}
  129. resp = requests.post(f"{comfy_url}/prompt", json=payload)
  130. resp.raise_for_status()
  131. data = resp.json()
  132. if data.get("node_errors"):
  133. print(f" 节点错误: {data['node_errors']}")
  134. prompt_id = data["prompt_id"]
  135. print(f"任务已提交: {prompt_id}")
  136. return prompt_id
  137. # ── WebSocket 监听 ────────────────────────────────────────
  138. def wait_for_completion(comfy_url: str, client_id: str, prompt_id: str, timeout: int = 600):
  139. scheme = "wss" if comfy_url.startswith("https") else "ws"
  140. ws_url = f"{scheme}://{comfy_url.split('://', 1)[-1]}/ws?clientId={client_id}"
  141. print("WebSocket 监听中...")
  142. ws = websocket.WebSocket()
  143. ws.settimeout(timeout)
  144. ws.connect(ws_url)
  145. try:
  146. while True:
  147. out = ws.recv()
  148. if not out or isinstance(out, bytes):
  149. continue
  150. msg = json.loads(out)
  151. msg_type = msg.get("type", "")
  152. data = msg.get("data", {})
  153. if msg_type == "executing":
  154. node = data.get("node")
  155. if data.get("prompt_id") == prompt_id and node is None:
  156. print(" 执行完成")
  157. break
  158. if node:
  159. print(f" 执行节点: {node}")
  160. elif msg_type == "progress":
  161. value = data.get("value", 0)
  162. max_val = data.get("max", 1)
  163. print(f" 进度: {value}/{max_val}")
  164. elif msg_type == "execution_error":
  165. if data.get("prompt_id") == prompt_id:
  166. raise Exception(f"执行错误: {data.get('exception_message', 'unknown')}")
  167. finally:
  168. ws.close()
  169. # ── 下载结果 ──────────────────────────────────────────────
  170. def download_outputs(comfy_url: str, prompt_id: str, output_dir: Path) -> list[str]:
  171. resp = requests.get(f"{comfy_url}/history/{prompt_id}")
  172. resp.raise_for_status()
  173. data = resp.json().get(prompt_id, {})
  174. outputs = data.get("outputs", {})
  175. output_dir.mkdir(parents=True, exist_ok=True)
  176. saved = []
  177. for node_id, node_output in outputs.items():
  178. if "images" in node_output:
  179. for image in node_output["images"]:
  180. params = {"filename": image["filename"], "subfolder": image.get("subfolder", ""), "type": image.get("temp") or image.get("type", "output")}
  181. resp = requests.get(f"{comfy_url}/view?{urllib.parse.urlencode(params)}")
  182. resp.raise_for_status()
  183. out_path = output_dir / image["filename"]
  184. out_path.write_bytes(resp.content)
  185. print(f" 图片: {out_path}")
  186. saved.append(str(out_path))
  187. if "gifs" in node_output:
  188. for video in node_output["gifs"]:
  189. params = {"filename": video["filename"], "subfolder": video.get("subfolder", ""), "format": video.get("format", "mp4")}
  190. resp = requests.get(f"{comfy_url}/view?{urllib.parse.urlencode(params)}")
  191. resp.raise_for_status()
  192. out_path = output_dir / video["filename"]
  193. out_path.write_bytes(resp.content)
  194. print(f" 视频: {out_path}")
  195. saved.append(str(out_path))
  196. return saved
  197. # ── 主流程 ────────────────────────────────────────────────
  198. def main():
  199. parser = argparse.ArgumentParser(description="RunComfy workflow runner")
  200. parser.add_argument("workflow", help="workflow_api.json 路径")
  201. parser.add_argument("--input-dir", default="input", metavar="DIR",
  202. help="输入文件目录,默认 input/。子目录 images/loras/checkpoints/vae/ 自动上传到对应位置")
  203. parser.add_argument("--version-id", default=DEFAULT_VERSION_ID, help="RunComfy workflow version_id")
  204. parser.add_argument("--server-type", default="medium",
  205. choices=["medium", "large", "extra-large", "2x-large", "2xl-turbo"])
  206. parser.add_argument("--duration", type=int, default=3600, help="预估运行时长(秒),默认3600")
  207. parser.add_argument("--keep-alive", action="store_true", help="完成后不自动关机")
  208. parser.add_argument("--server-id", metavar="ID", help="复用已有机器,跳过启动步骤")
  209. parser.add_argument("--skip-upload", action="store_true", help="跳过文件上传,直接提交 workflow")
  210. parser.add_argument("--output-dir", default="output", metavar="DIR", help="结果下载目录,默认 output/")
  211. args = parser.parse_args()
  212. if not USER_ID or not API_TOKEN:
  213. print("ERROR: 请设置 RUNCOMFY_USER_ID 和 API_TOKEN 环境变量")
  214. sys.exit(1)
  215. print(f"USER_ID : {USER_ID}")
  216. print(f"API_TOKEN: {API_TOKEN[:8]}...")
  217. workflow_path = Path(args.workflow)
  218. if not workflow_path.exists():
  219. print(f"ERROR: 文件不存在: {workflow_path}")
  220. sys.exit(1)
  221. with open(workflow_path, "r", encoding="utf-8") as f:
  222. workflow_api = json.load(f)
  223. # 提交前 check
  224. input_dir = Path(args.input_dir)
  225. all_input_files = list(input_dir.rglob("*")) if input_dir.exists() else []
  226. result = analyze(workflow_api)
  227. missing = check_files_exist(result["file_inputs"], all_input_files)
  228. if result["issues"] or missing:
  229. print("❌ workflow 检查未通过:")
  230. for issue in result["issues"]:
  231. print(f" {issue}")
  232. for fi in missing:
  233. print(f" 缺少文件: {fi['filename']} (节点 [{fi['node_id']}] {fi['class_type']})")
  234. print(f" 请将该文件放入 {input_dir}/ 目录")
  235. sys.exit(1)
  236. if result["widget_params"]:
  237. print("⚠️ 存在 widget_* 占位参数,参数名可能不准确,继续运行...")
  238. else:
  239. print("✓ workflow 检查通过")
  240. client_id = str(uuid.uuid4())
  241. server_id = args.server_id # None if not provided
  242. try:
  243. # 1. 启动机器(或复用已有机器)
  244. if server_id:
  245. print(f"复用已有机器: {server_id}")
  246. comfy_url = wait_for_ready(server_id)
  247. else:
  248. server_id = launch_machine(args.version_id, args.server_type, args.duration)
  249. comfy_url = wait_for_ready(server_id)
  250. # 2. 上传 input 目录
  251. if args.skip_upload:
  252. print("跳过文件上传 (--skip-upload)")
  253. else:
  254. print(f"\n上传 input 目录: {input_dir}")
  255. upload_input_dir(comfy_url, input_dir)
  256. # 3. 提交 workflow
  257. print(f"\n提交 workflow...")
  258. prompt_id = submit_prompt(comfy_url, workflow_api, client_id)
  259. # 4. 监听执行进度
  260. wait_for_completion(comfy_url, client_id, prompt_id)
  261. # 5. 下载结果
  262. print(f"\n下载结果...")
  263. saved = download_outputs(comfy_url, prompt_id, Path(args.output_dir))
  264. print(f"\n完成,共 {len(saved)} 个文件")
  265. if args.keep_alive:
  266. print(f"\n--keep-alive 模式,机器保持运行: {server_id}")
  267. print(f"ComfyUI URL: {comfy_url}")
  268. else:
  269. print("\n关闭机器...")
  270. stop_machine(server_id)
  271. except Exception as e:
  272. print(f"\n错误: {e}")
  273. print(f"机器 {server_id} 未自动关闭,请手动处理")
  274. sys.exit(1)
  275. if __name__ == "__main__":
  276. main()