compute_coverage_scores.py 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210
  1. import json
  2. import sys
  3. import asyncio
  4. from pathlib import Path
  5. repo_root = str(Path(__file__).parent.parent.parent.parent)
  6. if repo_root not in sys.path:
  7. sys.path.insert(0, repo_root)
  8. from dotenv import load_dotenv
  9. load_dotenv()
  10. from knowhub.knowhub_db.pg_strategy_store import PostgreSQLStrategyStore
  11. from knowhub.knowhub_db.pg_requirement_store import PostgreSQLRequirementStore
  12. from agent.llm.openrouter import openrouter_llm_call
  13. OUTPUT_JSON = Path("examples/process_pipeline/script/coverage_scores.json")
  14. EVAL_PROMPT = """
  15. You are an expert system architecture and pipeline evaluator.
  16. You will be provided with a User Requirement and multiple alternative Proposed Pipeline Workflows created to resolve that requirement.
  17. Your task is to evaluate how well each workflow semantically covers and resolves the user's needs.
  18. User Requirement:
  19. {req_desc}
  20. Proposed Workflows:
  21. {workflows_json}
  22. For each workflow, assign a `coverage_score` between 0.00 and 1.00 (1.00 = completely and deeply resolves the core requirement).
  23. Return your result STRICTLY as a JSON array of objects, one for each workflow evaluated, containing:
  24. [
  25. {{
  26. "strategy_id": "<exact strategy_id from the input>",
  27. "coverage_score": 0.85,
  28. "explanation": "<1-2 sentence justification on what it covers well and what it might be missing>"
  29. }}
  30. ]
  31. DO NOT output any thinking, markdown wrapping (```json), or conversational text. Output ONLY the raw JSON array.
  32. """
  33. async def process_requirement(req_desc: str, group_strats: list) -> dict:
  34. # Prepare payload to send to LLM
  35. workflows_payload = []
  36. for s in group_strats:
  37. body_data = s.get("body") or {}
  38. if isinstance(body_data, str):
  39. try:
  40. body_data = json.loads(body_data)
  41. except:
  42. body_data = {}
  43. workflows_payload.append({
  44. "strategy_id": s["id"],
  45. "workflow": body_data.get("workflow", [])
  46. })
  47. prompt = EVAL_PROMPT.format(
  48. req_desc=req_desc,
  49. workflows_json=json.dumps(workflows_payload, ensure_ascii=False, indent=2)
  50. )
  51. try:
  52. resp = await openrouter_llm_call(
  53. messages=[{"role": "user", "content": prompt}],
  54. model="anthropic/claude-sonnet-4.5", # OpenRouter uses this to route to latest 3.5 Sonnet
  55. max_tokens=4096,
  56. temperature=0.1
  57. )
  58. content = resp["content"].strip()
  59. if content.startswith("```json"):
  60. content = content.replace("```json", "").replace("```", "").strip()
  61. elif content.startswith("```"):
  62. content = content.replace("```", "").strip()
  63. return json.loads(content)
  64. except Exception as e:
  65. print(f" [Error] LLM Call failed for a requirement: {e}")
  66. return []
  67. async def main(dry_run: bool = False, force: bool = False):
  68. print("Connecting to DB...")
  69. strat_store = PostgreSQLStrategyStore()
  70. req_store = PostgreSQLRequirementStore()
  71. requirements = req_store.list_all(limit=10000)
  72. strategies = strat_store.list_all(limit=10000)
  73. strat_map = {s["id"]: s for s in strategies}
  74. output_data = {}
  75. if OUTPUT_JSON.exists() and not force:
  76. try:
  77. with open(OUTPUT_JSON, "r", encoding="utf-8") as f:
  78. output_data = json.load(f)
  79. print(f"Loaded existing coverage scores for {len(output_data)} requirements. Resuming...")
  80. except:
  81. print("Failed to load existing JSON, starting fresh.")
  82. elif force:
  83. print("Force run enabled. Discarding existing records and starting completely fresh.")
  84. processed_req_ids = set(output_data.keys())
  85. total_reqs = len(output_data)
  86. # Filter out already processed requirements
  87. pending_requirements = [r for r in requirements if r["id"] not in processed_req_ids]
  88. print(f"Starting LLM coverage semantic evaluation using Sonnet 4.5 via OpenRouter...")
  89. print(f"Total Requirements remaining to evaluate: {len(pending_requirements)} (out of {len(requirements)})")
  90. # Process in batches of 10 concurrent requests
  91. batch_size = 10
  92. for i in range(0, len(pending_requirements), batch_size):
  93. batch_reqs = pending_requirements[i:i+batch_size]
  94. tasks = []
  95. print(f"Evaluating Batch {i//batch_size + 1} (Reqs {i+1} to min({i+batch_size}, {len(pending_requirements)}))")
  96. for req in batch_reqs:
  97. req_id = req["id"]
  98. req_desc = req.get("description", "Unknown Description")
  99. req_strat_ids = req.get("strategy_ids") or []
  100. group_strats = [strat_map[sid] for sid in req_strat_ids if sid in strat_map]
  101. if not group_strats:
  102. continue
  103. tasks.append((req_id, req_desc, process_requirement(req_desc, group_strats)))
  104. if not tasks:
  105. continue
  106. # Execute batch concurrently
  107. results = await asyncio.gather(*(t[2] for t in tasks))
  108. # Map results back
  109. for idx, (req_id, req_desc, _) in enumerate(tasks):
  110. evaluations = results[idx]
  111. if not evaluations:
  112. continue
  113. strat_results = []
  114. for ev in evaluations:
  115. sid = ev.get("strategy_id")
  116. if sid in strat_map:
  117. strat_info = strat_map[sid]
  118. is_selected = (strat_info.get("status") == "published")
  119. strat_results.append({
  120. "strategy_id": sid,
  121. "strategy_name": strat_info.get("name", ""),
  122. "is_selected": is_selected,
  123. "coverage_score": ev.get("coverage_score", 0),
  124. "explanation": ev.get("explanation", "")
  125. })
  126. if strat_results:
  127. strat_results.sort(key=lambda x: x["coverage_score"], reverse=True)
  128. output_data[req_id] = {
  129. "requirement_desc": req_desc,
  130. "strategies": strat_results
  131. }
  132. total_reqs += 1
  133. # Write the calculated score and explanation directly back to the database body
  134. updated_count = 0
  135. for ev in strat_results:
  136. sid = ev["strategy_id"]
  137. if sid in strat_map:
  138. strat_info = strat_map[sid]
  139. body_data = strat_info.get("body") or {}
  140. if isinstance(body_data, str):
  141. try:
  142. body_data = json.loads(body_data)
  143. except:
  144. body_data = {}
  145. body_data.setdefault("coverage_evaluations", {})
  146. body_data["coverage_evaluations"][req_id] = {
  147. "score": ev["coverage_score"],
  148. "explanation": ev["explanation"]
  149. }
  150. if not dry_run:
  151. try:
  152. strat_store.update(sid, {"body": json.dumps(body_data, ensure_ascii=False)})
  153. strat_map[sid]["body"] = body_data # Update local map cache
  154. updated_count += 1
  155. except Exception as e:
  156. print(f" [Error] Failed to update body for strategy {sid}: {e}")
  157. else:
  158. updated_count += 1
  159. tag_word = "[DRY-RUN] Simulated updating" if dry_run else "Updated"
  160. print(f" -> Processed requirement {req_id}: {tag_word} DB body for {updated_count} strategies.")
  161. # Save incrementally after every batch to prevent data loss
  162. with open(OUTPUT_JSON, "w", encoding="utf-8") as f:
  163. json.dump(output_data, f, ensure_ascii=False, indent=2)
  164. print(f"Evaluated {total_reqs} requirements overall.")
  165. print(f"Results {"simulated (DB untouched)" if dry_run else "and DB updates"} successfully saved to: {OUTPUT_JSON}")
  166. if __name__ == "__main__":
  167. import argparse
  168. parser = argparse.ArgumentParser()
  169. parser.add_argument("--dry-run", action="store_true", help="Calculate scores and save to JSON only, do not write to DB")
  170. parser.add_argument("--force", action="store_true", help="Discard existing JSON and rerun all requirements from scratch")
  171. args = parser.parse_args()
  172. asyncio.run(main(args.dry_run, args.force))