浏览代码

fix compute_coverage_score script

guantao 1 月之前
父节点
当前提交
4324a0a84b

+ 65 - 20
examples/process_pipeline/script/compute_coverage_scores.py

@@ -1,6 +1,7 @@
 import json
 import json
 import sys
 import sys
 import asyncio
 import asyncio
+import re
 from pathlib import Path
 from pathlib import Path
 repo_root = str(Path(__file__).parent.parent.parent.parent)
 repo_root = str(Path(__file__).parent.parent.parent.parent)
 if repo_root not in sys.path:
 if repo_root not in sys.path:
@@ -14,6 +15,7 @@ from knowhub.knowhub_db.pg_requirement_store import PostgreSQLRequirementStore
 from agent.llm.openrouter import openrouter_llm_call
 from agent.llm.openrouter import openrouter_llm_call
 
 
 OUTPUT_JSON = Path("examples/process_pipeline/script/coverage_scores.json")
 OUTPUT_JSON = Path("examples/process_pipeline/script/coverage_scores.json")
+FAILED_JSON = Path("examples/process_pipeline/script/failed_requirements.json")
 
 
 EVAL_PROMPT = """
 EVAL_PROMPT = """
 You are an expert system architecture and pipeline evaluator.
 You are an expert system architecture and pipeline evaluator.
@@ -38,7 +40,7 @@ Return your result STRICTLY as a JSON array of objects, one for each workflow ev
 DO NOT output any thinking, markdown wrapping (```json), or conversational text. Output ONLY the raw JSON array.
 DO NOT output any thinking, markdown wrapping (```json), or conversational text. Output ONLY the raw JSON array.
 """
 """
 
 
-async def process_requirement(req_desc: str, group_strats: list) -> dict:
+async def process_requirement(req_desc: str, group_strats: list, max_retries: int = 3) -> list:
     # Prepare payload to send to LLM
     # Prepare payload to send to LLM
     workflows_payload = []
     workflows_payload = []
     for s in group_strats:
     for s in group_strats:
@@ -59,25 +61,46 @@ async def process_requirement(req_desc: str, group_strats: list) -> dict:
         workflows_json=json.dumps(workflows_payload, ensure_ascii=False, indent=2)
         workflows_json=json.dumps(workflows_payload, ensure_ascii=False, indent=2)
     )
     )
     
     
-    try:
-        resp = await openrouter_llm_call(
-            messages=[{"role": "user", "content": prompt}],
-            model="anthropic/claude-sonnet-4.5",  # OpenRouter uses this to route to latest 3.5 Sonnet
-            max_tokens=4096,
-            temperature=0.1
-        )
-        content = resp["content"].strip()
-        if content.startswith("```json"):
-            content = content.replace("```json", "").replace("```", "").strip()
-        elif content.startswith("```"):
-            content = content.replace("```", "").strip()
+    for attempt in range(max_retries):
+        try:
+            resp = await openrouter_llm_call(
+                messages=[{"role": "user", "content": prompt}],
+                model="anthropic/claude-sonnet-4.5",  # OpenRouter uses this to route to latest 3.5 Sonnet
+                max_tokens=4096,
+                temperature=0.1
+            )
+            content = resp["content"].strip()
+            
+            # Extract JSON array using regex if there's surrounding text
+            json_match = re.search(r'\[.*\]', content, re.DOTALL)
+            if json_match:
+                content = json_match.group(0)
+                
+            if content.startswith("```json"):
+                content = content.replace("```json", "", 1).replace("```", "").strip()
+            elif content.startswith("```"):
+                content = content.replace("```", "", 1).replace("```", "").strip()
+                
+            parsed_json = json.loads(content)
+            
+            # Validation
+            if not isinstance(parsed_json, list):
+                raise ValueError("LLM response is not a JSON array.")
+            for item in parsed_json:
+                if "strategy_id" not in item or "coverage_score" not in item:
+                    raise ValueError("JSON array items missing required keys (strategy_id, coverage_score).")
+                    
+            return parsed_json
             
             
-        return json.loads(content)
-    except Exception as e:
-        print(f"  [Error] LLM Call failed for a requirement: {e}")
-        return []
+        except Exception as e:
+            print(f"  [Error] LLM Call failed for a requirement (Attempt {attempt+1}/{max_retries}): {e}")
+            if attempt < max_retries - 1:
+                await asyncio.sleep(2 ** attempt)  # Exponential backoff
+            else:
+                print(f"  [Fatal] Failed to evaluate requirement after {max_retries} attempts.")
+                return []
 
 
-async def main(dry_run: bool = False, force: bool = False):
+async def main(dry_run: bool = False, force: bool = False, retry_failed: bool = False):
     print("Connecting to DB...")
     print("Connecting to DB...")
     strat_store = PostgreSQLStrategyStore()
     strat_store = PostgreSQLStrategyStore()
     req_store = PostgreSQLRequirementStore()
     req_store = PostgreSQLRequirementStore()
