"""RunComfy Workflow HTTP API""" import base64 import json import os import uuid from typing import Optional import requests import websocket from dotenv import load_dotenv from fastapi import FastAPI, HTTPException from pydantic import BaseModel load_dotenv() app = FastAPI(title="RunComfy Workflow API") BASE_URL = "https://beta-api.runcomfy.net/prod/api" USER_ID = os.getenv("RUNCOMFY_USER_ID") API_TOKEN = os.getenv("API_TOKEN") HEADERS = { "Authorization": f"Bearer {API_TOKEN}", "Content-Type": "application/json", } SUBDIR_UPLOAD_MAP = { "images": {"type": "input", "subfolder": ""}, "loras": {"type": "input", "subfolder": "loras"}, "checkpoints": {"type": "input", "subfolder": "checkpoints"}, "vae": {"type": "input", "subfolder": "vae"}, } class InputFile(BaseModel): filename: str type: str base64_data: str class WorkflowRequest(BaseModel): server_id: str workflow_api: dict input_files: Optional[list[InputFile]] = None class WorkflowResponse(BaseModel): prompt_id: str images: list[str] status: str server_id: str def get_server_url(server_id: str) -> str: resp = requests.get(f"{BASE_URL}/users/{USER_ID}/servers/{server_id}", headers=HEADERS) resp.raise_for_status() data = resp.json() if data.get("current_status") != "Ready": raise Exception(f"机器未就绪: {data.get('current_status')}") return data["main_service_url"].rstrip("/") def upload_file_from_base64(comfy_url: str, filename: str, base64_data: str, file_type: str, subfolder: str): file_bytes = base64.b64decode(base64_data) files = [("image", (filename, file_bytes, "application/octet-stream"))] data = {"overwrite": "true", "type": file_type, "subfolder": subfolder} resp = requests.post(f"{comfy_url}/upload/image", data=data, files=files) resp.raise_for_status() return resp.json()["name"] def submit_prompt(comfy_url: str, workflow_api: dict, client_id: str) -> str: payload = {"prompt": workflow_api, "client_id": client_id} resp = requests.post(f"{comfy_url}/prompt", json=payload) resp.raise_for_status() return resp.json()["prompt_id"] def wait_for_completion(comfy_url: str, client_id: str, prompt_id: str, timeout: int = 600): scheme = "wss" if comfy_url.startswith("https") else "ws" ws_url = f"{scheme}://{comfy_url.split('://', 1)[-1]}/ws?clientId={client_id}" ws = websocket.WebSocket() ws.settimeout(timeout) ws.connect(ws_url) try: while True: out = ws.recv() if not out or isinstance(out, bytes): continue msg = json.loads(out) if msg.get("type") == "executing": data = msg.get("data", {}) if data.get("prompt_id") == prompt_id and data.get("node") is None: break elif msg.get("type") == "execution_error": if msg.get("data", {}).get("prompt_id") == prompt_id: raise Exception(f"执行错误: {msg['data'].get('exception_message')}") finally: ws.close() def download_images_as_base64(comfy_url: str, prompt_id: str) -> list[str]: resp = requests.get(f"{comfy_url}/history/{prompt_id}") resp.raise_for_status() outputs = resp.json().get(prompt_id, {}).get("outputs", {}) images = [] for node_output in outputs.values(): if "images" in node_output: for img in node_output["images"]: params = {"filename": img["filename"], "subfolder": img.get("subfolder", ""), "type": img.get("type", "output")} resp = requests.get(f"{comfy_url}/view", params=params) resp.raise_for_status() images.append(base64.b64encode(resp.content).decode()) return images @app.post("/run", response_model=WorkflowResponse) async def run_workflow(request: WorkflowRequest): try: comfy_url = get_server_url(request.server_id) client_id = str(uuid.uuid4()) if request.input_files: for file in request.input_files: mapping = SUBDIR_UPLOAD_MAP.get(file.type, {"type": "input", "subfolder": file.type}) upload_file_from_base64(comfy_url, file.filename, file.base64_data, mapping["type"], mapping["subfolder"]) prompt_id = submit_prompt(comfy_url, request.workflow_api, client_id) wait_for_completion(comfy_url, client_id, prompt_id) images = download_images_as_base64(comfy_url, prompt_id) return WorkflowResponse(prompt_id=prompt_id, images=images, status="Success", server_id=request.server_id) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) if __name__ == "__main__": import uvicorn port = int(os.getenv("PORT", "8000")) uvicorn.run(app, host="0.0.0.0", port=port)