eval_compare.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
  1. # -*- coding: utf-8 -*-
  2. """一次性:用当前 eval_prompt_template.md(新 prompt)对单帖重评,与库里旧评估对比打分。
  3. 用法: python eval_compare.py <query_id> <case_id>
  4. """
  5. import argparse
  6. import asyncio
  7. import json
  8. import sys
  9. from pathlib import Path
  10. PROJECT_ROOT = Path(__file__).resolve().parents[2] # …/Agent
  11. sys.path.insert(0, str(PROJECT_ROOT))
  12. from dotenv import load_dotenv
  13. load_dotenv()
  14. HERE = Path(__file__).resolve().parent
  15. sys.path.insert(0, str(HERE))
  16. import db
  17. from examples.process_pipeline.script.search_eval.search_and_evaluate import _attach_image_refs
  18. from examples.process_pipeline.script.llm_evaluate_sources import (
  19. _evaluate_one, build_eval_llm_call, DEFAULT_EVAL_MODEL,
  20. )
  21. def _row_to_source(row):
  22. return {
  23. "case_id": row["case_id"], "platform": row["platform"],
  24. "channel_content_id": row["channel_content_id"], "source_url": row["url"],
  25. "post": {
  26. "title": row["title"], "body_text": row["body"],
  27. "images": row["images"] or [], "like_count": row["like_count"],
  28. "publish_timestamp": row["publish_time"], "link": row["url"],
  29. },
  30. }
  31. def flatten_scores(blob, prefix=""):
  32. """blob → {dotted_path: 得分}。只收叶子 {得分:...} 节点。"""
  33. out = {}
  34. if not isinstance(blob, dict):
  35. return out
  36. if "得分" in blob:
  37. out[prefix.rstrip(".")] = blob.get("得分")
  38. return out
  39. for k, v in blob.items():
  40. if isinstance(v, dict):
  41. out.update(flatten_scores(v, f"{prefix}{k}."))
  42. return out
  43. async def main():
  44. ap = argparse.ArgumentParser()
  45. ap.add_argument("query_id")
  46. ap.add_argument("case_id")
  47. ap.add_argument("--model", default=DEFAULT_EVAL_MODEL)
  48. ap.add_argument("--max-images", type=int, default=4)
  49. args = ap.parse_args()
  50. row = db.fetch_post(args.query_id, args.case_id, table="search_process")
  51. if not row:
  52. print(f"❌ {args.query_id}/{args.case_id} 不在 search_process"); return 1
  53. old_blob = row.get("llm_evaluation") or {}
  54. src = _row_to_source(row)
  55. await _attach_image_refs([src], args.max_images, 8, "url")
  56. n_img = len(src.get("_image_data_urls") or [])
  57. print(f"📄 {args.case_id} | {(row['title'] or '')[:40]} | 配图 {n_img} 张 | 模型 {args.model}")
  58. print(f"🔍 检索词: {row['query_text']}\n")
  59. eval_llm, model_id = build_eval_llm_call(args.model)
  60. sem = asyncio.Semaphore(1)
  61. new_blob, cost = await _evaluate_one(
  62. src, "", eval_llm, model_id, sem,
  63. image_urls=src.get("_image_data_urls"), query=row["query_text"])
  64. if new_blob is None:
  65. print("❌ 新评估失败(重试耗尽)"); return 1
  66. old_f = flatten_scores(old_blob)
  67. new_f = flatten_scores(new_blob)
  68. keys = sorted(set(old_f) | set(new_f))
  69. print(f"{'维度路径':<46} {'旧分':>6} {'新分':>6} 变化")
  70. print("─" * 72)
  71. for k in keys:
  72. o, n = old_f.get(k), new_f.get(k)
  73. mark = ""
  74. try:
  75. if o is not None and n is not None and float(o) != float(n):
  76. mark = f" {float(o):g}→{float(n):g}"
  77. except (TypeError, ValueError):
  78. pass
  79. only = "" if (k in old_f and k in new_f) else (" (旧无)" if k not in old_f else " (新无)")
  80. print(f"{k:<46} {str(o) if o is not None else '-':>6} {str(n) if n is not None else '-':>6}{mark}{only}")
  81. print("─" * 72)
  82. o_overall, n_overall = db.overall_score(old_blob), db.overall_score(new_blob)
  83. o_adopt = db.is_adopted(o_overall, old_blob, row["publish_time"])
  84. n_adopt = db.is_adopted(n_overall, new_blob, row["publish_time"])
  85. print(f"{'overall_score':<46} {str(o_overall):>6} {str(n_overall):>6}")
  86. print(f"{'知识类型':<46} {str(old_blob.get('知识类型')):>6} | {new_blob.get('知识类型')}")
  87. print(f"{'是否采纳':<46} {str(o_adopt):>6} {str(n_adopt):>6}")
  88. print(f"\n💲 本次重评成本 ${cost:.4f}")
  89. # 落盘完整新 blob,便于细看理由
  90. out = HERE / "runs" / f"eval_compare_{args.case_id}.json"
  91. out.write_text(json.dumps({"old": old_blob, "new": new_blob}, ensure_ascii=False, indent=2),
  92. encoding="utf-8")
  93. print(f"📝 完整新旧 blob(含理由): {out}")
  94. return 0
  95. if __name__ == "__main__":
  96. raise SystemExit(asyncio.run(main()))