assemble_strategy_workflow.py 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234
  1. """
  2. Phase 3: 策略组装 workflow
  3. 从 process.json 和 capabilities.json 读取数据,
  4. 调用 LLM 生成最终的 strategy.json
  5. """
  6. import asyncio
  7. import json
  8. from pathlib import Path
  9. from typing import Any, Dict, Set
  10. from examples.process_pipeline.script.llm_helper import call_llm_with_retry
  11. def load_prompt_template(prompt_name: str) -> str:
  12. """从 prompts 目录加载 prompt 模板"""
  13. base_dir = Path(__file__).parent.parent
  14. prompt_path = base_dir / "prompts" / f"{prompt_name}.prompt"
  15. with open(prompt_path, "r", encoding="utf-8") as f:
  16. content = f.read()
  17. if content.startswith("---"):
  18. parts = content.split("---", 2)
  19. if len(parts) >= 3:
  20. content = parts[2]
  21. content = content.replace("$system$", "").replace("$user$", "")
  22. return content.strip()
  23. def validate_strategy_references(
  24. strategy_data: Dict[str, Any],
  25. valid_cluster_ids: Set[str],
  26. valid_ability_ids: Set[str],
  27. cluster_step_counts: Dict[str, int],
  28. ) -> str:
  29. """
  30. 验证 LLM 输出的匹配关系中所有引用 ID 是否真实存在
  31. """
  32. matches = strategy_data.get("matches", [])
  33. if not isinstance(matches, list):
  34. return "'matches' is not a list"
  35. for i, m in enumerate(matches):
  36. cid = m.get("cluster_id")
  37. if cid not in valid_cluster_ids:
  38. return f"matches[{i}].cluster_id '{cid}' not found in process.json"
  39. step_idx = m.get("step_index")
  40. max_steps = cluster_step_counts.get(cid, 0)
  41. if isinstance(step_idx, int) and (step_idx < 0 or step_idx >= max_steps):
  42. return f"matches[{i}].step_index {step_idx} out of range (cluster '{cid}' has {max_steps} steps)"
  43. caps = m.get("capabilities", [])
  44. for k, cap_id in enumerate(caps):
  45. if not isinstance(cap_id, str):
  46. return f"matches[{i}].capabilities[{k}] must be a string id"
  47. if cap_id not in valid_ability_ids:
  48. return f"matches[{i}].capabilities[{k}]: '{cap_id}' not found in capabilities.json"
  49. return None
  50. async def assemble_strategy(
  51. process_file: Path,
  52. capabilities_file: Path,
  53. output_file: Path,
  54. requirement: str,
  55. llm_call,
  56. model: str = "anthropic/claude-sonnet-4-6",
  57. ) -> Dict[str, Any]:
  58. """
  59. 组装最终策略
  60. Returns:
  61. 统计信息
  62. """
  63. with open(process_file, "r", encoding="utf-8") as f:
  64. process_data = json.load(f)
  65. with open(capabilities_file, "r", encoding="utf-8") as f:
  66. capabilities_data = json.load(f)
  67. # 提取有效的 ID 集合
  68. clusters = process_data.get("clusters", [])
  69. capabilities = capabilities_data.get("capabilities", [])
  70. valid_cluster_ids = {cl.get("cluster_id") for cl in clusters if cl.get("cluster_id")}
  71. valid_ability_ids = {ab.get("id") for ab in capabilities if ab.get("id")}
  72. cluster_step_counts = {cl.get("cluster_id"): len(cl.get("工序步骤", [])) for cl in clusters}
  73. # 构造 prompt
  74. process_text = json.dumps(process_data, ensure_ascii=False, indent=2)
  75. capabilities_text = json.dumps(capabilities_data, ensure_ascii=False, indent=2)
  76. try:
  77. prompt_template = load_prompt_template("assemble_strategy")
  78. prompt = prompt_template.replace("%requirement%", requirement)
  79. prompt = prompt.replace("%process_data%", process_text)
  80. prompt = prompt.replace("%capabilities_data%", capabilities_text)
  81. except Exception:
  82. prompt = f"""为以下工序的每个步骤匹配对应的能力 ID。
  83. 需求:{requirement}
  84. 工序聚类:
  85. {process_text}
  86. 能力列表:
  87. {capabilities_text}
  88. 直接输出 JSON:
  89. {{"matches": [{{"cluster_id": "A", "step_index": 0, "capabilities": ["AB-01"]}}]}}"""
  90. messages = [{"role": "user", "content": prompt}]
  91. def _validate_with_refs(parsed):
  92. from examples.process_pipeline.script.validate_schema import validate_strategy
  93. err = validate_strategy(parsed)
  94. if err:
  95. return err
  96. return validate_strategy_references(parsed, valid_cluster_ids, valid_ability_ids, cluster_step_counts)
  97. match_data, total_cost = await call_llm_with_retry(
  98. llm_call=llm_call,
  99. messages=messages,
  100. model=model,
  101. temperature=0.1,
  102. max_tokens=4000,
  103. max_retries=3,
  104. validate_fn=_validate_with_refs,
  105. task_name="P3_AssembleStrategy",
  106. )
  107. if match_data is None:
  108. match_data = {"matches": []}
  109. # ── 代码组装完整 strategy.json ──
  110. cluster_map = {cl.get("cluster_id"): cl for cl in clusters}
  111. capability_map = {ab.get("id"): ab for ab in capabilities}
  112. # 按 cluster_id 分组 matches
  113. from collections import defaultdict
  114. cluster_matches = defaultdict(dict) # {cluster_id: {step_index: [cap_ids]}}
  115. for m in match_data.get("matches", []):
  116. cid = m.get("cluster_id")
  117. sidx = m.get("step_index", 0)
  118. cluster_matches[cid][sidx] = m.get("capabilities", [])
  119. full_workflow = []
  120. for cl in clusters:
  121. cid = cl.get("cluster_id")
  122. steps = cl.get("工序步骤", [])
  123. step_matches = cluster_matches.get(cid, {})
  124. assembled_steps = []
  125. for idx, step in enumerate(steps):
  126. cap_ids = step_matches.get(idx, [])
  127. matched_caps = []
  128. for cap_id in cap_ids:
  129. cap_info = capability_map.get(cap_id)
  130. if cap_info:
  131. matched_caps.append({
  132. "id": cap_id,
  133. "name": cap_info.get("name", ""),
  134. "description": cap_info.get("description", ""),
  135. "case_references": cap_info.get("case_references", []),
  136. "enriched_details": cap_info.get("enriched_details"),
  137. })
  138. assembled_steps.append({
  139. "步骤序号": step.get("步骤序号"),
  140. "步骤描述": step.get("步骤描述"),
  141. "capabilities": matched_caps,
  142. })
  143. full_workflow.append({
  144. "cluster_id": cid,
  145. "cluster_name": cl.get("cluster_name", ""),
  146. "关联案例": cl.get("关联案例", []),
  147. "score": cl.get("score"),
  148. "explanation": cl.get("explanation"),
  149. "steps": assembled_steps,
  150. })
  151. output_data = {
  152. "requirement": requirement,
  153. "workflow": full_workflow,
  154. }
  155. output_file.parent.mkdir(parents=True, exist_ok=True)
  156. with open(output_file, "w", encoding="utf-8") as f:
  157. json.dump(output_data, f, ensure_ascii=False, indent=2)
  158. return {
  159. "workflow_steps": sum(len(cl.get("steps", [])) for cl in full_workflow),
  160. "total_cost": total_cost,
  161. }
  162. async def main():
  163. """命令行入口"""
  164. import argparse
  165. import sys
  166. sys.path.insert(0, str(Path(__file__).parent.parent.parent))
  167. from agent.llm.openrouter import OpenRouterLLM
  168. parser = argparse.ArgumentParser()
  169. parser.add_argument("--process-file", required=True)
  170. parser.add_argument("--capabilities-file", required=True)
  171. parser.add_argument("--output-file", required=True)
  172. parser.add_argument("--requirement", required=True)
  173. parser.add_argument("--model", default="anthropic/claude-sonnet-4-6")
  174. args = parser.parse_args()
  175. llm = OpenRouterLLM()
  176. result = await assemble_strategy(
  177. process_file=Path(args.process_file),
  178. capabilities_file=Path(args.capabilities_file),
  179. output_file=Path(args.output_file),
  180. requirement=args.requirement,
  181. llm_call=llm.chat,
  182. model=args.model,
  183. )
  184. print(f"✓ Generated {result['strategies']} strategies")
  185. if __name__ == "__main__":
  186. asyncio.run(main())