evaluate_existing_results.py 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299
  1. #!/usr/bin/env python3
  2. """
  3. 离线评估脚本 - 对已有的run_context.json中的帖子进行评估
  4. 功能:
  5. 1. 读取run_context.json
  6. 2. 提取所有搜索结果中的帖子
  7. 3. 批量评估(跳过已评估的)
  8. 4. 更新run_context.json
  9. 5. 生成评估报告
  10. """
  11. import asyncio
  12. import json
  13. import argparse
  14. from pathlib import Path
  15. from datetime import datetime
  16. from typing import Dict, List
  17. from pydantic import BaseModel, Field
  18. from post_evaluator import evaluate_all_posts, apply_evaluation_to_post, get_relevance_level
  19. # ============================================================================
  20. # 数据模型(与knowledge_search_traverse.py保持一致)
  21. # ============================================================================
  22. class Post(BaseModel):
  23. """帖子(需要与主程序的Post模型一致)"""
  24. title: str = ""
  25. body_text: str = ""
  26. type: str = "normal" # video/normal
  27. images: list[str] = Field(default_factory=list)
  28. video: str = ""
  29. interact_info: dict = Field(default_factory=dict)
  30. note_id: str = ""
  31. note_url: str = ""
  32. # 评估字段
  33. is_knowledge: bool | None = None
  34. knowledge_reason: str = ""
  35. relevance_score: float | None = None
  36. relevance_level: str = ""
  37. relevance_reason: str = ""
  38. evaluation_time: str = ""
  39. class Search(BaseModel):
  40. """搜索结果"""
  41. text: str
  42. score_with_o: float = 0.0
  43. reason: str = ""
  44. from_q: dict | None = None
  45. post_list: list[Post] = Field(default_factory=list)
  46. # ============================================================================
  47. # 辅助函数
  48. # ============================================================================
  49. def load_run_context(file_path: str) -> dict:
  50. """
  51. 加载run_context.json
  52. Args:
  53. file_path: run_context.json文件路径
  54. Returns:
  55. run_context字典
  56. """
  57. with open(file_path, 'r', encoding='utf-8') as f:
  58. return json.load(f)
  59. def save_run_context(file_path: str, run_context: dict):
  60. """
  61. 保存run_context.json
  62. Args:
  63. file_path: run_context.json文件路径
  64. run_context: run_context字典
  65. """
  66. with open(file_path, 'w', encoding='utf-8') as f:
  67. json.dump(run_context, f, ensure_ascii=False, indent=2)
  68. print(f"\n✅ 已保存到: {file_path}")
  69. def extract_all_posts(run_context: dict) -> List[Post]:
  70. """
  71. 从run_context中提取所有帖子
  72. Args:
  73. run_context: run_context字典
  74. Returns:
  75. 所有帖子的列表(Post对象)
  76. """
  77. all_posts = []
  78. post_id_set = set() # 用于去重
  79. # 遍历所有轮次
  80. for round_data in run_context.get("rounds", []):
  81. # 获取搜索结果
  82. search_results = round_data.get("search_results", [])
  83. for search_data in search_results:
  84. for post_data in search_data.get("post_list", []):
  85. note_id = post_data.get("note_id", "")
  86. # 去重
  87. if note_id and note_id not in post_id_set:
  88. post_id_set.add(note_id)
  89. post = Post(**post_data)
  90. all_posts.append(post)
  91. return all_posts
  92. def update_posts_in_context(run_context: dict, evaluated_posts: Dict[str, Post]):
  93. """
  94. 更新run_context中的帖子数据
  95. Args:
  96. run_context: run_context字典
  97. evaluated_posts: 已评估的帖子字典 {note_id: Post}
  98. """
  99. updated_count = 0
  100. # 遍历所有轮次
  101. for round_data in run_context.get("rounds", []):
  102. # 获取搜索结果
  103. search_results = round_data.get("search_results", [])
  104. for search_data in search_results:
  105. for i, post_data in enumerate(search_data.get("post_list", [])):
  106. note_id = post_data.get("note_id", "")
  107. if note_id in evaluated_posts:
  108. # 更新帖子数据
  109. evaluated_post = evaluated_posts[note_id]
  110. search_data["post_list"][i] = evaluated_post.dict()
  111. updated_count += 1
  112. print(f"✅ 更新了 {updated_count} 个帖子的评估数据")
  113. def generate_evaluation_report(posts: List[Post]):
  114. """
  115. 生成评估报告
  116. Args:
  117. posts: 帖子列表
  118. """
  119. print("\n" + "=" * 60)
  120. print("📊 评估报告")
  121. print("=" * 60)
  122. total = len(posts)
  123. evaluated = sum(1 for p in posts if p.is_knowledge is not None)
  124. unevaluated = total - evaluated
  125. print(f"\n总帖子数: {total}")
  126. print(f"已评估: {evaluated}")
  127. print(f"未评估: {unevaluated}")
  128. if evaluated == 0:
  129. print("\n⚠️ 没有评估数据")
  130. return
  131. # 知识判定统计
  132. knowledge_count = sum(1 for p in posts if p.is_knowledge is True)
  133. non_knowledge_count = sum(1 for p in posts if p.is_knowledge is False)
  134. print(f"\n--- 知识判定 ---")
  135. print(f"知识内容: {knowledge_count} ({knowledge_count/evaluated*100:.1f}%)")
  136. print(f"非知识内容: {non_knowledge_count} ({non_knowledge_count/evaluated*100:.1f}%)")
  137. # 相关性统计
  138. scores = [p.relevance_score for p in posts if p.relevance_score is not None]
  139. if scores:
  140. avg_score = sum(scores) / len(scores)
  141. max_score = max(scores)
  142. min_score = min(scores)
  143. high_count = sum(1 for p in posts if p.relevance_level == "高度相关")
  144. mid_count = sum(1 for p in posts if p.relevance_level == "中度相关")
  145. low_count = sum(1 for p in posts if p.relevance_level == "低度相关")
  146. print(f"\n--- 相关性评估 ---")
  147. print(f"平均得分: {avg_score:.2f}")
  148. print(f"最高得分: {max_score:.2f}")
  149. print(f"最低得分: {min_score:.2f}")
  150. print(f"\n高度相关: {high_count} ({high_count/evaluated*100:.1f}%)")
  151. print(f"中度相关: {mid_count} ({mid_count/evaluated*100:.1f}%)")
  152. print(f"低度相关: {low_count} ({low_count/evaluated*100:.1f}%)")
  153. # 高质量帖子(知识+高相关性)
  154. high_quality = sum(1 for p in posts
  155. if p.is_knowledge is True
  156. and p.relevance_level == "高度相关")
  157. print(f"\n--- 高质量帖子 ---")
  158. print(f"知识内容 + 高度相关: {high_quality} ({high_quality/evaluated*100:.1f}%)")
  159. print("\n" + "=" * 60)
  160. # ============================================================================
  161. # 主函数
  162. # ============================================================================
  163. async def main():
  164. parser = argparse.ArgumentParser(description='对已有run_context.json中的帖子进行离线评估')
  165. parser.add_argument(
  166. '--input',
  167. type=str,
  168. required=True,
  169. help='run_context.json文件路径'
  170. )
  171. parser.add_argument(
  172. '--query',
  173. type=str,
  174. default=None,
  175. help='原始问题(如果json中没有则必须指定)'
  176. )
  177. parser.add_argument(
  178. '--force',
  179. action='store_true',
  180. help='强制重新评估所有帖子(包括已评估的)'
  181. )
  182. parser.add_argument(
  183. '--report-only',
  184. action='store_true',
  185. help='仅生成评估报告,不执行新评估'
  186. )
  187. args = parser.parse_args()
  188. # 加载run_context
  189. print(f"📂 加载文件: {args.input}")
  190. run_context = load_run_context(args.input)
  191. # 获取原始问题
  192. original_query = args.query or run_context.get("o", "")
  193. if not original_query:
  194. print("❌ 错误: 无法获取原始问题,请使用 --query 参数指定")
  195. return
  196. print(f"❓ 原始问题: {original_query}")
  197. # 提取所有帖子
  198. all_posts = extract_all_posts(run_context)
  199. print(f"📋 提取到 {len(all_posts)} 个帖子")
  200. if len(all_posts) == 0:
  201. print("⚠️ 没有找到帖子数据")
  202. return
  203. # 仅生成报告模式
  204. if args.report_only:
  205. generate_evaluation_report(all_posts)
  206. return
  207. # 筛选需要评估的帖子
  208. if args.force:
  209. posts_to_evaluate = all_posts
  210. print(f"🔄 强制重新评估所有帖子")
  211. else:
  212. posts_to_evaluate = [p for p in all_posts if p.is_knowledge is None]
  213. print(f"🆕 发现 {len(posts_to_evaluate)} 个未评估的帖子")
  214. if len(posts_to_evaluate) == 0:
  215. print("✅ 所有帖子已评估,使用 --force 强制重新评估")
  216. generate_evaluation_report(all_posts)
  217. return
  218. # 批量评估
  219. print(f"\n🚀 开始评估...")
  220. evaluation_results = await evaluate_all_posts(posts_to_evaluate, original_query)
  221. # 应用评估结果到Post对象
  222. evaluated_posts = {}
  223. for post in posts_to_evaluate:
  224. if post.note_id in evaluation_results:
  225. apply_evaluation_to_post(post, evaluation_results[post.note_id])
  226. evaluated_posts[post.note_id] = post
  227. # 更新run_context
  228. update_posts_in_context(run_context, evaluated_posts)
  229. # 保存
  230. save_run_context(args.input, run_context)
  231. # 生成报告(使用更新后的all_posts)
  232. generate_evaluation_report(all_posts)
  233. if __name__ == "__main__":
  234. asyncio.run(main())