image_identifier.py 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. 图文识别脚本
  5. 主要功能:使用 Gemini API 进行图片OCR识别
  6. """
  7. import os
  8. import json
  9. import time
  10. import sys
  11. from typing import Dict, Any, List, Optional
  12. from dotenv import load_dotenv
  13. import google.generativeai as genai
  14. from PIL import Image
  15. import requests
  16. from io import BytesIO
  17. from concurrent.futures import ThreadPoolExecutor, as_completed
  18. from utils.logging_config import get_logger
  19. # 导入自定义模块
  20. sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
  21. from llm.openrouter import OpenRouterProcessor, OpenRouterModel
  22. # 创建 logger
  23. logger = get_logger('ImageIdentifier')
  24. # 构建OCR提示词
  25. prompt = """
  26. #### 人设
  27. 你是一名图像文字理解专家,请对输入的文章图片进行精准的文字提取和结构化整理。
  28. #### 任务要求如下:
  29. 1. 仅提取图片中可见的文字内容,不需要改写、总结或推理隐藏信息。
  30. 2. 如果图片包含结构(如表格、图表、标题、段落等),请按结构输出。
  31. 3. 所有提取的内容需保持原始顺序和排版上下文的逻辑。
  32. 4. 不需要进行OCR校正,只需要原样提取图中文字。
  33. 5. 舍弃图片中和标题不相关的文字
  34. 6. 对于结构不明确或自由排列的文字,按照从上到下、从左到右的顺序依次提取。
  35. #### 输出格式
  36. 1. 仅输出提取的文字即可,不需要其他说明性的文字
  37. """
  38. class ImageIdentifier:
  39. def __init__(self):
  40. # 加载环境变量
  41. load_dotenv()
  42. # 延迟配置Gemini,在真正使用时再设置
  43. self._configured = False
  44. self.model = None
  45. def _ensure_configured(self):
  46. """确保Gemini已配置"""
  47. if not self._configured:
  48. self.api_key = os.getenv('GEMINI_API_KEY')
  49. if not self.api_key:
  50. raise ValueError("请在环境变量中设置 GEMINI_API_KEY")
  51. genai.configure(api_key=self.api_key)
  52. # 创建模型时设置安全设置,避免内容被过滤
  53. self.model = genai.GenerativeModel(
  54. 'gemini-2.5-flash',
  55. safety_settings={
  56. genai.types.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: genai.types.HarmBlockThreshold.BLOCK_NONE,
  57. genai.types.HarmCategory.HARM_CATEGORY_HARASSMENT: genai.types.HarmBlockThreshold.BLOCK_NONE,
  58. genai.types.HarmCategory.HARM_CATEGORY_HATE_SPEECH: genai.types.HarmBlockThreshold.BLOCK_NONE,
  59. genai.types.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: genai.types.HarmBlockThreshold.BLOCK_NONE,
  60. }
  61. )
  62. self._configured = True
  63. def download_image(self, image_url: str) -> Optional[Image.Image]:
  64. """下载图片并转换为PIL Image对象"""
  65. try:
  66. response = requests.get(image_url, timeout=10)
  67. response.raise_for_status()
  68. image = Image.open(BytesIO(response.content))
  69. return image
  70. except Exception as e:
  71. print(f"下载图片失败 {image_url}: {e}")
  72. return None
  73. def extract_image_urls(self, formatted_content: Dict[str, Any]) -> List[str]:
  74. """提取图片URL列表"""
  75. image_urls = []
  76. image_url_list = formatted_content.get('image_url_list', [])
  77. for img_data in image_url_list:
  78. if isinstance(img_data, dict) and 'image_url' in img_data:
  79. image_urls.append(img_data['image_url'])
  80. return image_urls
  81. def analyze_images_with_gemini(self, image_urls: List[str]) -> Dict[str, Any]:
  82. """使用 Gemini 并发(最多5条)提取图片文字(仅内容提取)"""
  83. try:
  84. if not image_urls:
  85. return {"images_comprehension": [], "error": "没有图片需要分析"}
  86. # 系统提示:严格限制为"仅提取文字,不做分析" [[memory:7272937]]
  87. system_prompt = prompt
  88. # 保持输入顺序
  89. results: List[Dict[str, Any]] = [{} for _ in range(len(image_urls))]
  90. def analyze_image_job(idx_and_url) -> Dict[str, Any]:
  91. idx, url = idx_and_url
  92. try:
  93. # 下载图片
  94. image = self.download_image(url)
  95. if image is None:
  96. return {"idx": idx, "url": url, "content": "", "success": False, "error": "图片下载失败"}
  97. # 使用 Gemini 直接分析图片
  98. self._ensure_configured()
  99. logger.info(f"配置Gemini: {self.api_key}")
  100. response = self.model.generate_content([system_prompt, image])
  101. # 检查响应状态
  102. if response.candidates and len(response.candidates) > 0:
  103. candidate = response.candidates[0]
  104. if candidate.finish_reason == 1: # SAFETY
  105. logger.warning(f"图片 {url} 被安全过滤器阻止")
  106. return {"idx": idx, "url": url, "content": "", "success": False, "error": "内容被安全过滤器阻止"}
  107. elif candidate.finish_reason == 2: # RECITATION
  108. logger.warning(f"图片 {url} 被引用过滤器阻止")
  109. return {"idx": idx, "url": url, "content": "", "success": False, "error": "内容被引用过滤器阻止"}
  110. elif candidate.finish_reason == 3: # OTHER
  111. logger.warning(f"图片 {url} 被其他原因阻止")
  112. return {"idx": idx, "url": url, "content": "", "success": False, "error": "内容被其他原因阻止"}
  113. # 尝试获取文本内容
  114. try:
  115. if response.text:
  116. return {"idx": idx, "url": url, "content": response.text, "success": True}
  117. else:
  118. return {"idx": idx, "url": url, "content": "", "success": False, "error": "识别失败或无内容返回"}
  119. except Exception as text_error:
  120. logger.error(f"获取响应文本失败: {text_error}")
  121. return {"idx": idx, "url": url, "content": "", "success": False, "error": f"获取响应文本失败: {str(text_error)}"}
  122. except Exception as e:
  123. return {"idx": idx, "url": url, "content": "", "success": False, "error": str(e)}
  124. # 并发最多5条
  125. with ThreadPoolExecutor(max_workers=5) as executor:
  126. future_to_index = {}
  127. for idx, url in enumerate(image_urls):
  128. future = executor.submit(analyze_image_job, (idx, url))
  129. future_to_index[future] = idx
  130. for future in as_completed(list(future_to_index.keys())):
  131. result = future.result()
  132. idx = result["idx"]
  133. results[idx] = {
  134. "url": result["url"],
  135. "content": result["content"],
  136. "success": result["success"]
  137. }
  138. if not result["success"]:
  139. results[idx]["error"] = result["error"]
  140. return {"images_comprehension": results}
  141. except Exception as e:
  142. print(f"Gemini 并发调用失败: {e}")
  143. return {"images_comprehension": [], "error": f"Gemini API 调用失败: {str(e)}"}
  144. def process_images(self, formatted_content: Dict[str, Any]) -> Dict[str, Any]:
  145. """处理图片识别的主函数"""
  146. # 提取图片URL
  147. image_urls = self.extract_image_urls(formatted_content)
  148. if not image_urls:
  149. print("没有图片需要分析")
  150. return {"images_comprehension": [], "error": "没有图片需要分析"}
  151. # 分析图片
  152. result = self.analyze_images_with_gemini(image_urls)
  153. if result.get("images_comprehension"):
  154. successful_count = sum(1 for img in result['images_comprehension'] if img.get('success', False))
  155. else:
  156. print("图片OCR识别失败")
  157. return result
  158. def main():
  159. """测试函数"""
  160. # 模拟数据
  161. test_content = {
  162. "image_url_list": [
  163. {
  164. "image_type": 2,
  165. "image_url": "http://rescdn.yishihui.com/pipeline/image/ea4f33e9-9e36-4124-aaec-138ea9bcadd9.jpg"
  166. },
  167. {
  168. "image_type": 2,
  169. "image_url": "http://rescdn.yishihui.com/pipeline/image/ea4f33e9-9e36-4124-aaec-138ea9bcadd9.jpg"
  170. }
  171. ]
  172. }
  173. try:
  174. identifier = ImageIdentifier()
  175. result = identifier.process_images(test_content)
  176. print(f"识别结果: {json.dumps(result, ensure_ascii=False, indent=2)}")
  177. except Exception as e:
  178. print(f"初始化失败: {e}")
  179. if __name__ == '__main__':
  180. main()