| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234 |
- """
- Phase 3: 策略组装 workflow
- 从 process.json 和 capabilities.json 读取数据,
- 调用 LLM 生成最终的 strategy.json
- """
- import asyncio
- import json
- from pathlib import Path
- from typing import Any, Dict, Set
- from examples.process_pipeline.script.llm_helper import call_llm_with_retry
- def load_prompt_template(prompt_name: str) -> str:
- """从 prompts 目录加载 prompt 模板"""
- base_dir = Path(__file__).parent.parent
- prompt_path = base_dir / "prompts" / f"{prompt_name}.prompt"
- with open(prompt_path, "r", encoding="utf-8") as f:
- content = f.read()
- if content.startswith("---"):
- parts = content.split("---", 2)
- if len(parts) >= 3:
- content = parts[2]
- content = content.replace("$system$", "").replace("$user$", "")
- return content.strip()
- def validate_strategy_references(
- strategy_data: Dict[str, Any],
- valid_cluster_ids: Set[str],
- valid_ability_ids: Set[str],
- cluster_step_counts: Dict[str, int],
- ) -> str:
- """
- 验证 LLM 输出的匹配关系中所有引用 ID 是否真实存在
- """
- matches = strategy_data.get("matches", [])
- if not isinstance(matches, list):
- return "'matches' is not a list"
- for i, m in enumerate(matches):
- cid = m.get("cluster_id")
- if cid not in valid_cluster_ids:
- return f"matches[{i}].cluster_id '{cid}' not found in process.json"
- step_idx = m.get("step_index")
- max_steps = cluster_step_counts.get(cid, 0)
- if isinstance(step_idx, int) and (step_idx < 0 or step_idx >= max_steps):
- return f"matches[{i}].step_index {step_idx} out of range (cluster '{cid}' has {max_steps} steps)"
- caps = m.get("capabilities", [])
- for k, cap_id in enumerate(caps):
- if not isinstance(cap_id, str):
- return f"matches[{i}].capabilities[{k}] must be a string id"
- if cap_id not in valid_ability_ids:
- return f"matches[{i}].capabilities[{k}]: '{cap_id}' not found in capabilities.json"
- return None
- async def assemble_strategy(
- process_file: Path,
- capabilities_file: Path,
- output_file: Path,
- requirement: str,
- llm_call,
- model: str = "anthropic/claude-sonnet-4-6",
- ) -> Dict[str, Any]:
- """
- 组装最终策略
- Returns:
- 统计信息
- """
- with open(process_file, "r", encoding="utf-8") as f:
- process_data = json.load(f)
- with open(capabilities_file, "r", encoding="utf-8") as f:
- capabilities_data = json.load(f)
- # 提取有效的 ID 集合
- clusters = process_data.get("clusters", [])
- capabilities = capabilities_data.get("capabilities", [])
- valid_cluster_ids = {cl.get("cluster_id") for cl in clusters if cl.get("cluster_id")}
- valid_ability_ids = {ab.get("id") for ab in capabilities if ab.get("id")}
- cluster_step_counts = {cl.get("cluster_id"): len(cl.get("工序步骤", [])) for cl in clusters}
- # 构造 prompt
- process_text = json.dumps(process_data, ensure_ascii=False, indent=2)
- capabilities_text = json.dumps(capabilities_data, ensure_ascii=False, indent=2)
- try:
- prompt_template = load_prompt_template("assemble_strategy")
- prompt = prompt_template.replace("%requirement%", requirement)
- prompt = prompt.replace("%process_data%", process_text)
- prompt = prompt.replace("%capabilities_data%", capabilities_text)
- except Exception:
- prompt = f"""为以下工序的每个步骤匹配对应的能力 ID。
- 需求:{requirement}
- 工序聚类:
- {process_text}
- 能力列表:
- {capabilities_text}
- 直接输出 JSON:
- {{"matches": [{{"cluster_id": "A", "step_index": 0, "capabilities": ["AB-01"]}}]}}"""
- messages = [{"role": "user", "content": prompt}]
- def _validate_with_refs(parsed):
- from examples.process_pipeline.script.validate_schema import validate_strategy
- err = validate_strategy(parsed)
- if err:
- return err
- return validate_strategy_references(parsed, valid_cluster_ids, valid_ability_ids, cluster_step_counts)
- match_data, total_cost = await call_llm_with_retry(
- llm_call=llm_call,
- messages=messages,
- model=model,
- temperature=0.1,
- max_tokens=4000,
- max_retries=3,
- validate_fn=_validate_with_refs,
- task_name="P3_AssembleStrategy",
- )
- if match_data is None:
- match_data = {"matches": []}
- # ── 代码组装完整 strategy.json ──
- cluster_map = {cl.get("cluster_id"): cl for cl in clusters}
- capability_map = {ab.get("id"): ab for ab in capabilities}
- # 按 cluster_id 分组 matches
- from collections import defaultdict
- cluster_matches = defaultdict(dict) # {cluster_id: {step_index: [cap_ids]}}
- for m in match_data.get("matches", []):
- cid = m.get("cluster_id")
- sidx = m.get("step_index", 0)
- cluster_matches[cid][sidx] = m.get("capabilities", [])
- full_workflow = []
- for cl in clusters:
- cid = cl.get("cluster_id")
- steps = cl.get("工序步骤", [])
- step_matches = cluster_matches.get(cid, {})
- assembled_steps = []
- for idx, step in enumerate(steps):
- cap_ids = step_matches.get(idx, [])
- matched_caps = []
- for cap_id in cap_ids:
- cap_info = capability_map.get(cap_id)
- if cap_info:
- matched_caps.append({
- "id": cap_id,
- "name": cap_info.get("name", ""),
- "description": cap_info.get("description", ""),
- "case_references": cap_info.get("case_references", []),
- "enriched_details": cap_info.get("enriched_details"),
- })
- assembled_steps.append({
- "步骤序号": step.get("步骤序号"),
- "步骤描述": step.get("步骤描述"),
- "capabilities": matched_caps,
- })
- full_workflow.append({
- "cluster_id": cid,
- "cluster_name": cl.get("cluster_name", ""),
- "关联案例": cl.get("关联案例", []),
- "score": cl.get("score"),
- "explanation": cl.get("explanation"),
- "steps": assembled_steps,
- })
- output_data = {
- "requirement": requirement,
- "workflow": full_workflow,
- }
- output_file.parent.mkdir(parents=True, exist_ok=True)
- with open(output_file, "w", encoding="utf-8") as f:
- json.dump(output_data, f, ensure_ascii=False, indent=2)
- return {
- "workflow_steps": sum(len(cl.get("steps", [])) for cl in full_workflow),
- "total_cost": total_cost,
- }
- async def main():
- """命令行入口"""
- import argparse
- import sys
- sys.path.insert(0, str(Path(__file__).parent.parent.parent))
- from agent.llm.openrouter import OpenRouterLLM
- parser = argparse.ArgumentParser()
- parser.add_argument("--process-file", required=True)
- parser.add_argument("--capabilities-file", required=True)
- parser.add_argument("--output-file", required=True)
- parser.add_argument("--requirement", required=True)
- parser.add_argument("--model", default="anthropic/claude-sonnet-4-6")
- args = parser.parse_args()
- llm = OpenRouterLLM()
- result = await assemble_strategy(
- process_file=Path(args.process_file),
- capabilities_file=Path(args.capabilities_file),
- output_file=Path(args.output_file),
- requirement=args.requirement,
- llm_call=llm.chat,
- model=args.model,
- )
- print(f"✓ Generated {result['strategies']} strategies")
- if __name__ == "__main__":
- asyncio.run(main())
|