extract_capability.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293
  1. """
  2. 逐 case 提取 capabilities (v5版本)
  3. 从 case.json 读取,按 index 遍历每个 case,
  4. 调用 LLM 提取 capabilities,按 index 原位回填到 case.json
  5. v5 架构特性:
  6. - 使用结构化 inputs/outputs(role, modality, artifact_type 等10个维度)
  7. - action 对象化:{main_action, mechanism}
  8. - Stage 1 输出 apply_to_draft(自然语言),为 Stage 2 内容树映射做准备
  9. """
  10. import asyncio
  11. import json
  12. from pathlib import Path
  13. from typing import Any, Dict, Optional, List
  14. from examples.process_pipeline.script.llm_helper import call_llm_with_retry
  15. # v5 词库文件路径
  16. SCRIPT_DIR = Path(__file__).resolve().parent
  17. METHOD_VOCAB_PATH = SCRIPT_DIR / "resource" / "method_vocab_v5.json"
  18. # 默认词库(如果文件不存在时使用)
  19. DEFAULT_METHOD_VOCAB = {
  20. "流程角色": [
  21. "生成指令", "编辑指令", "约束条件", "参考素材", "控制信号",
  22. "区域控制", "参数配置", "模型资源", "源素材", "中间产物",
  23. "成品", "模板", "评估结果"
  24. ],
  25. "模态": ["文本", "图片", "视频", "音频", "特征点", "参数", "模型", "向量", "表格"],
  26. "主动作": [
  27. "生成", "编辑", "提取", "改写", "合成", "修复", "增强",
  28. "训练", "评估", "剪辑", "模板化", "排版", "转写", "配音",
  29. "匹配", "扩展", "导出"
  30. ],
  31. "动作方式": [
  32. "直接生成", "一致性保持", "结构约束", "质量收束", "局部重绘",
  33. "扩图", "换背景", "提示词反推", "模板化", "多图融合", "清晰化",
  34. "风格迁移", "常规编辑", "变体生成", "动画化", "镜头延展",
  35. "换主体", "换装", "擦除", "调色", "前后景融合", "图文合成",
  36. "音画合成", "分层叠加", "特征提取", "蒙版提取", "关键帧提取",
  37. "字幕提取", "风格提取", "片段拼接", "节奏压缩", "转场编排",
  38. "字幕对齐", "音画同步", "降噪", "补帧", "超分", "稳定化",
  39. "质感增强", "结构抽象", "变量抽象", "版式套用", "格式转换", "压缩导出"
  40. ],
  41. }
  42. def load_method_vocab() -> Dict[str, List[str]]:
  43. """从 JSON 文件加载结构化词库(v5)"""
  44. if METHOD_VOCAB_PATH.exists():
  45. try:
  46. with open(METHOD_VOCAB_PATH, "r", encoding="utf-8") as f:
  47. return json.load(f)
  48. except Exception as e:
  49. print(f"Warning: Failed to load method_vocab.json: {e}, using default")
  50. return DEFAULT_METHOD_VOCAB
  51. def load_prompt_template(prompt_name: str) -> str:
  52. base_dir = Path(__file__).parent.parent
  53. prompt_path = base_dir / "prompts" / f"{prompt_name}.prompt"
  54. with open(prompt_path, "r", encoding="utf-8") as f:
  55. content = f.read()
  56. if content.startswith("---"):
  57. parts = content.split("---", 2)
  58. if len(parts) >= 3:
  59. content = parts[2]
  60. content = content.replace("$system$", "").replace("$user$", "")
  61. return content.strip()
  62. def render_method_vocab_block(vocab: Dict[str, List[str]]) -> str:
  63. """渲染结构化接口词库说明(v5)"""
  64. lines = [
  65. "\n# 结构化接口词库(v5,必须遵守)",
  66. "只输出结构化 inputs / outputs / action。",
  67. "- `role/流程角色` 只写接口职责,不写具体内容 what。",
  68. "- `modality/模态` 只写媒介或数据形态;统一用 `图片`,不要写 `图像`;统一用 `文本`,不要写 `文字`。",
  69. "- `artifact_type/工件类型` 写该模态下的具体工件,如 `正向提示词`、`蒙版`。",
  70. "- `action.main_action` 写主动作;`action.mechanism` 写动作内部机制。",
  71. "- 只有词库确实不够时才新增术语;新增术语也必须抽象、短、可复用。",
  72. "",
  73. "当前词库:",
  74. ]
  75. for key, values in vocab.items():
  76. lines.append(f"- {key}:{'、'.join(values)}")
  77. return "\n".join(lines)
  78. async def extract_capabilities_from_case_item(
  79. case_item: Dict[str, Any],
  80. llm_call: Any,
  81. model: str = "anthropic/claude-sonnet-4-5"
  82. ) -> tuple[Optional[List[Dict[str, Any]]], float]:
  83. """
  84. 从单个 case item 提取 capabilities (v5版本)
  85. v5 特性:
  86. - 结构化 inputs/outputs(role, modality, artifact_type 等)
  87. - action 对象化:{main_action, mechanism}
  88. - 输出 apply_to_draft(自然语言),为 Stage 2 内容树映射做准备
  89. """
  90. images = case_item.get("images", [])
  91. case_copy = dict(case_item)
  92. case_copy.pop("images", None)
  93. case_copy.pop("_raw", None)
  94. case_copy.pop("workflow", None)
  95. case_copy.pop("capabilities", None)
  96. if not case_copy and not images:
  97. return None, 0.0
  98. title = case_item.get("title", "")[:20] or "untitled"
  99. context = json.dumps(case_copy, ensure_ascii=False, indent=2)
  100. try:
  101. prompt_template = load_prompt_template("extract_capability")
  102. # 添加 v5 词库说明
  103. method_vocab = load_method_vocab()
  104. vocab_block = render_method_vocab_block(method_vocab)
  105. if "%context%" in prompt_template:
  106. prompt = prompt_template.replace("%context%", context)
  107. else:
  108. prompt = prompt_template + f"\n\n案例数据:\n```json\n{context}\n```"
  109. # 如果 prompt 中有 {interface_vocab} 占位符,替换为词库说明
  110. if "{interface_vocab}" in prompt:
  111. prompt = prompt.replace("{interface_vocab}", vocab_block)
  112. elif vocab_block not in prompt:
  113. # 如果 prompt 中没有词库说明,添加到末尾
  114. prompt = prompt + "\n" + vocab_block
  115. except Exception as e:
  116. print(f"Warning: Failed to load prompt template: {e}, using fallback")
  117. method_vocab = load_method_vocab()
  118. vocab_block = render_method_vocab_block(method_vocab)
  119. prompt = f"""请从以下案例中提取该案例包含的原子能力,以 JSON 格式输出。
  120. # 输出格式(v5)
  121. {{
  122. "skip": false,
  123. "skip_reason": "",
  124. "capabilities": [
  125. {{
  126. "inputs": [
  127. {{
  128. "role": "生成指令",
  129. "modality": "文本",
  130. "artifact_type": "正向提示词",
  131. "control_target": ["主体", "场景"],
  132. "target_scope": ["整图"],
  133. "constraint_strength": "硬约束",
  134. "source": "原帖文本",
  135. "lifecycle": "原始输入",
  136. "description": "用于触发图片生成的完整提示词"
  137. }}
  138. ],
  139. "outputs": [...],
  140. "action": {{"main_action": "生成", "mechanism": "直接生成"}},
  141. "body": "具体做法",
  142. "effects": ["实现 XX 效果"],
  143. "stage": ["generate"],
  144. "tools": [],
  145. "criterion": null,
  146. "apply_to_draft": {{"实质": ["相关 what"], "形式": ["相关呈现方式"]}},
  147. "unstructured_what": []
  148. }}
  149. ]
  150. }}
  151. {vocab_block}
  152. 案例数据:
  153. {context}
  154. 请严格按照上述格式输出JSON,不要包含其他内容。"""
  155. if images:
  156. image_urls = [img for img in images[:9] if isinstance(img, str) and img.startswith("http")]
  157. if image_urls:
  158. content_array = [{"type": "text", "text": prompt}]
  159. for url in image_urls:
  160. content_array.append({"type": "image_url", "image_url": {"url": url}})
  161. messages = [{"role": "user", "content": content_array}]
  162. else:
  163. messages = [{"role": "user", "content": prompt}]
  164. else:
  165. messages = [{"role": "user", "content": prompt}]
  166. result_data, cost = await call_llm_with_retry(
  167. llm_call=llm_call,
  168. messages=messages,
  169. model=model,
  170. temperature=0.1,
  171. max_tokens=8000,
  172. max_retries=3,
  173. schema_name="extract_capability",
  174. task_name=f"Capability_{title}",
  175. )
  176. # Stage 1 格式:{"skip": bool, "skip_reason": str, "capabilities": [...]}
  177. # 如果 skip=true,返回空数组
  178. if not result_data:
  179. return None, cost
  180. if result_data.get("skip"):
  181. return None, cost
  182. capabilities_data = result_data.get("capabilities", [])
  183. return capabilities_data, cost
  184. async def extract_capability(
  185. case_file: Path,
  186. llm_call: Any,
  187. model: str = "anthropic/claude-sonnet-4-5",
  188. max_concurrent: int = 3
  189. ) -> Dict[str, Any]:
  190. """
  191. 按 index 遍历 case.json,提取 capabilities
  192. """
  193. with open(case_file, "r", encoding="utf-8") as f:
  194. case_data = json.load(f)
  195. cases = case_data.get("cases", [])
  196. print(f"Extracting capabilities from {len(cases)} cases...")
  197. semaphore = asyncio.Semaphore(max_concurrent)
  198. async def process_with_semaphore(case_item):
  199. async with semaphore:
  200. index = case_item.get("index", 0)
  201. raw = case_item.get("_raw", {})
  202. case_id = raw.get("case_id", "unknown")
  203. title = case_item.get("title", "")
  204. print(f" -> [{index}] [{case_id}] extracting capabilities: {title[:60]}")
  205. capabilities_data, cost = await extract_capabilities_from_case_item(case_item, llm_call, model)
  206. status = "ok" if capabilities_data else "null"
  207. count = len(capabilities_data) if capabilities_data else 0
  208. print(f" <- [{index}] [{case_id}] capabilities {status} (count={count})")
  209. result = dict(case_item)
  210. result["capabilities"] = capabilities_data
  211. return result, cost
  212. tasks = [process_with_semaphore(case) for case in cases]
  213. results_with_costs = await asyncio.gather(*tasks)
  214. results = [r[0] for r in results_with_costs]
  215. costs = [r[1] for r in results_with_costs]
  216. total_cost = sum(costs)
  217. success_count = sum(1 for r in results if r.get("capabilities"))
  218. failed_count = len(results) - success_count
  219. results.sort(key=lambda x: x.get("index", 0))
  220. case_data["cases"] = results
  221. case_file.parent.mkdir(parents=True, exist_ok=True)
  222. with open(case_file, "w", encoding="utf-8") as f:
  223. json.dump(case_data, f, ensure_ascii=False, indent=2)
  224. return {
  225. "total": len(results),
  226. "success": success_count,
  227. "failed": failed_count,
  228. "total_cost": total_cost,
  229. "output_file": str(case_file),
  230. }
  231. if __name__ == "__main__":
  232. import sys
  233. if len(sys.argv) < 2:
  234. print("Usage: python extract_capability.py <output_dir>")
  235. sys.exit(1)
  236. output_dir = Path(sys.argv[1])
  237. case_file = output_dir / "case.json"
  238. if not case_file.exists():
  239. print(f"Error: {case_file} not found")
  240. sys.exit(1)
  241. print("Please use this module through run_pipeline.py")