main.py 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223
  1. """RunComfy Workflow HTTP API"""
  2. import base64
  3. import json
  4. import os
  5. import uuid
  6. from typing import Optional
  7. import requests
  8. import websocket
  9. from dotenv import load_dotenv
  10. from fastapi import FastAPI, HTTPException
  11. from pydantic import BaseModel
  12. load_dotenv()
  13. app = FastAPI(title="RunComfy Workflow API")
  14. BASE_URL = "https://beta-api.runcomfy.net/prod/api"
  15. USER_ID = os.getenv("RUNCOMFY_USER_ID")
  16. API_TOKEN = os.getenv("API_TOKEN")
  17. HEADERS = {
  18. "Authorization": f"Bearer {API_TOKEN}",
  19. "Content-Type": "application/json",
  20. }
  21. SUBDIR_UPLOAD_MAP = {
  22. "images": {"type": "input", "subfolder": ""},
  23. "loras": {"type": "input", "subfolder": "loras"},
  24. "checkpoints": {"type": "input", "subfolder": "checkpoints"},
  25. "vae": {"type": "input", "subfolder": "vae"},
  26. }
  27. class InputFile(BaseModel):
  28. filename: str
  29. type: str
  30. base64_data: Optional[str] = None
  31. url: Optional[str] = None
  32. class WorkflowRequest(BaseModel):
  33. server_id: str
  34. workflow_api: dict
  35. input_files: Optional[list[InputFile]] = None
  36. class WorkflowResponse(BaseModel):
  37. prompt_id: str
  38. images: list[str]
  39. status: str
  40. server_id: str
  41. def get_server_url(server_id: str) -> str:
  42. resp = requests.get(f"{BASE_URL}/users/{USER_ID}/servers/{server_id}", headers=HEADERS)
  43. resp.raise_for_status()
  44. data = resp.json()
  45. if data.get("current_status") != "Ready":
  46. raise Exception(f"机器未就绪: {data.get('current_status')}")
  47. return data["main_service_url"].rstrip("/")
  48. def upload_file_bytes(comfy_url: str, filename: str, file_bytes: bytes, file_type: str, subfolder: str):
  49. files = [("image", (filename, file_bytes, "application/octet-stream"))]
  50. data = {"overwrite": "true", "type": file_type, "subfolder": subfolder}
  51. resp = requests.post(f"{comfy_url}/upload/image", data=data, files=files)
  52. resp.raise_for_status()
  53. return resp.json()["name"]
  54. def verify_workflow_api(workflow_api: dict):
  55. if not isinstance(workflow_api, dict):
  56. raise ValueError("workflow_api 必须是一个字典对象存放各个节点")
  57. for node_id, node_data in workflow_api.items():
  58. if not isinstance(node_data, dict):
  59. raise ValueError(f"节点 {node_id} 数据异常:必须是一个字典")
  60. if "class_type" not in node_data:
  61. raise ValueError(f"节点 {node_id} 数据异常:缺少必填字段 'class_type'")
  62. inputs = node_data.get("inputs", {})
  63. if not isinstance(inputs, dict):
  64. raise ValueError(f"节点 {node_id} (class_type: {node_data['class_type']}) 数据异常:inputs 必须是一个字典")
  65. for input_key, input_value in inputs.items():
  66. # 检测连线格式:必须是 [node_id, port_index] 且 node_id 存在于 workflow_api 中
  67. if isinstance(input_value, list) and len(input_value) >= 2:
  68. target_node = str(input_value[0])
  69. if target_node not in workflow_api:
  70. raise ValueError(f"节点 {node_id} ({node_data['class_type']}) 的输入 '{input_key}' 引用了不存在的节点 ID: {target_node}。图存在断线!")
  71. def submit_prompt(comfy_url: str, workflow_api: dict, client_id: str) -> str:
  72. payload = {"prompt": workflow_api, "client_id": client_id}
  73. resp = requests.post(f"{comfy_url}/prompt", json=payload)
  74. try:
  75. resp.raise_for_status()
  76. except requests.exceptions.HTTPError as e:
  77. error_details = resp.text
  78. raise Exception(f"提交工作流失败: 400 Bad Request. 你的 workflow_api JSON 结构可能存在致命错误。\nComfyUI 返回的详细诊断信息: {error_details}")
  79. return resp.json()["prompt_id"]
  80. def wait_for_completion(comfy_url: str, client_id: str, prompt_id: str, timeout: int = 1200):
  81. import time
  82. import requests
  83. from requests.adapters import HTTPAdapter
  84. from urllib3.util.retry import Retry
  85. start_time = time.time()
  86. session = requests.Session()
  87. retry_strategy = Retry(total=3, backoff_factor=1, status_forcelist=[500, 502, 503, 504])
  88. adapter = HTTPAdapter(max_retries=retry_strategy)
  89. session.mount("https://", adapter)
  90. session.mount("http://", adapter)
  91. error_count = 0
  92. while time.time() - start_time < timeout:
  93. try:
  94. # 1. Check history to see if done
  95. resp = session.get(f"{comfy_url}/history/{prompt_id}", timeout=15)
  96. if resp.status_code == 200:
  97. data = resp.json()
  98. if prompt_id in data:
  99. return # Successfully completed
  100. # 2. Check queue to make sure it didn't error out completely
  101. resp_q = session.get(f"{comfy_url}/queue", timeout=15)
  102. if resp_q.status_code == 200:
  103. q_data = resp_q.json()
  104. running = q_data.get("queue_running", [])
  105. pending = q_data.get("queue_pending", [])
  106. is_active = False
  107. for task in running + pending:
  108. # task elements are usually: [id, prompt_id, prompt, extra_data, prompt_out]
  109. if len(task) > 1 and str(task[1]) == str(prompt_id):
  110. is_active = True
  111. break
  112. if not is_active:
  113. # It disappeared from queue but isn't in history!
  114. # Give it a tiny sleep to avoid race conditions
  115. time.sleep(2)
  116. resp_cf = session.get(f"{comfy_url}/history/{prompt_id}", timeout=15)
  117. if resp_cf.status_code == 200 and prompt_id in resp_cf.json():
  118. return
  119. raise Exception("执行中止:任务已从队列消失且未成功写入历史(可能某个节点出错。请检查工作流输入参数。)")
  120. error_count = 0
  121. except requests.exceptions.RequestException as e:
  122. error_count += 1
  123. print(f"Polling HTTP Error (count={error_count}): {e}")
  124. if error_count > 10:
  125. raise Exception(f"与 ComfyUI 服务器断开连接次数达到上限: {e}")
  126. # Sleep before polling again
  127. time.sleep(5)
  128. raise Exception(f"任务执行超时 (未在 {timeout} 秒内完成)")
  129. def get_comfy_image_urls(comfy_url: str, prompt_id: str) -> list[str]:
  130. resp = requests.get(f"{comfy_url}/history/{prompt_id}")
  131. resp.raise_for_status()
  132. outputs = resp.json().get(prompt_id, {}).get("outputs", {})
  133. images = []
  134. import urllib.parse
  135. for node_output in outputs.values():
  136. if "images" in node_output:
  137. for img in node_output["images"]:
  138. params = {"filename": img["filename"], "subfolder": img.get("subfolder", ""),
  139. "type": img.get("type", "output")}
  140. query = urllib.parse.urlencode(params)
  141. url = f"{comfy_url}/view?{query}"
  142. images.append(url)
  143. return images
  144. @app.post("/run", response_model=WorkflowResponse)
  145. async def run_workflow(request: WorkflowRequest):
  146. try:
  147. comfy_url = get_server_url(request.server_id)
  148. client_id = str(uuid.uuid4())
  149. if request.input_files:
  150. for file in request.input_files:
  151. mapping = SUBDIR_UPLOAD_MAP.get(file.type, {"type": "input", "subfolder": file.type})
  152. file_bytes = None
  153. if file.url:
  154. # 从 CDN 下载
  155. resp = requests.get(file.url)
  156. resp.raise_for_status()
  157. file_bytes = resp.content
  158. elif file.base64_data:
  159. # Base64 解码
  160. file_bytes = base64.b64decode(file.base64_data)
  161. if file_bytes:
  162. upload_file_bytes(comfy_url, file.filename, file_bytes,
  163. mapping["type"], mapping["subfolder"])
  164. else:
  165. raise Exception(f"Input file {file.filename} must have either 'url' or 'base64_data'")
  166. verify_workflow_api(request.workflow_api)
  167. prompt_id = submit_prompt(comfy_url, request.workflow_api, client_id)
  168. wait_for_completion(comfy_url, client_id, prompt_id)
  169. images = get_comfy_image_urls(comfy_url, prompt_id)
  170. return WorkflowResponse(prompt_id=prompt_id, images=images, status="Success", server_id=request.server_id)
  171. except Exception as e:
  172. raise HTTPException(status_code=500, detail=str(e))
  173. if __name__ == "__main__":
  174. import uvicorn
  175. import argparse
  176. parser = argparse.ArgumentParser()
  177. parser.add_argument("--port", type=int, default=8000)
  178. args = parser.parse_args()
  179. uvicorn.run(app, host="0.0.0.0", port=args.port)