extract_workflow.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349
  1. """
  2. 逐 case 提取 workflow + capability (v6版本)
  3. 从 case.json 读取,按 index 遍历每个 case,
  4. 调用 LLM 同时提取 workflow(薄壳 steps)和 capability(原子操作),
  5. 按 index 原位回填到 case.json
  6. v6 架构特性:
  7. - workflow.steps 是薄壳:step_id / order / phase,不含 capability 字段
  8. - capability 是原子操作列表:每个 capability 含 workflow_step_ref + is_alternative_to
  9. - 步内多原子操作 + 步内 alternative 都在 capability 层表达
  10. - standalone capability(workflow_step_ref=null)用于无 workflow 上下文的能力提及
  11. """
  12. import asyncio
  13. import json
  14. from pathlib import Path
  15. from typing import Any, Dict, Optional, List
  16. from examples.process_pipeline.script.llm_helper import call_llm_with_retry
  17. # v5 词库文件路径
  18. SCRIPT_DIR = Path(__file__).resolve().parent
  19. METHOD_VOCAB_PATH = SCRIPT_DIR / "resource" / "method_vocab_v5.json"
  20. # 默认词库(如果文件不存在时使用)
  21. DEFAULT_METHOD_VOCAB = {
  22. "流程角色": [
  23. "生成指令", "编辑指令", "约束条件", "参考素材", "控制信号",
  24. "区域控制", "参数配置", "模型资源", "源素材", "中间产物",
  25. "成品", "模板", "评估结果"
  26. ],
  27. "模态": ["文本", "图片", "视频", "音频", "特征点", "参数", "模型", "向量"],
  28. "主动作": [
  29. "生成", "编辑", "提取", "改写", "合成", "修复", "增强",
  30. "训练", "评估", "剪辑", "模板化", "排版", "转写", "配音",
  31. "匹配", "扩展", "导出"
  32. ],
  33. "动作方式": [
  34. "直接生成", "一致性保持", "结构约束", "质量收束", "局部重绘",
  35. "扩图", "换背景", "提示词反推", "模板化", "多图融合", "清晰化",
  36. "风格迁移", "常规编辑", "变体生成", "动画化", "镜头延展",
  37. "换主体", "换装", "擦除", "调色", "前后景融合", "图文合成",
  38. "音画合成", "分层叠加", "特征提取", "蒙版提取", "关键帧提取",
  39. "字幕提取", "风格提取", "片段拼接", "节奏压缩", "转场编排",
  40. "字幕对齐", "音画同步", "降噪", "补帧", "超分", "稳定化",
  41. "质感增强", "结构抽象", "变量抽象", "版式套用", "格式转换", "压缩导出"
  42. ],
  43. }
  44. def load_method_vocab() -> Dict[str, list]:
  45. """从 JSON 文件加载结构化词库(v5)"""
  46. if METHOD_VOCAB_PATH.exists():
  47. try:
  48. with open(METHOD_VOCAB_PATH, "r", encoding="utf-8") as f:
  49. return json.load(f)
  50. except Exception as e:
  51. print(f"Warning: Failed to load method_vocab.json: {e}, using default")
  52. return DEFAULT_METHOD_VOCAB
  53. def load_prompt_template(prompt_name: str) -> str:
  54. base_dir = Path(__file__).parent.parent
  55. prompt_path = base_dir / "prompts" / f"{prompt_name}.prompt"
  56. with open(prompt_path, "r", encoding="utf-8") as f:
  57. content = f.read()
  58. if content.startswith("---"):
  59. parts = content.split("---", 2)
  60. if len(parts) >= 3:
  61. content = parts[2]
  62. content = content.replace("$system$", "").replace("$user$", "")
  63. return content.strip()
  64. def render_method_vocab_block(vocab: Dict[str, list]) -> str:
  65. """渲染结构化接口词库说明(v5)"""
  66. lines = [
  67. "\n# 结构化接口词库(v5,必须遵守)",
  68. "只输出结构化 inputs / outputs / action。",
  69. "- `role/流程角色` 只写接口职责,不写具体内容 what。",
  70. "- `modality/模态` 只写媒介或数据形态;统一用 `图片`,不要写 `图像`;统一用 `文本`,不要写 `文字`。",
  71. "- `artifact_type/工件类型` 写该模态下的具体工件,如 `正向提示词`、`蒙版`。",
  72. "- `action.description` 写输入到输出之间客观发生的信息变化;`action.reasoning` 写判断依据。",
  73. "- 只有词库确实不够时才新增术语;新增术语也必须抽象、短、可复用。",
  74. "",
  75. "当前词库:",
  76. ]
  77. for key, values in vocab.items():
  78. lines.append(f"- {key}:{'、'.join(values)}")
  79. return "\n".join(lines)
  80. async def extract_workflow_from_case(
  81. case_item: Dict[str, Any],
  82. llm_call: Any,
  83. model: str = "anthropic/claude-sonnet-4-5"
  84. ) -> tuple[Optional[List[Dict[str, Any]]], float]:
  85. """
  86. 从单个 case item 提取 1+ 组 workflow_group。
  87. Returns:
  88. (workflow_groups, cost)
  89. workflow_groups 为 None 表示 skip 或提取失败
  90. """
  91. images = case_item.get("images", [])
  92. case_copy = dict(case_item)
  93. case_copy.pop("images", None)
  94. case_copy.pop("_raw", None)
  95. case_copy.pop("workflow", None)
  96. case_copy.pop("workflow_groups", None)
  97. case_copy.pop("capability", None)
  98. case_copy.pop("capabilities", None)
  99. case_copy.pop("fragments", None)
  100. if not case_copy and not images:
  101. return None, None, 0.0
  102. title = case_item.get("title", "")[:20] or "untitled"
  103. context = json.dumps(case_copy, ensure_ascii=False, indent=2)
  104. try:
  105. prompt_template = load_prompt_template("extract_workflow")
  106. method_vocab = load_method_vocab()
  107. vocab_block = render_method_vocab_block(method_vocab)
  108. if "%context%" in prompt_template:
  109. prompt = prompt_template.replace("%context%", context)
  110. else:
  111. prompt = prompt_template + f"\n\n## 帖子内容\n{context}"
  112. if "{interface_vocab}" in prompt:
  113. prompt = prompt.replace("{interface_vocab}", vocab_block)
  114. elif vocab_block not in prompt:
  115. prompt = prompt + "\n" + vocab_block
  116. except Exception as e:
  117. print(f"Warning: Failed to load prompt template: {e}, using fallback")
  118. method_vocab = load_method_vocab()
  119. vocab_block = render_method_vocab_block(method_vocab)
  120. prompt = f"""将以下帖子内容总结为AI图片生成的工序和原子操作,以JSON格式输出。
  121. # 输出格式(v7)
  122. {{
  123. "skip": false,
  124. "skip_reason": "",
  125. "workflow_groups": [
  126. {{
  127. "workflow_id": "w1",
  128. "workflow": {{
  129. "workflow_id": "w1",
  130. "steps": [
  131. {{
  132. "step_id": "s1",
  133. "order": 1,
  134. "phase": "生成"
  135. }}
  136. ]
  137. }},
  138. "capability": [
  139. {{
  140. "capability_id": "c_w1_s1_0",
  141. "action": {{"description": "生成", "reasoning": "输入为文本提示词,输出为图片,客观信息变化是生成"}},
  142. "inputs": [{{"modality": "文本", "description": "...", "relation": "[来源.原始输入]"}}],
  143. "outputs": [{{"modality": "图片", "description": "...", "relation": "[去向.最终成品]"}}],
  144. "body": "string | null",
  145. "effects": [
  146. {{
  147. "statement": "实现XXX",
  148. "criteria": "判断标准",
  149. "judge_method": "vlm",
  150. "negative_examples": []
  151. }}
  152. ],
  153. "control_target": [],
  154. "artifact_type": null,
  155. "tools": [],
  156. "apply_to_draft": {{"实质": ["..."], "形式": ["..."]}},
  157. "workflow_step_ref": {{"workflow_id": "w1", "step_id": "s1"}},
  158. "is_alternative_to": []
  159. }}
  160. ]
  161. }}
  162. ]
  163. }}
  164. {vocab_block}
  165. ## 帖子内容
  166. {context}
  167. 请严格按照上述格式输出JSON,不要包含其他内容。"""
  168. if images:
  169. image_urls = [img for img in images[:30] if isinstance(img, str) and img.startswith("http")]
  170. if image_urls:
  171. content_array = [{"type": "text", "text": prompt}]
  172. for url in image_urls:
  173. content_array.append({"type": "image_url", "image_url": {"url": url}})
  174. messages = [{"role": "user", "content": content_array}]
  175. else:
  176. messages = [{"role": "user", "content": prompt}]
  177. else:
  178. messages = [{"role": "user", "content": prompt}]
  179. result_data, cost = await call_llm_with_retry(
  180. llm_call=llm_call,
  181. messages=messages,
  182. model=model,
  183. temperature=0.1,
  184. max_tokens=10000,
  185. max_retries=3,
  186. schema_name="extract_workflow",
  187. task_name=f"Workflow_{title}",
  188. )
  189. if not result_data:
  190. return None, cost
  191. if result_data.get("skip"):
  192. return None, cost
  193. workflow_groups = result_data.get("workflow_groups")
  194. if not isinstance(workflow_groups, list):
  195. workflow_groups = []
  196. return workflow_groups, cost
  197. async def extract_workflow(
  198. case_file: Path,
  199. llm_call: Any,
  200. model: str = "anthropic/claude-sonnet-4-5",
  201. max_concurrent: int = 3,
  202. case_indices: Optional[List[int]] = None
  203. ) -> Dict[str, Any]:
  204. """
  205. 按 index 遍历 case.json,提取 workflow
  206. Args:
  207. case_file: case.json 文件路径
  208. llm_call: LLM 调用函数
  209. model: 使用的模型
  210. max_concurrent: 最大并发数
  211. case_indices: 可选,指定要处理的 case index 列表。如果为 None,处理所有 case
  212. """
  213. with open(case_file, "r", encoding="utf-8") as f:
  214. case_data = json.load(f)
  215. cases = case_data.get("cases", [])
  216. # 如果指定了 case_indices,只处理这些 case
  217. if case_indices is not None:
  218. cases_to_process = [c for c in cases if c.get("index") in case_indices]
  219. print(f"Extracting workflow from {len(cases_to_process)} cases (filtered by indices: {case_indices})...")
  220. else:
  221. cases_to_process = cases
  222. print(f"Extracting workflow from {len(cases)} cases...")
  223. semaphore = asyncio.Semaphore(max_concurrent)
  224. async def process_with_semaphore(case_item):
  225. async with semaphore:
  226. index = case_item.get("index", 0)
  227. raw = case_item.get("_raw", {})
  228. case_id = raw.get("case_id", "unknown")
  229. title = case_item.get("title", "")
  230. print(f" -> [{index}] [{case_id}] extracting workflow: {title[:60]}")
  231. workflow_groups, cost = await extract_workflow_from_case(case_item, llm_call, model)
  232. group_count = len(workflow_groups) if workflow_groups else 0
  233. capability_count = sum(len(g.get("capability") or []) for g in (workflow_groups or []) if isinstance(g, dict))
  234. status = f"ok ({group_count} workflow, {capability_count} capability)" if workflow_groups else "null"
  235. print(f" <- [{index}] [{case_id}] workflow {status}")
  236. result = dict(case_item)
  237. result["workflow_groups"] = workflow_groups if workflow_groups is not None else []
  238. result.pop("workflow", None)
  239. result.pop("capability", None)
  240. result.pop("capabilities", None)
  241. result.pop("fragments", None)
  242. return result, cost
  243. tasks = [process_with_semaphore(case) for case in cases_to_process]
  244. results_with_costs = await asyncio.gather(*tasks)
  245. results = [r[0] for r in results_with_costs]
  246. costs = [r[1] for r in results_with_costs]
  247. total_cost = sum(costs)
  248. success_count = sum(1 for r in results if r.get("workflow_groups"))
  249. failed_count = len(results) - success_count
  250. # 如果是部分更新,需要合并回原始 cases 列表
  251. if case_indices is not None:
  252. # 创建一个 index -> result 的映射
  253. result_map = {r.get("index"): r for r in results}
  254. # 更新原始 cases 列表中对应的项
  255. for i, case in enumerate(cases):
  256. if case.get("index") in result_map:
  257. cases[i] = result_map[case.get("index")]
  258. results = cases
  259. results.sort(key=lambda x: x.get("index", 0))
  260. case_data["cases"] = results
  261. case_file.parent.mkdir(parents=True, exist_ok=True)
  262. with open(case_file, "w", encoding="utf-8") as f:
  263. json.dump(case_data, f, ensure_ascii=False, indent=2)
  264. workflow_count = sum(len(r.get("workflow_groups") or []) for r in results)
  265. capability_count = sum(
  266. len(group.get("capability") or [])
  267. for r in results
  268. for group in (r.get("workflow_groups") or [])
  269. if isinstance(group, dict)
  270. )
  271. return {
  272. "total": len(results),
  273. "success": success_count,
  274. "failed": failed_count,
  275. "workflow_total": workflow_count,
  276. "capability_total": capability_count,
  277. "total_cost": total_cost,
  278. "output_file": str(case_file),
  279. }
  280. if __name__ == "__main__":
  281. import sys
  282. if len(sys.argv) < 2:
  283. print("Usage: python extract_workflow.py <output_dir>")
  284. sys.exit(1)
  285. output_dir = Path(sys.argv[1])
  286. case_file = output_dir / "case.json"
  287. if not case_file.exists():
  288. print(f"Error: {case_file} not found")
  289. sys.exit(1)
  290. print("Please use this module through run_pipeline.py")