convert_workflow.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. #!/usr/bin/env python3
  2. """将 ComfyUI workflow.json (UI格式) 转换为 workflow_api.json (API格式)
  3. 用法:
  4. python convert_workflow.py refcontrol_pose.json
  5. python convert_workflow.py refcontrol_pose.json -o output_api.json
  6. """
  7. import argparse
  8. import json
  9. from pathlib import Path
  10. def convert(workflow: dict) -> dict:
  11. """
  12. workflow.json → workflow_api.json
  13. UI格式: nodes[] 每个节点有 id/type/inputs[]/outputs[]/widgets_values[]/links[]
  14. API格式: {node_id: {class_type, inputs: {param: value_or_link}}}
  15. 连接关系: links[] = [link_id, src_node, src_slot, dst_node, dst_slot]
  16. widgets_values 按顺序对应节点没有连线的输入参数
  17. """
  18. nodes = workflow.get("nodes", [])
  19. links = workflow.get("links", [])
  20. # 建立 link_id → [src_node_id, src_slot] 的映射
  21. link_map = {}
  22. for link in links:
  23. link_id, src_node, src_slot, dst_node, dst_slot = link[:5]
  24. link_map[link_id] = [str(src_node), src_slot]
  25. api = {}
  26. for node in nodes:
  27. node_id = str(node["id"])
  28. class_type = node.get("type", "")
  29. # 跳过纯 UI 节点(Note、MarkdownNote 等无实际计算的节点)
  30. if class_type in ("Note", "MarkdownNote", "PrimitiveNode"):
  31. continue
  32. # 跳过 mode=4(disabled)的节点
  33. if node.get("mode") == 4:
  34. continue
  35. inputs_def = node.get("inputs", []) # 有连线的输入槽
  36. widgets_values = node.get("widgets_values", []) # 无连线的参数值
  37. inputs = {}
  38. # 先处理有连线的输入槽
  39. linked_slots = set()
  40. for inp in inputs_def:
  41. link_id = inp.get("link")
  42. if link_id is not None and link_id in link_map:
  43. inputs[inp["name"]] = link_map[link_id]
  44. linked_slots.add(inp["name"])
  45. # 再处理 widgets_values(按顺序填入没有连线的参数)
  46. # 需要知道节点有哪些 widget 参数,通过排除已连线的输入来推断
  47. # widgets 对应的参数名无法从 workflow.json 直接获取,只能按顺序赋值
  48. # 用 widget_idx_0, widget_idx_1 ... 作为 key,实际使用时按需重命名
  49. for i, val in enumerate(widgets_values):
  50. key = f"widget_{i}"
  51. inputs[key] = val
  52. api[node_id] = {
  53. "class_type": class_type,
  54. "inputs": inputs,
  55. }
  56. return api
  57. def main():
  58. parser = argparse.ArgumentParser(description="Convert workflow.json to workflow_api.json")
  59. parser.add_argument("input", help="输入 workflow.json 路径")
  60. parser.add_argument("-o", "--output", help="输出路径,默认在原文件名加 _api 后缀")
  61. args = parser.parse_args()
  62. input_path = Path(args.input)
  63. output_path = Path(args.output) if args.output else input_path.with_stem(input_path.stem + "_api")
  64. with open(input_path, "r", encoding="utf-8") as f:
  65. workflow = json.load(f)
  66. # 支持两种格式:直接是 nodes[] 的对象,或者包含 nodes[] 的对象
  67. if "nodes" not in workflow:
  68. print("ERROR: 不是有效的 workflow.json,缺少 nodes 字段")
  69. return
  70. api = convert(workflow)
  71. with open(output_path, "w", encoding="utf-8") as f:
  72. json.dump(api, f, indent=2, ensure_ascii=False)
  73. print(f"转换完成: {output_path}")
  74. print(f"节点数: {len(api)}")
  75. for node_id, node in api.items():
  76. print(f" [{node_id}] {node['class_type']}")
  77. if __name__ == "__main__":
  78. main()