| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299 |
- #!/usr/bin/env python3
- """
- 离线评估脚本 - 对已有的run_context.json中的帖子进行评估
- 功能:
- 1. 读取run_context.json
- 2. 提取所有搜索结果中的帖子
- 3. 批量评估(跳过已评估的)
- 4. 更新run_context.json
- 5. 生成评估报告
- """
- import asyncio
- import json
- import argparse
- from pathlib import Path
- from datetime import datetime
- from typing import Dict, List
- from pydantic import BaseModel, Field
- from post_evaluator import evaluate_all_posts, apply_evaluation_to_post, get_relevance_level
- # ============================================================================
- # 数据模型(与knowledge_search_traverse.py保持一致)
- # ============================================================================
- class Post(BaseModel):
- """帖子(需要与主程序的Post模型一致)"""
- title: str = ""
- body_text: str = ""
- type: str = "normal" # video/normal
- images: list[str] = Field(default_factory=list)
- video: str = ""
- interact_info: dict = Field(default_factory=dict)
- note_id: str = ""
- note_url: str = ""
- # 评估字段
- is_knowledge: bool | None = None
- knowledge_reason: str = ""
- relevance_score: float | None = None
- relevance_level: str = ""
- relevance_reason: str = ""
- evaluation_time: str = ""
- class Search(BaseModel):
- """搜索结果"""
- text: str
- score_with_o: float = 0.0
- reason: str = ""
- from_q: dict | None = None
- post_list: list[Post] = Field(default_factory=list)
- # ============================================================================
- # 辅助函数
- # ============================================================================
- def load_run_context(file_path: str) -> dict:
- """
- 加载run_context.json
- Args:
- file_path: run_context.json文件路径
- Returns:
- run_context字典
- """
- with open(file_path, 'r', encoding='utf-8') as f:
- return json.load(f)
- def save_run_context(file_path: str, run_context: dict):
- """
- 保存run_context.json
- Args:
- file_path: run_context.json文件路径
- run_context: run_context字典
- """
- with open(file_path, 'w', encoding='utf-8') as f:
- json.dump(run_context, f, ensure_ascii=False, indent=2)
- print(f"\n✅ 已保存到: {file_path}")
- def extract_all_posts(run_context: dict) -> List[Post]:
- """
- 从run_context中提取所有帖子
- Args:
- run_context: run_context字典
- Returns:
- 所有帖子的列表(Post对象)
- """
- all_posts = []
- post_id_set = set() # 用于去重
- # 遍历所有轮次
- for round_data in run_context.get("rounds", []):
- # 获取搜索结果
- search_results = round_data.get("search_results", [])
- for search_data in search_results:
- for post_data in search_data.get("post_list", []):
- note_id = post_data.get("note_id", "")
- # 去重
- if note_id and note_id not in post_id_set:
- post_id_set.add(note_id)
- post = Post(**post_data)
- all_posts.append(post)
- return all_posts
- def update_posts_in_context(run_context: dict, evaluated_posts: Dict[str, Post]):
- """
- 更新run_context中的帖子数据
- Args:
- run_context: run_context字典
- evaluated_posts: 已评估的帖子字典 {note_id: Post}
- """
- updated_count = 0
- # 遍历所有轮次
- for round_data in run_context.get("rounds", []):
- # 获取搜索结果
- search_results = round_data.get("search_results", [])
- for search_data in search_results:
- for i, post_data in enumerate(search_data.get("post_list", [])):
- note_id = post_data.get("note_id", "")
- if note_id in evaluated_posts:
- # 更新帖子数据
- evaluated_post = evaluated_posts[note_id]
- search_data["post_list"][i] = evaluated_post.dict()
- updated_count += 1
- print(f"✅ 更新了 {updated_count} 个帖子的评估数据")
- def generate_evaluation_report(posts: List[Post]):
- """
- 生成评估报告
- Args:
- posts: 帖子列表
- """
- print("\n" + "=" * 60)
- print("📊 评估报告")
- print("=" * 60)
- total = len(posts)
- evaluated = sum(1 for p in posts if p.is_knowledge is not None)
- unevaluated = total - evaluated
- print(f"\n总帖子数: {total}")
- print(f"已评估: {evaluated}")
- print(f"未评估: {unevaluated}")
- if evaluated == 0:
- print("\n⚠️ 没有评估数据")
- return
- # 知识判定统计
- knowledge_count = sum(1 for p in posts if p.is_knowledge is True)
- non_knowledge_count = sum(1 for p in posts if p.is_knowledge is False)
- print(f"\n--- 知识判定 ---")
- print(f"知识内容: {knowledge_count} ({knowledge_count/evaluated*100:.1f}%)")
- print(f"非知识内容: {non_knowledge_count} ({non_knowledge_count/evaluated*100:.1f}%)")
- # 相关性统计
- scores = [p.relevance_score for p in posts if p.relevance_score is not None]
- if scores:
- avg_score = sum(scores) / len(scores)
- max_score = max(scores)
- min_score = min(scores)
- high_count = sum(1 for p in posts if p.relevance_level == "高度相关")
- mid_count = sum(1 for p in posts if p.relevance_level == "中度相关")
- low_count = sum(1 for p in posts if p.relevance_level == "低度相关")
- print(f"\n--- 相关性评估 ---")
- print(f"平均得分: {avg_score:.2f}")
- print(f"最高得分: {max_score:.2f}")
- print(f"最低得分: {min_score:.2f}")
- print(f"\n高度相关: {high_count} ({high_count/evaluated*100:.1f}%)")
- print(f"中度相关: {mid_count} ({mid_count/evaluated*100:.1f}%)")
- print(f"低度相关: {low_count} ({low_count/evaluated*100:.1f}%)")
- # 高质量帖子(知识+高相关性)
- high_quality = sum(1 for p in posts
- if p.is_knowledge is True
- and p.relevance_level == "高度相关")
- print(f"\n--- 高质量帖子 ---")
- print(f"知识内容 + 高度相关: {high_quality} ({high_quality/evaluated*100:.1f}%)")
- print("\n" + "=" * 60)
- # ============================================================================
- # 主函数
- # ============================================================================
- async def main():
- parser = argparse.ArgumentParser(description='对已有run_context.json中的帖子进行离线评估')
- parser.add_argument(
- '--input',
- type=str,
- required=True,
- help='run_context.json文件路径'
- )
- parser.add_argument(
- '--query',
- type=str,
- default=None,
- help='原始问题(如果json中没有则必须指定)'
- )
- parser.add_argument(
- '--force',
- action='store_true',
- help='强制重新评估所有帖子(包括已评估的)'
- )
- parser.add_argument(
- '--report-only',
- action='store_true',
- help='仅生成评估报告,不执行新评估'
- )
- args = parser.parse_args()
- # 加载run_context
- print(f"📂 加载文件: {args.input}")
- run_context = load_run_context(args.input)
- # 获取原始问题
- original_query = args.query or run_context.get("o", "")
- if not original_query:
- print("❌ 错误: 无法获取原始问题,请使用 --query 参数指定")
- return
- print(f"❓ 原始问题: {original_query}")
- # 提取所有帖子
- all_posts = extract_all_posts(run_context)
- print(f"📋 提取到 {len(all_posts)} 个帖子")
- if len(all_posts) == 0:
- print("⚠️ 没有找到帖子数据")
- return
- # 仅生成报告模式
- if args.report_only:
- generate_evaluation_report(all_posts)
- return
- # 筛选需要评估的帖子
- if args.force:
- posts_to_evaluate = all_posts
- print(f"🔄 强制重新评估所有帖子")
- else:
- posts_to_evaluate = [p for p in all_posts if p.is_knowledge is None]
- print(f"🆕 发现 {len(posts_to_evaluate)} 个未评估的帖子")
- if len(posts_to_evaluate) == 0:
- print("✅ 所有帖子已评估,使用 --force 强制重新评估")
- generate_evaluation_report(all_posts)
- return
- # 批量评估
- print(f"\n🚀 开始评估...")
- evaluation_results = await evaluate_all_posts(posts_to_evaluate, original_query)
- # 应用评估结果到Post对象
- evaluated_posts = {}
- for post in posts_to_evaluate:
- if post.note_id in evaluation_results:
- apply_evaluation_to_post(post, evaluation_results[post.note_id])
- evaluated_posts[post.note_id] = post
- # 更新run_context
- update_posts_in_context(run_context, evaluated_posts)
- # 保存
- save_run_context(args.input, run_context)
- # 生成报告(使用更新后的all_posts)
- generate_evaluation_report(all_posts)
- if __name__ == "__main__":
- asyncio.run(main())
|