test_evaluation_v2.py 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245
  1. """
  2. 测试评估V2模块
  3. 从现有run_context.json读取帖子,使用V2评估模块重新评估,生成统计报告
  4. """
  5. import asyncio
  6. import json
  7. import sys
  8. from pathlib import Path
  9. from datetime import datetime
  10. from collections import defaultdict
  11. # 导入必要的模块
  12. from knowledge_search_traverse import Post
  13. from post_evaluator_v2 import evaluate_post_v2, apply_evaluation_v2_to_post
  14. async def test_evaluation_v2(run_context_path: str, max_posts: int = 10):
  15. """
  16. 测试V2评估模块
  17. Args:
  18. run_context_path: run_context.json路径
  19. max_posts: 最多评估的帖子数量(用于快速测试)
  20. """
  21. print(f"\n{'='*80}")
  22. print(f"📊 评估V2测试 - {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
  23. print(f"{'='*80}\n")
  24. # 读取run_context.json
  25. print(f"📂 读取: {run_context_path}")
  26. with open(run_context_path, 'r', encoding='utf-8') as f:
  27. run_context = json.load(f)
  28. # 提取原始query
  29. original_query = run_context.get('o', '')
  30. print(f"🔍 原始Query: {original_query}\n")
  31. # 提取所有帖子 (从rounds -> search_results -> post_list)
  32. post_data_list = []
  33. rounds = run_context.get('rounds', [])
  34. for round_idx, round_data in enumerate(rounds):
  35. search_results = round_data.get('search_results', [])
  36. for search_idx, search in enumerate(search_results):
  37. post_list = search.get('post_list', [])
  38. for post_idx, post_data in enumerate(post_list):
  39. # 生成唯一ID
  40. post_id = f"r{round_idx}_s{search_idx}_p{post_idx}"
  41. post_data_list.append((round_idx, search_idx, post_id, post_data))
  42. total_posts = len(post_data_list)
  43. print(f"📝 找到 {total_posts} 个帖子 (来自 {len(rounds)} 轮)")
  44. # 限制评估数量(快速测试)
  45. if max_posts and max_posts < total_posts:
  46. post_data_list = post_data_list[:max_posts]
  47. print(f"⚡ 快速测试模式: 仅评估前 {max_posts} 个帖子\n")
  48. else:
  49. print()
  50. # 将post_data转换为Post对象
  51. posts = []
  52. for round_idx, search_idx, post_id, post_data in post_data_list:
  53. post = Post(
  54. note_id=post_data.get('note_id', post_id),
  55. title=post_data.get('title', ''),
  56. body_text=post_data.get('body_text', ''),
  57. images=post_data.get('images', []),
  58. type=post_data.get('type', 'normal')
  59. )
  60. posts.append((round_idx, search_idx, post_id, post))
  61. # 批量评估
  62. print(f"🚀 开始批量评估 (并发数: 5)...\n")
  63. semaphore = asyncio.Semaphore(5)
  64. tasks = []
  65. for round_idx, search_idx, post_id, post in posts:
  66. task = evaluate_post_v2(post, original_query, semaphore)
  67. tasks.append((round_idx, search_idx, post_id, post, task))
  68. results = []
  69. for i, (round_idx, search_idx, post_id, post, task) in enumerate(tasks, 1):
  70. print(f" [{i}/{len(tasks)}] 评估: {post.note_id}")
  71. knowledge_eval, relevance_eval = await task
  72. if knowledge_eval and relevance_eval:
  73. apply_evaluation_v2_to_post(post, knowledge_eval, relevance_eval)
  74. results.append((round_idx, search_idx, post_id, post, knowledge_eval, relevance_eval))
  75. print(f" ✅ 知识:{post.knowledge_score:.0f}分({post.knowledge_level}⭐) | 相关:{post.relevance_score:.0f}分({post.relevance_conclusion})")
  76. else:
  77. print(f" ❌ 评估失败")
  78. print(f"\n✅ 评估完成: {len(results)}/{len(posts)} 成功\n")
  79. # 更新run_context.json中的帖子数据
  80. print("💾 更新 run_context.json...")
  81. for round_idx, search_idx, post_id, post, knowledge_eval, relevance_eval in results:
  82. # 定位到对应的post_list
  83. if round_idx < len(rounds):
  84. search_results = rounds[round_idx].get('search_results', [])
  85. if search_idx < len(search_results):
  86. post_list = search_results[search_idx].get('post_list', [])
  87. # 找到对应的帖子并更新
  88. for p in post_list:
  89. if p.get('note_id') == post.note_id:
  90. # 更新顶层字段
  91. p['is_knowledge'] = post.is_knowledge
  92. p['knowledge_reason'] = post.knowledge_reason
  93. p['knowledge_score'] = post.knowledge_score
  94. p['knowledge_level'] = post.knowledge_level
  95. p['relevance_score'] = post.relevance_score
  96. p['relevance_level'] = post.relevance_level
  97. p['relevance_reason'] = post.relevance_reason
  98. p['relevance_conclusion'] = post.relevance_conclusion
  99. p['evaluation_time'] = post.evaluation_time
  100. p['evaluator_version'] = post.evaluator_version
  101. # 更新嵌套字段
  102. p['knowledge_evaluation'] = post.knowledge_evaluation
  103. p['relevance_evaluation'] = post.relevance_evaluation
  104. break
  105. # 保存更新后的run_context.json
  106. output_path = run_context_path.replace('.json', '_v2.json')
  107. with open(output_path, 'w', encoding='utf-8') as f:
  108. json.dump(run_context, f, ensure_ascii=False, indent=2)
  109. print(f"✅ 已保存: {output_path}\n")
  110. # 生成统计报告
  111. print(f"\n{'='*80}")
  112. print("📊 统计报告")
  113. print(f"{'='*80}\n")
  114. # 知识评估统计
  115. knowledge_counts = defaultdict(int)
  116. knowledge_level_counts = defaultdict(int)
  117. knowledge_scores = []
  118. for _, _, _, post, _, _ in results:
  119. if post.is_knowledge:
  120. knowledge_counts['知识内容'] += 1
  121. else:
  122. knowledge_counts['非知识内容'] += 1
  123. if post.knowledge_level:
  124. knowledge_level_counts[post.knowledge_level] += 1
  125. if post.knowledge_score is not None:
  126. knowledge_scores.append(post.knowledge_score)
  127. total = len(results)
  128. print("📚 知识评估:")
  129. print(f" 知识内容: {knowledge_counts['知识内容']:3d} / {total} ({knowledge_counts['知识内容']/total*100:.1f}%)")
  130. print(f" 非知识内容: {knowledge_counts['非知识内容']:3d} / {total} ({knowledge_counts['非知识内容']/total*100:.1f}%)")
  131. print()
  132. if knowledge_scores:
  133. avg_score = sum(knowledge_scores) / len(knowledge_scores)
  134. print(f" 平均得分: {avg_score:.1f}分")
  135. print(f" 最高得分: {max(knowledge_scores):.0f}分")
  136. print(f" 最低得分: {min(knowledge_scores):.0f}分")
  137. print()
  138. print(" 星级分布:")
  139. for level in range(1, 6):
  140. count = knowledge_level_counts.get(level, 0)
  141. bar = '★' * count
  142. print(f" {level}星: {count:3d} {bar}")
  143. print()
  144. # 相关性评估统计
  145. relevance_conclusion_counts = defaultdict(int)
  146. relevance_scores = []
  147. purpose_scores = []
  148. category_scores = []
  149. for _, _, _, post, _, _ in results:
  150. if post.relevance_conclusion:
  151. relevance_conclusion_counts[post.relevance_conclusion] += 1
  152. if post.relevance_score is not None:
  153. relevance_scores.append(post.relevance_score)
  154. if post.relevance_evaluation:
  155. if 'purpose_score' in post.relevance_evaluation:
  156. purpose_scores.append(post.relevance_evaluation['purpose_score'])
  157. if 'category_score' in post.relevance_evaluation:
  158. category_scores.append(post.relevance_evaluation['category_score'])
  159. print("🎯 相关性评估:")
  160. for conclusion in ['高度匹配', '中度匹配', '低度匹配', '不匹配']:
  161. count = relevance_conclusion_counts.get(conclusion, 0)
  162. if count > 0:
  163. print(f" {conclusion}: {count:3d} / {total} ({count/total*100:.1f}%)")
  164. print()
  165. if relevance_scores:
  166. avg_score = sum(relevance_scores) / len(relevance_scores)
  167. high_relevance = sum(1 for s in relevance_scores if s >= 70)
  168. print(f" 平均得分: {avg_score:.1f}分")
  169. print(f" 高相关性: {high_relevance} / {total} ({high_relevance/total*100:.1f}%) [≥70分]")
  170. print(f" 最高得分: {max(relevance_scores):.0f}分")
  171. print(f" 最低得分: {min(relevance_scores):.0f}分")
  172. print()
  173. if purpose_scores and category_scores:
  174. avg_purpose = sum(purpose_scores) / len(purpose_scores)
  175. avg_category = sum(category_scores) / len(category_scores)
  176. print(f" 目的性平均: {avg_purpose:.1f}分 (权重70%)")
  177. print(f" 品类平均: {avg_category:.1f}分 (权重30%)")
  178. print()
  179. # 综合分析
  180. print("🔥 高质量内容 (知识内容 + 高相关性):")
  181. high_quality = sum(
  182. 1 for _, _, _, post, _, _ in results
  183. if post.is_knowledge and post.relevance_score and post.relevance_score >= 70
  184. )
  185. print(f" {high_quality} / {total} ({high_quality/total*100:.1f}%)")
  186. print()
  187. print(f"{'='*80}\n")
  188. return results
  189. if __name__ == "__main__":
  190. if len(sys.argv) < 2:
  191. print("用法: python3 test_evaluation_v2.py <run_context.json路径> [最大评估数量]")
  192. print()
  193. print("示例:")
  194. print(" python3 test_evaluation_v2.py input/test_case/output/knowledge_search_traverse/20251112/173512_dc/run_context.json")
  195. print(" python3 test_evaluation_v2.py input/test_case/output/knowledge_search_traverse/20251112/173512_dc/run_context.json 20")
  196. sys.exit(1)
  197. run_context_path = sys.argv[1]
  198. max_posts = int(sys.argv[2]) if len(sys.argv) > 2 else None
  199. asyncio.run(test_evaluation_v2(run_context_path, max_posts))