|
|
@@ -1,15 +1,17 @@
|
|
|
"""
|
|
|
-从 run_context_v3.json 中提取 topN 帖子并进行多模态解析
|
|
|
+从 run_context_v3.json 中提取 topN 帖子并进行多模态解析和清洗
|
|
|
|
|
|
功能:
|
|
|
1. 读取 run_context_v3.json
|
|
|
2. 提取所有帖子,按 final_score 排序,取 topN
|
|
|
3. 使用 multimodal_extractor 进行图片内容解析
|
|
|
-4. 保存结果到独立的 JSON 文件
|
|
|
+4. 自动进行数据清洗和结构化
|
|
|
+5. 输出清洗后的 JSON 文件(默认不保留原始文件)
|
|
|
|
|
|
参数化配置:
|
|
|
- top_n: 提取前N个帖子(默认10)
|
|
|
- max_concurrent: 最大并发数(默认5)
|
|
|
+- keep_raw: 是否保留原始提取结果(默认False)
|
|
|
"""
|
|
|
|
|
|
import argparse
|
|
|
@@ -19,12 +21,283 @@ import os
|
|
|
import sys
|
|
|
from pathlib import Path
|
|
|
from typing import Optional
|
|
|
+import requests
|
|
|
|
|
|
# 导入必要的模块
|
|
|
from knowledge_search_traverse import Post
|
|
|
from multimodal_extractor import extract_all_posts
|
|
|
|
|
|
|
|
|
+# ============================================================================
|
|
|
+# 清洗模块 - 整合自 clean_multimodal_data.py
|
|
|
+# ============================================================================
|
|
|
+
|
|
|
+MODEL_NAME = "google/gemini-2.5-flash"
|
|
|
+API_TIMEOUT = 60 # API 超时时间(秒)
|
|
|
+
|
|
|
+CLEAN_TEXT_PROMPT = """
|
|
|
+请清洗以下图片文本,要求:
|
|
|
+
|
|
|
+1. 去除品牌标识和装饰性文字(如"Blank Plan 计划留白"、"品牌诊断|战略定位|创意内容|VI设计|爆品传播"等)
|
|
|
+2. 去除多余换行符,整理成连贯文本
|
|
|
+3. **完整保留所有核心内容**,不要概括或删减
|
|
|
+4. 保持原文表达和语气
|
|
|
+5. 将内容整理成流畅的段落
|
|
|
+
|
|
|
+图片文本:
|
|
|
+{extract_text}
|
|
|
+
|
|
|
+请直接输出清洗后的文本(纯文本,不要任何格式标记)。
|
|
|
+"""
|
|
|
+
|
|
|
+
|
|
|
+async def call_llm_for_text_cleaning(extract_text: str) -> str:
|
|
|
+ """
|
|
|
+ 调用LLM清洗文本
|
|
|
+
|
|
|
+ Args:
|
|
|
+ extract_text: 原始图片文本
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ 清洗后的文本
|
|
|
+ """
|
|
|
+ # 获取API密钥
|
|
|
+ api_key = os.getenv("OPENROUTER_API_KEY")
|
|
|
+ if not api_key:
|
|
|
+ raise ValueError("OPENROUTER_API_KEY environment variable not set")
|
|
|
+
|
|
|
+ # 构建prompt
|
|
|
+ prompt = CLEAN_TEXT_PROMPT.format(extract_text=extract_text)
|
|
|
+
|
|
|
+ # 构建API请求
|
|
|
+ payload = {
|
|
|
+ "model": MODEL_NAME,
|
|
|
+ "messages": [
|
|
|
+ {
|
|
|
+ "role": "user",
|
|
|
+ "content": prompt
|
|
|
+ }
|
|
|
+ ]
|
|
|
+ }
|
|
|
+
|
|
|
+ headers = {
|
|
|
+ "Authorization": f"Bearer {api_key}",
|
|
|
+ "Content-Type": "application/json"
|
|
|
+ }
|
|
|
+
|
|
|
+ # 在异步上下文中执行同步请求
|
|
|
+ loop = asyncio.get_event_loop()
|
|
|
+ response = await loop.run_in_executor(
|
|
|
+ None,
|
|
|
+ lambda: requests.post(
|
|
|
+ "https://openrouter.ai/api/v1/chat/completions",
|
|
|
+ headers=headers,
|
|
|
+ json=payload,
|
|
|
+ timeout=API_TIMEOUT
|
|
|
+ )
|
|
|
+ )
|
|
|
+
|
|
|
+ # 检查响应
|
|
|
+ if response.status_code != 200:
|
|
|
+ raise Exception(f"OpenRouter API error: {response.status_code} - {response.text[:200]}")
|
|
|
+
|
|
|
+ # 解析响应
|
|
|
+ result = response.json()
|
|
|
+ cleaned_text = result["choices"][0]["message"]["content"].strip()
|
|
|
+
|
|
|
+ return cleaned_text
|
|
|
+
|
|
|
+
|
|
|
+async def clean_single_image_text(
|
|
|
+ extract_text: str,
|
|
|
+ semaphore: Optional[asyncio.Semaphore] = None
|
|
|
+) -> str:
|
|
|
+ """
|
|
|
+ 清洗单张图片的文本
|
|
|
+
|
|
|
+ Args:
|
|
|
+ extract_text: 原始文本
|
|
|
+ semaphore: 并发控制信号量
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ 清洗后的文本
|
|
|
+ """
|
|
|
+ try:
|
|
|
+ if semaphore:
|
|
|
+ async with semaphore:
|
|
|
+ cleaned = await call_llm_for_text_cleaning(extract_text)
|
|
|
+ else:
|
|
|
+ cleaned = await call_llm_for_text_cleaning(extract_text)
|
|
|
+
|
|
|
+ return cleaned
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ print(f" ⚠️ 清洗失败,保留原文: {str(e)[:100]}")
|
|
|
+ # 如果清洗失败,返回简单清理的版本(去换行)
|
|
|
+ return extract_text.replace('\n', ' ').strip()
|
|
|
+
|
|
|
+
|
|
|
+async def structure_post_content(
|
|
|
+ post: dict,
|
|
|
+ max_concurrent: int = 5
|
|
|
+) -> dict:
|
|
|
+ """
|
|
|
+ 结构化整理单个帖子的内容
|
|
|
+
|
|
|
+ Args:
|
|
|
+ post: 帖子数据(包含images列表)
|
|
|
+ max_concurrent: 最大并发数
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ 添加了 content_structured 字段的帖子数据
|
|
|
+ """
|
|
|
+ images = post.get('images', [])
|
|
|
+
|
|
|
+ if not images:
|
|
|
+ # 如果没有图片,直接返回
|
|
|
+ post['content_structured'] = {
|
|
|
+ "total_images": 0,
|
|
|
+ "points": [],
|
|
|
+ "formatted_text": ""
|
|
|
+ }
|
|
|
+ return post
|
|
|
+
|
|
|
+ print(f" 🧹 清洗帖子: {post.get('note_id')} ({len(images)}张图片)")
|
|
|
+
|
|
|
+ # 创建信号量控制并发
|
|
|
+ semaphore = asyncio.Semaphore(max_concurrent)
|
|
|
+
|
|
|
+ # 并发清洗所有图片的文本
|
|
|
+ tasks = []
|
|
|
+ for img in images:
|
|
|
+ extract_text = img.get('extract_text', '')
|
|
|
+ if extract_text:
|
|
|
+ task = clean_single_image_text(extract_text, semaphore)
|
|
|
+ else:
|
|
|
+ # 如果原始文本为空,直接返回空字符串
|
|
|
+ task = asyncio.sleep(0, result='')
|
|
|
+ tasks.append(task)
|
|
|
+
|
|
|
+ cleaned_texts = await asyncio.gather(*tasks)
|
|
|
+
|
|
|
+ # 构建结构化points
|
|
|
+ points = []
|
|
|
+ for idx, (img, cleaned_text) in enumerate(zip(images, cleaned_texts)):
|
|
|
+ # 保存清洗后的文本到图片信息中
|
|
|
+ img['extract_text_cleaned'] = cleaned_text
|
|
|
+
|
|
|
+ # 添加到points(如果清洗后文本不为空)
|
|
|
+ if cleaned_text:
|
|
|
+ points.append({
|
|
|
+ "index": idx + 1,
|
|
|
+ "source_image": idx,
|
|
|
+ "content": cleaned_text
|
|
|
+ })
|
|
|
+
|
|
|
+ # 生成格式化文本
|
|
|
+ formatted_text = "\n".join([
|
|
|
+ f"{p['index']}. {p['content']}"
|
|
|
+ for p in points
|
|
|
+ ])
|
|
|
+
|
|
|
+ # 构建content_structured
|
|
|
+ post['content_structured'] = {
|
|
|
+ "total_images": len(images),
|
|
|
+ "points": points,
|
|
|
+ "formatted_text": formatted_text
|
|
|
+ }
|
|
|
+
|
|
|
+ print(f" ✅ 清洗完成: {post.get('note_id')}")
|
|
|
+
|
|
|
+ return post
|
|
|
+
|
|
|
+
|
|
|
+async def clean_all_posts(
|
|
|
+ posts: list[dict],
|
|
|
+ max_concurrent: int = 5
|
|
|
+) -> list[dict]:
|
|
|
+ """
|
|
|
+ 批量清洗所有帖子
|
|
|
+
|
|
|
+ Args:
|
|
|
+ posts: 帖子列表
|
|
|
+ max_concurrent: 最大并发数
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ 清洗后的帖子列表
|
|
|
+ """
|
|
|
+ print(f"\n 开始清洗 {len(posts)} 个帖子...")
|
|
|
+
|
|
|
+ # 顺序处理每个帖子(但每个帖子内部的图片是并发处理的)
|
|
|
+ cleaned_posts = []
|
|
|
+ for post in posts:
|
|
|
+ cleaned_post = await structure_post_content(post, max_concurrent)
|
|
|
+ cleaned_posts.append(cleaned_post)
|
|
|
+
|
|
|
+ print(f" 清洗完成: {len(cleaned_posts)} 个帖子")
|
|
|
+
|
|
|
+ return cleaned_posts
|
|
|
+
|
|
|
+
|
|
|
+async def clean_and_merge_to_context(
|
|
|
+ context_file_path: str,
|
|
|
+ extraction_file_path: str,
|
|
|
+ max_concurrent: int = 5
|
|
|
+) -> list[dict]:
|
|
|
+ """
|
|
|
+ 清洗数据并合并到 run_context_v3.json
|
|
|
+
|
|
|
+ Args:
|
|
|
+ context_file_path: run_context_v3.json 文件路径
|
|
|
+ extraction_file_path: 临时提取结果文件路径
|
|
|
+ max_concurrent: 最大并发数
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ 清洗后的帖子列表
|
|
|
+ """
|
|
|
+ # 步骤1: 加载临时提取数据
|
|
|
+ print(f"\n 📂 加载临时提取数据: {extraction_file_path}")
|
|
|
+ with open(extraction_file_path, 'r', encoding='utf-8') as f:
|
|
|
+ extraction_data = json.load(f)
|
|
|
+
|
|
|
+ posts = extraction_data.get('extraction_results', [])
|
|
|
+
|
|
|
+ if not posts:
|
|
|
+ print(" ⚠️ 没有找到需要清洗的帖子")
|
|
|
+ return []
|
|
|
+
|
|
|
+ # 步骤2: LLM清洗所有帖子
|
|
|
+ cleaned_posts = await clean_all_posts(posts, max_concurrent)
|
|
|
+
|
|
|
+ # 步骤3: 读取 run_context_v3.json
|
|
|
+ print(f"\n 📂 读取 run_context: {context_file_path}")
|
|
|
+ with open(context_file_path, 'r', encoding='utf-8') as f:
|
|
|
+ context_data = json.load(f)
|
|
|
+
|
|
|
+ # 步骤4: 将清洗结果写入 multimodal_cleaned_posts 字段
|
|
|
+ from datetime import datetime
|
|
|
+ context_data['multimodal_cleaned_posts'] = {
|
|
|
+ 'total_posts': len(cleaned_posts),
|
|
|
+ 'posts': cleaned_posts,
|
|
|
+ 'extraction_time': datetime.now().isoformat(),
|
|
|
+ 'version': 'v1.0'
|
|
|
+ }
|
|
|
+
|
|
|
+ # 步骤5: 保存回 run_context_v3.json
|
|
|
+ print(f"\n 💾 保存回 run_context_v3.json...")
|
|
|
+ with open(context_file_path, 'w', encoding='utf-8') as f:
|
|
|
+ json.dump(context_data, f, ensure_ascii=False, indent=2)
|
|
|
+
|
|
|
+ print(f" ✅ 清洗结果已写入 multimodal_cleaned_posts 字段")
|
|
|
+
|
|
|
+ return cleaned_posts
|
|
|
+
|
|
|
+
|
|
|
+# ============================================================================
|
|
|
+# 原有函数
|
|
|
+# ============================================================================
|
|
|
+
|
|
|
+
|
|
|
def load_run_context(json_path: str) -> dict:
|
|
|
"""加载 run_context_v3.json 文件"""
|
|
|
with open(json_path, 'r', encoding='utf-8') as f:
|
|
|
@@ -144,7 +417,7 @@ def save_extraction_results(results: dict, output_path: str, topn_posts: list[di
|
|
|
|
|
|
|
|
|
async def main(context_file_path: str, output_file_path: str, top_n: int = 10,
|
|
|
- max_concurrent: int = 5):
|
|
|
+ max_concurrent: int = 5, keep_raw: bool = False):
|
|
|
"""主函数
|
|
|
|
|
|
Args:
|
|
|
@@ -152,6 +425,7 @@ async def main(context_file_path: str, output_file_path: str, top_n: int = 10,
|
|
|
output_file_path: 输出文件路径
|
|
|
top_n: 提取前N个帖子(默认10)
|
|
|
max_concurrent: 最大并发数(默认5)
|
|
|
+ keep_raw: 是否保留原始提取结果文件(默认False)
|
|
|
"""
|
|
|
print("=" * 80)
|
|
|
print(f"多模态解析 - Top{top_n} 帖子")
|
|
|
@@ -194,9 +468,36 @@ async def main(context_file_path: str, output_file_path: str, top_n: int = 10,
|
|
|
max_concurrent=max_concurrent
|
|
|
)
|
|
|
|
|
|
- # 6. 保存结果
|
|
|
- print(f"\n💾 保存解析结果...")
|
|
|
- save_extraction_results(extraction_results, output_file_path, topn_posts)
|
|
|
+ # 6. 保存原始提取结果到临时文件
|
|
|
+ print(f"\n💾 保存原始提取结果到临时文件...")
|
|
|
+ temp_output_path = output_file_path.replace('.json', '_temp_raw.json')
|
|
|
+ save_extraction_results(extraction_results, temp_output_path, topn_posts)
|
|
|
+
|
|
|
+ # 7. 数据清洗并写回到 run_context_v3.json
|
|
|
+ print(f"\n🧹 开始数据清洗并写回到 run_context...")
|
|
|
+ cleaned_posts = await clean_and_merge_to_context(
|
|
|
+ context_file_path, # 写回到原始context文件
|
|
|
+ temp_output_path, # 从临时文件读取
|
|
|
+ max_concurrent=max_concurrent
|
|
|
+ )
|
|
|
+
|
|
|
+ # 8. 可选:同时保存一份独立的清洗结果文件(方便查看)
|
|
|
+ if keep_raw:
|
|
|
+ output_data = {
|
|
|
+ 'total_extracted': len(cleaned_posts),
|
|
|
+ 'extraction_results': cleaned_posts
|
|
|
+ }
|
|
|
+ print(f"\n💾 保存独立清洗结果文件...")
|
|
|
+ with open(output_file_path, 'w', encoding='utf-8') as f:
|
|
|
+ json.dump(output_data, f, ensure_ascii=False, indent=2)
|
|
|
+ print(f" ✅ 独立清洗结果已保存到: {output_file_path}")
|
|
|
+
|
|
|
+ # 9. 清理临时文件
|
|
|
+ if os.path.exists(temp_output_path):
|
|
|
+ os.remove(temp_output_path)
|
|
|
+ print(f"\n🗑️ 已清理临时文件")
|
|
|
+
|
|
|
+ print(f"\n✅ 完成!清洗结果已写入 {context_file_path} 的 multimodal_cleaned_posts 字段")
|
|
|
|
|
|
print("\n" + "=" * 80)
|
|
|
print("✅ 处理完成!")
|
|
|
@@ -210,7 +511,7 @@ if __name__ == "__main__":
|
|
|
formatter_class=argparse.RawDescriptionHelpFormatter,
|
|
|
epilog='''
|
|
|
示例用法:
|
|
|
- # 使用默认参数 (top10, 并发5)
|
|
|
+ # 使用默认参数 (top10, 并发5, 只输出清洗后结果)
|
|
|
python3 extract_topn_multimodal.py
|
|
|
|
|
|
# 提取前20个帖子
|
|
|
@@ -219,14 +520,17 @@ if __name__ == "__main__":
|
|
|
# 自定义并发数
|
|
|
python3 extract_topn_multimodal.py --top-n 15 --max-concurrent 10
|
|
|
|
|
|
+ # 保留原始提取结果(会生成 *_raw.json 文件)
|
|
|
+ python3 extract_topn_multimodal.py --keep-raw
|
|
|
+
|
|
|
# 指定输入输出文件
|
|
|
python3 extract_topn_multimodal.py -i input.json -o output.json --top-n 30
|
|
|
'''
|
|
|
)
|
|
|
|
|
|
# 默认路径配置
|
|
|
- DEFAULT_CONTEXT_FILE = "input/test_case/output/knowledge_search_traverse/20251114/005215_b1/run_context_v3.json"
|
|
|
- DEFAULT_OUTPUT_FILE = "input/test_case/output/knowledge_search_traverse/20251114/005215_b1/multimodal_extraction_topn.json"
|
|
|
+ DEFAULT_CONTEXT_FILE = "input/test_case/output/knowledge_search_traverse/20251118/194351_e3/run_context_v3.json"
|
|
|
+ DEFAULT_OUTPUT_FILE = "input/test_case/output/knowledge_search_traverse/20251118/194351_e3/multimodal_extraction_topn_cleaned.json"
|
|
|
|
|
|
# 添加参数
|
|
|
parser.add_argument(
|
|
|
@@ -255,6 +559,12 @@ if __name__ == "__main__":
|
|
|
default=5,
|
|
|
help='最大并发数 (默认: 5)'
|
|
|
)
|
|
|
+ parser.add_argument(
|
|
|
+ '--keep-raw',
|
|
|
+ dest='keep_raw',
|
|
|
+ action='store_true',
|
|
|
+ help='保留原始提取结果文件(默认只保留清洗后的结果)'
|
|
|
+ )
|
|
|
|
|
|
# 解析参数
|
|
|
args = parser.parse_args()
|
|
|
@@ -270,6 +580,7 @@ if __name__ == "__main__":
|
|
|
print(f" 输出文件: {args.output_file}")
|
|
|
print(f" 提取数量: Top{args.top_n}")
|
|
|
print(f" 最大并发: {args.max_concurrent}")
|
|
|
+ print(f" 保留原始: {'是' if args.keep_raw else '否'}")
|
|
|
print()
|
|
|
|
|
|
# 运行主函数
|
|
|
@@ -277,5 +588,6 @@ if __name__ == "__main__":
|
|
|
args.context_file,
|
|
|
args.output_file,
|
|
|
args.top_n,
|
|
|
- args.max_concurrent
|
|
|
+ args.max_concurrent,
|
|
|
+ keep_raw=args.keep_raw
|
|
|
))
|