enrich_capabilities.py 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251
  1. """
  2. Phase 2.2.2: 能力丰富化
  3. 从 capabilities_temp.json 读取初步聚类的能力,
  4. 对每个能力,根据 case_references 从 source.json 提取原始帖子信息(包括图片),
  5. 调用 LLM 进行丰富化,输出到 capabilities.json
  6. """
  7. import asyncio
  8. import json
  9. import re
  10. from pathlib import Path
  11. from typing import Any, Dict, List
  12. from examples.process_pipeline.script.llm_helper import call_llm_with_retry
  13. from examples.process_pipeline.script.validate_schema import validate_capabilities_enriched
  14. def load_prompt_template(prompt_name: str) -> str:
  15. """从 prompts 目录加载 prompt 模板"""
  16. base_dir = Path(__file__).parent.parent
  17. prompt_path = base_dir / "prompts" / f"{prompt_name}.prompt"
  18. with open(prompt_path, "r", encoding="utf-8") as f:
  19. content = f.read()
  20. if content.startswith("---"):
  21. parts = content.split("---", 2)
  22. if len(parts) >= 3:
  23. content = parts[2]
  24. content = content.replace("$system$", "").replace("$user$", "")
  25. return content.strip()
  26. async def enrich_single_capability(
  27. capability: Dict[str, Any],
  28. source_data: Dict[str, Any],
  29. llm_call,
  30. model: str
  31. ) -> Dict[str, Any]:
  32. """
  33. 丰富化单个能力
  34. Args:
  35. capability: 能力信息(包含 case_references)
  36. source_data: source.json 的完整数据
  37. llm_call: LLM 调用函数
  38. model: 模型名称
  39. Returns:
  40. 丰富化后的能力
  41. """
  42. case_refs = capability.get("case_references", [])
  43. if not case_refs:
  44. return capability
  45. # 从 source.json 中提取对应的帖子
  46. posts_content = []
  47. for ref in case_refs:
  48. # ref 格式可能是 "bili_BV1xxx" 或 "bili_BV1xxx 中的垫图操作"
  49. case_id = ref.split()[0] if " " in ref else ref
  50. # 在 source.json 中查找
  51. for src in source_data.get("sources", []):
  52. src_case_id = f"{src['platform']}_{src['channel_content_id']}"
  53. if src_case_id == case_id:
  54. post = src.get("post", {})
  55. post_info = {
  56. "case_id": case_id,
  57. "title": post.get("title", ""),
  58. "body_text": post.get("body_text", ""),
  59. "images": []
  60. }
  61. # 提取图片 URL
  62. images = post.get("images", [])
  63. if isinstance(images, list):
  64. for img in images[:5]:
  65. if isinstance(img, str):
  66. post_info["images"].append(img)
  67. elif isinstance(img, dict) and "url" in img:
  68. post_info["images"].append(img["url"])
  69. # 也尝试从 image_url_list 提取
  70. image_url_list = post.get("image_url_list", [])
  71. if isinstance(image_url_list, list):
  72. for img_obj in image_url_list[:5]:
  73. if isinstance(img_obj, dict) and "image_url" in img_obj:
  74. post_info["images"].append(img_obj["image_url"])
  75. posts_content.append(post_info)
  76. break
  77. if not posts_content:
  78. return capability
  79. # 构造 posts_content 字符串
  80. posts_text = ""
  81. for i, post in enumerate(posts_content, 1):
  82. posts_text += f"\n### 帖子 {i}({post['case_id']})\n"
  83. posts_text += f"**标题**:{post['title']}\n\n"
  84. posts_text += f"**正文**:\n{post['body_text'][:1000]}\n\n"
  85. if post['images']:
  86. posts_text += f"**图片**:{len(post['images'])} 张\n"
  87. for img_url in post['images']:
  88. posts_text += f"- {img_url}\n"
  89. posts_text += "\n"
  90. # 构造 prompt
  91. try:
  92. prompt_template = load_prompt_template("enrich_capability")
  93. prompt = prompt_template.replace("%capability_name%", capability.get("name", ""))
  94. prompt = prompt.replace("%capability_description%", capability.get("description", ""))
  95. prompt = prompt.replace("%posts_content%", posts_text)
  96. except Exception:
  97. prompt = f"""从以下帖子中提取该能力的具体执行过程和核心参数。
  98. 能力名称:{capability.get("name", "")}
  99. 能力描述:{capability.get("description", "")}
  100. 相关帖子内容:
  101. {posts_text}
  102. 输出 JSON 格式:
  103. {{"execution_process": "...", "core_parameters": "...", "effects": "...", "visual_notes": "..."}}"""
  104. messages = [{"role": "user", "content": prompt}]
  105. def _validate_enrichment(parsed):
  106. from examples.process_pipeline.script.schema_manager import validate_with_schema
  107. return validate_with_schema(parsed, "enrich_capability")
  108. enriched_data, _ = await call_llm_with_retry(
  109. llm_call=llm_call,
  110. messages=messages,
  111. model=model,
  112. temperature=0.1,
  113. max_tokens=8000,
  114. max_retries=3,
  115. validate_fn=_validate_enrichment,
  116. task_name=f"Enrich_{capability.get('name', '')[:20]}",
  117. )
  118. if enriched_data:
  119. capability["enriched_details"] = enriched_data
  120. return capability
  121. async def enrich_all_capabilities(
  122. capabilities_temp_file: Path,
  123. source_file: Path,
  124. output_file: Path,
  125. llm_call,
  126. model: str = "anthropic/claude-sonnet-4-6",
  127. ) -> Dict[str, Any]:
  128. """
  129. 丰富化所有能力
  130. Returns:
  131. 统计信息
  132. """
  133. with open(capabilities_temp_file, "r", encoding="utf-8") as f:
  134. capabilities_data = json.load(f)
  135. with open(source_file, "r", encoding="utf-8") as f:
  136. source_data = json.load(f)
  137. capabilities = capabilities_data.get("abilities", [])
  138. enriched_capabilities = []
  139. total_cost = 0.0
  140. failed_caps = []
  141. print(f"Starting enrichment for {len(capabilities)} capabilities...", flush=True)
  142. for i, cap in enumerate(capabilities, 1):
  143. # 转换字段名:ability_name -> name, ability_description -> description
  144. # 保留 ability_id
  145. normalized_cap = {
  146. "id": cap.get("ability_id", ""),
  147. "name": cap.get("ability_name", ""),
  148. "description": cap.get("ability_description", ""),
  149. "case_references": cap.get("关联案例", []),
  150. }
  151. cap_name = normalized_cap.get("name", "unknown")
  152. print(f" [{i}/{len(capabilities)}] Enriching: {cap_name}", flush=True)
  153. enriched_cap = await enrich_single_capability(normalized_cap, source_data, llm_call, model)
  154. enriched_capabilities.append(enriched_cap)
  155. if "enriched_details" in enriched_cap:
  156. total_cost += 0.01
  157. print(f" [{i}/{len(capabilities)}] ✓ {cap_name}", flush=True)
  158. else:
  159. failed_caps.append(cap_name)
  160. print(f" [{i}/{len(capabilities)}] ⚠️ Failed: {cap_name}", flush=True)
  161. if failed_caps:
  162. print(f" ⚠️ {len(failed_caps)} capabilities failed enrichment: {failed_caps}")
  163. # 输出结果
  164. output_data = {
  165. "requirement": capabilities_data.get("requirement", ""),
  166. "capabilities": enriched_capabilities
  167. }
  168. schema_err = validate_capabilities_enriched(output_data)
  169. if schema_err:
  170. raise ValueError(f"Final capabilities.json schema invalid: {schema_err}")
  171. output_file.parent.mkdir(parents=True, exist_ok=True)
  172. with open(output_file, "w", encoding="utf-8") as f:
  173. json.dump(output_data, f, ensure_ascii=False, indent=2)
  174. return {
  175. "total_capabilities": len(capabilities),
  176. "enriched": len([c for c in enriched_capabilities if "enriched_details" in c]),
  177. "total_cost": total_cost,
  178. }
  179. async def main():
  180. """命令行入口"""
  181. import argparse
  182. import sys
  183. sys.path.insert(0, str(Path(__file__).parent.parent.parent))
  184. from agent.llm.openrouter import OpenRouterLLM
  185. parser = argparse.ArgumentParser()
  186. parser.add_argument("--capabilities-temp", required=True)
  187. parser.add_argument("--source-file", required=True)
  188. parser.add_argument("--output-file", required=True)
  189. parser.add_argument("--model", default="anthropic/claude-sonnet-4-6")
  190. args = parser.parse_args()
  191. llm = OpenRouterLLM()
  192. result = await enrich_all_capabilities(
  193. capabilities_temp_file=Path(args.capabilities_temp),
  194. source_file=Path(args.source_file),
  195. output_file=Path(args.output_file),
  196. llm_call=llm.chat,
  197. model=args.model,
  198. )
  199. print(f"✓ Enriched {result['enriched']}/{result['total_capabilities']} capabilities")
  200. if __name__ == "__main__":
  201. asyncio.run(main())