_batch_reeval_q0000.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141
  1. # -*- coding: utf-8 -*-
  2. """批量重评 q0000 下当前【命中(is_adopted)】的帖子,用 flash-lite+sonnet 组合(模糊带升级),
  3. 跑完定向替换 DB 的得分相关字段(overall_score / knowledge_type / llm_evaluation)。
  4. 先备份旧值到 runs/search_process/q0000.score_backup.<ts>.json,可回滚。"""
  5. import asyncio, copy, json, sys
  6. from datetime import datetime
  7. from pathlib import Path
  8. PROJECT_ROOT = Path(__file__).resolve().parents[3]
  9. sys.path.insert(0, str(PROJECT_ROOT))
  10. from dotenv import load_dotenv
  11. load_dotenv()
  12. MW = Path(__file__).resolve().parent
  13. sys.path.insert(0, str(MW))
  14. import db
  15. from examples.process_pipeline.script.search_eval.search_and_evaluate import evaluate_posts
  16. from examples.process_pipeline.script.llm_evaluate_sources import (
  17. _EVAL_PRODUCT_FIELDS, build_eval_llm_call,
  18. )
  19. QUERY_ID = "q0000"
  20. TABLE = "search_process"
  21. INIT_MODEL = "gemini-flash-lite"
  22. ESC_MODEL = "sonnet"
  23. BAND = (4.0, 6.0)
  24. def _load_db_rows():
  25. conn = db._conn()
  26. try:
  27. with conn.cursor() as c:
  28. c.execute(f"SELECT case_id, overall_score, knowledge_type, publish_time, "
  29. f"llm_evaluation FROM {TABLE} WHERE query_id=%s", (QUERY_ID,))
  30. return c.fetchall()
  31. finally:
  32. conn.close()
  33. def _update_scores(case_id, overall, knowledge_type, evaluation):
  34. conn = db._conn()
  35. try:
  36. with conn.cursor() as c:
  37. c.execute(
  38. f"UPDATE {TABLE} SET overall_score=%s, knowledge_type=%s, llm_evaluation=%s, "
  39. f"updated_at=CURRENT_TIMESTAMP WHERE query_id=%s AND case_id=%s",
  40. (overall, db._j(knowledge_type or []), db._j(evaluation), QUERY_ID, case_id))
  41. finally:
  42. conn.close()
  43. async def main():
  44. rows = _load_db_rows()
  45. def _ev(r):
  46. e = r["llm_evaluation"]
  47. return json.loads(e) if isinstance(e, str) else (e or {})
  48. adopted = [r for r in rows if db.is_adopted(r["overall_score"], _ev(r), r["publish_time"])]
  49. adopted_ids = {r["case_id"] for r in adopted}
  50. print(f"q0000 共 {len(rows)} 帖,当前命中 {len(adopted)} 帖 → 重评这些\n")
  51. # 备份旧得分字段
  52. ts = datetime.now().strftime("%Y%m%d_%H%M%S")
  53. backup = [{"case_id": r["case_id"], "overall_score": r["overall_score"],
  54. "knowledge_type": r["knowledge_type"], "publish_time": r["publish_time"],
  55. "llm_evaluation": _ev(r)} for r in adopted]
  56. bpath = MW / "runs" / TABLE / f"{QUERY_ID}.score_backup.{ts}.json"
  57. bpath.write_text(json.dumps(backup, ensure_ascii=False, indent=2), encoding="utf-8")
  58. print(f"💾 旧得分已备份 → {bpath.name}\n")
  59. # 从 runs json 取完整帖子(含配图)作为重评输入
  60. data = json.loads((MW / "runs" / TABLE / f"{QUERY_ID}.json").read_text(encoding="utf-8"))
  61. query = data.get("query", "")
  62. by_id = {s["case_id"]: s for s in data.get("results", [])}
  63. missing = [cid for cid in adopted_ids if cid not in by_id]
  64. if missing:
  65. print(f"⚠️ runs json 缺 {len(missing)} 条,将跳过: {missing}")
  66. targets = []
  67. for cid in adopted_ids:
  68. if cid not in by_id:
  69. continue
  70. s = copy.deepcopy(by_id[cid])
  71. for k in _EVAL_PRODUCT_FIELDS:
  72. s.pop(k, None)
  73. s.pop("_image_data_urls", None)
  74. targets.append(s)
  75. eval_llm, eval_model = build_eval_llm_call(INIT_MODEL)
  76. esc_llm, esc_model = build_eval_llm_call(ESC_MODEL)
  77. print(f"🧠 组合评估:{eval_model} 初评 → {esc_model} 复核(带 [{BAND[0]:g},{BAND[1]:g}])\n")
  78. sources, cost = await evaluate_posts(
  79. targets, "", eval_llm, eval_model, max_concurrent=4,
  80. include_images=True, max_images=4, image_mode="url", query=query,
  81. escalate_llm=esc_llm, escalate_model=esc_model, escalate_band=BAND)
  82. # 旧分查表
  83. old_by_id = {r["case_id"]: r for r in adopted}
  84. report = []
  85. for s in sources:
  86. cid = s["case_id"]
  87. ev = s["llm_evaluation"]
  88. if not isinstance(ev, dict) or ev.get("_error"):
  89. print(f" ⚠️ 评估失败,跳过更新: {cid}")
  90. continue
  91. kt = ev.get("知识类型") or []
  92. ov = db.overall_score(ev)
  93. pub = (s.get("post") or {}).get("publish_timestamp") or old_by_id[cid]["publish_time"]
  94. new_adopt = db.is_adopted(ov, ev, pub)
  95. _update_scores(cid, ov, kt, ev) # 定向替换 DB
  96. by_id[cid]["llm_evaluation"] = ev # 同步 runs json
  97. report.append({
  98. "case_id": cid, "escalated": bool(s.get("_escalated")),
  99. "old_overall": old_by_id[cid]["overall_score"], "new_overall": ov,
  100. "repro": db._fixed_dim_score(ev, "可复现性"),
  101. "intent": db._fixed_dim_score(ev, "意图可控性"),
  102. "new_adopted": new_adopt,
  103. "title": (s.get("post") or {}).get("title", "")[:22],
  104. })
  105. # 同步 runs json
  106. (MW / "runs" / TABLE / f"{QUERY_ID}.json").write_text(
  107. json.dumps(data, ensure_ascii=False, indent=2), encoding="utf-8")
  108. # 报告
  109. print("\n" + "=" * 92)
  110. print(f"{'case_id':26} {'升级':4} {'旧综':>5} {'新综':>5} {'复现':>4} {'意图':>4} {'命中':>5} 标题")
  111. still = 0
  112. for r in sorted(report, key=lambda x: x["new_overall"]):
  113. still += int(r["new_adopted"])
  114. print(f"{r['case_id'][:26]:26} {'★' if r['escalated'] else ' ':^4} "
  115. f"{(r['old_overall'] or 0):5.2f} {(r['new_overall'] or 0):5.2f} "
  116. f"{str(r['repro']):>4} {str(r['intent']):>4} "
  117. f"{'是' if r['new_adopted'] else '否':>4} {r['title']}")
  118. esc_n = sum(r["escalated"] for r in report)
  119. print("=" * 92)
  120. print(f"重评 {len(report)} 帖 · 升级 sonnet {esc_n} 帖 · 命中 {len(adopted)}→{still} · "
  121. f"总成本 ${cost:.4f}")
  122. print(f"DB 已更新,旧值备份在 {bpath.name}")
  123. if __name__ == "__main__":
  124. asyncio.run(main())