| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281 |
- """
- 从 run_context_v3.json 中提取 topN 帖子并进行多模态解析
- 功能:
- 1. 读取 run_context_v3.json
- 2. 提取所有帖子,按 final_score 排序,取 topN
- 3. 使用 multimodal_extractor 进行图片内容解析
- 4. 保存结果到独立的 JSON 文件
- 参数化配置:
- - top_n: 提取前N个帖子(默认10)
- - max_concurrent: 最大并发数(默认5)
- """
- import argparse
- import asyncio
- import json
- import os
- import sys
- from pathlib import Path
- from typing import Optional
- # 导入必要的模块
- from knowledge_search_traverse import Post
- from multimodal_extractor import extract_all_posts
- def load_run_context(json_path: str) -> dict:
- """加载 run_context_v3.json 文件"""
- with open(json_path, 'r', encoding='utf-8') as f:
- return json.load(f)
- def extract_all_posts_from_context(context_data: dict) -> list[dict]:
- """从 context 数据中提取所有帖子(按note_id去重,保留得分最高的)"""
- # 使用字典进行去重,key为note_id
- posts_dict = {}
- # 遍历所有轮次
- for round_data in context_data.get('rounds', []):
- # 遍历搜索结果
- for search_result in round_data.get('search_results', []):
- # 遍历帖子列表
- for post in search_result.get('post_list', []):
- note_id = post.get('note_id')
- if not note_id:
- continue
- # 如果是新帖子,直接添加
- if note_id not in posts_dict:
- posts_dict[note_id] = post
- else:
- # 如果已存在,比较final_score,保留得分更高的
- existing_score = posts_dict[note_id].get('final_score')
- current_score = post.get('final_score')
- # 如果当前帖子的分数更高,或者现有帖子没有分数,则替换
- if existing_score is None or (current_score is not None and current_score > existing_score):
- posts_dict[note_id] = post
- # 返回去重后的帖子列表
- return list(posts_dict.values())
- def filter_and_sort_topn(posts: list[dict], top_n: int = 10) -> list[dict]:
- """过滤并排序,获取 final_score topN 的帖子"""
- # 过滤掉 final_score 为 null 的帖子
- valid_posts = [p for p in posts if p.get('final_score') is not None]
- # 按 final_score 降序排序
- sorted_posts = sorted(valid_posts, key=lambda x: x.get('final_score', 0), reverse=True)
- # 取前N个
- topn = sorted_posts[:top_n]
- return topn
- def convert_to_post_objects(post_dicts: list[dict]) -> list[Post]:
- """将字典数据转换为 Post 对象"""
- post_objects = []
- for post_dict in post_dicts:
- # 创建 Post 对象,设置默认 type="normal"
- post = Post(
- note_id=post_dict.get('note_id', ''),
- note_url=post_dict.get('note_url', ''),
- title=post_dict.get('title', ''),
- body_text=post_dict.get('body_text', ''),
- type='normal', # 默认值,因为原数据缺少此字段
- images=post_dict.get('images', []),
- video=post_dict.get('video', ''),
- interact_info=post_dict.get('interact_info', {}),
- )
- post_objects.append(post)
- return post_objects
- def save_extraction_results(results: dict, output_path: str, topn_posts: list[dict]):
- """保存多模态解析结果到 JSON 文件"""
- # 构建输出数据
- output_data = {
- 'total_extracted': len(results),
- 'extraction_results': []
- }
- # 遍历每个解析结果
- for note_id, extraction in results.items():
- # 找到对应的原始帖子数据
- original_post = None
- for post in topn_posts:
- if post.get('note_id') == note_id:
- original_post = post
- break
- # 构建结果条目
- result_entry = {
- 'note_id': extraction.note_id,
- 'note_url': extraction.note_url,
- 'title': extraction.title,
- 'body_text': extraction.body_text,
- 'type': extraction.type,
- 'extraction_time': extraction.extraction_time,
- 'final_score': original_post.get('final_score') if original_post else None,
- 'images': [
- {
- 'image_index': img.image_index,
- 'original_url': img.original_url,
- 'description': img.description,
- 'extract_text': img.extract_text
- }
- for img in extraction.images
- ]
- }
- output_data['extraction_results'].append(result_entry)
- # 保存到文件
- with open(output_path, 'w', encoding='utf-8') as f:
- json.dump(output_data, f, ensure_ascii=False, indent=2)
- print(f"\n✅ 结果已保存到: {output_path}")
- async def main(context_file_path: str, output_file_path: str, top_n: int = 10,
- max_concurrent: int = 5):
- """主函数
- Args:
- context_file_path: run_context_v3.json 文件路径
- output_file_path: 输出文件路径
- top_n: 提取前N个帖子(默认10)
- max_concurrent: 最大并发数(默认5)
- """
- print("=" * 80)
- print(f"多模态解析 - Top{top_n} 帖子")
- print("=" * 80)
- # 1. 加载数据
- print(f"\n📂 加载文件: {context_file_path}")
- context_data = load_run_context(context_file_path)
- # 2. 提取所有帖子
- print(f"\n🔍 提取所有帖子...")
- all_posts = extract_all_posts_from_context(context_data)
- print(f" 去重后共找到 {len(all_posts)} 个唯一帖子")
- # 3. 过滤并排序获取 topN
- print(f"\n📊 筛选 top{top_n} 帖子...")
- topn_posts = filter_and_sort_topn(all_posts, top_n)
- if len(topn_posts) == 0:
- print(" ⚠️ 没有找到有效的帖子")
- return
- print(f" Top{top_n} 帖子得分范围: {topn_posts[-1].get('final_score')} ~ {topn_posts[0].get('final_score')}")
- # 打印 topN 列表
- print(f"\n Top{top_n} 帖子列表:")
- for i, post in enumerate(topn_posts, 1):
- print(f" {i}. [{post.get('final_score')}] {post.get('title')[:40]}... ({post.get('note_id')})")
- # 4. 转换为 Post 对象
- print(f"\n🔄 转换为 Post 对象...")
- post_objects = convert_to_post_objects(topn_posts)
- print(f" 成功转换 {len(post_objects)} 个 Post 对象")
- # 5. 进行多模态解析
- print(f"\n🖼️ 开始多模态图片内容解析...")
- print(f" (并发限制: {max_concurrent})")
- extraction_results = await extract_all_posts(
- post_objects,
- max_concurrent=max_concurrent
- )
- # 6. 保存结果
- print(f"\n💾 保存解析结果...")
- save_extraction_results(extraction_results, output_file_path, topn_posts)
- print("\n" + "=" * 80)
- print("✅ 处理完成!")
- print("=" * 80)
- if __name__ == "__main__":
- # 创建命令行参数解析器
- parser = argparse.ArgumentParser(
- description='从 run_context_v3.json 中提取 topN 帖子并进行多模态解析',
- formatter_class=argparse.RawDescriptionHelpFormatter,
- epilog='''
- 示例用法:
- # 使用默认参数 (top10, 并发5)
- python3 extract_topn_multimodal.py
- # 提取前20个帖子
- python3 extract_topn_multimodal.py --top-n 20
- # 自定义并发数
- python3 extract_topn_multimodal.py --top-n 15 --max-concurrent 10
- # 指定输入输出文件
- python3 extract_topn_multimodal.py -i input.json -o output.json --top-n 30
- '''
- )
- # 默认路径配置
- DEFAULT_CONTEXT_FILE = "input/test_case/output/knowledge_search_traverse/20251114/005215_b1/run_context_v3.json"
- DEFAULT_OUTPUT_FILE = "input/test_case/output/knowledge_search_traverse/20251114/005215_b1/multimodal_extraction_topn.json"
- # 添加参数
- parser.add_argument(
- '-i', '--input',
- dest='context_file',
- default=DEFAULT_CONTEXT_FILE,
- help=f'输入的 run_context_v3.json 文件路径 (默认: {DEFAULT_CONTEXT_FILE})'
- )
- parser.add_argument(
- '-o', '--output',
- dest='output_file',
- default=DEFAULT_OUTPUT_FILE,
- help=f'输出的 JSON 文件路径 (默认: {DEFAULT_OUTPUT_FILE})'
- )
- parser.add_argument(
- '-n', '--top-n',
- dest='top_n',
- type=int,
- default=10,
- help='提取前N个帖子 (默认: 10)'
- )
- parser.add_argument(
- '-c', '--max-concurrent',
- dest='max_concurrent',
- type=int,
- default=5,
- help='最大并发数 (默认: 5)'
- )
- # 解析参数
- args = parser.parse_args()
- # 检查文件是否存在
- if not os.path.exists(args.context_file):
- print(f"❌ 错误: 文件不存在 - {args.context_file}")
- sys.exit(1)
- # 打印参数配置
- print(f"\n📋 参数配置:")
- print(f" 输入文件: {args.context_file}")
- print(f" 输出文件: {args.output_file}")
- print(f" 提取数量: Top{args.top_n}")
- print(f" 最大并发: {args.max_concurrent}")
- print()
- # 运行主函数
- asyncio.run(main(
- args.context_file,
- args.output_file,
- args.top_n,
- args.max_concurrent
- ))
|