extract_workflow.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406
  1. """
  2. 逐 case 提取 workflow (v5版本)
  3. 从 case.json 读取,按 index 遍历每个 case,
  4. 调用 LLM 提取 workflow,按 index 原位回填到 case.json
  5. v5 架构特性:
  6. - 使用结构化 inputs/outputs(role, modality, artifact_type 等10个维度)
  7. - action 对象化:{main_action, mechanism}(替代旧的 method 字符串)
  8. - Stage 1 输出 apply_to_draft(自然语言),为 Stage 2 内容树映射做准备
  9. - strategy 顶层字段(method, inputs, outputs, tools, stage)由脚本自动推导
  10. """
  11. import asyncio
  12. import json
  13. from pathlib import Path
  14. from typing import Any, Dict, Optional, List
  15. from examples.process_pipeline.script.llm_helper import call_llm_with_retry
  16. # v5 词库文件路径
  17. SCRIPT_DIR = Path(__file__).resolve().parent
  18. METHOD_VOCAB_PATH = SCRIPT_DIR / "resource" / "method_vocab_v5.json"
  19. # 默认词库(如果文件不存在时使用)
  20. DEFAULT_METHOD_VOCAB = {
  21. "流程角色": [
  22. "生成指令", "编辑指令", "约束条件", "参考素材", "控制信号",
  23. "区域控制", "参数配置", "模型资源", "源素材", "中间产物",
  24. "成品", "模板", "评估结果"
  25. ],
  26. "模态": ["文本", "图片", "视频", "音频", "特征点", "参数", "模型", "向量", "表格"],
  27. "主动作": [
  28. "生成", "编辑", "提取", "改写", "合成", "修复", "增强",
  29. "训练", "评估", "剪辑", "模板化", "排版", "转写", "配音",
  30. "匹配", "扩展", "导出"
  31. ],
  32. "动作方式": [
  33. "直接生成", "一致性保持", "结构约束", "质量收束", "局部重绘",
  34. "扩图", "换背景", "提示词反推", "模板化", "多图融合", "清晰化",
  35. "风格迁移", "常规编辑", "变体生成", "动画化", "镜头延展",
  36. "换主体", "换装", "擦除", "调色", "前后景融合", "图文合成",
  37. "音画合成", "分层叠加", "特征提取", "蒙版提取", "关键帧提取",
  38. "字幕提取", "风格提取", "片段拼接", "节奏压缩", "转场编排",
  39. "字幕对齐", "音画同步", "降噪", "补帧", "超分", "稳定化",
  40. "质感增强", "结构抽象", "变量抽象", "版式套用", "格式转换", "压缩导出"
  41. ],
  42. }
  43. def load_method_vocab() -> Dict[str, list]:
  44. """从 JSON 文件加载结构化词库(v5)"""
  45. if METHOD_VOCAB_PATH.exists():
  46. try:
  47. with open(METHOD_VOCAB_PATH, "r", encoding="utf-8") as f:
  48. return json.load(f)
  49. except Exception as e:
  50. print(f"Warning: Failed to load method_vocab.json: {e}, using default")
  51. return DEFAULT_METHOD_VOCAB
  52. def load_prompt_template(prompt_name: str) -> str:
  53. base_dir = Path(__file__).parent.parent
  54. prompt_path = base_dir / "prompts" / f"{prompt_name}.prompt"
  55. with open(prompt_path, "r", encoding="utf-8") as f:
  56. content = f.read()
  57. if content.startswith("---"):
  58. parts = content.split("---", 2)
  59. if len(parts) >= 3:
  60. content = parts[2]
  61. content = content.replace("$system$", "").replace("$user$", "")
  62. return content.strip()
  63. def render_method_vocab_block(vocab: Dict[str, list]) -> str:
  64. """渲染结构化接口词库说明(v5)"""
  65. lines = [
  66. "\n# 结构化接口词库(v5,必须遵守)",
  67. "只输出结构化 inputs / outputs / action。",
  68. "- `role/流程角色` 只写接口职责,不写具体内容 what。",
  69. "- `modality/模态` 只写媒介或数据形态;统一用 `图片`,不要写 `图像`;统一用 `文本`,不要写 `文字`。",
  70. "- `artifact_type/工件类型` 写该模态下的具体工件,如 `正向提示词`、`蒙版`。",
  71. "- `action.main_action` 写主动作;`action.mechanism` 写动作内部机制。",
  72. "- 只有词库确实不够时才新增术语;新增术语也必须抽象、短、可复用。",
  73. "",
  74. "当前词库:",
  75. ]
  76. for key, values in vocab.items():
  77. lines.append(f"- {key}:{'、'.join(values)}")
  78. return "\n".join(lines)
  79. import re
  80. def _infer_stage_from_action(action_obj: dict) -> str:
  81. """从 action 对象推断 stage(v5版本)"""
  82. main_action = action_obj.get("main_action", "")
  83. mechanism = action_obj.get("mechanism", "")
  84. # 根据主动作和动作方式推断阶段
  85. if main_action in ["提取", "改写", "模板化", "训练", "评估"]:
  86. return "preprocess"
  87. elif main_action in ["编辑", "修复", "增强", "剪辑", "排版"]:
  88. return "refine"
  89. elif mechanism in ["局部重绘", "扩图", "换背景", "换主体", "换装", "擦除", "调色",
  90. "前后景融合", "降噪", "补帧", "超分", "稳定化", "质感增强"]:
  91. return "refine"
  92. else:
  93. return "generate"
  94. def derive_strategy_rollup(strategy: dict) -> None:
  95. """
  96. 从 steps 自动推导 strategy 的顶层字段(v5版本):
  97. method, inputs, outputs, tools, stage
  98. v5 变化:
  99. - method 从 action.main_action 提取(不再从旧的 method 字符串解析)
  100. - stage 从 action 对象推断
  101. """
  102. steps = [s for s in (strategy.get("steps") or []) if isinstance(s, dict)]
  103. if not steps:
  104. return
  105. steps.sort(key=lambda s: s.get("order") if isinstance(s.get("order"), int) else 9999)
  106. # method = 所有步骤的 main_action 用 "-" 连接
  107. actions = []
  108. for s in steps:
  109. action_obj = s.get("action")
  110. if isinstance(action_obj, dict):
  111. main_action = action_obj.get("main_action", "")
  112. if main_action:
  113. actions.append(main_action)
  114. if actions:
  115. strategy["method"] = "-".join(actions)
  116. # inputs = 第一步的 inputs
  117. first_inputs = steps[0].get("inputs")
  118. strategy["inputs"] = first_inputs if isinstance(first_inputs, list) else []
  119. # outputs = 最后一步的 outputs
  120. last_outputs = steps[-1].get("outputs")
  121. strategy["outputs"] = last_outputs if isinstance(last_outputs, list) else []
  122. # tools = 所有步骤的 tools 去重合并
  123. tools = []
  124. for step in steps:
  125. for tool in step.get("tools") or []:
  126. if isinstance(tool, str) and tool and tool not in tools:
  127. tools.append(tool)
  128. strategy["tools"] = tools
  129. # stage = 从 action 对象推断
  130. stages = []
  131. for step in steps:
  132. action_obj = step.get("action")
  133. if isinstance(action_obj, dict):
  134. stage = _infer_stage_from_action(action_obj)
  135. if stage not in stages:
  136. stages.append(stage)
  137. strategy["stage"] = stages or ["generate"]
  138. async def extract_workflow_from_case(
  139. case_item: Dict[str, Any],
  140. llm_call: Any,
  141. model: str = "anthropic/claude-sonnet-4-5"
  142. ) -> tuple[Optional[Dict[str, Any]], float]:
  143. """
  144. 从单个 case item 提取 workflow (v5版本)
  145. v5 特性:
  146. - 结构化 inputs/outputs(role, modality, artifact_type 等)
  147. - action 对象化:{main_action, mechanism}(替代旧的 method 字符串)
  148. - 输出 apply_to_draft(自然语言),为 Stage 2 内容树映射做准备
  149. - strategy 顶层字段由 derive_strategy_rollup 自动推导
  150. """
  151. images = case_item.get("images", [])
  152. case_copy = dict(case_item)
  153. case_copy.pop("images", None)
  154. case_copy.pop("_raw", None)
  155. case_copy.pop("workflow", None)
  156. case_copy.pop("capabilities", None)
  157. if not case_copy and not images:
  158. return None, 0.0
  159. title = case_item.get("title", "")[:20] or "untitled"
  160. context = json.dumps(case_copy, ensure_ascii=False, indent=2)
  161. try:
  162. prompt_template = load_prompt_template("extract_workflow")
  163. # 添加 v5 词库说明
  164. method_vocab = load_method_vocab()
  165. vocab_block = render_method_vocab_block(method_vocab)
  166. if "%context%" in prompt_template:
  167. prompt = prompt_template.replace("%context%", context)
  168. else:
  169. prompt = prompt_template + f"\n\n## 帖子内容\n{context}"
  170. # 如果 prompt 中有 {interface_vocab} 占位符,替换为词库说明
  171. if "{interface_vocab}" in prompt:
  172. prompt = prompt.replace("{interface_vocab}", vocab_block)
  173. elif vocab_block not in prompt:
  174. # 如果 prompt 中没有词库说明,添加到末尾
  175. prompt = prompt + "\n" + vocab_block
  176. except Exception as e:
  177. print(f"Warning: Failed to load prompt template: {e}, using fallback")
  178. method_vocab = load_method_vocab()
  179. vocab_block = render_method_vocab_block(method_vocab)
  180. prompt = f"""将以下帖子内容总结为AI图片生成的工序,以JSON格式输出。
  181. # 工序提取规则(v5)
  182. - 步骤粒度是"做了什么",而非"怎么做"
  183. - 以"触发生成 / 处理的动作"为步骤边界
  184. - 若本质上只有一步,也输出一步,不要返回 strategy=null
  185. - 本阶段严禁生成 apply_to,只生成 apply_to_draft
  186. # 输出格式(v5)
  187. {{
  188. "skip": false,
  189. "skip_reason": "",
  190. "strategy": {{
  191. "steps": [
  192. {{
  193. "order": 1,
  194. "action": {{"main_action": "生成", "mechanism": "直接生成"}},
  195. "body": "string | null",
  196. "inputs": [
  197. {{
  198. "role": "生成指令",
  199. "modality": "文本",
  200. "artifact_type": "正向提示词",
  201. "control_target": ["主体", "场景"],
  202. "target_scope": ["整图"],
  203. "constraint_strength": "硬约束",
  204. "source": "原帖文本",
  205. "lifecycle": "原始输入",
  206. "description": "用于触发图片生成的完整提示词"
  207. }}
  208. ],
  209. "outputs": [...],
  210. "tools": []
  211. }}
  212. ],
  213. "effects": ["实现 XX 效果"],
  214. "criterion": null,
  215. "apply_to_draft": {{"实质": ["相关 what"], "形式": ["相关呈现方式"]}},
  216. "unstructured_what": []
  217. }}
  218. }}
  219. {vocab_block}
  220. ## 帖子内容
  221. {context}
  222. 请严格按照上述格式输出JSON,不要包含其他内容。"""
  223. if images:
  224. image_urls = [img for img in images[:9] if isinstance(img, str) and img.startswith("http")]
  225. if image_urls:
  226. content_array = [{"type": "text", "text": prompt}]
  227. for url in image_urls:
  228. content_array.append({"type": "image_url", "image_url": {"url": url}})
  229. messages = [{"role": "user", "content": content_array}]
  230. else:
  231. messages = [{"role": "user", "content": prompt}]
  232. else:
  233. messages = [{"role": "user", "content": prompt}]
  234. result_data, cost = await call_llm_with_retry(
  235. llm_call=llm_call,
  236. messages=messages,
  237. model=model,
  238. temperature=0.1,
  239. max_tokens=8000, # 从2000增加到4000,处理更长的输出
  240. max_retries=3, # 从3增加到5,增加重试机会
  241. schema_name="extract_workflow",
  242. task_name=f"Workflow_{title}",
  243. )
  244. # Stage 1 格式:{"skip": bool, "skip_reason": str, "strategy": {...}}
  245. # 如果 skip=true 或 strategy=null,返回 None
  246. if not result_data:
  247. return None, cost
  248. if result_data.get("skip"):
  249. return None, cost
  250. workflow_data = result_data.get("strategy")
  251. # 从 steps 自动推导顶层字段(v5版本)
  252. if workflow_data and isinstance(workflow_data, dict):
  253. derive_strategy_rollup(workflow_data)
  254. return workflow_data, cost
  255. async def extract_workflow(
  256. case_file: Path,
  257. llm_call: Any,
  258. model: str = "anthropic/claude-sonnet-4-5",
  259. max_concurrent: int = 3,
  260. case_indices: Optional[List[int]] = None
  261. ) -> Dict[str, Any]:
  262. """
  263. 按 index 遍历 case.json,提取 workflow
  264. Args:
  265. case_file: case.json 文件路径
  266. llm_call: LLM 调用函数
  267. model: 使用的模型
  268. max_concurrent: 最大并发数
  269. case_indices: 可选,指定要处理的 case index 列表。如果为 None,处理所有 case
  270. """
  271. with open(case_file, "r", encoding="utf-8") as f:
  272. case_data = json.load(f)
  273. cases = case_data.get("cases", [])
  274. # 如果指定了 case_indices,只处理这些 case
  275. if case_indices is not None:
  276. cases_to_process = [c for c in cases if c.get("index") in case_indices]
  277. print(f"Extracting workflow from {len(cases_to_process)} cases (filtered by indices: {case_indices})...")
  278. else:
  279. cases_to_process = cases
  280. print(f"Extracting workflow from {len(cases)} cases...")
  281. semaphore = asyncio.Semaphore(max_concurrent)
  282. async def process_with_semaphore(case_item):
  283. async with semaphore:
  284. index = case_item.get("index", 0)
  285. raw = case_item.get("_raw", {})
  286. case_id = raw.get("case_id", "unknown")
  287. title = case_item.get("title", "")
  288. print(f" -> [{index}] [{case_id}] extracting workflow: {title[:60]}")
  289. workflow, cost = await extract_workflow_from_case(case_item, llm_call, model)
  290. status = "ok" if workflow else "null"
  291. print(f" <- [{index}] [{case_id}] workflow {status}")
  292. result = dict(case_item)
  293. result["workflow"] = workflow
  294. return result, cost
  295. tasks = [process_with_semaphore(case) for case in cases_to_process]
  296. results_with_costs = await asyncio.gather(*tasks)
  297. results = [r[0] for r in results_with_costs]
  298. costs = [r[1] for r in results_with_costs]
  299. total_cost = sum(costs)
  300. success_count = sum(1 for r in results if r.get("workflow"))
  301. failed_count = len(results) - success_count
  302. # 如果是部分更新,需要合并回原始 cases 列表
  303. if case_indices is not None:
  304. # 创建一个 index -> result 的映射
  305. result_map = {r.get("index"): r for r in results}
  306. # 更新原始 cases 列表中对应的项
  307. for i, case in enumerate(cases):
  308. if case.get("index") in result_map:
  309. cases[i] = result_map[case.get("index")]
  310. results = cases
  311. results.sort(key=lambda x: x.get("index", 0))
  312. case_data["cases"] = results
  313. case_file.parent.mkdir(parents=True, exist_ok=True)
  314. with open(case_file, "w", encoding="utf-8") as f:
  315. json.dump(case_data, f, ensure_ascii=False, indent=2)
  316. return {
  317. "total": len(results),
  318. "success": success_count,
  319. "failed": failed_count,
  320. "total_cost": total_cost,
  321. "output_file": str(case_file),
  322. }
  323. if __name__ == "__main__":
  324. import sys
  325. if len(sys.argv) < 2:
  326. print("Usage: python extract_workflow.py <output_dir>")
  327. sys.exit(1)
  328. output_dir = Path(sys.argv[1])
  329. case_file = output_dir / "case.json"
  330. if not case_file.exists():
  331. print(f"Error: {case_file} not found")
  332. sys.exit(1)
  333. print("Please use this module through run_pipeline.py")