|
|
@@ -0,0 +1,281 @@
|
|
|
+"""
|
|
|
+从 run_context_v3.json 中提取 topN 帖子并进行多模态解析
|
|
|
+
|
|
|
+功能:
|
|
|
+1. 读取 run_context_v3.json
|
|
|
+2. 提取所有帖子,按 final_score 排序,取 topN
|
|
|
+3. 使用 multimodal_extractor 进行图片内容解析
|
|
|
+4. 保存结果到独立的 JSON 文件
|
|
|
+
|
|
|
+参数化配置:
|
|
|
+- top_n: 提取前N个帖子(默认10)
|
|
|
+- max_concurrent: 最大并发数(默认5)
|
|
|
+"""
|
|
|
+
|
|
|
+import argparse
|
|
|
+import asyncio
|
|
|
+import json
|
|
|
+import os
|
|
|
+import sys
|
|
|
+from pathlib import Path
|
|
|
+from typing import Optional
|
|
|
+
|
|
|
+# 导入必要的模块
|
|
|
+from knowledge_search_traverse import Post
|
|
|
+from multimodal_extractor import extract_all_posts
|
|
|
+
|
|
|
+
|
|
|
+def load_run_context(json_path: str) -> dict:
|
|
|
+ """加载 run_context_v3.json 文件"""
|
|
|
+ with open(json_path, 'r', encoding='utf-8') as f:
|
|
|
+ return json.load(f)
|
|
|
+
|
|
|
+
|
|
|
+def extract_all_posts_from_context(context_data: dict) -> list[dict]:
|
|
|
+ """从 context 数据中提取所有帖子(按note_id去重,保留得分最高的)"""
|
|
|
+ # 使用字典进行去重,key为note_id
|
|
|
+ posts_dict = {}
|
|
|
+
|
|
|
+ # 遍历所有轮次
|
|
|
+ for round_data in context_data.get('rounds', []):
|
|
|
+ # 遍历搜索结果
|
|
|
+ for search_result in round_data.get('search_results', []):
|
|
|
+ # 遍历帖子列表
|
|
|
+ for post in search_result.get('post_list', []):
|
|
|
+ note_id = post.get('note_id')
|
|
|
+ if not note_id:
|
|
|
+ continue
|
|
|
+
|
|
|
+ # 如果是新帖子,直接添加
|
|
|
+ if note_id not in posts_dict:
|
|
|
+ posts_dict[note_id] = post
|
|
|
+ else:
|
|
|
+ # 如果已存在,比较final_score,保留得分更高的
|
|
|
+ existing_score = posts_dict[note_id].get('final_score')
|
|
|
+ current_score = post.get('final_score')
|
|
|
+
|
|
|
+ # 如果当前帖子的分数更高,或者现有帖子没有分数,则替换
|
|
|
+ if existing_score is None or (current_score is not None and current_score > existing_score):
|
|
|
+ posts_dict[note_id] = post
|
|
|
+
|
|
|
+ # 返回去重后的帖子列表
|
|
|
+ return list(posts_dict.values())
|
|
|
+
|
|
|
+
|
|
|
+def filter_and_sort_topn(posts: list[dict], top_n: int = 10) -> list[dict]:
|
|
|
+ """过滤并排序,获取 final_score topN 的帖子"""
|
|
|
+ # 过滤掉 final_score 为 null 的帖子
|
|
|
+ valid_posts = [p for p in posts if p.get('final_score') is not None]
|
|
|
+
|
|
|
+ # 按 final_score 降序排序
|
|
|
+ sorted_posts = sorted(valid_posts, key=lambda x: x.get('final_score', 0), reverse=True)
|
|
|
+
|
|
|
+ # 取前N个
|
|
|
+ topn = sorted_posts[:top_n]
|
|
|
+
|
|
|
+ return topn
|
|
|
+
|
|
|
+
|
|
|
+def convert_to_post_objects(post_dicts: list[dict]) -> list[Post]:
|
|
|
+ """将字典数据转换为 Post 对象"""
|
|
|
+ post_objects = []
|
|
|
+
|
|
|
+ for post_dict in post_dicts:
|
|
|
+ # 创建 Post 对象,设置默认 type="normal"
|
|
|
+ post = Post(
|
|
|
+ note_id=post_dict.get('note_id', ''),
|
|
|
+ note_url=post_dict.get('note_url', ''),
|
|
|
+ title=post_dict.get('title', ''),
|
|
|
+ body_text=post_dict.get('body_text', ''),
|
|
|
+ type='normal', # 默认值,因为原数据缺少此字段
|
|
|
+ images=post_dict.get('images', []),
|
|
|
+ video=post_dict.get('video', ''),
|
|
|
+ interact_info=post_dict.get('interact_info', {}),
|
|
|
+ )
|
|
|
+ post_objects.append(post)
|
|
|
+
|
|
|
+ return post_objects
|
|
|
+
|
|
|
+
|
|
|
+def save_extraction_results(results: dict, output_path: str, topn_posts: list[dict]):
|
|
|
+ """保存多模态解析结果到 JSON 文件"""
|
|
|
+ # 构建输出数据
|
|
|
+ output_data = {
|
|
|
+ 'total_extracted': len(results),
|
|
|
+ 'extraction_results': []
|
|
|
+ }
|
|
|
+
|
|
|
+ # 遍历每个解析结果
|
|
|
+ for note_id, extraction in results.items():
|
|
|
+ # 找到对应的原始帖子数据
|
|
|
+ original_post = None
|
|
|
+ for post in topn_posts:
|
|
|
+ if post.get('note_id') == note_id:
|
|
|
+ original_post = post
|
|
|
+ break
|
|
|
+
|
|
|
+ # 构建结果条目
|
|
|
+ result_entry = {
|
|
|
+ 'note_id': extraction.note_id,
|
|
|
+ 'note_url': extraction.note_url,
|
|
|
+ 'title': extraction.title,
|
|
|
+ 'body_text': extraction.body_text,
|
|
|
+ 'type': extraction.type,
|
|
|
+ 'extraction_time': extraction.extraction_time,
|
|
|
+ 'final_score': original_post.get('final_score') if original_post else None,
|
|
|
+ 'images': [
|
|
|
+ {
|
|
|
+ 'image_index': img.image_index,
|
|
|
+ 'original_url': img.original_url,
|
|
|
+ 'description': img.description,
|
|
|
+ 'extract_text': img.extract_text
|
|
|
+ }
|
|
|
+ for img in extraction.images
|
|
|
+ ]
|
|
|
+ }
|
|
|
+
|
|
|
+ output_data['extraction_results'].append(result_entry)
|
|
|
+
|
|
|
+ # 保存到文件
|
|
|
+ with open(output_path, 'w', encoding='utf-8') as f:
|
|
|
+ json.dump(output_data, f, ensure_ascii=False, indent=2)
|
|
|
+
|
|
|
+ print(f"\n✅ 结果已保存到: {output_path}")
|
|
|
+
|
|
|
+
|
|
|
+async def main(context_file_path: str, output_file_path: str, top_n: int = 10,
|
|
|
+ max_concurrent: int = 5):
|
|
|
+ """主函数
|
|
|
+
|
|
|
+ Args:
|
|
|
+ context_file_path: run_context_v3.json 文件路径
|
|
|
+ output_file_path: 输出文件路径
|
|
|
+ top_n: 提取前N个帖子(默认10)
|
|
|
+ max_concurrent: 最大并发数(默认5)
|
|
|
+ """
|
|
|
+ print("=" * 80)
|
|
|
+ print(f"多模态解析 - Top{top_n} 帖子")
|
|
|
+ print("=" * 80)
|
|
|
+
|
|
|
+ # 1. 加载数据
|
|
|
+ print(f"\n📂 加载文件: {context_file_path}")
|
|
|
+ context_data = load_run_context(context_file_path)
|
|
|
+
|
|
|
+ # 2. 提取所有帖子
|
|
|
+ print(f"\n🔍 提取所有帖子...")
|
|
|
+ all_posts = extract_all_posts_from_context(context_data)
|
|
|
+ print(f" 去重后共找到 {len(all_posts)} 个唯一帖子")
|
|
|
+
|
|
|
+ # 3. 过滤并排序获取 topN
|
|
|
+ print(f"\n📊 筛选 top{top_n} 帖子...")
|
|
|
+ topn_posts = filter_and_sort_topn(all_posts, top_n)
|
|
|
+
|
|
|
+ if len(topn_posts) == 0:
|
|
|
+ print(" ⚠️ 没有找到有效的帖子")
|
|
|
+ return
|
|
|
+
|
|
|
+ print(f" Top{top_n} 帖子得分范围: {topn_posts[-1].get('final_score')} ~ {topn_posts[0].get('final_score')}")
|
|
|
+
|
|
|
+ # 打印 topN 列表
|
|
|
+ print(f"\n Top{top_n} 帖子列表:")
|
|
|
+ for i, post in enumerate(topn_posts, 1):
|
|
|
+ print(f" {i}. [{post.get('final_score')}] {post.get('title')[:40]}... ({post.get('note_id')})")
|
|
|
+
|
|
|
+ # 4. 转换为 Post 对象
|
|
|
+ print(f"\n🔄 转换为 Post 对象...")
|
|
|
+ post_objects = convert_to_post_objects(topn_posts)
|
|
|
+ print(f" 成功转换 {len(post_objects)} 个 Post 对象")
|
|
|
+
|
|
|
+ # 5. 进行多模态解析
|
|
|
+ print(f"\n🖼️ 开始多模态图片内容解析...")
|
|
|
+ print(f" (并发限制: {max_concurrent})")
|
|
|
+ extraction_results = await extract_all_posts(
|
|
|
+ post_objects,
|
|
|
+ max_concurrent=max_concurrent
|
|
|
+ )
|
|
|
+
|
|
|
+ # 6. 保存结果
|
|
|
+ print(f"\n💾 保存解析结果...")
|
|
|
+ save_extraction_results(extraction_results, output_file_path, topn_posts)
|
|
|
+
|
|
|
+ print("\n" + "=" * 80)
|
|
|
+ print("✅ 处理完成!")
|
|
|
+ print("=" * 80)
|
|
|
+
|
|
|
+
|
|
|
+if __name__ == "__main__":
|
|
|
+ # 创建命令行参数解析器
|
|
|
+ parser = argparse.ArgumentParser(
|
|
|
+ description='从 run_context_v3.json 中提取 topN 帖子并进行多模态解析',
|
|
|
+ formatter_class=argparse.RawDescriptionHelpFormatter,
|
|
|
+ epilog='''
|
|
|
+示例用法:
|
|
|
+ # 使用默认参数 (top10, 并发5)
|
|
|
+ python3 extract_topn_multimodal.py
|
|
|
+
|
|
|
+ # 提取前20个帖子
|
|
|
+ python3 extract_topn_multimodal.py --top-n 20
|
|
|
+
|
|
|
+ # 自定义并发数
|
|
|
+ python3 extract_topn_multimodal.py --top-n 15 --max-concurrent 10
|
|
|
+
|
|
|
+ # 指定输入输出文件
|
|
|
+ python3 extract_topn_multimodal.py -i input.json -o output.json --top-n 30
|
|
|
+ '''
|
|
|
+ )
|
|
|
+
|
|
|
+ # 默认路径配置
|
|
|
+ DEFAULT_CONTEXT_FILE = "input/test_case/output/knowledge_search_traverse/20251114/005215_b1/run_context_v3.json"
|
|
|
+ DEFAULT_OUTPUT_FILE = "input/test_case/output/knowledge_search_traverse/20251114/005215_b1/multimodal_extraction_topn.json"
|
|
|
+
|
|
|
+ # 添加参数
|
|
|
+ parser.add_argument(
|
|
|
+ '-i', '--input',
|
|
|
+ dest='context_file',
|
|
|
+ default=DEFAULT_CONTEXT_FILE,
|
|
|
+ help=f'输入的 run_context_v3.json 文件路径 (默认: {DEFAULT_CONTEXT_FILE})'
|
|
|
+ )
|
|
|
+ parser.add_argument(
|
|
|
+ '-o', '--output',
|
|
|
+ dest='output_file',
|
|
|
+ default=DEFAULT_OUTPUT_FILE,
|
|
|
+ help=f'输出的 JSON 文件路径 (默认: {DEFAULT_OUTPUT_FILE})'
|
|
|
+ )
|
|
|
+ parser.add_argument(
|
|
|
+ '-n', '--top-n',
|
|
|
+ dest='top_n',
|
|
|
+ type=int,
|
|
|
+ default=10,
|
|
|
+ help='提取前N个帖子 (默认: 10)'
|
|
|
+ )
|
|
|
+ parser.add_argument(
|
|
|
+ '-c', '--max-concurrent',
|
|
|
+ dest='max_concurrent',
|
|
|
+ type=int,
|
|
|
+ default=5,
|
|
|
+ help='最大并发数 (默认: 5)'
|
|
|
+ )
|
|
|
+
|
|
|
+ # 解析参数
|
|
|
+ args = parser.parse_args()
|
|
|
+
|
|
|
+ # 检查文件是否存在
|
|
|
+ if not os.path.exists(args.context_file):
|
|
|
+ print(f"❌ 错误: 文件不存在 - {args.context_file}")
|
|
|
+ sys.exit(1)
|
|
|
+
|
|
|
+ # 打印参数配置
|
|
|
+ print(f"\n📋 参数配置:")
|
|
|
+ print(f" 输入文件: {args.context_file}")
|
|
|
+ print(f" 输出文件: {args.output_file}")
|
|
|
+ print(f" 提取数量: Top{args.top_n}")
|
|
|
+ print(f" 最大并发: {args.max_concurrent}")
|
|
|
+ print()
|
|
|
+
|
|
|
+ # 运行主函数
|
|
|
+ asyncio.run(main(
|
|
|
+ args.context_file,
|
|
|
+ args.output_file,
|
|
|
+ args.top_n,
|
|
|
+ args.max_concurrent
|
|
|
+ ))
|