main.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
  1. """RunComfy Workflow HTTP API"""
  2. import base64
  3. import json
  4. import os
  5. import uuid
  6. from typing import Optional
  7. import requests
  8. import websocket
  9. from dotenv import load_dotenv
  10. from fastapi import FastAPI, HTTPException
  11. from pydantic import BaseModel
  12. load_dotenv()
  13. app = FastAPI(title="RunComfy Workflow API")
  14. BASE_URL = "https://beta-api.runcomfy.net/prod/api"
  15. USER_ID = os.getenv("RUNCOMFY_USER_ID")
  16. API_TOKEN = os.getenv("API_TOKEN")
  17. HEADERS = {
  18. "Authorization": f"Bearer {API_TOKEN}",
  19. "Content-Type": "application/json",
  20. }
  21. SUBDIR_UPLOAD_MAP = {
  22. "images": {"type": "input", "subfolder": ""},
  23. "loras": {"type": "input", "subfolder": "loras"},
  24. "checkpoints": {"type": "input", "subfolder": "checkpoints"},
  25. "vae": {"type": "input", "subfolder": "vae"},
  26. }
  27. class InputFile(BaseModel):
  28. filename: str
  29. type: str
  30. base64_data: str
  31. class WorkflowRequest(BaseModel):
  32. server_id: str
  33. workflow_api: dict
  34. input_files: Optional[list[InputFile]] = None
  35. class WorkflowResponse(BaseModel):
  36. prompt_id: str
  37. images: list[str]
  38. status: str
  39. server_id: str
  40. def get_server_url(server_id: str) -> str:
  41. resp = requests.get(f"{BASE_URL}/users/{USER_ID}/servers/{server_id}", headers=HEADERS)
  42. resp.raise_for_status()
  43. data = resp.json()
  44. if data.get("current_status") != "Ready":
  45. raise Exception(f"机器未就绪: {data.get('current_status')}")
  46. return data["main_service_url"].rstrip("/")
  47. def upload_file_from_base64(comfy_url: str, filename: str, base64_data: str, file_type: str, subfolder: str):
  48. file_bytes = base64.b64decode(base64_data)
  49. files = [("image", (filename, file_bytes, "application/octet-stream"))]
  50. data = {"overwrite": "true", "type": file_type, "subfolder": subfolder}
  51. resp = requests.post(f"{comfy_url}/upload/image", data=data, files=files)
  52. resp.raise_for_status()
  53. return resp.json()["name"]
  54. def submit_prompt(comfy_url: str, workflow_api: dict, client_id: str) -> str:
  55. payload = {"prompt": workflow_api, "client_id": client_id}
  56. resp = requests.post(f"{comfy_url}/prompt", json=payload)
  57. resp.raise_for_status()
  58. return resp.json()["prompt_id"]
  59. def wait_for_completion(comfy_url: str, client_id: str, prompt_id: str, timeout: int = 600):
  60. scheme = "wss" if comfy_url.startswith("https") else "ws"
  61. ws_url = f"{scheme}://{comfy_url.split('://', 1)[-1]}/ws?clientId={client_id}"
  62. ws = websocket.WebSocket()
  63. ws.settimeout(timeout)
  64. ws.connect(ws_url)
  65. try:
  66. while True:
  67. out = ws.recv()
  68. if not out or isinstance(out, bytes):
  69. continue
  70. msg = json.loads(out)
  71. if msg.get("type") == "executing":
  72. data = msg.get("data", {})
  73. if data.get("prompt_id") == prompt_id and data.get("node") is None:
  74. break
  75. elif msg.get("type") == "execution_error":
  76. if msg.get("data", {}).get("prompt_id") == prompt_id:
  77. raise Exception(f"执行错误: {msg['data'].get('exception_message')}")
  78. finally:
  79. ws.close()
  80. def download_images_as_base64(comfy_url: str, prompt_id: str) -> list[str]:
  81. resp = requests.get(f"{comfy_url}/history/{prompt_id}")
  82. resp.raise_for_status()
  83. outputs = resp.json().get(prompt_id, {}).get("outputs", {})
  84. images = []
  85. for node_output in outputs.values():
  86. if "images" in node_output:
  87. for img in node_output["images"]:
  88. params = {"filename": img["filename"], "subfolder": img.get("subfolder", ""),
  89. "type": img.get("type", "output")}
  90. resp = requests.get(f"{comfy_url}/view", params=params)
  91. resp.raise_for_status()
  92. images.append(base64.b64encode(resp.content).decode())
  93. return images
  94. @app.post("/run", response_model=WorkflowResponse)
  95. async def run_workflow(request: WorkflowRequest):
  96. try:
  97. comfy_url = get_server_url(request.server_id)
  98. client_id = str(uuid.uuid4())
  99. if request.input_files:
  100. for file in request.input_files:
  101. mapping = SUBDIR_UPLOAD_MAP.get(file.type, {"type": "input", "subfolder": file.type})
  102. upload_file_from_base64(comfy_url, file.filename, file.base64_data,
  103. mapping["type"], mapping["subfolder"])
  104. prompt_id = submit_prompt(comfy_url, request.workflow_api, client_id)
  105. wait_for_completion(comfy_url, client_id, prompt_id)
  106. images = download_images_as_base64(comfy_url, prompt_id)
  107. return WorkflowResponse(prompt_id=prompt_id, images=images, status="Success", server_id=request.server_id)
  108. except Exception as e:
  109. raise HTTPException(status_code=500, detail=str(e))
  110. if __name__ == "__main__":
  111. import uvicorn
  112. port = int(os.getenv("PORT", "8000"))
  113. uvicorn.run(app, host="0.0.0.0", port=port)