| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287 |
- """
- 多模态图片内容提取模块
- 功能:
- 1. 对帖子的图片进行文字提取和语义描述
- 2. 支持一次性处理多张图片(最多10张)
- 3. 使用Gemini多模态模型(直接调用OpenRouter API)
- 4. 视频帖子自动跳过
- """
- import asyncio
- import json
- import os
- from datetime import datetime
- from typing import Optional
- from pydantic import BaseModel, Field
- import requests
- MODEL_NAME = "google/gemini-2.5-flash"
- MAX_IMAGES_PER_POST = 10 # 最大处理图片数
- MAX_CONCURRENT_EXTRACTIONS = 5 # 最大并发提取数
- API_TIMEOUT = 120 # API 超时时间(秒)
- # ============================================================================
- # 数据模型
- # ============================================================================
- class ImageExtraction(BaseModel):
- """单张图片的提取结果"""
- image_index: int = Field(..., description="图片索引")
- original_url: str = Field(..., description="原始图片URL")
- description: str = Field(..., description="图片的详细描述(200-500字)")
- extract_text: str = Field(..., description="提取的文字内容")
- class PostExtraction(BaseModel):
- """帖子的完整提取结果"""
- note_id: str
- note_url: str
- title: str
- body_text: str
- type: str
- extraction_time: str
- images: list[ImageExtraction] = Field(default_factory=list)
- # ============================================================================
- # Prompt 定义
- # ============================================================================
- ANALYSIS_PROMPT_TEMPLATE = """
- 你是一名专业的图像内容分析和文字提取专家。
- 请分析以下{num_images}张图片,这些图片来自标题为《{title}》的帖子。
- ## 任务
- 为**每张图片**分别提取两类信息:
- ### 1. description(图片描述)
- 对图片进行详细、全面的描述,包括但不限于:
- - **整体场景**:图片展示的主要场景或主题
- - **核心元素**:图片中的关键元素(人物、物体、文字、图表等)
- - **元素细节**:每个元素的特征(颜色、形状、位置、大小、状态等)
- - **空间布局**:元素的位置关系和排列方式
- - **视觉风格**:摄影风格、设计风格、色调、光线等
- - **文字内容**:如果有文字,简要说明文字的主题和作用
- - **情感氛围**:图片传达的情绪、氛围或意图
- **要求**:
- - 使用自然、流畅的语言
- - 从整体到局部、从主要到次要
- - 准确、具体、客观
- - 字数控制在200-500字之间
- ### 2. extract_text(文字提取)
- 精准提取图片中的所有可见文字内容。
- **要求**:
- 1. 仅提取可见文字,不改写、总结或推理
- 2. 如有结构(表格、图表、标题、段落),按结构输出
- 3. 保持原始顺序和排版逻辑
- 4. 不需要OCR校正,原样提取
- 5. 舍弃与标题不相关的文字
- 6. 结构不明确时,按从上到下、从左到右顺序提取
- 7. 如果图片无文字,输出空字符串""
- ## 输出要求
- 必须返回一个JSON对象,包含images数组,每个元素对应一张图片:
- {{
- "images": [
- {{
- "description": "第1张图片的详细描述...",
- "extract_text": "第1张图片提取的文字内容..."
- }},
- {{
- "description": "第2张图片的详细描述...",
- "extract_text": "第2张图片提取的文字内容..."
- }}
- ]
- }}
- ## 重要提示
- - images数组的顺序必须与输入图片顺序一致
- - 每张图片都必须有对应的结果
- - 如果某张图片无文字,extract_text设为空字符串""
- - 如果某张图片无法分析,description简要说明原因
- """.strip()
- # ============================================================================
- # 核心提取函数
- # ============================================================================
- async def extract_post_images(
- post, # Post对象
- semaphore: Optional[asyncio.Semaphore] = None
- ) -> Optional[PostExtraction]:
- """
- 提取单个帖子的图片内容
- Args:
- post: Post对象(包含images列表)
- semaphore: 可选的信号量用于并发控制
- Returns:
- PostExtraction对象,提取失败返回None
- """
- # 视频帖子跳过
- if post.type == "video":
- print(f" ⊗ 跳过视频帖子: {post.note_id}")
- return None
- # 没有图片跳过
- if not post.images or len(post.images) == 0:
- print(f" ⊗ 帖子无图片: {post.note_id}")
- return None
- # 限制图片数量
- image_urls = post.images[:MAX_IMAGES_PER_POST]
- image_count = len(image_urls)
- print(f" 🖼️ 开始提取图片内容: {post.note_id} ({image_count}张图片)")
- try:
- # 如果有信号量,使用它进行并发控制
- if semaphore:
- async with semaphore:
- result = await _extract_images(image_urls, post)
- else:
- result = await _extract_images(image_urls, post)
- print(f" ✅ 提取完成: {post.note_id}")
- return result
- except Exception as e:
- print(f" ❌ 提取失败: {post.note_id} - {str(e)[:100]}")
- return None
- async def _extract_images(image_urls: list[str], post) -> PostExtraction:
- """
- 实际执行图片提取的内部函数 - 直接调用OpenRouter API
- """
- # 获取API密钥
- api_key = os.getenv("OPENROUTER_API_KEY")
- if not api_key:
- raise ValueError("OPENROUTER_API_KEY environment variable not set")
- # 构建提示文本
- prompt_text = ANALYSIS_PROMPT_TEMPLATE.format(
- num_images=len(image_urls),
- title=post.title
- )
- # 构建消息内容:文本 + 多张图片
- content = [{"type": "text", "text": prompt_text}]
- for url in image_urls:
- content.append({
- "type": "image_url",
- "image_url": {"url": url}
- })
- # 构建API请求
- payload = {
- "model": MODEL_NAME,
- "messages": [{"role": "user", "content": content}],
- "response_format": {"type": "json_object"}
- }
- headers = {
- "Authorization": f"Bearer {api_key}",
- "Content-Type": "application/json"
- }
- # 在异步上下文中执行同步请求
- loop = asyncio.get_event_loop()
- response = await loop.run_in_executor(
- None,
- lambda: requests.post(
- "https://openrouter.ai/api/v1/chat/completions",
- headers=headers,
- json=payload,
- timeout=API_TIMEOUT
- )
- )
- # 检查响应
- if response.status_code != 200:
- raise Exception(f"OpenRouter API error: {response.status_code} - {response.text[:200]}")
- # 解析响应
- result = response.json()
- content_text = result["choices"][0]["message"]["content"]
- # 去除Markdown代码块标记(Gemini即使设置了json_object也会返回带```json标记的内容)
- content_text = content_text.strip()
- if content_text.startswith("```json"):
- content_text = content_text[7:]
- elif content_text.startswith("```"):
- content_text = content_text[3:]
- if content_text.endswith("```"):
- content_text = content_text[:-3]
- content_text = content_text.strip()
- analysis_data = json.loads(content_text)
- # 构建PostExtraction
- extraction = PostExtraction(
- note_id=post.note_id,
- note_url=post.note_url,
- title=post.title,
- body_text=post.body_text,
- type=post.type,
- extraction_time=datetime.now().isoformat(),
- images=[]
- )
- # 解析每张图片的结果
- for idx, img_result in enumerate(analysis_data.get("images", [])):
- if idx >= len(image_urls):
- break # 防止结果数量不匹配
- extraction.images.append(ImageExtraction(
- image_index=idx,
- original_url=image_urls[idx],
- description=img_result.get("description", ""),
- extract_text=img_result.get("extract_text", "")
- ))
- return extraction
- async def extract_all_posts(
- posts: list, # list[Post]
- max_concurrent: int = MAX_CONCURRENT_EXTRACTIONS
- ) -> dict[str, PostExtraction]:
- """
- 批量提取多个帖子的图片内容(带并发控制)
- Args:
- posts: Post对象列表
- max_concurrent: 最大并发数
- Returns:
- dict: {note_id: PostExtraction}
- """
- semaphore = asyncio.Semaphore(max_concurrent)
- print(f"\n开始批量提取 {len(posts)} 个帖子的图片内容(并发限制: {max_concurrent})...")
- tasks = [extract_post_images(post, semaphore) for post in posts]
- results = await asyncio.gather(*tasks)
- # 构建字典(过滤None)
- extraction_dict = {}
- success_count = 0
- for extraction in results:
- if extraction is not None:
- extraction_dict[extraction.note_id] = extraction
- success_count += 1
- print(f"批量提取完成: 成功 {success_count}/{len(posts)}")
- return extraction_dict
|