| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593 |
- """
- 从 run_context_v3.json 中提取 topN 帖子并进行多模态解析和清洗
- 功能:
- 1. 读取 run_context_v3.json
- 2. 提取所有帖子,按 final_score 排序,取 topN
- 3. 使用 multimodal_extractor 进行图片内容解析
- 4. 自动进行数据清洗和结构化
- 5. 输出清洗后的 JSON 文件(默认不保留原始文件)
- 参数化配置:
- - top_n: 提取前N个帖子(默认10)
- - max_concurrent: 最大并发数(默认5)
- - keep_raw: 是否保留原始提取结果(默认False)
- """
- import argparse
- import asyncio
- import json
- import os
- import sys
- from pathlib import Path
- from typing import Optional
- import requests
- # 导入必要的模块
- from knowledge_search_traverse import Post
- from multimodal_extractor import extract_all_posts
- # ============================================================================
- # 清洗模块 - 整合自 clean_multimodal_data.py
- # ============================================================================
- MODEL_NAME = "google/gemini-2.5-flash"
- API_TIMEOUT = 60 # API 超时时间(秒)
- CLEAN_TEXT_PROMPT = """
- 请清洗以下图片文本,要求:
- 1. 去除品牌标识和装饰性文字(如"Blank Plan 计划留白"、"品牌诊断|战略定位|创意内容|VI设计|爆品传播"等)
- 2. 去除多余换行符,整理成连贯文本
- 3. **完整保留所有核心内容**,不要概括或删减
- 4. 保持原文表达和语气
- 5. 将内容整理成流畅的段落
- 图片文本:
- {extract_text}
- 请直接输出清洗后的文本(纯文本,不要任何格式标记)。
- """
- async def call_llm_for_text_cleaning(extract_text: str) -> str:
- """
- 调用LLM清洗文本
- Args:
- extract_text: 原始图片文本
- Returns:
- 清洗后的文本
- """
- # 获取API密钥
- api_key = os.getenv("OPENROUTER_API_KEY")
- if not api_key:
- raise ValueError("OPENROUTER_API_KEY environment variable not set")
- # 构建prompt
- prompt = CLEAN_TEXT_PROMPT.format(extract_text=extract_text)
- # 构建API请求
- payload = {
- "model": MODEL_NAME,
- "messages": [
- {
- "role": "user",
- "content": prompt
- }
- ]
- }
- headers = {
- "Authorization": f"Bearer {api_key}",
- "Content-Type": "application/json"
- }
- # 在异步上下文中执行同步请求
- loop = asyncio.get_event_loop()
- response = await loop.run_in_executor(
- None,
- lambda: requests.post(
- "https://openrouter.ai/api/v1/chat/completions",
- headers=headers,
- json=payload,
- timeout=API_TIMEOUT
- )
- )
- # 检查响应
- if response.status_code != 200:
- raise Exception(f"OpenRouter API error: {response.status_code} - {response.text[:200]}")
- # 解析响应
- result = response.json()
- cleaned_text = result["choices"][0]["message"]["content"].strip()
- return cleaned_text
- async def clean_single_image_text(
- extract_text: str,
- semaphore: Optional[asyncio.Semaphore] = None
- ) -> str:
- """
- 清洗单张图片的文本
- Args:
- extract_text: 原始文本
- semaphore: 并发控制信号量
- Returns:
- 清洗后的文本
- """
- try:
- if semaphore:
- async with semaphore:
- cleaned = await call_llm_for_text_cleaning(extract_text)
- else:
- cleaned = await call_llm_for_text_cleaning(extract_text)
- return cleaned
- except Exception as e:
- print(f" ⚠️ 清洗失败,保留原文: {str(e)[:100]}")
- # 如果清洗失败,返回简单清理的版本(去换行)
- return extract_text.replace('\n', ' ').strip()
- async def structure_post_content(
- post: dict,
- max_concurrent: int = 5
- ) -> dict:
- """
- 结构化整理单个帖子的内容
- Args:
- post: 帖子数据(包含images列表)
- max_concurrent: 最大并发数
- Returns:
- 添加了 content_structured 字段的帖子数据
- """
- images = post.get('images', [])
- if not images:
- # 如果没有图片,直接返回
- post['content_structured'] = {
- "total_images": 0,
- "points": [],
- "formatted_text": ""
- }
- return post
- print(f" 🧹 清洗帖子: {post.get('note_id')} ({len(images)}张图片)")
- # 创建信号量控制并发
- semaphore = asyncio.Semaphore(max_concurrent)
- # 并发清洗所有图片的文本
- tasks = []
- for img in images:
- extract_text = img.get('extract_text', '')
- if extract_text:
- task = clean_single_image_text(extract_text, semaphore)
- else:
- # 如果原始文本为空,直接返回空字符串
- task = asyncio.sleep(0, result='')
- tasks.append(task)
- cleaned_texts = await asyncio.gather(*tasks)
- # 构建结构化points
- points = []
- for idx, (img, cleaned_text) in enumerate(zip(images, cleaned_texts)):
- # 保存清洗后的文本到图片信息中
- img['extract_text_cleaned'] = cleaned_text
- # 添加到points(如果清洗后文本不为空)
- if cleaned_text:
- points.append({
- "index": idx + 1,
- "source_image": idx,
- "content": cleaned_text
- })
- # 生成格式化文本
- formatted_text = "\n".join([
- f"{p['index']}. {p['content']}"
- for p in points
- ])
- # 构建content_structured
- post['content_structured'] = {
- "total_images": len(images),
- "points": points,
- "formatted_text": formatted_text
- }
- print(f" ✅ 清洗完成: {post.get('note_id')}")
- return post
- async def clean_all_posts(
- posts: list[dict],
- max_concurrent: int = 5
- ) -> list[dict]:
- """
- 批量清洗所有帖子
- Args:
- posts: 帖子列表
- max_concurrent: 最大并发数
- Returns:
- 清洗后的帖子列表
- """
- print(f"\n 开始清洗 {len(posts)} 个帖子...")
- # 顺序处理每个帖子(但每个帖子内部的图片是并发处理的)
- cleaned_posts = []
- for post in posts:
- cleaned_post = await structure_post_content(post, max_concurrent)
- cleaned_posts.append(cleaned_post)
- print(f" 清洗完成: {len(cleaned_posts)} 个帖子")
- return cleaned_posts
- async def clean_and_merge_to_context(
- context_file_path: str,
- extraction_file_path: str,
- max_concurrent: int = 5
- ) -> list[dict]:
- """
- 清洗数据并合并到 run_context_v3.json
- Args:
- context_file_path: run_context_v3.json 文件路径
- extraction_file_path: 临时提取结果文件路径
- max_concurrent: 最大并发数
- Returns:
- 清洗后的帖子列表
- """
- # 步骤1: 加载临时提取数据
- print(f"\n 📂 加载临时提取数据: {extraction_file_path}")
- with open(extraction_file_path, 'r', encoding='utf-8') as f:
- extraction_data = json.load(f)
- posts = extraction_data.get('extraction_results', [])
- if not posts:
- print(" ⚠️ 没有找到需要清洗的帖子")
- return []
- # 步骤2: LLM清洗所有帖子
- cleaned_posts = await clean_all_posts(posts, max_concurrent)
- # 步骤3: 读取 run_context_v3.json
- print(f"\n 📂 读取 run_context: {context_file_path}")
- with open(context_file_path, 'r', encoding='utf-8') as f:
- context_data = json.load(f)
- # 步骤4: 将清洗结果写入 multimodal_cleaned_posts 字段
- from datetime import datetime
- context_data['multimodal_cleaned_posts'] = {
- 'total_posts': len(cleaned_posts),
- 'posts': cleaned_posts,
- 'extraction_time': datetime.now().isoformat(),
- 'version': 'v1.0'
- }
- # 步骤5: 保存回 run_context_v3.json
- print(f"\n 💾 保存回 run_context_v3.json...")
- with open(context_file_path, 'w', encoding='utf-8') as f:
- json.dump(context_data, f, ensure_ascii=False, indent=2)
- print(f" ✅ 清洗结果已写入 multimodal_cleaned_posts 字段")
- return cleaned_posts
- # ============================================================================
- # 原有函数
- # ============================================================================
- def load_run_context(json_path: str) -> dict:
- """加载 run_context_v3.json 文件"""
- with open(json_path, 'r', encoding='utf-8') as f:
- return json.load(f)
- def extract_all_posts_from_context(context_data: dict) -> list[dict]:
- """从 context 数据中提取所有帖子(按note_id去重,保留得分最高的)"""
- # 使用字典进行去重,key为note_id
- posts_dict = {}
- # 遍历所有轮次
- for round_data in context_data.get('rounds', []):
- # 遍历搜索结果
- for search_result in round_data.get('search_results', []):
- # 遍历帖子列表
- for post in search_result.get('post_list', []):
- note_id = post.get('note_id')
- if not note_id:
- continue
- # 如果是新帖子,直接添加
- if note_id not in posts_dict:
- posts_dict[note_id] = post
- else:
- # 如果已存在,比较final_score,保留得分更高的
- existing_score = posts_dict[note_id].get('final_score')
- current_score = post.get('final_score')
- # 如果当前帖子的分数更高,或者现有帖子没有分数,则替换
- if existing_score is None or (current_score is not None and current_score > existing_score):
- posts_dict[note_id] = post
- # 返回去重后的帖子列表
- return list(posts_dict.values())
- def filter_and_sort_topn(posts: list[dict], top_n: int = 10) -> list[dict]:
- """过滤并排序,获取 final_score topN 的帖子"""
- # 过滤掉 final_score 为 null 的帖子
- valid_posts = [p for p in posts if p.get('final_score') is not None]
- # 按 final_score 降序排序
- sorted_posts = sorted(valid_posts, key=lambda x: x.get('final_score', 0), reverse=True)
- # 取前N个
- topn = sorted_posts[:top_n]
- return topn
- def convert_to_post_objects(post_dicts: list[dict]) -> list[Post]:
- """将字典数据转换为 Post 对象"""
- post_objects = []
- for post_dict in post_dicts:
- # 创建 Post 对象,设置默认 type="normal"
- post = Post(
- note_id=post_dict.get('note_id', ''),
- note_url=post_dict.get('note_url', ''),
- title=post_dict.get('title', ''),
- body_text=post_dict.get('body_text', ''),
- type='normal', # 默认值,因为原数据缺少此字段
- images=post_dict.get('images', []),
- video=post_dict.get('video', ''),
- interact_info=post_dict.get('interact_info', {}),
- )
- post_objects.append(post)
- return post_objects
- def save_extraction_results(results: dict, output_path: str, topn_posts: list[dict]):
- """保存多模态解析结果到 JSON 文件"""
- # 构建输出数据
- output_data = {
- 'total_extracted': len(results),
- 'extraction_results': []
- }
- # 遍历每个解析结果
- for note_id, extraction in results.items():
- # 找到对应的原始帖子数据
- original_post = None
- for post in topn_posts:
- if post.get('note_id') == note_id:
- original_post = post
- break
- # 构建结果条目
- result_entry = {
- 'note_id': extraction.note_id,
- 'note_url': extraction.note_url,
- 'title': extraction.title,
- 'body_text': extraction.body_text,
- 'type': extraction.type,
- 'extraction_time': extraction.extraction_time,
- 'final_score': original_post.get('final_score') if original_post else None,
- 'images': [
- {
- 'image_index': img.image_index,
- 'original_url': img.original_url,
- 'description': img.description,
- 'extract_text': img.extract_text
- }
- for img in extraction.images
- ]
- }
- output_data['extraction_results'].append(result_entry)
- # 保存到文件
- with open(output_path, 'w', encoding='utf-8') as f:
- json.dump(output_data, f, ensure_ascii=False, indent=2)
- print(f"\n✅ 结果已保存到: {output_path}")
- async def main(context_file_path: str, output_file_path: str, top_n: int = 10,
- max_concurrent: int = 5, keep_raw: bool = False):
- """主函数
- Args:
- context_file_path: run_context_v3.json 文件路径
- output_file_path: 输出文件路径
- top_n: 提取前N个帖子(默认10)
- max_concurrent: 最大并发数(默认5)
- keep_raw: 是否保留原始提取结果文件(默认False)
- """
- print("=" * 80)
- print(f"多模态解析 - Top{top_n} 帖子")
- print("=" * 80)
- # 1. 加载数据
- print(f"\n📂 加载文件: {context_file_path}")
- context_data = load_run_context(context_file_path)
- # 2. 提取所有帖子
- print(f"\n🔍 提取所有帖子...")
- all_posts = extract_all_posts_from_context(context_data)
- print(f" 去重后共找到 {len(all_posts)} 个唯一帖子")
- # 3. 过滤并排序获取 topN
- print(f"\n📊 筛选 top{top_n} 帖子...")
- topn_posts = filter_and_sort_topn(all_posts, top_n)
- if len(topn_posts) == 0:
- print(" ⚠️ 没有找到有效的帖子")
- return
- print(f" Top{top_n} 帖子得分范围: {topn_posts[-1].get('final_score')} ~ {topn_posts[0].get('final_score')}")
- # 打印 topN 列表
- print(f"\n Top{top_n} 帖子列表:")
- for i, post in enumerate(topn_posts, 1):
- print(f" {i}. [{post.get('final_score')}] {post.get('title')[:40]}... ({post.get('note_id')})")
- # 4. 转换为 Post 对象
- print(f"\n🔄 转换为 Post 对象...")
- post_objects = convert_to_post_objects(topn_posts)
- print(f" 成功转换 {len(post_objects)} 个 Post 对象")
- # 5. 进行多模态解析
- print(f"\n🖼️ 开始多模态图片内容解析...")
- print(f" (并发限制: {max_concurrent})")
- extraction_results = await extract_all_posts(
- post_objects,
- max_concurrent=max_concurrent
- )
- # 6. 保存原始提取结果到临时文件
- print(f"\n💾 保存原始提取结果到临时文件...")
- temp_output_path = output_file_path.replace('.json', '_temp_raw.json')
- save_extraction_results(extraction_results, temp_output_path, topn_posts)
- # 7. 数据清洗并写回到 run_context_v3.json
- print(f"\n🧹 开始数据清洗并写回到 run_context...")
- cleaned_posts = await clean_and_merge_to_context(
- context_file_path, # 写回到原始context文件
- temp_output_path, # 从临时文件读取
- max_concurrent=max_concurrent
- )
- # 8. 可选:同时保存一份独立的清洗结果文件(方便查看)
- if keep_raw:
- output_data = {
- 'total_extracted': len(cleaned_posts),
- 'extraction_results': cleaned_posts
- }
- print(f"\n💾 保存独立清洗结果文件...")
- with open(output_file_path, 'w', encoding='utf-8') as f:
- json.dump(output_data, f, ensure_ascii=False, indent=2)
- print(f" ✅ 独立清洗结果已保存到: {output_file_path}")
- # 9. 清理临时文件
- if os.path.exists(temp_output_path):
- os.remove(temp_output_path)
- print(f"\n🗑️ 已清理临时文件")
- print(f"\n✅ 完成!清洗结果已写入 {context_file_path} 的 multimodal_cleaned_posts 字段")
- print("\n" + "=" * 80)
- print("✅ 处理完成!")
- print("=" * 80)
- if __name__ == "__main__":
- # 创建命令行参数解析器
- parser = argparse.ArgumentParser(
- description='从 run_context_v3.json 中提取 topN 帖子并进行多模态解析',
- formatter_class=argparse.RawDescriptionHelpFormatter,
- epilog='''
- 示例用法:
- # 使用默认参数 (top10, 并发5, 只输出清洗后结果)
- python3 extract_topn_multimodal.py
- # 提取前20个帖子
- python3 extract_topn_multimodal.py --top-n 20
- # 自定义并发数
- python3 extract_topn_multimodal.py --top-n 15 --max-concurrent 10
- # 保留原始提取结果(会生成 *_raw.json 文件)
- python3 extract_topn_multimodal.py --keep-raw
- # 指定输入输出文件
- python3 extract_topn_multimodal.py -i input.json -o output.json --top-n 30
- '''
- )
- # 默认路径配置
- DEFAULT_CONTEXT_FILE = "input/test_case/output/knowledge_search_traverse/20251119/004308_d3/run_context_v3.json"
- DEFAULT_OUTPUT_FILE = "input/test_case/output/knowledge_search_traverse/20251119/004308_d3/multimodal_extraction_topn_cleaned.json"
- # 添加参数
- parser.add_argument(
- '-i', '--input',
- dest='context_file',
- default=DEFAULT_CONTEXT_FILE,
- help=f'输入的 run_context_v3.json 文件路径 (默认: {DEFAULT_CONTEXT_FILE})'
- )
- parser.add_argument(
- '-o', '--output',
- dest='output_file',
- default=DEFAULT_OUTPUT_FILE,
- help=f'输出的 JSON 文件路径 (默认: {DEFAULT_OUTPUT_FILE})'
- )
- parser.add_argument(
- '-n', '--top-n',
- dest='top_n',
- type=int,
- default=20,
- help='提取前N个帖子 (默认: 10)'
- )
- parser.add_argument(
- '-c', '--max-concurrent',
- dest='max_concurrent',
- type=int,
- default=5,
- help='最大并发数 (默认: 5)'
- )
- parser.add_argument(
- '--keep-raw',
- dest='keep_raw',
- action='store_true',
- help='保留原始提取结果文件(默认只保留清洗后的结果)'
- )
- # 解析参数
- args = parser.parse_args()
- # 检查文件是否存在
- if not os.path.exists(args.context_file):
- print(f"❌ 错误: 文件不存在 - {args.context_file}")
- sys.exit(1)
- # 打印参数配置
- print(f"\n📋 参数配置:")
- print(f" 输入文件: {args.context_file}")
- print(f" 输出文件: {args.output_file}")
- print(f" 提取数量: Top{args.top_n}")
- print(f" 最大并发: {args.max_concurrent}")
- print(f" 保留原始: {'是' if args.keep_raw else '否'}")
- print()
- # 运行主函数
- asyncio.run(main(
- args.context_file,
- args.output_file,
- args.top_n,
- args.max_concurrent,
- keep_raw=args.keep_raw
- ))
|