extract_topn_multimodal.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593
  1. """
  2. 从 run_context_v3.json 中提取 topN 帖子并进行多模态解析和清洗
  3. 功能:
  4. 1. 读取 run_context_v3.json
  5. 2. 提取所有帖子,按 final_score 排序,取 topN
  6. 3. 使用 multimodal_extractor 进行图片内容解析
  7. 4. 自动进行数据清洗和结构化
  8. 5. 输出清洗后的 JSON 文件(默认不保留原始文件)
  9. 参数化配置:
  10. - top_n: 提取前N个帖子(默认10)
  11. - max_concurrent: 最大并发数(默认5)
  12. - keep_raw: 是否保留原始提取结果(默认False)
  13. """
  14. import argparse
  15. import asyncio
  16. import json
  17. import os
  18. import sys
  19. from pathlib import Path
  20. from typing import Optional
  21. import requests
  22. # 导入必要的模块
  23. from knowledge_search_traverse import Post
  24. from multimodal_extractor import extract_all_posts
  25. # ============================================================================
  26. # 清洗模块 - 整合自 clean_multimodal_data.py
  27. # ============================================================================
  28. MODEL_NAME = "google/gemini-2.5-flash"
  29. API_TIMEOUT = 60 # API 超时时间(秒)
  30. CLEAN_TEXT_PROMPT = """
  31. 请清洗以下图片文本,要求:
  32. 1. 去除品牌标识和装饰性文字(如"Blank Plan 计划留白"、"品牌诊断|战略定位|创意内容|VI设计|爆品传播"等)
  33. 2. 去除多余换行符,整理成连贯文本
  34. 3. **完整保留所有核心内容**,不要概括或删减
  35. 4. 保持原文表达和语气
  36. 5. 将内容整理成流畅的段落
  37. 图片文本:
  38. {extract_text}
  39. 请直接输出清洗后的文本(纯文本,不要任何格式标记)。
  40. """
  41. async def call_llm_for_text_cleaning(extract_text: str) -> str:
  42. """
  43. 调用LLM清洗文本
  44. Args:
  45. extract_text: 原始图片文本
  46. Returns:
  47. 清洗后的文本
  48. """
  49. # 获取API密钥
  50. api_key = os.getenv("OPENROUTER_API_KEY")
  51. if not api_key:
  52. raise ValueError("OPENROUTER_API_KEY environment variable not set")
  53. # 构建prompt
  54. prompt = CLEAN_TEXT_PROMPT.format(extract_text=extract_text)
  55. # 构建API请求
  56. payload = {
  57. "model": MODEL_NAME,
  58. "messages": [
  59. {
  60. "role": "user",
  61. "content": prompt
  62. }
  63. ]
  64. }
  65. headers = {
  66. "Authorization": f"Bearer {api_key}",
  67. "Content-Type": "application/json"
  68. }
  69. # 在异步上下文中执行同步请求
  70. loop = asyncio.get_event_loop()
  71. response = await loop.run_in_executor(
  72. None,
  73. lambda: requests.post(
  74. "https://openrouter.ai/api/v1/chat/completions",
  75. headers=headers,
  76. json=payload,
  77. timeout=API_TIMEOUT
  78. )
  79. )
  80. # 检查响应
  81. if response.status_code != 200:
  82. raise Exception(f"OpenRouter API error: {response.status_code} - {response.text[:200]}")
  83. # 解析响应
  84. result = response.json()
  85. cleaned_text = result["choices"][0]["message"]["content"].strip()
  86. return cleaned_text
  87. async def clean_single_image_text(
  88. extract_text: str,
  89. semaphore: Optional[asyncio.Semaphore] = None
  90. ) -> str:
  91. """
  92. 清洗单张图片的文本
  93. Args:
  94. extract_text: 原始文本
  95. semaphore: 并发控制信号量
  96. Returns:
  97. 清洗后的文本
  98. """
  99. try:
  100. if semaphore:
  101. async with semaphore:
  102. cleaned = await call_llm_for_text_cleaning(extract_text)
  103. else:
  104. cleaned = await call_llm_for_text_cleaning(extract_text)
  105. return cleaned
  106. except Exception as e:
  107. print(f" ⚠️ 清洗失败,保留原文: {str(e)[:100]}")
  108. # 如果清洗失败,返回简单清理的版本(去换行)
  109. return extract_text.replace('\n', ' ').strip()
  110. async def structure_post_content(
  111. post: dict,
  112. max_concurrent: int = 5
  113. ) -> dict:
  114. """
  115. 结构化整理单个帖子的内容
  116. Args:
  117. post: 帖子数据(包含images列表)
  118. max_concurrent: 最大并发数
  119. Returns:
  120. 添加了 content_structured 字段的帖子数据
  121. """
  122. images = post.get('images', [])
  123. if not images:
  124. # 如果没有图片,直接返回
  125. post['content_structured'] = {
  126. "total_images": 0,
  127. "points": [],
  128. "formatted_text": ""
  129. }
  130. return post
  131. print(f" 🧹 清洗帖子: {post.get('note_id')} ({len(images)}张图片)")
  132. # 创建信号量控制并发
  133. semaphore = asyncio.Semaphore(max_concurrent)
  134. # 并发清洗所有图片的文本
  135. tasks = []
  136. for img in images:
  137. extract_text = img.get('extract_text', '')
  138. if extract_text:
  139. task = clean_single_image_text(extract_text, semaphore)
  140. else:
  141. # 如果原始文本为空,直接返回空字符串
  142. task = asyncio.sleep(0, result='')
  143. tasks.append(task)
  144. cleaned_texts = await asyncio.gather(*tasks)
  145. # 构建结构化points
  146. points = []
  147. for idx, (img, cleaned_text) in enumerate(zip(images, cleaned_texts)):
  148. # 保存清洗后的文本到图片信息中
  149. img['extract_text_cleaned'] = cleaned_text
  150. # 添加到points(如果清洗后文本不为空)
  151. if cleaned_text:
  152. points.append({
  153. "index": idx + 1,
  154. "source_image": idx,
  155. "content": cleaned_text
  156. })
  157. # 生成格式化文本
  158. formatted_text = "\n".join([
  159. f"{p['index']}. {p['content']}"
  160. for p in points
  161. ])
  162. # 构建content_structured
  163. post['content_structured'] = {
  164. "total_images": len(images),
  165. "points": points,
  166. "formatted_text": formatted_text
  167. }
  168. print(f" ✅ 清洗完成: {post.get('note_id')}")
  169. return post
  170. async def clean_all_posts(
  171. posts: list[dict],
  172. max_concurrent: int = 5
  173. ) -> list[dict]:
  174. """
  175. 批量清洗所有帖子
  176. Args:
  177. posts: 帖子列表
  178. max_concurrent: 最大并发数
  179. Returns:
  180. 清洗后的帖子列表
  181. """
  182. print(f"\n 开始清洗 {len(posts)} 个帖子...")
  183. # 顺序处理每个帖子(但每个帖子内部的图片是并发处理的)
  184. cleaned_posts = []
  185. for post in posts:
  186. cleaned_post = await structure_post_content(post, max_concurrent)
  187. cleaned_posts.append(cleaned_post)
  188. print(f" 清洗完成: {len(cleaned_posts)} 个帖子")
  189. return cleaned_posts
  190. async def clean_and_merge_to_context(
  191. context_file_path: str,
  192. extraction_file_path: str,
  193. max_concurrent: int = 5
  194. ) -> list[dict]:
  195. """
  196. 清洗数据并合并到 run_context_v3.json
  197. Args:
  198. context_file_path: run_context_v3.json 文件路径
  199. extraction_file_path: 临时提取结果文件路径
  200. max_concurrent: 最大并发数
  201. Returns:
  202. 清洗后的帖子列表
  203. """
  204. # 步骤1: 加载临时提取数据
  205. print(f"\n 📂 加载临时提取数据: {extraction_file_path}")
  206. with open(extraction_file_path, 'r', encoding='utf-8') as f:
  207. extraction_data = json.load(f)
  208. posts = extraction_data.get('extraction_results', [])
  209. if not posts:
  210. print(" ⚠️ 没有找到需要清洗的帖子")
  211. return []
  212. # 步骤2: LLM清洗所有帖子
  213. cleaned_posts = await clean_all_posts(posts, max_concurrent)
  214. # 步骤3: 读取 run_context_v3.json
  215. print(f"\n 📂 读取 run_context: {context_file_path}")
  216. with open(context_file_path, 'r', encoding='utf-8') as f:
  217. context_data = json.load(f)
  218. # 步骤4: 将清洗结果写入 multimodal_cleaned_posts 字段
  219. from datetime import datetime
  220. context_data['multimodal_cleaned_posts'] = {
  221. 'total_posts': len(cleaned_posts),
  222. 'posts': cleaned_posts,
  223. 'extraction_time': datetime.now().isoformat(),
  224. 'version': 'v1.0'
  225. }
  226. # 步骤5: 保存回 run_context_v3.json
  227. print(f"\n 💾 保存回 run_context_v3.json...")
  228. with open(context_file_path, 'w', encoding='utf-8') as f:
  229. json.dump(context_data, f, ensure_ascii=False, indent=2)
  230. print(f" ✅ 清洗结果已写入 multimodal_cleaned_posts 字段")
  231. return cleaned_posts
  232. # ============================================================================
  233. # 原有函数
  234. # ============================================================================
  235. def load_run_context(json_path: str) -> dict:
  236. """加载 run_context_v3.json 文件"""
  237. with open(json_path, 'r', encoding='utf-8') as f:
  238. return json.load(f)
  239. def extract_all_posts_from_context(context_data: dict) -> list[dict]:
  240. """从 context 数据中提取所有帖子(按note_id去重,保留得分最高的)"""
  241. # 使用字典进行去重,key为note_id
  242. posts_dict = {}
  243. # 遍历所有轮次
  244. for round_data in context_data.get('rounds', []):
  245. # 遍历搜索结果
  246. for search_result in round_data.get('search_results', []):
  247. # 遍历帖子列表
  248. for post in search_result.get('post_list', []):
  249. note_id = post.get('note_id')
  250. if not note_id:
  251. continue
  252. # 如果是新帖子,直接添加
  253. if note_id not in posts_dict:
  254. posts_dict[note_id] = post
  255. else:
  256. # 如果已存在,比较final_score,保留得分更高的
  257. existing_score = posts_dict[note_id].get('final_score')
  258. current_score = post.get('final_score')
  259. # 如果当前帖子的分数更高,或者现有帖子没有分数,则替换
  260. if existing_score is None or (current_score is not None and current_score > existing_score):
  261. posts_dict[note_id] = post
  262. # 返回去重后的帖子列表
  263. return list(posts_dict.values())
  264. def filter_and_sort_topn(posts: list[dict], top_n: int = 10) -> list[dict]:
  265. """过滤并排序,获取 final_score topN 的帖子"""
  266. # 过滤掉 final_score 为 null 的帖子
  267. valid_posts = [p for p in posts if p.get('final_score') is not None]
  268. # 按 final_score 降序排序
  269. sorted_posts = sorted(valid_posts, key=lambda x: x.get('final_score', 0), reverse=True)
  270. # 取前N个
  271. topn = sorted_posts[:top_n]
  272. return topn
  273. def convert_to_post_objects(post_dicts: list[dict]) -> list[Post]:
  274. """将字典数据转换为 Post 对象"""
  275. post_objects = []
  276. for post_dict in post_dicts:
  277. # 创建 Post 对象,设置默认 type="normal"
  278. post = Post(
  279. note_id=post_dict.get('note_id', ''),
  280. note_url=post_dict.get('note_url', ''),
  281. title=post_dict.get('title', ''),
  282. body_text=post_dict.get('body_text', ''),
  283. type='normal', # 默认值,因为原数据缺少此字段
  284. images=post_dict.get('images', []),
  285. video=post_dict.get('video', ''),
  286. interact_info=post_dict.get('interact_info', {}),
  287. )
  288. post_objects.append(post)
  289. return post_objects
  290. def save_extraction_results(results: dict, output_path: str, topn_posts: list[dict]):
  291. """保存多模态解析结果到 JSON 文件"""
  292. # 构建输出数据
  293. output_data = {
  294. 'total_extracted': len(results),
  295. 'extraction_results': []
  296. }
  297. # 遍历每个解析结果
  298. for note_id, extraction in results.items():
  299. # 找到对应的原始帖子数据
  300. original_post = None
  301. for post in topn_posts:
  302. if post.get('note_id') == note_id:
  303. original_post = post
  304. break
  305. # 构建结果条目
  306. result_entry = {
  307. 'note_id': extraction.note_id,
  308. 'note_url': extraction.note_url,
  309. 'title': extraction.title,
  310. 'body_text': extraction.body_text,
  311. 'type': extraction.type,
  312. 'extraction_time': extraction.extraction_time,
  313. 'final_score': original_post.get('final_score') if original_post else None,
  314. 'images': [
  315. {
  316. 'image_index': img.image_index,
  317. 'original_url': img.original_url,
  318. 'description': img.description,
  319. 'extract_text': img.extract_text
  320. }
  321. for img in extraction.images
  322. ]
  323. }
  324. output_data['extraction_results'].append(result_entry)
  325. # 保存到文件
  326. with open(output_path, 'w', encoding='utf-8') as f:
  327. json.dump(output_data, f, ensure_ascii=False, indent=2)
  328. print(f"\n✅ 结果已保存到: {output_path}")
  329. async def main(context_file_path: str, output_file_path: str, top_n: int = 10,
  330. max_concurrent: int = 5, keep_raw: bool = False):
  331. """主函数
  332. Args:
  333. context_file_path: run_context_v3.json 文件路径
  334. output_file_path: 输出文件路径
  335. top_n: 提取前N个帖子(默认10)
  336. max_concurrent: 最大并发数(默认5)
  337. keep_raw: 是否保留原始提取结果文件(默认False)
  338. """
  339. print("=" * 80)
  340. print(f"多模态解析 - Top{top_n} 帖子")
  341. print("=" * 80)
  342. # 1. 加载数据
  343. print(f"\n📂 加载文件: {context_file_path}")
  344. context_data = load_run_context(context_file_path)
  345. # 2. 提取所有帖子
  346. print(f"\n🔍 提取所有帖子...")
  347. all_posts = extract_all_posts_from_context(context_data)
  348. print(f" 去重后共找到 {len(all_posts)} 个唯一帖子")
  349. # 3. 过滤并排序获取 topN
  350. print(f"\n📊 筛选 top{top_n} 帖子...")
  351. topn_posts = filter_and_sort_topn(all_posts, top_n)
  352. if len(topn_posts) == 0:
  353. print(" ⚠️ 没有找到有效的帖子")
  354. return
  355. print(f" Top{top_n} 帖子得分范围: {topn_posts[-1].get('final_score')} ~ {topn_posts[0].get('final_score')}")
  356. # 打印 topN 列表
  357. print(f"\n Top{top_n} 帖子列表:")
  358. for i, post in enumerate(topn_posts, 1):
  359. print(f" {i}. [{post.get('final_score')}] {post.get('title')[:40]}... ({post.get('note_id')})")
  360. # 4. 转换为 Post 对象
  361. print(f"\n🔄 转换为 Post 对象...")
  362. post_objects = convert_to_post_objects(topn_posts)
  363. print(f" 成功转换 {len(post_objects)} 个 Post 对象")
  364. # 5. 进行多模态解析
  365. print(f"\n🖼️ 开始多模态图片内容解析...")
  366. print(f" (并发限制: {max_concurrent})")
  367. extraction_results = await extract_all_posts(
  368. post_objects,
  369. max_concurrent=max_concurrent
  370. )
  371. # 6. 保存原始提取结果到临时文件
  372. print(f"\n💾 保存原始提取结果到临时文件...")
  373. temp_output_path = output_file_path.replace('.json', '_temp_raw.json')
  374. save_extraction_results(extraction_results, temp_output_path, topn_posts)
  375. # 7. 数据清洗并写回到 run_context_v3.json
  376. print(f"\n🧹 开始数据清洗并写回到 run_context...")
  377. cleaned_posts = await clean_and_merge_to_context(
  378. context_file_path, # 写回到原始context文件
  379. temp_output_path, # 从临时文件读取
  380. max_concurrent=max_concurrent
  381. )
  382. # 8. 可选:同时保存一份独立的清洗结果文件(方便查看)
  383. if keep_raw:
  384. output_data = {
  385. 'total_extracted': len(cleaned_posts),
  386. 'extraction_results': cleaned_posts
  387. }
  388. print(f"\n💾 保存独立清洗结果文件...")
  389. with open(output_file_path, 'w', encoding='utf-8') as f:
  390. json.dump(output_data, f, ensure_ascii=False, indent=2)
  391. print(f" ✅ 独立清洗结果已保存到: {output_file_path}")
  392. # 9. 清理临时文件
  393. if os.path.exists(temp_output_path):
  394. os.remove(temp_output_path)
  395. print(f"\n🗑️ 已清理临时文件")
  396. print(f"\n✅ 完成!清洗结果已写入 {context_file_path} 的 multimodal_cleaned_posts 字段")
  397. print("\n" + "=" * 80)
  398. print("✅ 处理完成!")
  399. print("=" * 80)
  400. if __name__ == "__main__":
  401. # 创建命令行参数解析器
  402. parser = argparse.ArgumentParser(
  403. description='从 run_context_v3.json 中提取 topN 帖子并进行多模态解析',
  404. formatter_class=argparse.RawDescriptionHelpFormatter,
  405. epilog='''
  406. 示例用法:
  407. # 使用默认参数 (top10, 并发5, 只输出清洗后结果)
  408. python3 extract_topn_multimodal.py
  409. # 提取前20个帖子
  410. python3 extract_topn_multimodal.py --top-n 20
  411. # 自定义并发数
  412. python3 extract_topn_multimodal.py --top-n 15 --max-concurrent 10
  413. # 保留原始提取结果(会生成 *_raw.json 文件)
  414. python3 extract_topn_multimodal.py --keep-raw
  415. # 指定输入输出文件
  416. python3 extract_topn_multimodal.py -i input.json -o output.json --top-n 30
  417. '''
  418. )
  419. # 默认路径配置
  420. DEFAULT_CONTEXT_FILE = "input/test_case/output/knowledge_search_traverse/20251119/004308_d3/run_context_v3.json"
  421. DEFAULT_OUTPUT_FILE = "input/test_case/output/knowledge_search_traverse/20251119/004308_d3/multimodal_extraction_topn_cleaned.json"
  422. # 添加参数
  423. parser.add_argument(
  424. '-i', '--input',
  425. dest='context_file',
  426. default=DEFAULT_CONTEXT_FILE,
  427. help=f'输入的 run_context_v3.json 文件路径 (默认: {DEFAULT_CONTEXT_FILE})'
  428. )
  429. parser.add_argument(
  430. '-o', '--output',
  431. dest='output_file',
  432. default=DEFAULT_OUTPUT_FILE,
  433. help=f'输出的 JSON 文件路径 (默认: {DEFAULT_OUTPUT_FILE})'
  434. )
  435. parser.add_argument(
  436. '-n', '--top-n',
  437. dest='top_n',
  438. type=int,
  439. default=20,
  440. help='提取前N个帖子 (默认: 10)'
  441. )
  442. parser.add_argument(
  443. '-c', '--max-concurrent',
  444. dest='max_concurrent',
  445. type=int,
  446. default=5,
  447. help='最大并发数 (默认: 5)'
  448. )
  449. parser.add_argument(
  450. '--keep-raw',
  451. dest='keep_raw',
  452. action='store_true',
  453. help='保留原始提取结果文件(默认只保留清洗后的结果)'
  454. )
  455. # 解析参数
  456. args = parser.parse_args()
  457. # 检查文件是否存在
  458. if not os.path.exists(args.context_file):
  459. print(f"❌ 错误: 文件不存在 - {args.context_file}")
  460. sys.exit(1)
  461. # 打印参数配置
  462. print(f"\n📋 参数配置:")
  463. print(f" 输入文件: {args.context_file}")
  464. print(f" 输出文件: {args.output_file}")
  465. print(f" 提取数量: Top{args.top_n}")
  466. print(f" 最大并发: {args.max_concurrent}")
  467. print(f" 保留原始: {'是' if args.keep_raw else '否'}")
  468. print()
  469. # 运行主函数
  470. asyncio.run(main(
  471. args.context_file,
  472. args.output_file,
  473. args.top_n,
  474. args.max_concurrent,
  475. keep_raw=args.keep_raw
  476. ))