check_workflow.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194
  1. #!/usr/bin/env python3
  2. """检查 workflow_api.json 需要哪些外部输入
  3. 分析所有 LoadImage / LoadVideo 等输入节点,列出需要提供的文件。
  4. 检查所有节点连接是否完整,发现缺失的输入。
  5. 用法:
  6. python check_workflow.py workflow_api.json
  7. python check_workflow.py workflow_api.json --images ref.png pose.jpg
  8. """
  9. import argparse
  10. import json
  11. from pathlib import Path
  12. # 需要外部文件输入的节点类型,以及对应的文件参数名
  13. FILE_INPUT_NODES = {
  14. "LoadImage": ["image"],
  15. "LoadVideo": ["video"],
  16. "LoadVideoPath": ["video"],
  17. "LoadImageMask": ["image"],
  18. "LoadImageOutput": ["image"], # 从 output 目录加载,特殊处理
  19. "VHS_LoadVideo": ["video"],
  20. "VHS_LoadImages": ["directory"],
  21. }
  22. # 纯 UI / 注释节点,无实际输入
  23. SKIP_NODES = {"Note", "MarkdownNote", "PrimitiveNode"}
  24. def analyze(workflow: dict) -> dict:
  25. node_ids = set(workflow.keys())
  26. # 收集所有被其他节点引用的 [node_id, slot] 连接
  27. referenced = set()
  28. for node_id, node in workflow.items():
  29. for key, val in node.get("inputs", {}).items():
  30. if isinstance(val, list) and len(val) == 2 and isinstance(val[0], str):
  31. referenced.add(val[0])
  32. issues = []
  33. file_inputs = [] # 需要上传的文件
  34. widget_params = [] # widget_* 占位参数(转换不准确的)
  35. for node_id, node in workflow.items():
  36. class_type = node.get("class_type", "")
  37. if class_type in SKIP_NODES:
  38. continue
  39. inputs = node.get("inputs", {})
  40. # 检查连接引用是否存在
  41. for key, val in inputs.items():
  42. if isinstance(val, list) and len(val) == 2 and isinstance(val[0], str):
  43. ref_node = val[0]
  44. if ref_node not in node_ids:
  45. issues.append(f"节点 [{node_id}] {class_type}: 输入 '{key}' 引用了不存在的节点 [{ref_node}]")
  46. # 检查文件输入节点
  47. if class_type in FILE_INPUT_NODES:
  48. param_names = FILE_INPUT_NODES[class_type]
  49. for param in param_names:
  50. # 可能是真实参数名,也可能被转成了 widget_*
  51. value = inputs.get(param)
  52. if value is None:
  53. # 找 widget_* 里的值
  54. widget_vals = {k: v for k, v in inputs.items() if k.startswith("widget_")}
  55. value = next(iter(widget_vals.values()), None) if widget_vals else None
  56. is_output_node = class_type == "LoadImageOutput"
  57. file_inputs.append({
  58. "node_id": node_id,
  59. "class_type": class_type,
  60. "param": param,
  61. "current_value": value,
  62. "is_output": is_output_node,
  63. })
  64. # 标记 widget_* 占位参数(说明转换不准确)
  65. widget_keys = [k for k in inputs if k.startswith("widget_")]
  66. if widget_keys and class_type not in FILE_INPUT_NODES:
  67. widget_params.append({
  68. "node_id": node_id,
  69. "class_type": class_type,
  70. "params": {k: inputs[k] for k in widget_keys},
  71. })
  72. return {
  73. "file_inputs": file_inputs,
  74. "widget_params": widget_params,
  75. "issues": issues,
  76. }
  77. def check_files_exist(file_inputs: list, provided: list[Path]) -> list:
  78. provided_names = {p.name for p in provided if p.exists()}
  79. missing = []
  80. for fi in file_inputs:
  81. if fi["is_output"]:
  82. continue # LoadImageOutput 从服务器 output 目录读,不需要上传
  83. val = fi["current_value"]
  84. if val and isinstance(val, str):
  85. # 去掉 [output] 等后缀
  86. filename = val.split(" ")[0]
  87. if filename not in provided_names:
  88. missing.append({**fi, "filename": filename})
  89. return missing
  90. def main():
  91. parser = argparse.ArgumentParser(description="检查 workflow_api.json 所需输入")
  92. parser.add_argument("workflow", help="workflow_api.json 路径")
  93. parser.add_argument("--input-dir", default="input", metavar="DIR", help="输入文件目录,默认 input/")
  94. args = parser.parse_args()
  95. workflow_path = Path(args.workflow)
  96. if not workflow_path.exists():
  97. print(f"ERROR: 文件不存在: {workflow_path}")
  98. return 1
  99. with open(workflow_path, "r", encoding="utf-8") as f:
  100. workflow = json.load(f)
  101. input_dir = Path(args.input_dir)
  102. all_input_files = list(input_dir.rglob("*")) if input_dir.exists() else []
  103. print(f"=== 检查 {workflow_path.name} ===\n")
  104. print(f"节点总数: {len(workflow)}")
  105. print(f"input 目录: {input_dir} ({'存在' if input_dir.exists() else '不存在'})")
  106. if all_input_files:
  107. print(f" 已有文件:")
  108. for f in sorted(all_input_files):
  109. if f.is_file():
  110. print(f" - {f.relative_to(input_dir)}")
  111. result = analyze(workflow)
  112. # ── 文件输入 ──
  113. print(f"\n── 文件输入节点 ({len(result['file_inputs'])} 个) ──")
  114. if result["file_inputs"]:
  115. for fi in result["file_inputs"]:
  116. tag = "[output目录]" if fi["is_output"] else "[需上传]"
  117. print(f" [{fi['node_id']}] {fi['class_type']}.{fi['param']}")
  118. print(f" 当前值: {fi['current_value']} {tag}")
  119. else:
  120. print(" 无")
  121. # ── widget_* 警告 ──
  122. if result["widget_params"]:
  123. print(f"\n── ⚠️ widget_* 占位参数(参数名不准确,建议用 ComfyUI 导出 API 格式)──")
  124. for wp in result["widget_params"]:
  125. print(f" [{wp['node_id']}] {wp['class_type']}")
  126. for k, v in wp["params"].items():
  127. print(f" {k}: {v}")
  128. # ── 连接问题 ──
  129. if result["issues"]:
  130. print(f"\n── ❌ 连接问题 ({len(result['issues'])} 个) ──")
  131. for issue in result["issues"]:
  132. print(f" {issue}")
  133. # ── 文件缺失检查 ──
  134. missing = check_files_exist(result["file_inputs"], all_input_files)
  135. print(f"\n── 文件准备状态 ──")
  136. needs_upload = [fi for fi in result["file_inputs"] if not fi["is_output"]]
  137. if not needs_upload:
  138. print(" 无需上传文件")
  139. else:
  140. for fi in needs_upload:
  141. filename = (fi["current_value"] or "").split(" ")[0]
  142. found = any(f.is_file() and f.name == filename for f in all_input_files)
  143. status = "✓ 已提供" if found else "✗ 缺失"
  144. print(f" {status} {filename} (节点 [{fi['node_id']}] {fi['class_type']})")
  145. # ── 最终结论 ──
  146. print()
  147. has_error = bool(result["issues"]) or bool(missing)
  148. has_warn = bool(result["widget_params"])
  149. if has_error:
  150. print("❌ 检查未通过,请将缺失文件放入 input/ 目录后再提交")
  151. return 1
  152. elif has_warn:
  153. print("⚠️ 存在 widget_* 占位参数,建议用 ComfyUI 导出正确的 API 格式,否则可能运行出错")
  154. return 0
  155. else:
  156. print("✓ 检查通过,可以提交")
  157. return 0
  158. if __name__ == "__main__":
  159. exit(main())