""" 从 run_context_v3.json 中提取 topN 帖子并进行多模态解析和清洗 功能: 1. 读取 run_context_v3.json 2. 提取所有帖子,按 final_score 排序,取 topN 3. 使用 multimodal_extractor 进行图片内容解析 4. 自动进行数据清洗和结构化 5. 输出清洗后的 JSON 文件(默认不保留原始文件) 参数化配置: - top_n: 提取前N个帖子(默认10) - max_concurrent: 最大并发数(默认5) - keep_raw: 是否保留原始提取结果(默认False) """ import argparse import asyncio import json import os import sys from pathlib import Path from typing import Optional import requests # 导入必要的模块 from knowledge_search_traverse import Post from multimodal_extractor import extract_all_posts # ============================================================================ # 清洗模块 - 整合自 clean_multimodal_data.py # ============================================================================ MODEL_NAME = "google/gemini-2.5-flash" API_TIMEOUT = 60 # API 超时时间(秒) CLEAN_TEXT_PROMPT = """ 请清洗以下图片文本,要求: 1. 去除品牌标识和装饰性文字(如"Blank Plan 计划留白"、"品牌诊断|战略定位|创意内容|VI设计|爆品传播"等) 2. 去除多余换行符,整理成连贯文本 3. **完整保留所有核心内容**,不要概括或删减 4. 保持原文表达和语气 5. 将内容整理成流畅的段落 图片文本: {extract_text} 请直接输出清洗后的文本(纯文本,不要任何格式标记)。 """ async def call_llm_for_text_cleaning(extract_text: str) -> str: """ 调用LLM清洗文本 Args: extract_text: 原始图片文本 Returns: 清洗后的文本 """ # 获取API密钥 api_key = os.getenv("OPENROUTER_API_KEY") if not api_key: raise ValueError("OPENROUTER_API_KEY environment variable not set") # 构建prompt prompt = CLEAN_TEXT_PROMPT.format(extract_text=extract_text) # 构建API请求 payload = { "model": MODEL_NAME, "messages": [ { "role": "user", "content": prompt } ] } headers = { "Authorization": f"Bearer {api_key}", "Content-Type": "application/json" } # 在异步上下文中执行同步请求 loop = asyncio.get_event_loop() response = await loop.run_in_executor( None, lambda: requests.post( "https://openrouter.ai/api/v1/chat/completions", headers=headers, json=payload, timeout=API_TIMEOUT ) ) # 检查响应 if response.status_code != 200: raise Exception(f"OpenRouter API error: {response.status_code} - {response.text[:200]}") # 解析响应 result = response.json() cleaned_text = result["choices"][0]["message"]["content"].strip() return cleaned_text async def clean_single_image_text( extract_text: str, semaphore: Optional[asyncio.Semaphore] = None ) -> str: """ 清洗单张图片的文本 Args: extract_text: 原始文本 semaphore: 并发控制信号量 Returns: 清洗后的文本 """ try: if semaphore: async with semaphore: cleaned = await call_llm_for_text_cleaning(extract_text) else: cleaned = await call_llm_for_text_cleaning(extract_text) return cleaned except Exception as e: print(f" ⚠️ 清洗失败,保留原文: {str(e)[:100]}") # 如果清洗失败,返回简单清理的版本(去换行) return extract_text.replace('\n', ' ').strip() async def structure_post_content( post: dict, max_concurrent: int = 5 ) -> dict: """ 结构化整理单个帖子的内容 Args: post: 帖子数据(包含images列表) max_concurrent: 最大并发数 Returns: 添加了 content_structured 字段的帖子数据 """ images = post.get('images', []) if not images: # 如果没有图片,直接返回 post['content_structured'] = { "total_images": 0, "points": [], "formatted_text": "" } return post print(f" 🧹 清洗帖子: {post.get('note_id')} ({len(images)}张图片)") # 创建信号量控制并发 semaphore = asyncio.Semaphore(max_concurrent) # 并发清洗所有图片的文本 tasks = [] for img in images: extract_text = img.get('extract_text', '') if extract_text: task = clean_single_image_text(extract_text, semaphore) else: # 如果原始文本为空,直接返回空字符串 task = asyncio.sleep(0, result='') tasks.append(task) cleaned_texts = await asyncio.gather(*tasks) # 构建结构化points points = [] for idx, (img, cleaned_text) in enumerate(zip(images, cleaned_texts)): # 保存清洗后的文本到图片信息中 img['extract_text_cleaned'] = cleaned_text # 添加到points(如果清洗后文本不为空) if cleaned_text: points.append({ "index": idx + 1, "source_image": idx, "content": cleaned_text }) # 生成格式化文本 formatted_text = "\n".join([ f"{p['index']}. {p['content']}" for p in points ]) # 构建content_structured post['content_structured'] = { "total_images": len(images), "points": points, "formatted_text": formatted_text } print(f" ✅ 清洗完成: {post.get('note_id')}") return post async def clean_all_posts( posts: list[dict], max_concurrent: int = 5 ) -> list[dict]: """ 批量清洗所有帖子 Args: posts: 帖子列表 max_concurrent: 最大并发数 Returns: 清洗后的帖子列表 """ print(f"\n 开始清洗 {len(posts)} 个帖子...") # 顺序处理每个帖子(但每个帖子内部的图片是并发处理的) cleaned_posts = [] for post in posts: cleaned_post = await structure_post_content(post, max_concurrent) cleaned_posts.append(cleaned_post) print(f" 清洗完成: {len(cleaned_posts)} 个帖子") return cleaned_posts async def clean_and_merge_to_context( context_file_path: str, extraction_file_path: str, max_concurrent: int = 5 ) -> list[dict]: """ 清洗数据并合并到 run_context_v3.json Args: context_file_path: run_context_v3.json 文件路径 extraction_file_path: 临时提取结果文件路径 max_concurrent: 最大并发数 Returns: 清洗后的帖子列表 """ # 步骤1: 加载临时提取数据 print(f"\n 📂 加载临时提取数据: {extraction_file_path}") with open(extraction_file_path, 'r', encoding='utf-8') as f: extraction_data = json.load(f) posts = extraction_data.get('extraction_results', []) if not posts: print(" ⚠️ 没有找到需要清洗的帖子") return [] # 步骤2: LLM清洗所有帖子 cleaned_posts = await clean_all_posts(posts, max_concurrent) # 步骤3: 读取 run_context_v3.json print(f"\n 📂 读取 run_context: {context_file_path}") with open(context_file_path, 'r', encoding='utf-8') as f: context_data = json.load(f) # 步骤4: 将清洗结果写入 multimodal_cleaned_posts 字段 from datetime import datetime context_data['multimodal_cleaned_posts'] = { 'total_posts': len(cleaned_posts), 'posts': cleaned_posts, 'extraction_time': datetime.now().isoformat(), 'version': 'v1.0' } # 步骤5: 保存回 run_context_v3.json print(f"\n 💾 保存回 run_context_v3.json...") with open(context_file_path, 'w', encoding='utf-8') as f: json.dump(context_data, f, ensure_ascii=False, indent=2) print(f" ✅ 清洗结果已写入 multimodal_cleaned_posts 字段") return cleaned_posts # ============================================================================ # 原有函数 # ============================================================================ def load_run_context(json_path: str) -> dict: """加载 run_context_v3.json 文件""" with open(json_path, 'r', encoding='utf-8') as f: return json.load(f) def extract_all_posts_from_context(context_data: dict) -> list[dict]: """从 context 数据中提取所有帖子(按note_id去重,保留得分最高的)""" # 使用字典进行去重,key为note_id posts_dict = {} # 遍历所有轮次 for round_data in context_data.get('rounds', []): # 遍历搜索结果 for search_result in round_data.get('search_results', []): # 遍历帖子列表 for post in search_result.get('post_list', []): note_id = post.get('note_id') if not note_id: continue # 如果是新帖子,直接添加 if note_id not in posts_dict: posts_dict[note_id] = post else: # 如果已存在,比较final_score,保留得分更高的 existing_score = posts_dict[note_id].get('final_score') current_score = post.get('final_score') # 如果当前帖子的分数更高,或者现有帖子没有分数,则替换 if existing_score is None or (current_score is not None and current_score > existing_score): posts_dict[note_id] = post # 返回去重后的帖子列表 return list(posts_dict.values()) def filter_and_sort_topn(posts: list[dict], top_n: int = 10) -> list[dict]: """过滤并排序,获取 final_score topN 的帖子""" # 过滤掉 final_score 为 null 的帖子 valid_posts = [p for p in posts if p.get('final_score') is not None] # 按 final_score 降序排序 sorted_posts = sorted(valid_posts, key=lambda x: x.get('final_score', 0), reverse=True) # 取前N个 topn = sorted_posts[:top_n] return topn def convert_to_post_objects(post_dicts: list[dict]) -> list[Post]: """将字典数据转换为 Post 对象""" post_objects = [] for post_dict in post_dicts: # 创建 Post 对象,设置默认 type="normal" post = Post( note_id=post_dict.get('note_id', ''), note_url=post_dict.get('note_url', ''), title=post_dict.get('title', ''), body_text=post_dict.get('body_text', ''), type='normal', # 默认值,因为原数据缺少此字段 images=post_dict.get('images', []), video=post_dict.get('video', ''), interact_info=post_dict.get('interact_info', {}), ) post_objects.append(post) return post_objects def save_extraction_results(results: dict, output_path: str, topn_posts: list[dict]): """保存多模态解析结果到 JSON 文件""" # 构建输出数据 output_data = { 'total_extracted': len(results), 'extraction_results': [] } # 遍历每个解析结果 for note_id, extraction in results.items(): # 找到对应的原始帖子数据 original_post = None for post in topn_posts: if post.get('note_id') == note_id: original_post = post break # 构建结果条目 result_entry = { 'note_id': extraction.note_id, 'note_url': extraction.note_url, 'title': extraction.title, 'body_text': extraction.body_text, 'type': extraction.type, 'extraction_time': extraction.extraction_time, 'final_score': original_post.get('final_score') if original_post else None, 'images': [ { 'image_index': img.image_index, 'original_url': img.original_url, 'description': img.description, 'extract_text': img.extract_text } for img in extraction.images ] } output_data['extraction_results'].append(result_entry) # 保存到文件 with open(output_path, 'w', encoding='utf-8') as f: json.dump(output_data, f, ensure_ascii=False, indent=2) print(f"\n✅ 结果已保存到: {output_path}") async def main(context_file_path: str, output_file_path: str, top_n: int = 10, max_concurrent: int = 5, keep_raw: bool = False): """主函数 Args: context_file_path: run_context_v3.json 文件路径 output_file_path: 输出文件路径 top_n: 提取前N个帖子(默认10) max_concurrent: 最大并发数(默认5) keep_raw: 是否保留原始提取结果文件(默认False) """ print("=" * 80) print(f"多模态解析 - Top{top_n} 帖子") print("=" * 80) # 1. 加载数据 print(f"\n📂 加载文件: {context_file_path}") context_data = load_run_context(context_file_path) # 2. 提取所有帖子 print(f"\n🔍 提取所有帖子...") all_posts = extract_all_posts_from_context(context_data) print(f" 去重后共找到 {len(all_posts)} 个唯一帖子") # 3. 过滤并排序获取 topN print(f"\n📊 筛选 top{top_n} 帖子...") topn_posts = filter_and_sort_topn(all_posts, top_n) if len(topn_posts) == 0: print(" ⚠️ 没有找到有效的帖子") return print(f" Top{top_n} 帖子得分范围: {topn_posts[-1].get('final_score')} ~ {topn_posts[0].get('final_score')}") # 打印 topN 列表 print(f"\n Top{top_n} 帖子列表:") for i, post in enumerate(topn_posts, 1): print(f" {i}. [{post.get('final_score')}] {post.get('title')[:40]}... ({post.get('note_id')})") # 4. 转换为 Post 对象 print(f"\n🔄 转换为 Post 对象...") post_objects = convert_to_post_objects(topn_posts) print(f" 成功转换 {len(post_objects)} 个 Post 对象") # 5. 进行多模态解析 print(f"\n🖼️ 开始多模态图片内容解析...") print(f" (并发限制: {max_concurrent})") extraction_results = await extract_all_posts( post_objects, max_concurrent=max_concurrent ) # 6. 保存原始提取结果到临时文件 print(f"\n💾 保存原始提取结果到临时文件...") temp_output_path = output_file_path.replace('.json', '_temp_raw.json') save_extraction_results(extraction_results, temp_output_path, topn_posts) # 7. 数据清洗并写回到 run_context_v3.json print(f"\n🧹 开始数据清洗并写回到 run_context...") cleaned_posts = await clean_and_merge_to_context( context_file_path, # 写回到原始context文件 temp_output_path, # 从临时文件读取 max_concurrent=max_concurrent ) # 8. 可选:同时保存一份独立的清洗结果文件(方便查看) if keep_raw: output_data = { 'total_extracted': len(cleaned_posts), 'extraction_results': cleaned_posts } print(f"\n💾 保存独立清洗结果文件...") with open(output_file_path, 'w', encoding='utf-8') as f: json.dump(output_data, f, ensure_ascii=False, indent=2) print(f" ✅ 独立清洗结果已保存到: {output_file_path}") # 9. 清理临时文件 if os.path.exists(temp_output_path): os.remove(temp_output_path) print(f"\n🗑️ 已清理临时文件") print(f"\n✅ 完成!清洗结果已写入 {context_file_path} 的 multimodal_cleaned_posts 字段") print("\n" + "=" * 80) print("✅ 处理完成!") print("=" * 80) if __name__ == "__main__": # 创建命令行参数解析器 parser = argparse.ArgumentParser( description='从 run_context_v3.json 中提取 topN 帖子并进行多模态解析', formatter_class=argparse.RawDescriptionHelpFormatter, epilog=''' 示例用法: # 使用默认参数 (top10, 并发5, 只输出清洗后结果) python3 extract_topn_multimodal.py # 提取前20个帖子 python3 extract_topn_multimodal.py --top-n 20 # 自定义并发数 python3 extract_topn_multimodal.py --top-n 15 --max-concurrent 10 # 保留原始提取结果(会生成 *_raw.json 文件) python3 extract_topn_multimodal.py --keep-raw # 指定输入输出文件 python3 extract_topn_multimodal.py -i input.json -o output.json --top-n 30 ''' ) # 默认路径配置 DEFAULT_CONTEXT_FILE = "input/test_case/output/knowledge_search_traverse/20251119/004308_d3/run_context_v3.json" DEFAULT_OUTPUT_FILE = "input/test_case/output/knowledge_search_traverse/20251119/004308_d3/multimodal_extraction_topn_cleaned.json" # 添加参数 parser.add_argument( '-i', '--input', dest='context_file', default=DEFAULT_CONTEXT_FILE, help=f'输入的 run_context_v3.json 文件路径 (默认: {DEFAULT_CONTEXT_FILE})' ) parser.add_argument( '-o', '--output', dest='output_file', default=DEFAULT_OUTPUT_FILE, help=f'输出的 JSON 文件路径 (默认: {DEFAULT_OUTPUT_FILE})' ) parser.add_argument( '-n', '--top-n', dest='top_n', type=int, default=20, help='提取前N个帖子 (默认: 10)' ) parser.add_argument( '-c', '--max-concurrent', dest='max_concurrent', type=int, default=5, help='最大并发数 (默认: 5)' ) parser.add_argument( '--keep-raw', dest='keep_raw', action='store_true', help='保留原始提取结果文件(默认只保留清洗后的结果)' ) # 解析参数 args = parser.parse_args() # 检查文件是否存在 if not os.path.exists(args.context_file): print(f"❌ 错误: 文件不存在 - {args.context_file}") sys.exit(1) # 打印参数配置 print(f"\n📋 参数配置:") print(f" 输入文件: {args.context_file}") print(f" 输出文件: {args.output_file}") print(f" 提取数量: Top{args.top_n}") print(f" 最大并发: {args.max_concurrent}") print(f" 保留原始: {'是' if args.keep_raw else '否'}") print() # 运行主函数 asyncio.run(main( args.context_file, args.output_file, args.top_n, args.max_concurrent, keep_raw=args.keep_raw ))