score_processes.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132
  1. """
  2. Phase 2.1.2: 工序匹配度打分
  3. 读取 blueprint.json,对每个 blueprint 工序类进行与需求的匹配度打分,
  4. 输出到 process.json
  5. """
  6. import asyncio
  7. import json
  8. import re
  9. from pathlib import Path
  10. from typing import Any, Dict
  11. from examples.process_pipeline.script.llm_helper import call_llm_with_retry
  12. from examples.process_pipeline.script.validate_schema import validate_process, List, Optional
  13. def load_prompt_template(prompt_name: str) -> str:
  14. """从 prompts 目录加载 prompt 模板"""
  15. base_dir = Path(__file__).parent.parent
  16. prompt_path = base_dir / "prompts" / f"{prompt_name}.prompt"
  17. with open(prompt_path, "r", encoding="utf-8") as f:
  18. content = f.read()
  19. if content.startswith("---"):
  20. parts = content.split("---", 2)
  21. if len(parts) >= 3:
  22. content = parts[2]
  23. content = content.replace("$system$", "").replace("$user$", "")
  24. return content.strip()
  25. async def score_blueprints(
  26. blueprint_file: Path,
  27. output_file: Path,
  28. requirement: str,
  29. llm_call,
  30. model: str = "anthropic/claude-sonnet-4-6",
  31. ) -> Dict[str, Any]:
  32. """
  33. 对 blueprint_temp.json 中的每个工序聚类进行匹配度打分
  34. Returns:
  35. 统计信息
  36. """
  37. with open(blueprint_file, "r", encoding="utf-8") as f:
  38. blueprint_data = json.load(f)
  39. clusters = blueprint_data.get("clusters", [])
  40. if not clusters:
  41. return {"error": "No clusters found", "scored": 0, "total_cost": 0.0}
  42. # 构造 prompt
  43. try:
  44. prompt_template = load_prompt_template("score_processes")
  45. clusters_text = json.dumps(clusters, ensure_ascii=False, indent=2)
  46. prompt = prompt_template.replace("%requirement%", requirement)
  47. prompt = prompt.replace("%clusters_data%", clusters_text)
  48. except Exception:
  49. clusters_text = json.dumps(clusters, ensure_ascii=False, indent=2)
  50. prompt = f"""对以下工序聚类进行与需求的匹配度打分。
  51. 需求:{requirement}
  52. 工序聚类:
  53. {clusters_text}
  54. 直接输出 JSON:
  55. {{"scored_clusters": [{{"cluster_id": "A", "cluster_name": "...", "score": 0.85, "explanation": "评分理由"}}]}}"""
  56. messages = [{"role": "user", "content": prompt}]
  57. def _validate_scored_output(parsed):
  58. scored = parsed.get("scored_clusters", [])
  59. if not isinstance(scored, list):
  60. return "'scored_clusters' is not a list"
  61. if len(scored) == 0:
  62. return "'scored_clusters' is empty"
  63. for i, item in enumerate(scored):
  64. if "score" not in item:
  65. return f"scored_clusters[{i}] missing 'score'"
  66. if not isinstance(item["score"], (int, float)):
  67. return f"scored_clusters[{i}].score must be a number"
  68. if "explanation" not in item or not (item.get("explanation") or "").strip():
  69. return f"scored_clusters[{i}] missing or empty 'explanation'"
  70. return None
  71. scored_data, total_cost = await call_llm_with_retry(
  72. llm_call=llm_call,
  73. messages=messages,
  74. model=model,
  75. temperature=0.1,
  76. max_tokens=4000,
  77. max_retries=3,
  78. validate_fn=_validate_scored_output,
  79. task_name="P2.1.2_ScoreProcesses",
  80. )
  81. if scored_data is None:
  82. scored_data = {"scored_clusters": []}
  83. # 把 score 和 explanation 合并回原始 clusters
  84. scored_map = {}
  85. for sc in scored_data.get("scored_clusters", []):
  86. cid = sc.get("cluster_id")
  87. if cid:
  88. scored_map[cid] = sc
  89. merged_clusters = []
  90. for cl in clusters:
  91. cid = cl.get("cluster_id")
  92. if cid in scored_map:
  93. cl["score"] = scored_map[cid].get("score", 0)
  94. cl["explanation"] = scored_map[cid].get("explanation", "")
  95. merged_clusters.append(cl)
  96. output_data = {
  97. "requirement": requirement,
  98. "clusters": merged_clusters,
  99. }
  100. output_file.parent.mkdir(parents=True, exist_ok=True)
  101. with open(output_file, "w", encoding="utf-8") as f:
  102. json.dump(output_data, f, ensure_ascii=False, indent=2)
  103. return {
  104. "scored": len(scored_map),
  105. "total_cost": total_cost,
  106. }