image_identifier.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. 图文识别脚本
  5. 主要功能:使用 Gemini API 进行图片OCR识别
  6. """
  7. import os
  8. import json
  9. import time
  10. import sys
  11. from typing import Dict, Any, List, Optional
  12. from dotenv import load_dotenv
  13. import google.generativeai as genai
  14. from PIL import Image
  15. import requests
  16. from io import BytesIO
  17. # 导入自定义模块
  18. sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
  19. class ImageIdentifier:
  20. def __init__(self):
  21. # 加载环境变量
  22. load_dotenv()
  23. # 初始化Gemini API
  24. api_key = os.getenv('GEMINI_API_KEY')
  25. if not api_key:
  26. raise ValueError("请在环境变量中设置 GEMINI_API_KEY")
  27. genai.configure(api_key=api_key)
  28. self.model = genai.GenerativeModel('gemini-2.5-flash')
  29. def download_image(self, image_url: str) -> Optional[Image.Image]:
  30. """下载图片并转换为PIL Image对象"""
  31. try:
  32. response = requests.get(image_url, timeout=10)
  33. response.raise_for_status()
  34. image = Image.open(BytesIO(response.content))
  35. return image
  36. except Exception as e:
  37. print(f"下载图片失败 {image_url}: {e}")
  38. return None
  39. def extract_image_urls(self, formatted_content: Dict[str, Any]) -> List[str]:
  40. """提取图片URL列表"""
  41. image_urls = []
  42. image_url_list = formatted_content.get('image_url_list', [])
  43. for img_data in image_url_list:
  44. if isinstance(img_data, dict) and 'image_url' in img_data:
  45. image_urls.append(img_data['image_url'])
  46. return image_urls
  47. def analyze_image_with_gemini(self, image: Image.Image) -> Dict[str, Any]:
  48. """使用Gemini API分析单张图片内容"""
  49. try:
  50. # 构建OCR提示词
  51. prompt = """
  52. #### 人设
  53. 你是一名图像文字理解专家,请对输入的文章图片进行精准的文字提取和结构化整理。
  54. #### 任务要求如下:
  55. 1. 仅提取图片中可见的文字内容,不需要改写、总结或推理隐藏信息。
  56. 2. 如果图片包含结构(如表格、图表、标题、段落等),请按结构输出。
  57. 3. 所有提取的内容需保持原始顺序和排版上下文的逻辑。
  58. 4. 不需要进行OCR校正,只需要原样提取图中文字。
  59. 5. 舍弃图片中和标题不相关的文字
  60. 6. 对于结构不明确或自由排列的文字,按照从上到下、从左到右的顺序依次提取。
  61. """
  62. response = self.model.generate_content([prompt, image])
  63. return {
  64. "text_content": response.text,
  65. "success": True
  66. }
  67. except Exception as e:
  68. print(f"Gemini API调用失败: {e}")
  69. return {
  70. "text_content": "",
  71. "success": False,
  72. "error": str(e)
  73. }
  74. def analyze_images_with_gemini(self, image_urls: List[str]) -> Dict[str, Any]:
  75. """使用Gemini API分析多张图片内容"""
  76. try:
  77. if not image_urls:
  78. return {"images_comprehension": [], "error": "没有图片需要分析"}
  79. print(f"正在使用Gemini API分析 {len(image_urls)} 张图片...")
  80. results = []
  81. for i, image_url in enumerate(image_urls):
  82. print(f"正在处理第 {i+1} 张图片: {image_url}")
  83. # 下载图片
  84. image = self.download_image(image_url)
  85. if image is None:
  86. results.append({
  87. "image_url": image_url,
  88. "text_content": "",
  89. "success": False,
  90. "error": "图片下载失败"
  91. })
  92. continue
  93. # 分析图片
  94. result = self.analyze_image_with_gemini(image)
  95. result["image_url"] = image_url
  96. results.append(result)
  97. # 添加延迟避免API限制
  98. time.sleep(1)
  99. return {
  100. "images_comprehension": results
  101. }
  102. except Exception as e:
  103. print(f"Gemini API批量调用失败: {e}")
  104. return {"images_comprehension": [], "error": f"Gemini API调用失败: {str(e)}"}
  105. def process_images(self, formatted_content: Dict[str, Any]) -> Dict[str, Any]:
  106. """处理图片识别的主函数"""
  107. print("开始图片OCR识别处理...")
  108. # 提取图片URL
  109. image_urls = self.extract_image_urls(formatted_content)
  110. print(f"提取到 {len(image_urls)} 张图片")
  111. if not image_urls:
  112. print("没有图片需要分析")
  113. return {"images_comprehension": [], "error": "没有图片需要分析"}
  114. # 分析图片
  115. result = self.analyze_images_with_gemini(image_urls)
  116. if result.get("images_comprehension"):
  117. successful_count = sum(1 for img in result['images_comprehension'] if img.get('success', False))
  118. print(f"图片OCR识别完成,成功分析 {successful_count}/{len(result['images_comprehension'])} 张图片")
  119. else:
  120. print("图片OCR识别失败")
  121. return result
  122. def main():
  123. """测试函数"""
  124. # 模拟数据
  125. test_content = {
  126. "image_url_list": [
  127. {
  128. "image_type": 2,
  129. "image_url": "http://rescdn.yishihui.com/pipeline/image/ea4f33e9-9e36-4124-aaec-138ea9bcadd9.jpg"
  130. }
  131. ]
  132. }
  133. try:
  134. identifier = ImageIdentifier()
  135. result = identifier.process_images(test_content)
  136. print(f"识别结果: {json.dumps(result, ensure_ascii=False, indent=2)}")
  137. except Exception as e:
  138. print(f"初始化失败: {e}")
  139. if __name__ == '__main__':
  140. main()