@@ -100,10 +123,23 @@ async def main(dry_run: bool = False, force: bool = False):
             
             
     processed_req_ids = set(output_data.keys())
     processed_req_ids = set(output_data.keys())
     
     
+    failed_req_ids = set()
+    if FAILED_JSON.exists() and not force:
+        try:
+            with open(FAILED_JSON, "r", encoding="utf-8") as f:
+                failed_req_ids = set(json.load(f))
+            print(f"Loaded {len(failed_req_ids)} previously failed requirements.")
+        except:
+            print("Failed to load failed_requirements.json.")
+            
     total_reqs = len(output_data)
     total_reqs = len(output_data)
 
 
     # Filter out already processed requirements
     # Filter out already processed requirements
-    pending_requirements = [r for r in requirements if r["id"] not in processed_req_ids]
+    if retry_failed:
+        pending_requirements = [r for r in requirements if r["id"] in failed_req_ids and r["id"] not in processed_req_ids]
+        print("Retry-failed mode enabled. Only processing previously failed requirements.")
+    else:
+        pending_requirements = [r for r in requirements if r["id"] not in processed_req_ids]
     
     
     print(f"Starting LLM coverage semantic evaluation using Sonnet 4.5 via OpenRouter...")
     print(f"Starting LLM coverage semantic evaluation using Sonnet 4.5 via OpenRouter...")
     print(f"Total Requirements remaining to evaluate: {len(pending_requirements)} (out of {len(requirements)})")
     print(f"Total Requirements remaining to evaluate: {len(pending_requirements)} (out of {len(requirements)})")
@@ -138,8 +174,12 @@ async def main(dry_run: bool = False, force: bool = False):
         for idx, (req_id, req_desc, _) in enumerate(tasks):
         for idx, (req_id, req_desc, _) in enumerate(tasks):
             evaluations = results[idx]
             evaluations = results[idx]
             if not evaluations:
             if not evaluations:
+                failed_req_ids.add(req_id)
                 continue
                 continue
                 
                 
+            if req_id in failed_req_ids:
+                failed_req_ids.remove(req_id)
+                
             strat_results = []
             strat_results = []
             for ev in evaluations:
             for ev in evaluations:
                 sid = ev.get("strategy_id")
                 sid = ev.get("strategy_id")
@@ -197,14 +237,19 @@ async def main(dry_run: bool = False, force: bool = False):
         # Save incrementally after every batch to prevent data loss
         # Save incrementally after every batch to prevent data loss
         with open(OUTPUT_JSON, "w", encoding="utf-8") as f:
         with open(OUTPUT_JSON, "w", encoding="utf-8") as f:
             json.dump(output_data, f, ensure_ascii=False, indent=2)
             json.dump(output_data, f, ensure_ascii=False, indent=2)
+        with open(FAILED_JSON, "w", encoding="utf-8") as f:
+            json.dump(list(failed_req_ids), f, ensure_ascii=False, indent=2)
 
 
     print(f"Evaluated {total_reqs} requirements overall.")
     print(f"Evaluated {total_reqs} requirements overall.")
     print(f"Results {"simulated (DB untouched)" if dry_run else "and DB updates"} successfully saved to: {OUTPUT_JSON}")
     print(f"Results {"simulated (DB untouched)" if dry_run else "and DB updates"} successfully saved to: {OUTPUT_JSON}")
+    if failed_req_ids:
+        print(f"WARNING: {len(failed_req_ids)} requirements failed during evaluation. They have been saved to {FAILED_JSON}")
 
 
 if __name__ == "__main__":
 if __name__ == "__main__":
     import argparse
     import argparse
     parser = argparse.ArgumentParser()
     parser = argparse.ArgumentParser()
     parser.add_argument("--dry-run", action="store_true", help="Calculate scores and save to JSON only, do not write to DB")
     parser.add_argument("--dry-run", action="store_true", help="Calculate scores and save to JSON only, do not write to DB")
     parser.add_argument("--force", action="store_true", help="Discard existing JSON and rerun all requirements from scratch")
     parser.add_argument("--force", action="store_true", help="Discard existing JSON and rerun all requirements from scratch")
+    parser.add_argument("--retry-failed", action="store_true", help="Only retry requirements that are listed in the failed JSON file")
     args = parser.parse_args()
     args = parser.parse_args()
-    asyncio.run(main(args.dry_run, args.force))
+    asyncio.run(main(args.dry_run, args.force, args.retry_failed))

文件差异内容过多而无法显示
+ 408 - 271
examples/process_pipeline/script/coverage_scores.json


+ 1 - 0
examples/process_pipeline/script/failed_requirements.json

@@ -0,0 +1 @@
+[]

部分文件因为文件数量过多而无法显示