_reeval_one.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. # -*- coding: utf-8 -*-
  2. """一次性:用当前 eval_prompt_template.md 对单条已存帖子重评(复用生产评估链路 evaluate_posts)。
  3. 支持 --escalate-model 演示 sonnet+flash-lite 组合(模糊带升级)。"""
  4. import argparse, asyncio, json, sys
  5. from datetime import datetime
  6. from pathlib import Path
  7. PROJECT_ROOT = Path(__file__).resolve().parents[3] # …/Agent
  8. sys.path.insert(0, str(PROJECT_ROOT))
  9. from dotenv import load_dotenv
  10. load_dotenv()
  11. MW = Path(__file__).resolve().parent
  12. sys.path.insert(0, str(MW))
  13. import db
  14. from examples.process_pipeline.script.search_eval.search_and_evaluate import evaluate_posts
  15. from examples.process_pipeline.script.llm_evaluate_sources import (
  16. _EVAL_PRODUCT_FIELDS, build_eval_llm_call, DEFAULT_EVAL_MODEL,
  17. )
  18. def _load(query_id):
  19. return json.loads((MW / "runs" / "search_process" / f"{query_id}.json")
  20. .read_text(encoding="utf-8"))
  21. def _save(query_id, data):
  22. (MW / "runs" / "search_process" / f"{query_id}.json").write_text(
  23. json.dumps(data, ensure_ascii=False, indent=2), encoding="utf-8")
  24. async def main():
  25. ap = argparse.ArgumentParser()
  26. ap.add_argument("--query-id", required=True)
  27. ap.add_argument("--case-id", required=True)
  28. ap.add_argument("--query", default="")
  29. ap.add_argument("--model", default=DEFAULT_EVAL_MODEL)
  30. ap.add_argument("--escalate-model", default="")
  31. ap.add_argument("--escalate-band", type=float, nargs=2, default=[4.0, 6.0])
  32. ap.add_argument("--max-images", type=int, default=4)
  33. ap.add_argument("--persist", action="store_true",
  34. help="把新评估写回 DB(overall_score/knowledge_type/llm_evaluation),落库前先备份旧值")
  35. a = ap.parse_args()
  36. data = _load(a.query_id)
  37. query = a.query or data.get("query", "")
  38. src = next((s for s in data.get("results", []) if s.get("case_id") == a.case_id), None)
  39. if not src:
  40. raise SystemExit(f"未找到 case_id={a.case_id}")
  41. for k in _EVAL_PRODUCT_FIELDS:
  42. src.pop(k, None)
  43. llm_call, model_id = build_eval_llm_call(a.model)
  44. esc_llm = esc_model = None
  45. if a.escalate_model:
  46. esc_llm, esc_model = build_eval_llm_call(a.escalate_model)
  47. print(f"▶ 重评 {a.case_id} 初评={model_id}"
  48. + (f" 升级={esc_model} 带[{a.escalate_band[0]:g},{a.escalate_band[1]:g}]" if esc_model else "")
  49. + f" query={query!r}\n")
  50. sources, cost = await evaluate_posts(
  51. [src], "", llm_call, model_id, max_concurrent=1,
  52. include_images=True, max_images=a.max_images, image_mode="url", query=query,
  53. escalate_llm=esc_llm, escalate_model=esc_model, escalate_band=tuple(a.escalate_band),
  54. )
  55. ev = sources[0]["llm_evaluation"]
  56. overall = db.overall_score(ev)
  57. pub = (src.get("post") or {}).get("publish_timestamp", "")
  58. adopted = db.is_adopted(overall, ev, pub)
  59. print("\n" + "=" * 60)
  60. print(f"最终评估模型 = {sources[0].get('_escalated') or model_id}")
  61. print(f"综合分(overall_score) = {overall}")
  62. print(f" · 和内容制作知识相关 = {((ev.get('相关性') or {}).get('和内容制作知识相关') or {}).get('得分')}")
  63. print(f" · 实现完整性/可复现门槛 = {db._repro_score(ev)} (门槛 <4 → 不采纳)")
  64. print(f" · 意图可控性 = {db._fixed_dim_score(ev, '意图可控性')} (暂只采分)")
  65. print(f"采纳判定(is_adopted) = {adopted}")
  66. print(f"总成本 ≈ ${cost:.4f}")
  67. if a.persist:
  68. if not isinstance(ev, dict) or ev.get("_error"):
  69. raise SystemExit("评估结果异常(_error),拒绝落库")
  70. # 1) 备份旧 DB 行(overall_score/knowledge_type/llm_evaluation/publish_time)
  71. old = next((p for p in db.fetch_posts(a.query_id, "process")
  72. if p["case_id"] == a.case_id), None)
  73. if old is None:
  74. raise SystemExit(f"DB 无此行,无法落库: query={a.query_id} case={a.case_id}")
  75. ts = datetime.now().strftime("%Y%m%d_%H%M%S")
  76. bpath = (MW / "runs" / "search_process"
  77. / f"{a.query_id}.{a.case_id}.score_backup.{ts}.json")
  78. bpath.write_text(json.dumps({
  79. "query_id": a.query_id, "case_id": a.case_id,
  80. "old_overall_score": old.get("overall_score"),
  81. "old_knowledge_type": old.get("knowledge_type"),
  82. "old_llm_evaluation": old.get("llm_evaluation"),
  83. "old_adopted": old.get("adopted"),
  84. }, ensure_ascii=False, indent=2), encoding="utf-8")
  85. # 2) 写回 DB(派生列 overall_score/knowledge_type 由 update_post_eval 重算)
  86. n = db.update_post_eval(a.query_id, a.case_id, ev, table="search_process")
  87. # 3) 同步 runs json,保持后续重评输入一致
  88. src["llm_evaluation"] = ev
  89. _save(a.query_id, data)
  90. print(f"\n💾 旧值已备份 → {bpath.name}")
  91. print(f"✅ DB 已更新 {n} 行(overall={overall} 采纳={adopted})")
  92. if __name__ == "__main__":
  93. asyncio.run(main())