extract_topn_multimodal.py 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281
  1. """
  2. 从 run_context_v3.json 中提取 topN 帖子并进行多模态解析
  3. 功能:
  4. 1. 读取 run_context_v3.json
  5. 2. 提取所有帖子,按 final_score 排序,取 topN
  6. 3. 使用 multimodal_extractor 进行图片内容解析
  7. 4. 保存结果到独立的 JSON 文件
  8. 参数化配置:
  9. - top_n: 提取前N个帖子(默认10)
  10. - max_concurrent: 最大并发数(默认5)
  11. """
  12. import argparse
  13. import asyncio
  14. import json
  15. import os
  16. import sys
  17. from pathlib import Path
  18. from typing import Optional
  19. # 导入必要的模块
  20. from knowledge_search_traverse import Post
  21. from multimodal_extractor import extract_all_posts
  22. def load_run_context(json_path: str) -> dict:
  23. """加载 run_context_v3.json 文件"""
  24. with open(json_path, 'r', encoding='utf-8') as f:
  25. return json.load(f)
  26. def extract_all_posts_from_context(context_data: dict) -> list[dict]:
  27. """从 context 数据中提取所有帖子(按note_id去重,保留得分最高的)"""
  28. # 使用字典进行去重,key为note_id
  29. posts_dict = {}
  30. # 遍历所有轮次
  31. for round_data in context_data.get('rounds', []):
  32. # 遍历搜索结果
  33. for search_result in round_data.get('search_results', []):
  34. # 遍历帖子列表
  35. for post in search_result.get('post_list', []):
  36. note_id = post.get('note_id')
  37. if not note_id:
  38. continue
  39. # 如果是新帖子,直接添加
  40. if note_id not in posts_dict:
  41. posts_dict[note_id] = post
  42. else:
  43. # 如果已存在,比较final_score,保留得分更高的
  44. existing_score = posts_dict[note_id].get('final_score')
  45. current_score = post.get('final_score')
  46. # 如果当前帖子的分数更高,或者现有帖子没有分数,则替换
  47. if existing_score is None or (current_score is not None and current_score > existing_score):
  48. posts_dict[note_id] = post
  49. # 返回去重后的帖子列表
  50. return list(posts_dict.values())
  51. def filter_and_sort_topn(posts: list[dict], top_n: int = 10) -> list[dict]:
  52. """过滤并排序,获取 final_score topN 的帖子"""
  53. # 过滤掉 final_score 为 null 的帖子
  54. valid_posts = [p for p in posts if p.get('final_score') is not None]
  55. # 按 final_score 降序排序
  56. sorted_posts = sorted(valid_posts, key=lambda x: x.get('final_score', 0), reverse=True)
  57. # 取前N个
  58. topn = sorted_posts[:top_n]
  59. return topn
  60. def convert_to_post_objects(post_dicts: list[dict]) -> list[Post]:
  61. """将字典数据转换为 Post 对象"""
  62. post_objects = []
  63. for post_dict in post_dicts:
  64. # 创建 Post 对象,设置默认 type="normal"
  65. post = Post(
  66. note_id=post_dict.get('note_id', ''),
  67. note_url=post_dict.get('note_url', ''),
  68. title=post_dict.get('title', ''),
  69. body_text=post_dict.get('body_text', ''),
  70. type='normal', # 默认值,因为原数据缺少此字段
  71. images=post_dict.get('images', []),
  72. video=post_dict.get('video', ''),
  73. interact_info=post_dict.get('interact_info', {}),
  74. )
  75. post_objects.append(post)
  76. return post_objects
  77. def save_extraction_results(results: dict, output_path: str, topn_posts: list[dict]):
  78. """保存多模态解析结果到 JSON 文件"""
  79. # 构建输出数据
  80. output_data = {
  81. 'total_extracted': len(results),
  82. 'extraction_results': []
  83. }
  84. # 遍历每个解析结果
  85. for note_id, extraction in results.items():
  86. # 找到对应的原始帖子数据
  87. original_post = None
  88. for post in topn_posts:
  89. if post.get('note_id') == note_id:
  90. original_post = post
  91. break
  92. # 构建结果条目
  93. result_entry = {
  94. 'note_id': extraction.note_id,
  95. 'note_url': extraction.note_url,
  96. 'title': extraction.title,
  97. 'body_text': extraction.body_text,
  98. 'type': extraction.type,
  99. 'extraction_time': extraction.extraction_time,
  100. 'final_score': original_post.get('final_score') if original_post else None,
  101. 'images': [
  102. {
  103. 'image_index': img.image_index,
  104. 'original_url': img.original_url,
  105. 'description': img.description,
  106. 'extract_text': img.extract_text
  107. }
  108. for img in extraction.images
  109. ]
  110. }
  111. output_data['extraction_results'].append(result_entry)
  112. # 保存到文件
  113. with open(output_path, 'w', encoding='utf-8') as f:
  114. json.dump(output_data, f, ensure_ascii=False, indent=2)
  115. print(f"\n✅ 结果已保存到: {output_path}")
  116. async def main(context_file_path: str, output_file_path: str, top_n: int = 10,
  117. max_concurrent: int = 5):
  118. """主函数
  119. Args:
  120. context_file_path: run_context_v3.json 文件路径
  121. output_file_path: 输出文件路径
  122. top_n: 提取前N个帖子(默认10)
  123. max_concurrent: 最大并发数(默认5)
  124. """
  125. print("=" * 80)
  126. print(f"多模态解析 - Top{top_n} 帖子")
  127. print("=" * 80)
  128. # 1. 加载数据
  129. print(f"\n📂 加载文件: {context_file_path}")
  130. context_data = load_run_context(context_file_path)
  131. # 2. 提取所有帖子
  132. print(f"\n🔍 提取所有帖子...")
  133. all_posts = extract_all_posts_from_context(context_data)
  134. print(f" 去重后共找到 {len(all_posts)} 个唯一帖子")
  135. # 3. 过滤并排序获取 topN
  136. print(f"\n📊 筛选 top{top_n} 帖子...")
  137. topn_posts = filter_and_sort_topn(all_posts, top_n)
  138. if len(topn_posts) == 0:
  139. print(" ⚠️ 没有找到有效的帖子")
  140. return
  141. print(f" Top{top_n} 帖子得分范围: {topn_posts[-1].get('final_score')} ~ {topn_posts[0].get('final_score')}")
  142. # 打印 topN 列表
  143. print(f"\n Top{top_n} 帖子列表:")
  144. for i, post in enumerate(topn_posts, 1):
  145. print(f" {i}. [{post.get('final_score')}] {post.get('title')[:40]}... ({post.get('note_id')})")
  146. # 4. 转换为 Post 对象
  147. print(f"\n🔄 转换为 Post 对象...")
  148. post_objects = convert_to_post_objects(topn_posts)
  149. print(f" 成功转换 {len(post_objects)} 个 Post 对象")
  150. # 5. 进行多模态解析
  151. print(f"\n🖼️ 开始多模态图片内容解析...")
  152. print(f" (并发限制: {max_concurrent})")
  153. extraction_results = await extract_all_posts(
  154. post_objects,
  155. max_concurrent=max_concurrent
  156. )
  157. # 6. 保存结果
  158. print(f"\n💾 保存解析结果...")
  159. save_extraction_results(extraction_results, output_file_path, topn_posts)
  160. print("\n" + "=" * 80)
  161. print("✅ 处理完成!")
  162. print("=" * 80)
  163. if __name__ == "__main__":
  164. # 创建命令行参数解析器
  165. parser = argparse.ArgumentParser(
  166. description='从 run_context_v3.json 中提取 topN 帖子并进行多模态解析',
  167. formatter_class=argparse.RawDescriptionHelpFormatter,
  168. epilog='''
  169. 示例用法:
  170. # 使用默认参数 (top10, 并发5)
  171. python3 extract_topn_multimodal.py
  172. # 提取前20个帖子
  173. python3 extract_topn_multimodal.py --top-n 20
  174. # 自定义并发数
  175. python3 extract_topn_multimodal.py --top-n 15 --max-concurrent 10
  176. # 指定输入输出文件
  177. python3 extract_topn_multimodal.py -i input.json -o output.json --top-n 30
  178. '''
  179. )
  180. # 默认路径配置
  181. DEFAULT_CONTEXT_FILE = "input/test_case/output/knowledge_search_traverse/20251114/005215_b1/run_context_v3.json"
  182. DEFAULT_OUTPUT_FILE = "input/test_case/output/knowledge_search_traverse/20251114/005215_b1/multimodal_extraction_topn.json"
  183. # 添加参数
  184. parser.add_argument(
  185. '-i', '--input',
  186. dest='context_file',
  187. default=DEFAULT_CONTEXT_FILE,
  188. help=f'输入的 run_context_v3.json 文件路径 (默认: {DEFAULT_CONTEXT_FILE})'
  189. )
  190. parser.add_argument(
  191. '-o', '--output',
  192. dest='output_file',
  193. default=DEFAULT_OUTPUT_FILE,
  194. help=f'输出的 JSON 文件路径 (默认: {DEFAULT_OUTPUT_FILE})'
  195. )
  196. parser.add_argument(
  197. '-n', '--top-n',
  198. dest='top_n',
  199. type=int,
  200. default=10,
  201. help='提取前N个帖子 (默认: 10)'
  202. )
  203. parser.add_argument(
  204. '-c', '--max-concurrent',
  205. dest='max_concurrent',
  206. type=int,
  207. default=5,
  208. help='最大并发数 (默认: 5)'
  209. )
  210. # 解析参数
  211. args = parser.parse_args()
  212. # 检查文件是否存在
  213. if not os.path.exists(args.context_file):
  214. print(f"❌ 错误: 文件不存在 - {args.context_file}")
  215. sys.exit(1)
  216. # 打印参数配置
  217. print(f"\n📋 参数配置:")
  218. print(f" 输入文件: {args.context_file}")
  219. print(f" 输出文件: {args.output_file}")
  220. print(f" 提取数量: Top{args.top_n}")
  221. print(f" 最大并发: {args.max_concurrent}")
  222. print()
  223. # 运行主函数
  224. asyncio.run(main(
  225. args.context_file,
  226. args.output_file,
  227. args.top_n,
  228. args.max_concurrent
  229. ))