main.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474
  1. import os
  2. import json
  3. import requests
  4. import warnings
  5. from pathlib import Path
  6. from dotenv import load_dotenv
  7. load_dotenv()
  8. # ---------------------------------------------------------------------------
  9. # Tag annotations for example workflows — maps filename stems to a brief
  10. # human-readable description so that agents can quickly decide which example
  11. # to load. Maintained manually; new scraped files should be added here.
  12. # ---------------------------------------------------------------------------
  13. _WORKFLOW_ANNOTATIONS = {
  14. # --- flux ---
  15. "flux_dev_example": "Flux Dev txt2img (UNETLoader + SamplerCustomAdvanced + ModelSamplingFlux)",
  16. "flux_dev_checkpoint_example": "Flux Dev txt2img via CheckpointLoaderSimple (简化版, KSampler cfg=1)",
  17. "flux_schnell_example": "Flux Schnell 4-step fast txt2img (UNETLoader)",
  18. "flux_schnell_checkpoint_example": "Flux Schnell 4-step via CheckpointLoaderSimple",
  19. "flux_controlnet_example": "Flux + ControlNet (Canny) via ControlNetApplySD3",
  20. "flux_canny_model_example": "Flux Canny 内建模型 (InstructPixToPixConditioning)",
  21. "flux_depth_lora_example": "Flux Depth LoRA (LoraLoaderModelOnly + InstructPixToPixConditioning)",
  22. "flux_fill_inpaint_example": "Flux Fill inpaint (DifferentialDiffusion + InpaintModelConditioning)",
  23. "flux_fill_outpaint_example": "Flux Fill outpaint (ImagePadForOutpaint + InpaintModelConditioning)",
  24. "flux_redux_model_example": "Flux Redux 图像风格迁移 (CLIPVisionLoader + StyleModelApply)",
  25. "flux_kontext_example": "Flux Kontext 图像编辑/角色一致性 (FluxKontextImageScale + ReferenceLatent)",
  26. # --- flux2 ---
  27. "flux2_example": "Flux 2 txt2img",
  28. # --- sdxl ---
  29. "sdxl_simple_example": "SDXL Base+Refiner 两阶段 txt2img (KSamplerAdvanced)",
  30. "sdxl_refiner_prompt_example": "SDXL Base+Refiner 各自独立提示词",
  31. "sdxl_revision_text_prompts": "SDXL Revision 文本提示融合",
  32. "sdxl_revision_zero_positive": "SDXL Revision zero-positive 风格",
  33. # --- sd3 ---
  34. "sd3.5_simple_example": "SD3.5 基础 txt2img",
  35. "sd3.5_text_encoders_example": "SD3.5 三编码器 (clip_g + clip_l + t5xxl)",
  36. "sd3.5_large_canny_controlnet_example": "SD3.5 Large + Canny ControlNet",
  37. # --- controlnet ---
  38. "controlnet_example": "SD1.5 ControlNet (scribble) 基础用法",
  39. "depth_controlnet": "SD1.5 Depth ControlNet",
  40. "depth_t2i_adapter": "SD1.5 Depth T2I-Adapter",
  41. "mixing_controlnets": "SD1.5 混合双 ControlNet (openpose + scribble)",
  42. "2_pass_pose_worship": "SD1.5 两阶段 ControlNet + LatentUpscale Hi-Res Fix",
  43. # --- img2img ---
  44. "img2img_workflow": "SD1.5 图生图 (VAEEncode, denoise<1)",
  45. # --- inpaint ---
  46. "inpaint_example": "SD1.5 Inpaint (VAEEncodeForInpaint + 专用 checkpoint)",
  47. "inpain_model_cat": "SD1.5 Inpaint 猫咪涂抹",
  48. "inpain_model_woman": "SD1.5 Inpaint 女性涂抹",
  49. "inpain_model_outpainting": "SD1.5 Outpaint (ImagePadForOutpaint + VAEEncodeForInpaint)",
  50. "inpaint_anythingv3_woman": "SD1.5 Inpaint AnythingV3",
  51. "yosemite_outpaint_example": "SD1.5 Outpaint Yosemite 扩展画布",
  52. # --- lora ---
  53. "lora": "SD1.5 单 LoRA (LoraLoader)",
  54. "lora_multiple": "SD1.5 多 LoRA 链式堆叠",
  55. # --- upscale ---
  56. "esrgan_example": "ESRGAN 超分辨率 (UpscaleModelLoader + ImageUpscaleWithModel)",
  57. # --- area_composition ---
  58. "square_area_for_subject": "区域化构图 (ConditioningSetArea)",
  59. "workflow_night_evening_day_morning": "四时段区域化构图",
  60. # --- others ---
  61. "aura_flow_0.1_example": "AuraFlow 0.1 txt2img",
  62. "aura_flow_0.2_example": "AuraFlow 0.2 txt2img",
  63. "chroma_example": "Chroma 模型 txt2img",
  64. "cosmos_predict2_2b_t2i_example":"Cosmos Predict2 2B txt2img",
  65. "sdxl_edit_model": "SDXL Edit Model (InstructPixToPixConditioning)",
  66. "gligen_textbox_example": "GLIGEN 文本框定位 (GLIGENTextBoxApply)",
  67. "hidream_dev_example": "HiDream Dev txt2img",
  68. "hidream_e1.1_example": "HiDream E1.1 txt2img",
  69. "hidream_full_example": "HiDream Full txt2img",
  70. "hunyuan_dit_1.2_example": "HunyuanDiT 1.2 txt2img",
  71. "hunyuan_image_example": "Hunyuan Image txt2img",
  72. "hypernetwork_example": "Hypernetwork 示例",
  73. "lcm_basic_example": "LCM 快速采样",
  74. "lumina2_basic_example": "Lumina2 txt2img",
  75. "model_merging_basic": "模型合并 基础 (ModelMergeSimple)",
  76. "model_merging_3_checkpoints": "模型合并 三模型",
  77. "model_merging_cosxl": "模型合并 CosXL",
  78. "model_merging_inpaint": "模型合并 Inpaint",
  79. "model_merging_lora": "模型合并 LoRA",
  80. "noisy_latents_3_subjects": "噪声潜空间构图 三主体",
  81. "noisy_latents_3_subjects_": "噪声潜空间构图 三主体 (变体)",
  82. "omnigen2_example": "OmniGen2 txt2img",
  83. "qwen_image_basic_example": "Qwen Image 基础 txt2img",
  84. "qwen_image_edit_basic_example": "Qwen Image Edit 基础编辑",
  85. "qwen_image_edit_2509_basic_example": "Qwen Image Edit 2509 编辑",
  86. "sdxlturbo_example": "SDXL Turbo 快速采样",
  87. "stable_cascade__text_to_image": "Stable Cascade txt2img",
  88. "stable_cascade__image_to_image":"Stable Cascade img2img",
  89. "stable_cascade__canny_controlnet":"Stable Cascade Canny ControlNet",
  90. "stable_cascade__inpaint_controlnet":"Stable Cascade Inpaint ControlNet",
  91. "stable_cascade__image_remixing":"Stable Cascade 图像混合",
  92. "stable_cascade__image_remixing_multiple":"Stable Cascade 多图混合",
  93. "embedding_example": "Textual Inversion (embedding) 示例",
  94. "unclip_example": "UnCLIP 基础",
  95. "unclip_2pass": "UnCLIP 两阶段",
  96. "unclip_example_multiple": "UnCLIP 多图输入",
  97. "z_image_turbo_example": "Z-Image Turbo txt2img",
  98. }
  99. class RunComfySchemaInspector:
  100. def __init__(self, server_url=None):
  101. self.server_url = server_url
  102. self.object_info = {}
  103. # We store a backup copy inside this directory in case the cloud machine is sleep/offline.
  104. self.cache_path = os.path.join(os.path.dirname(__file__), "object_info_cache.json")
  105. self._examples_dir = self._locate_examples_dir()
  106. self._hot_reload()
  107. # ------------------------------------------------------------------
  108. # Private helpers
  109. # ------------------------------------------------------------------
  110. @staticmethod
  111. def _locate_examples_dir() -> str:
  112. """Walk up from this file to find the project root containing `data/comfyui_examples`."""
  113. anchor = Path(__file__).resolve()
  114. for parent in [anchor] + list(anchor.parents):
  115. candidate = parent / "data" / "comfyui_examples"
  116. if candidate.is_dir():
  117. return str(candidate)
  118. # Fallback: assume standard repo layout (this file is 3 levels deep under repo root)
  119. fallback = Path(__file__).resolve().parents[3] / "data" / "comfyui_examples"
  120. return str(fallback)
  121. def _hot_reload(self):
  122. """Attempts to fetch fresh object_info from the server. Falls back to cached JSON if offline."""
  123. if not self.server_url:
  124. # Fallback to standard generic environment if none provided
  125. # Find an active server if possible, or use the dedicated testing ID
  126. self.server_url = "https://90f77137-ba75-400d-870f-204c614ae8a3-comfyui.runcomfy.com"
  127. print(f"[SchemaInspector] Attempting hot-reload from {self.server_url}/object_info...")
  128. try:
  129. resp = requests.get(f"{self.server_url}/object_info", timeout=10)
  130. if resp.status_code == 200:
  131. self.object_info = resp.json()
  132. # Update local cache
  133. with open(self.cache_path, "w", encoding="utf-8") as f:
  134. json.dump(self.object_info, f)
  135. print("[SchemaInspector] Successfully updated schema from remote server.")
  136. return
  137. except Exception as e:
  138. print(f"[SchemaInspector] Warning: Hot-reload failed ({e}). Machine might be offline.")
  139. if os.path.exists(self.cache_path):
  140. print("[SchemaInspector] Loading schema from local cache.")
  141. with open(self.cache_path, "r", encoding="utf-8") as f:
  142. self.object_info = json.load(f)
  143. else:
  144. warnings.warn("No active remote connection and no local cache found! Some tools will fail.")
  145. # ------------------------------------------------------------------
  146. # Schema / Model inspection (existing)
  147. # ------------------------------------------------------------------
  148. def get_node_schema(self, class_type: str) -> dict:
  149. """Returns the rigorous Required and Optional properties for a specific ComfyUI Node."""
  150. if class_type not in self.object_info:
  151. return {"error": f"Node '{class_type}' not found in the environment registry."}
  152. node_def = self.object_info[class_type]
  153. schema = {
  154. "name": class_type,
  155. "inputs": {
  156. "required": node_def.get("input", {}).get("required", {}),
  157. "optional": node_def.get("input", {}).get("optional", {})
  158. },
  159. "outputs": node_def.get("output_name", [])
  160. }
  161. return schema
  162. def search_models(self, category: str = "checkpoints", keyword: str = "") -> list:
  163. """
  164. category: 'checkpoints', 'loras', 'vaes', 'controlnets'
  165. """
  166. target_keys = {
  167. "checkpoints": ("CheckpointLoaderSimple", "ckpt_name"),
  168. "loras": ("LoraLoader", "lora_name"),
  169. "vaes": ("VAELoader", "vae_name"),
  170. "controlnets": ("ControlNetLoader", "control_net_name")
  171. }
  172. if category not in target_keys:
  173. return []
  174. node_type, prop = target_keys[category]
  175. if node_type not in self.object_info:
  176. return []
  177. try:
  178. items = self.object_info[node_type]["input"]["required"][prop][0]
  179. kw = keyword.lower()
  180. return [x for x in items if kw in x.lower()]
  181. except (KeyError, IndexError):
  182. return []
  183. def verify_workflow(self, api_json: dict) -> dict:
  184. """
  185. Validates an LLM-generated API JSON against the dynamic schema.
  186. Returns a dict containing {"valid": bool, "errors": list_of_strings}
  187. """
  188. errors = []
  189. if not isinstance(api_json, dict):
  190. return {"valid": False, "errors": ["workflow_api 必须是一个字典对象"]}
  191. for node_id, node in api_json.items():
  192. ctype = node.get("class_type")
  193. if not ctype:
  194. errors.append(f"Node '{node_id}' is missing a class_type.")
  195. continue
  196. if ctype not in self.object_info:
  197. errors.append(f"Node '{node_id}' requests non-existent class '{ctype}'.")
  198. continue
  199. expected_req = self.object_info[ctype].get("input", {}).get("required", {})
  200. actual_inputs = node.get("inputs", {})
  201. for req_key in expected_req.keys():
  202. if req_key not in actual_inputs:
  203. errors.append(f"Node '{node_id}' ({ctype}) is missing REQUIRED input '{req_key}'.")
  204. # 检查连线合法性 (检查断线)
  205. for input_key, input_value in actual_inputs.items():
  206. if isinstance(input_value, list) and len(input_value) >= 2:
  207. target_node = str(input_value[0])
  208. if target_node not in api_json:
  209. errors.append(f"Node '{node_id}' ({ctype}) 的 '{input_key}' 连向了不存在的节点: '{target_node}'")
  210. return {
  211. "valid": len(errors) == 0,
  212. "errors": errors
  213. }
  214. # ------------------------------------------------------------------
  215. # Example workflow browsing & loading (NEW)
  216. # ------------------------------------------------------------------
  217. def list_example_workflows(self, category: str = None, keyword: str = "") -> dict:
  218. """
  219. Browse the built-in example workflow library.
  220. Args:
  221. category: Filter by subdirectory name (e.g. 'flux', 'controlnet', 'inpaint').
  222. Pass None to list ALL categories and their workflows.
  223. keyword: Optional keyword filter on filename (case-insensitive).
  224. Returns:
  225. {
  226. "examples_dir": str,
  227. "categories": {
  228. "flux": [
  229. {"name": "flux_dev_example", "file": "flux_dev_example_api.json",
  230. "description": "Flux Dev txt2img ...", "path": "flux/flux_dev_example_api.json"},
  231. ...
  232. ],
  233. ...
  234. }
  235. }
  236. """
  237. base = Path(self._examples_dir)
  238. if not base.is_dir():
  239. return {"error": f"Examples directory not found: {self._examples_dir}"}
  240. kw = keyword.lower()
  241. result_categories = {}
  242. for cat_dir in sorted(base.iterdir()):
  243. if not cat_dir.is_dir():
  244. continue
  245. cat_name = cat_dir.name
  246. # Category filter
  247. if category and cat_name != category:
  248. continue
  249. entries = []
  250. for json_file in sorted(cat_dir.glob("*.json")):
  251. fname = json_file.name
  252. # strip the trailing _api.json to get the stem
  253. stem = fname.replace("_api.json", "").replace("_api.", ".")
  254. # Keyword filter on stem
  255. if kw and kw not in stem.lower():
  256. continue
  257. desc = _WORKFLOW_ANNOTATIONS.get(stem, "")
  258. entries.append({
  259. "name": stem,
  260. "file": fname,
  261. "description": desc,
  262. "path": f"{cat_name}/{fname}",
  263. })
  264. if entries:
  265. result_categories[cat_name] = entries
  266. return {
  267. "examples_dir": self._examples_dir,
  268. "categories": result_categories,
  269. }
  270. def load_example_workflow(self, name: str) -> dict:
  271. """
  272. Load a specific example workflow as a Python dict.
  273. Args:
  274. name: Can be any of the following forms:
  275. - Full relative path: "flux/flux_dev_example_api.json"
  276. - Stem (auto-resolved): "flux_dev_example"
  277. - Partial keyword: "flux_dev" (picks first match)
  278. Returns:
  279. {"name": str, "path": str, "description": str, "workflow": dict}
  280. or {"error": str} if not found.
  281. """
  282. base = Path(self._examples_dir)
  283. # Strategy 1: exact relative path
  284. full_path = base / name
  285. if full_path.is_file():
  286. return self._read_example(full_path)
  287. # Strategy 2: exact stem → category/stem_api.json
  288. for json_file in base.rglob("*.json"):
  289. stem = json_file.name.replace("_api.json", "").replace("_api.", ".")
  290. if stem == name:
  291. return self._read_example(json_file)
  292. # Strategy 3: partial keyword match (first hit)
  293. name_lower = name.lower()
  294. for json_file in sorted(base.rglob("*.json")):
  295. if name_lower in json_file.stem.lower():
  296. return self._read_example(json_file)
  297. return {"error": f"No example workflow matching '{name}' found in {self._examples_dir}"}
  298. def _read_example(self, filepath: Path) -> dict:
  299. """Read a single example JSON and return annotated result."""
  300. base = Path(self._examples_dir)
  301. rel = filepath.relative_to(base)
  302. stem = filepath.name.replace("_api.json", "").replace("_api.", ".")
  303. desc = _WORKFLOW_ANNOTATIONS.get(stem, "")
  304. with open(filepath, "r", encoding="utf-8") as f:
  305. wf = json.load(f)
  306. # Build a quick summary of nodes used
  307. node_types = sorted(set(
  308. n.get("class_type", "?") for n in wf.values() if isinstance(n, dict)
  309. ))
  310. return {
  311. "name": stem,
  312. "path": str(rel),
  313. "description": desc,
  314. "node_types_used": node_types,
  315. "node_count": len(wf),
  316. "workflow": wf,
  317. }
  318. # ===========================================================================
  319. # FastAPI 服务层 — 统一 /query 端点,通过 action 参数分发
  320. # ===========================================================================
  321. import argparse
  322. from typing import Any, Optional
  323. from fastapi import FastAPI, HTTPException
  324. from pydantic import BaseModel, Field
  325. import uvicorn
  326. app = FastAPI(title="RunComfy Workflow Builder", version="1.0")
  327. # 懒加载 inspector(首次请求时初始化,避免启动时阻塞)
  328. _inspector: RunComfySchemaInspector | None = None
  329. def _get_inspector() -> RunComfySchemaInspector:
  330. global _inspector
  331. if _inspector is None:
  332. _inspector = RunComfySchemaInspector()
  333. return _inspector
  334. class QueryRequest(BaseModel):
  335. action: str = Field(..., description=(
  336. "要执行的操作: "
  337. "search_models | get_node_schema | verify_workflow | "
  338. "list_examples | load_example | read_skill"
  339. ))
  340. # 以下参数按 action 使用,不需要的可以不传
  341. category: Optional[str] = Field(None, description="模型分类(checkpoints/loras/vaes/controlnets) 或示例分类(flux/controlnet/...)")
  342. keyword: Optional[str] = Field(None, description="搜索关键词")
  343. class_type: Optional[str] = Field(None, description="ComfyUI 节点类型名 (get_node_schema 用)")
  344. name: Optional[str] = Field(None, description="示例工作流名称 (load_example 用)")
  345. workflow: Optional[dict[str, Any]] = Field(None, description="待验证的 API JSON (verify_workflow 用)")
  346. @app.get("/health")
  347. def health():
  348. return {"status": "ok"}
  349. @app.post("/")
  350. def query(req: QueryRequest):
  351. """统一入口 — 根据 action 分发到对应的内部方法。"""
  352. inspector = _get_inspector()
  353. # ---------- search_models ----------
  354. if req.action == "search_models":
  355. results = inspector.search_models(
  356. category=req.category or "checkpoints",
  357. keyword=req.keyword or ""
  358. )
  359. return {"action": "search_models", "count": len(results), "models": results}
  360. # ---------- get_node_schema ----------
  361. elif req.action == "get_node_schema":
  362. if not req.class_type:
  363. raise HTTPException(400, "get_node_schema 需要 class_type 参数")
  364. schema = inspector.get_node_schema(req.class_type)
  365. return {"action": "get_node_schema", **schema}
  366. # ---------- verify_workflow ----------
  367. elif req.action == "verify_workflow":
  368. if not req.workflow:
  369. raise HTTPException(400, "verify_workflow 需要 workflow 参数 (API JSON dict)")
  370. result = inspector.verify_workflow(req.workflow)
  371. return {"action": "verify_workflow", **result}
  372. # ---------- list_examples ----------
  373. elif req.action == "list_examples":
  374. result = inspector.list_example_workflows(
  375. category=req.category,
  376. keyword=req.keyword or ""
  377. )
  378. return {"action": "list_examples", **result}
  379. # ---------- load_example ----------
  380. elif req.action == "load_example":
  381. if not req.name:
  382. raise HTTPException(400, "load_example 需要 name 参数")
  383. result = inspector.load_example_workflow(req.name)
  384. if "error" in result:
  385. raise HTTPException(404, result["error"])
  386. return {"action": "load_example", **result}
  387. # ---------- read_skill ----------
  388. elif req.action == "read_skill":
  389. skill_path = Path(__file__).parent / "skill.md"
  390. if not skill_path.exists():
  391. raise HTTPException(404, "skill.md not found")
  392. content = skill_path.read_text(encoding="utf-8")
  393. return {"action": "read_skill", "content": content}
  394. else:
  395. raise HTTPException(400, (
  396. f"未知 action: '{req.action}'。"
  397. "支持: search_models, get_node_schema, verify_workflow, "
  398. "list_examples, load_example, read_skill"
  399. ))
  400. if __name__ == "__main__":
  401. parser = argparse.ArgumentParser()
  402. parser.add_argument("--port", type=int, default=8010)
  403. args = parser.parse_args()
  404. uvicorn.run(app, host="0.0.0.0", port=args.port)