""" Phase 3: 策略组装 workflow 从 process.json 和 capabilities.json 读取数据, 调用 LLM 生成最终的 strategy.json """ import asyncio import json from pathlib import Path from typing import Any, Dict, Set from examples.process_pipeline.script.llm_helper import call_llm_with_retry def load_prompt_template(prompt_name: str) -> str: """从 prompts 目录加载 prompt 模板""" base_dir = Path(__file__).parent.parent prompt_path = base_dir / "prompts" / f"{prompt_name}.prompt" with open(prompt_path, "r", encoding="utf-8") as f: content = f.read() if content.startswith("---"): parts = content.split("---", 2) if len(parts) >= 3: content = parts[2] content = content.replace("$system$", "").replace("$user$", "") return content.strip() def validate_strategy_references( strategy_data: Dict[str, Any], valid_cluster_ids: Set[str], valid_ability_ids: Set[str], cluster_step_counts: Dict[str, int], ) -> str: """ 验证 LLM 输出的匹配关系中所有引用 ID 是否真实存在 """ matches = strategy_data.get("matches", []) if not isinstance(matches, list): return "'matches' is not a list" for i, m in enumerate(matches): cid = m.get("cluster_id") if cid not in valid_cluster_ids: return f"matches[{i}].cluster_id '{cid}' not found in process.json" step_idx = m.get("step_index") max_steps = cluster_step_counts.get(cid, 0) if isinstance(step_idx, int) and (step_idx < 0 or step_idx >= max_steps): return f"matches[{i}].step_index {step_idx} out of range (cluster '{cid}' has {max_steps} steps)" caps = m.get("capabilities", []) for k, cap_id in enumerate(caps): if not isinstance(cap_id, str): return f"matches[{i}].capabilities[{k}] must be a string id" if cap_id not in valid_ability_ids: return f"matches[{i}].capabilities[{k}]: '{cap_id}' not found in capabilities.json" return None async def assemble_strategy( process_file: Path, capabilities_file: Path, output_file: Path, requirement: str, llm_call, model: str = "anthropic/claude-sonnet-4-6", ) -> Dict[str, Any]: """ 组装最终策略 Returns: 统计信息 """ with open(process_file, "r", encoding="utf-8") as f: process_data = json.load(f) with open(capabilities_file, "r", encoding="utf-8") as f: capabilities_data = json.load(f) # 提取有效的 ID 集合 clusters = process_data.get("clusters", []) capabilities = capabilities_data.get("capabilities", []) valid_cluster_ids = {cl.get("cluster_id") for cl in clusters if cl.get("cluster_id")} valid_ability_ids = {ab.get("id") for ab in capabilities if ab.get("id")} cluster_step_counts = {cl.get("cluster_id"): len(cl.get("工序步骤", [])) for cl in clusters} # 构造 prompt process_text = json.dumps(process_data, ensure_ascii=False, indent=2) capabilities_text = json.dumps(capabilities_data, ensure_ascii=False, indent=2) try: prompt_template = load_prompt_template("assemble_strategy") prompt = prompt_template.replace("%requirement%", requirement) prompt = prompt.replace("%process_data%", process_text) prompt = prompt.replace("%capabilities_data%", capabilities_text) except Exception: prompt = f"""为以下工序的每个步骤匹配对应的能力 ID。 需求:{requirement} 工序聚类: {process_text} 能力列表: {capabilities_text} 直接输出 JSON: {{"matches": [{{"cluster_id": "A", "step_index": 0, "capabilities": ["AB-01"]}}]}}""" messages = [{"role": "user", "content": prompt}] def _validate_with_refs(parsed): from examples.process_pipeline.script.validate_schema import validate_strategy err = validate_strategy(parsed) if err: return err return validate_strategy_references(parsed, valid_cluster_ids, valid_ability_ids, cluster_step_counts) match_data, total_cost = await call_llm_with_retry( llm_call=llm_call, messages=messages, model=model, temperature=0.1, max_tokens=4000, max_retries=3, validate_fn=_validate_with_refs, task_name="P3_AssembleStrategy", ) if match_data is None: match_data = {"matches": []} # ── 代码组装完整 strategy.json ── cluster_map = {cl.get("cluster_id"): cl for cl in clusters} capability_map = {ab.get("id"): ab for ab in capabilities} # 按 cluster_id 分组 matches from collections import defaultdict cluster_matches = defaultdict(dict) # {cluster_id: {step_index: [cap_ids]}} for m in match_data.get("matches", []): cid = m.get("cluster_id") sidx = m.get("step_index", 0) cluster_matches[cid][sidx] = m.get("capabilities", []) full_workflow = [] for cl in clusters: cid = cl.get("cluster_id") steps = cl.get("工序步骤", []) step_matches = cluster_matches.get(cid, {}) assembled_steps = [] for idx, step in enumerate(steps): cap_ids = step_matches.get(idx, []) matched_caps = [] for cap_id in cap_ids: cap_info = capability_map.get(cap_id) if cap_info: matched_caps.append({ "id": cap_id, "name": cap_info.get("name", ""), "description": cap_info.get("description", ""), "case_references": cap_info.get("case_references", []), "enriched_details": cap_info.get("enriched_details"), }) assembled_steps.append({ "步骤序号": step.get("步骤序号"), "步骤描述": step.get("步骤描述"), "capabilities": matched_caps, }) full_workflow.append({ "cluster_id": cid, "cluster_name": cl.get("cluster_name", ""), "关联案例": cl.get("关联案例", []), "score": cl.get("score"), "explanation": cl.get("explanation"), "steps": assembled_steps, }) output_data = { "requirement": requirement, "workflow": full_workflow, } output_file.parent.mkdir(parents=True, exist_ok=True) with open(output_file, "w", encoding="utf-8") as f: json.dump(output_data, f, ensure_ascii=False, indent=2) return { "workflow_steps": sum(len(cl.get("steps", [])) for cl in full_workflow), "total_cost": total_cost, } async def main(): """命令行入口""" import argparse import sys sys.path.insert(0, str(Path(__file__).parent.parent.parent)) from agent.llm.openrouter import OpenRouterLLM parser = argparse.ArgumentParser() parser.add_argument("--process-file", required=True) parser.add_argument("--capabilities-file", required=True) parser.add_argument("--output-file", required=True) parser.add_argument("--requirement", required=True) parser.add_argument("--model", default="anthropic/claude-sonnet-4-6") args = parser.parse_args() llm = OpenRouterLLM() result = await assemble_strategy( process_file=Path(args.process_file), capabilities_file=Path(args.capabilities_file), output_file=Path(args.output_file), requirement=args.requirement, llm_call=llm.chat, model=args.model, ) print(f"✓ Generated {result['strategies']} strategies") if __name__ == "__main__": asyncio.run(main())