extract_top10_multimodal.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184
  1. """
  2. 从 run_context_v3.json 中提取 top10 帖子并进行多模态解析
  3. 功能:
  4. 1. 读取 run_context_v3.json
  5. 2. 提取所有帖子,按 final_score 排序,取 top10
  6. 3. 使用 multimodal_extractor 进行图片内容解析
  7. 4. 保存结果到独立的 JSON 文件
  8. """
  9. import asyncio
  10. import json
  11. import os
  12. import sys
  13. from pathlib import Path
  14. from typing import Optional
  15. # 导入必要的模块
  16. from knowledge_search_traverse import Post
  17. from multimodal_extractor import extract_all_posts
  18. def load_run_context(json_path: str) -> dict:
  19. """加载 run_context_v3.json 文件"""
  20. with open(json_path, 'r', encoding='utf-8') as f:
  21. return json.load(f)
  22. def extract_all_posts_from_context(context_data: dict) -> list[dict]:
  23. """从 context 数据中提取所有帖子"""
  24. all_posts = []
  25. # 遍历所有轮次
  26. for round_data in context_data.get('rounds', []):
  27. # 遍历搜索结果
  28. for search_result in round_data.get('search_results', []):
  29. # 遍历帖子列表
  30. for post in search_result.get('post_list', []):
  31. all_posts.append(post)
  32. return all_posts
  33. def filter_and_sort_top10(posts: list[dict]) -> list[dict]:
  34. """过滤并排序,获取 final_score top10 的帖子"""
  35. # 过滤掉 final_score 为 null 的帖子
  36. valid_posts = [p for p in posts if p.get('final_score') is not None]
  37. # 按 final_score 降序排序
  38. sorted_posts = sorted(valid_posts, key=lambda x: x.get('final_score', 0), reverse=True)
  39. # 取前10个
  40. top10 = sorted_posts[:10]
  41. return top10
  42. def convert_to_post_objects(post_dicts: list[dict]) -> list[Post]:
  43. """将字典数据转换为 Post 对象"""
  44. post_objects = []
  45. for post_dict in post_dicts:
  46. # 创建 Post 对象,设置默认 type="normal"
  47. post = Post(
  48. note_id=post_dict.get('note_id', ''),
  49. note_url=post_dict.get('note_url', ''),
  50. title=post_dict.get('title', ''),
  51. body_text=post_dict.get('body_text', ''),
  52. type='normal', # 默认值,因为原数据缺少此字段
  53. images=post_dict.get('images', []),
  54. video=post_dict.get('video', ''),
  55. interact_info=post_dict.get('interact_info', {}),
  56. )
  57. post_objects.append(post)
  58. return post_objects
  59. def save_extraction_results(results: dict, output_path: str, top10_posts: list[dict]):
  60. """保存多模态解析结果到 JSON 文件"""
  61. # 构建输出数据
  62. output_data = {
  63. 'total_extracted': len(results),
  64. 'extraction_results': []
  65. }
  66. # 遍历每个解析结果
  67. for note_id, extraction in results.items():
  68. # 找到对应的原始帖子数据
  69. original_post = None
  70. for post in top10_posts:
  71. if post.get('note_id') == note_id:
  72. original_post = post
  73. break
  74. # 构建结果条目
  75. result_entry = {
  76. 'note_id': extraction.note_id,
  77. 'note_url': extraction.note_url,
  78. 'title': extraction.title,
  79. 'body_text': extraction.body_text,
  80. 'type': extraction.type,
  81. 'extraction_time': extraction.extraction_time,
  82. 'final_score': original_post.get('final_score') if original_post else None,
  83. 'images': [
  84. {
  85. 'image_index': img.image_index,
  86. 'original_url': img.original_url,
  87. 'description': img.description,
  88. 'extract_text': img.extract_text
  89. }
  90. for img in extraction.images
  91. ]
  92. }
  93. output_data['extraction_results'].append(result_entry)
  94. # 保存到文件
  95. with open(output_path, 'w', encoding='utf-8') as f:
  96. json.dump(output_data, f, ensure_ascii=False, indent=2)
  97. print(f"\n✅ 结果已保存到: {output_path}")
  98. async def main(context_file_path: str, output_file_path: str):
  99. """主函数"""
  100. print("=" * 80)
  101. print("多模态解析 - Top10 帖子")
  102. print("=" * 80)
  103. # 1. 加载数据
  104. print(f"\n📂 加载文件: {context_file_path}")
  105. context_data = load_run_context(context_file_path)
  106. # 2. 提取所有帖子
  107. print(f"\n🔍 提取所有帖子...")
  108. all_posts = extract_all_posts_from_context(context_data)
  109. print(f" 共找到 {len(all_posts)} 个帖子")
  110. # 3. 过滤并排序获取 top10
  111. print(f"\n📊 筛选 top10 帖子...")
  112. top10_posts = filter_and_sort_top10(all_posts)
  113. print(f" Top10 帖子得分范围: {top10_posts[-1].get('final_score')} ~ {top10_posts[0].get('final_score')}")
  114. # 打印 top10 列表
  115. print("\n Top10 帖子列表:")
  116. for i, post in enumerate(top10_posts, 1):
  117. print(f" {i}. [{post.get('final_score')}] {post.get('title')[:40]}... ({post.get('note_id')})")
  118. # 4. 转换为 Post 对象
  119. print(f"\n🔄 转换为 Post 对象...")
  120. post_objects = convert_to_post_objects(top10_posts)
  121. print(f" 成功转换 {len(post_objects)} 个 Post 对象")
  122. # 5. 进行多模态解析
  123. print(f"\n🖼️ 开始多模态图片内容解析...")
  124. print(f" (并发限制: 5, 每张图片最多 10 张)")
  125. extraction_results = await extract_all_posts(post_objects, max_concurrent=5)
  126. # 6. 保存结果
  127. print(f"\n💾 保存解析结果...")
  128. save_extraction_results(extraction_results, output_file_path, top10_posts)
  129. print("\n" + "=" * 80)
  130. print("✅ 处理完成!")
  131. print("=" * 80)
  132. if __name__ == "__main__":
  133. # 默认路径配置
  134. DEFAULT_CONTEXT_FILE = "input/test_case/output/knowledge_search_traverse/20251114/005215_b1/run_context_v3.json"
  135. DEFAULT_OUTPUT_FILE = "input/test_case/output/knowledge_search_traverse/20251114/005215_b1/multimodal_extraction_top10.json"
  136. # 可以通过命令行参数覆盖
  137. context_file = sys.argv[1] if len(sys.argv) > 1 else DEFAULT_CONTEXT_FILE
  138. output_file = sys.argv[2] if len(sys.argv) > 2 else DEFAULT_OUTPUT_FILE
  139. # 检查文件是否存在
  140. if not os.path.exists(context_file):
  141. print(f"❌ 错误: 文件不存在 - {context_file}")
  142. sys.exit(1)
  143. # 运行主函数
  144. asyncio.run(main(context_file, output_file))