""" 从 run_context_v3.json 中提取 topN 帖子并进行多模态解析 功能: 1. 读取 run_context_v3.json 2. 提取所有帖子,按 final_score 排序,取 topN 3. 使用 multimodal_extractor 进行图片内容解析 4. 保存结果到独立的 JSON 文件 参数化配置: - top_n: 提取前N个帖子(默认10) - max_concurrent: 最大并发数(默认5) """ import argparse import asyncio import json import os import sys from pathlib import Path from typing import Optional # 导入必要的模块 from knowledge_search_traverse import Post from multimodal_extractor import extract_all_posts def load_run_context(json_path: str) -> dict: """加载 run_context_v3.json 文件""" with open(json_path, 'r', encoding='utf-8') as f: return json.load(f) def extract_all_posts_from_context(context_data: dict) -> list[dict]: """从 context 数据中提取所有帖子(按note_id去重,保留得分最高的)""" # 使用字典进行去重,key为note_id posts_dict = {} # 遍历所有轮次 for round_data in context_data.get('rounds', []): # 遍历搜索结果 for search_result in round_data.get('search_results', []): # 遍历帖子列表 for post in search_result.get('post_list', []): note_id = post.get('note_id') if not note_id: continue # 如果是新帖子,直接添加 if note_id not in posts_dict: posts_dict[note_id] = post else: # 如果已存在,比较final_score,保留得分更高的 existing_score = posts_dict[note_id].get('final_score') current_score = post.get('final_score') # 如果当前帖子的分数更高,或者现有帖子没有分数,则替换 if existing_score is None or (current_score is not None and current_score > existing_score): posts_dict[note_id] = post # 返回去重后的帖子列表 return list(posts_dict.values()) def filter_and_sort_topn(posts: list[dict], top_n: int = 10) -> list[dict]: """过滤并排序,获取 final_score topN 的帖子""" # 过滤掉 final_score 为 null 的帖子 valid_posts = [p for p in posts if p.get('final_score') is not None] # 按 final_score 降序排序 sorted_posts = sorted(valid_posts, key=lambda x: x.get('final_score', 0), reverse=True) # 取前N个 topn = sorted_posts[:top_n] return topn def convert_to_post_objects(post_dicts: list[dict]) -> list[Post]: """将字典数据转换为 Post 对象""" post_objects = [] for post_dict in post_dicts: # 创建 Post 对象,设置默认 type="normal" post = Post( note_id=post_dict.get('note_id', ''), note_url=post_dict.get('note_url', ''), title=post_dict.get('title', ''), body_text=post_dict.get('body_text', ''), type='normal', # 默认值,因为原数据缺少此字段 images=post_dict.get('images', []), video=post_dict.get('video', ''), interact_info=post_dict.get('interact_info', {}), ) post_objects.append(post) return post_objects def save_extraction_results(results: dict, output_path: str, topn_posts: list[dict]): """保存多模态解析结果到 JSON 文件""" # 构建输出数据 output_data = { 'total_extracted': len(results), 'extraction_results': [] } # 遍历每个解析结果 for note_id, extraction in results.items(): # 找到对应的原始帖子数据 original_post = None for post in topn_posts: if post.get('note_id') == note_id: original_post = post break # 构建结果条目 result_entry = { 'note_id': extraction.note_id, 'note_url': extraction.note_url, 'title': extraction.title, 'body_text': extraction.body_text, 'type': extraction.type, 'extraction_time': extraction.extraction_time, 'final_score': original_post.get('final_score') if original_post else None, 'images': [ { 'image_index': img.image_index, 'original_url': img.original_url, 'description': img.description, 'extract_text': img.extract_text } for img in extraction.images ] } output_data['extraction_results'].append(result_entry) # 保存到文件 with open(output_path, 'w', encoding='utf-8') as f: json.dump(output_data, f, ensure_ascii=False, indent=2) print(f"\n✅ 结果已保存到: {output_path}") async def main(context_file_path: str, output_file_path: str, top_n: int = 10, max_concurrent: int = 5): """主函数 Args: context_file_path: run_context_v3.json 文件路径 output_file_path: 输出文件路径 top_n: 提取前N个帖子(默认10) max_concurrent: 最大并发数(默认5) """ print("=" * 80) print(f"多模态解析 - Top{top_n} 帖子") print("=" * 80) # 1. 加载数据 print(f"\n📂 加载文件: {context_file_path}") context_data = load_run_context(context_file_path) # 2. 提取所有帖子 print(f"\n🔍 提取所有帖子...") all_posts = extract_all_posts_from_context(context_data) print(f" 去重后共找到 {len(all_posts)} 个唯一帖子") # 3. 过滤并排序获取 topN print(f"\n📊 筛选 top{top_n} 帖子...") topn_posts = filter_and_sort_topn(all_posts, top_n) if len(topn_posts) == 0: print(" ⚠️ 没有找到有效的帖子") return print(f" Top{top_n} 帖子得分范围: {topn_posts[-1].get('final_score')} ~ {topn_posts[0].get('final_score')}") # 打印 topN 列表 print(f"\n Top{top_n} 帖子列表:") for i, post in enumerate(topn_posts, 1): print(f" {i}. [{post.get('final_score')}] {post.get('title')[:40]}... ({post.get('note_id')})") # 4. 转换为 Post 对象 print(f"\n🔄 转换为 Post 对象...") post_objects = convert_to_post_objects(topn_posts) print(f" 成功转换 {len(post_objects)} 个 Post 对象") # 5. 进行多模态解析 print(f"\n🖼️ 开始多模态图片内容解析...") print(f" (并发限制: {max_concurrent})") extraction_results = await extract_all_posts( post_objects, max_concurrent=max_concurrent ) # 6. 保存结果 print(f"\n💾 保存解析结果...") save_extraction_results(extraction_results, output_file_path, topn_posts) print("\n" + "=" * 80) print("✅ 处理完成!") print("=" * 80) if __name__ == "__main__": # 创建命令行参数解析器 parser = argparse.ArgumentParser( description='从 run_context_v3.json 中提取 topN 帖子并进行多模态解析', formatter_class=argparse.RawDescriptionHelpFormatter, epilog=''' 示例用法: # 使用默认参数 (top10, 并发5) python3 extract_topn_multimodal.py # 提取前20个帖子 python3 extract_topn_multimodal.py --top-n 20 # 自定义并发数 python3 extract_topn_multimodal.py --top-n 15 --max-concurrent 10 # 指定输入输出文件 python3 extract_topn_multimodal.py -i input.json -o output.json --top-n 30 ''' ) # 默认路径配置 DEFAULT_CONTEXT_FILE = "input/test_case/output/knowledge_search_traverse/20251114/005215_b1/run_context_v3.json" DEFAULT_OUTPUT_FILE = "input/test_case/output/knowledge_search_traverse/20251114/005215_b1/multimodal_extraction_topn.json" # 添加参数 parser.add_argument( '-i', '--input', dest='context_file', default=DEFAULT_CONTEXT_FILE, help=f'输入的 run_context_v3.json 文件路径 (默认: {DEFAULT_CONTEXT_FILE})' ) parser.add_argument( '-o', '--output', dest='output_file', default=DEFAULT_OUTPUT_FILE, help=f'输出的 JSON 文件路径 (默认: {DEFAULT_OUTPUT_FILE})' ) parser.add_argument( '-n', '--top-n', dest='top_n', type=int, default=10, help='提取前N个帖子 (默认: 10)' ) parser.add_argument( '-c', '--max-concurrent', dest='max_concurrent', type=int, default=5, help='最大并发数 (默认: 5)' ) # 解析参数 args = parser.parse_args() # 检查文件是否存在 if not os.path.exists(args.context_file): print(f"❌ 错误: 文件不存在 - {args.context_file}") sys.exit(1) # 打印参数配置 print(f"\n📋 参数配置:") print(f" 输入文件: {args.context_file}") print(f" 输出文件: {args.output_file}") print(f" 提取数量: Top{args.top_n}") print(f" 最大并发: {args.max_concurrent}") print() # 运行主函数 asyncio.run(main( args.context_file, args.output_file, args.top_n, args.max_concurrent ))