| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184 |
- """
- 从 run_context_v3.json 中提取 top10 帖子并进行多模态解析
- 功能:
- 1. 读取 run_context_v3.json
- 2. 提取所有帖子,按 final_score 排序,取 top10
- 3. 使用 multimodal_extractor 进行图片内容解析
- 4. 保存结果到独立的 JSON 文件
- """
- import asyncio
- import json
- import os
- import sys
- from pathlib import Path
- from typing import Optional
- # 导入必要的模块
- from knowledge_search_traverse import Post
- from multimodal_extractor import extract_all_posts
- def load_run_context(json_path: str) -> dict:
- """加载 run_context_v3.json 文件"""
- with open(json_path, 'r', encoding='utf-8') as f:
- return json.load(f)
- def extract_all_posts_from_context(context_data: dict) -> list[dict]:
- """从 context 数据中提取所有帖子"""
- all_posts = []
- # 遍历所有轮次
- for round_data in context_data.get('rounds', []):
- # 遍历搜索结果
- for search_result in round_data.get('search_results', []):
- # 遍历帖子列表
- for post in search_result.get('post_list', []):
- all_posts.append(post)
- return all_posts
- def filter_and_sort_top10(posts: list[dict]) -> list[dict]:
- """过滤并排序,获取 final_score top10 的帖子"""
- # 过滤掉 final_score 为 null 的帖子
- valid_posts = [p for p in posts if p.get('final_score') is not None]
- # 按 final_score 降序排序
- sorted_posts = sorted(valid_posts, key=lambda x: x.get('final_score', 0), reverse=True)
- # 取前10个
- top10 = sorted_posts[:10]
- return top10
- def convert_to_post_objects(post_dicts: list[dict]) -> list[Post]:
- """将字典数据转换为 Post 对象"""
- post_objects = []
- for post_dict in post_dicts:
- # 创建 Post 对象,设置默认 type="normal"
- post = Post(
- note_id=post_dict.get('note_id', ''),
- note_url=post_dict.get('note_url', ''),
- title=post_dict.get('title', ''),
- body_text=post_dict.get('body_text', ''),
- type='normal', # 默认值,因为原数据缺少此字段
- images=post_dict.get('images', []),
- video=post_dict.get('video', ''),
- interact_info=post_dict.get('interact_info', {}),
- )
- post_objects.append(post)
- return post_objects
- def save_extraction_results(results: dict, output_path: str, top10_posts: list[dict]):
- """保存多模态解析结果到 JSON 文件"""
- # 构建输出数据
- output_data = {
- 'total_extracted': len(results),
- 'extraction_results': []
- }
- # 遍历每个解析结果
- for note_id, extraction in results.items():
- # 找到对应的原始帖子数据
- original_post = None
- for post in top10_posts:
- if post.get('note_id') == note_id:
- original_post = post
- break
- # 构建结果条目
- result_entry = {
- 'note_id': extraction.note_id,
- 'note_url': extraction.note_url,
- 'title': extraction.title,
- 'body_text': extraction.body_text,
- 'type': extraction.type,
- 'extraction_time': extraction.extraction_time,
- 'final_score': original_post.get('final_score') if original_post else None,
- 'images': [
- {
- 'image_index': img.image_index,
- 'original_url': img.original_url,
- 'description': img.description,
- 'extract_text': img.extract_text
- }
- for img in extraction.images
- ]
- }
- output_data['extraction_results'].append(result_entry)
- # 保存到文件
- with open(output_path, 'w', encoding='utf-8') as f:
- json.dump(output_data, f, ensure_ascii=False, indent=2)
- print(f"\n✅ 结果已保存到: {output_path}")
- async def main(context_file_path: str, output_file_path: str):
- """主函数"""
- print("=" * 80)
- print("多模态解析 - Top10 帖子")
- print("=" * 80)
- # 1. 加载数据
- print(f"\n📂 加载文件: {context_file_path}")
- context_data = load_run_context(context_file_path)
- # 2. 提取所有帖子
- print(f"\n🔍 提取所有帖子...")
- all_posts = extract_all_posts_from_context(context_data)
- print(f" 共找到 {len(all_posts)} 个帖子")
- # 3. 过滤并排序获取 top10
- print(f"\n📊 筛选 top10 帖子...")
- top10_posts = filter_and_sort_top10(all_posts)
- print(f" Top10 帖子得分范围: {top10_posts[-1].get('final_score')} ~ {top10_posts[0].get('final_score')}")
- # 打印 top10 列表
- print("\n Top10 帖子列表:")
- for i, post in enumerate(top10_posts, 1):
- print(f" {i}. [{post.get('final_score')}] {post.get('title')[:40]}... ({post.get('note_id')})")
- # 4. 转换为 Post 对象
- print(f"\n🔄 转换为 Post 对象...")
- post_objects = convert_to_post_objects(top10_posts)
- print(f" 成功转换 {len(post_objects)} 个 Post 对象")
- # 5. 进行多模态解析
- print(f"\n🖼️ 开始多模态图片内容解析...")
- print(f" (并发限制: 5, 每张图片最多 10 张)")
- extraction_results = await extract_all_posts(post_objects, max_concurrent=5)
- # 6. 保存结果
- print(f"\n💾 保存解析结果...")
- save_extraction_results(extraction_results, output_file_path, top10_posts)
- print("\n" + "=" * 80)
- print("✅ 处理完成!")
- print("=" * 80)
- if __name__ == "__main__":
- # 默认路径配置
- DEFAULT_CONTEXT_FILE = "input/test_case/output/knowledge_search_traverse/20251114/005215_b1/run_context_v3.json"
- DEFAULT_OUTPUT_FILE = "input/test_case/output/knowledge_search_traverse/20251114/005215_b1/multimodal_extraction_top10.json"
- # 可以通过命令行参数覆盖
- context_file = sys.argv[1] if len(sys.argv) > 1 else DEFAULT_CONTEXT_FILE
- output_file = sys.argv[2] if len(sys.argv) > 2 else DEFAULT_OUTPUT_FILE
- # 检查文件是否存在
- if not os.path.exists(context_file):
- print(f"❌ 错误: 文件不存在 - {context_file}")
- sys.exit(1)
- # 运行主函数
- asyncio.run(main(context_file, output_file))
|