multimodal_extractor.py 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287
  1. """
  2. 多模态图片内容提取模块
  3. 功能:
  4. 1. 对帖子的图片进行文字提取和语义描述
  5. 2. 支持一次性处理多张图片(最多10张)
  6. 3. 使用Gemini多模态模型(直接调用OpenRouter API)
  7. 4. 视频帖子自动跳过
  8. """
  9. import asyncio
  10. import json
  11. import os
  12. from datetime import datetime
  13. from typing import Optional
  14. from pydantic import BaseModel, Field
  15. import requests
  16. MODEL_NAME = "google/gemini-2.5-flash"
  17. MAX_IMAGES_PER_POST = 10 # 最大处理图片数
  18. MAX_CONCURRENT_EXTRACTIONS = 5 # 最大并发提取数
  19. API_TIMEOUT = 120 # API 超时时间(秒)
  20. # ============================================================================
  21. # 数据模型
  22. # ============================================================================
  23. class ImageExtraction(BaseModel):
  24. """单张图片的提取结果"""
  25. image_index: int = Field(..., description="图片索引")
  26. original_url: str = Field(..., description="原始图片URL")
  27. description: str = Field(..., description="图片的详细描述(200-500字)")
  28. extract_text: str = Field(..., description="提取的文字内容")
  29. class PostExtraction(BaseModel):
  30. """帖子的完整提取结果"""
  31. note_id: str
  32. note_url: str
  33. title: str
  34. body_text: str
  35. type: str
  36. extraction_time: str
  37. images: list[ImageExtraction] = Field(default_factory=list)
  38. # ============================================================================
  39. # Prompt 定义
  40. # ============================================================================
  41. ANALYSIS_PROMPT_TEMPLATE = """
  42. 你是一名专业的图像内容分析和文字提取专家。
  43. 请分析以下{num_images}张图片,这些图片来自标题为《{title}》的帖子。
  44. ## 任务
  45. 为**每张图片**分别提取两类信息:
  46. ### 1. description(图片描述)
  47. 对图片进行详细、全面的描述,包括但不限于:
  48. - **整体场景**:图片展示的主要场景或主题
  49. - **核心元素**:图片中的关键元素(人物、物体、文字、图表等)
  50. - **元素细节**:每个元素的特征(颜色、形状、位置、大小、状态等)
  51. - **空间布局**:元素的位置关系和排列方式
  52. - **视觉风格**:摄影风格、设计风格、色调、光线等
  53. - **文字内容**:如果有文字,简要说明文字的主题和作用
  54. - **情感氛围**:图片传达的情绪、氛围或意图
  55. **要求**:
  56. - 使用自然、流畅的语言
  57. - 从整体到局部、从主要到次要
  58. - 准确、具体、客观
  59. - 字数控制在200-500字之间
  60. ### 2. extract_text(文字提取)
  61. 精准提取图片中的所有可见文字内容。
  62. **要求**:
  63. 1. 仅提取可见文字,不改写、总结或推理
  64. 2. 如有结构(表格、图表、标题、段落),按结构输出
  65. 3. 保持原始顺序和排版逻辑
  66. 4. 不需要OCR校正,原样提取
  67. 5. 舍弃与标题不相关的文字
  68. 6. 结构不明确时,按从上到下、从左到右顺序提取
  69. 7. 如果图片无文字,输出空字符串""
  70. ## 输出要求
  71. 必须返回一个JSON对象,包含images数组,每个元素对应一张图片:
  72. {{
  73. "images": [
  74. {{
  75. "description": "第1张图片的详细描述...",
  76. "extract_text": "第1张图片提取的文字内容..."
  77. }},
  78. {{
  79. "description": "第2张图片的详细描述...",
  80. "extract_text": "第2张图片提取的文字内容..."
  81. }}
  82. ]
  83. }}
  84. ## 重要提示
  85. - images数组的顺序必须与输入图片顺序一致
  86. - 每张图片都必须有对应的结果
  87. - 如果某张图片无文字,extract_text设为空字符串""
  88. - 如果某张图片无法分析,description简要说明原因
  89. """.strip()
  90. # ============================================================================
  91. # 核心提取函数
  92. # ============================================================================
  93. async def extract_post_images(
  94. post, # Post对象
  95. semaphore: Optional[asyncio.Semaphore] = None
  96. ) -> Optional[PostExtraction]:
  97. """
  98. 提取单个帖子的图片内容
  99. Args:
  100. post: Post对象(包含images列表)
  101. semaphore: 可选的信号量用于并发控制
  102. Returns:
  103. PostExtraction对象,提取失败返回None
  104. """
  105. # 视频帖子跳过
  106. if post.type == "video":
  107. print(f" ⊗ 跳过视频帖子: {post.note_id}")
  108. return None
  109. # 没有图片跳过
  110. if not post.images or len(post.images) == 0:
  111. print(f" ⊗ 帖子无图片: {post.note_id}")
  112. return None
  113. # 限制图片数量
  114. image_urls = post.images[:MAX_IMAGES_PER_POST]
  115. image_count = len(image_urls)
  116. print(f" 🖼️ 开始提取图片内容: {post.note_id} ({image_count}张图片)")
  117. try:
  118. # 如果有信号量,使用它进行并发控制
  119. if semaphore:
  120. async with semaphore:
  121. result = await _extract_images(image_urls, post)
  122. else:
  123. result = await _extract_images(image_urls, post)
  124. print(f" ✅ 提取完成: {post.note_id}")
  125. return result
  126. except Exception as e:
  127. print(f" ❌ 提取失败: {post.note_id} - {str(e)[:100]}")
  128. return None
  129. async def _extract_images(image_urls: list[str], post) -> PostExtraction:
  130. """
  131. 实际执行图片提取的内部函数 - 直接调用OpenRouter API
  132. """
  133. # 获取API密钥
  134. api_key = os.getenv("OPENROUTER_API_KEY")
  135. if not api_key:
  136. raise ValueError("OPENROUTER_API_KEY environment variable not set")
  137. # 构建提示文本
  138. prompt_text = ANALYSIS_PROMPT_TEMPLATE.format(
  139. num_images=len(image_urls),
  140. title=post.title
  141. )
  142. # 构建消息内容:文本 + 多张图片
  143. content = [{"type": "text", "text": prompt_text}]
  144. for url in image_urls:
  145. content.append({
  146. "type": "image_url",
  147. "image_url": {"url": url}
  148. })
  149. # 构建API请求
  150. payload = {
  151. "model": MODEL_NAME,
  152. "messages": [{"role": "user", "content": content}],
  153. "response_format": {"type": "json_object"}
  154. }
  155. headers = {
  156. "Authorization": f"Bearer {api_key}",
  157. "Content-Type": "application/json"
  158. }
  159. # 在异步上下文中执行同步请求
  160. loop = asyncio.get_event_loop()
  161. response = await loop.run_in_executor(
  162. None,
  163. lambda: requests.post(
  164. "https://openrouter.ai/api/v1/chat/completions",
  165. headers=headers,
  166. json=payload,
  167. timeout=API_TIMEOUT
  168. )
  169. )
  170. # 检查响应
  171. if response.status_code != 200:
  172. raise Exception(f"OpenRouter API error: {response.status_code} - {response.text[:200]}")
  173. # 解析响应
  174. result = response.json()
  175. content_text = result["choices"][0]["message"]["content"]
  176. # 去除Markdown代码块标记(Gemini即使设置了json_object也会返回带```json标记的内容)
  177. content_text = content_text.strip()
  178. if content_text.startswith("```json"):
  179. content_text = content_text[7:]
  180. elif content_text.startswith("```"):
  181. content_text = content_text[3:]
  182. if content_text.endswith("```"):
  183. content_text = content_text[:-3]
  184. content_text = content_text.strip()
  185. analysis_data = json.loads(content_text)
  186. # 构建PostExtraction
  187. extraction = PostExtraction(
  188. note_id=post.note_id,
  189. note_url=post.note_url,
  190. title=post.title,
  191. body_text=post.body_text,
  192. type=post.type,
  193. extraction_time=datetime.now().isoformat(),
  194. images=[]
  195. )
  196. # 解析每张图片的结果
  197. for idx, img_result in enumerate(analysis_data.get("images", [])):
  198. if idx >= len(image_urls):
  199. break # 防止结果数量不匹配
  200. extraction.images.append(ImageExtraction(
  201. image_index=idx,
  202. original_url=image_urls[idx],
  203. description=img_result.get("description", ""),
  204. extract_text=img_result.get("extract_text", "")
  205. ))
  206. return extraction
  207. async def extract_all_posts(
  208. posts: list, # list[Post]
  209. max_concurrent: int = MAX_CONCURRENT_EXTRACTIONS
  210. ) -> dict[str, PostExtraction]:
  211. """
  212. 批量提取多个帖子的图片内容(带并发控制)
  213. Args:
  214. posts: Post对象列表
  215. max_concurrent: 最大并发数
  216. Returns:
  217. dict: {note_id: PostExtraction}
  218. """
  219. semaphore = asyncio.Semaphore(max_concurrent)
  220. print(f"\n开始批量提取 {len(posts)} 个帖子的图片内容(并发限制: {max_concurrent})...")
  221. tasks = [extract_post_images(post, semaphore) for post in posts]
  222. results = await asyncio.gather(*tasks)
  223. # 构建字典(过滤None)
  224. extraction_dict = {}
  225. success_count = 0
  226. for extraction in results:
  227. if extraction is not None:
  228. extraction_dict[extraction.note_id] = extraction
  229. success_count += 1
  230. print(f"批量提取完成: 成功 {success_count}/{len(posts)}")
  231. return extraction_dict