gemini_client.py 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245
  1. """
  2. 一个使用 google-genai SDK 向 Gemini API 发送请求的客户端。
  3. 该脚本提供了一个与Gemini模型交互的函数,支持文本提示和多种文件输入
  4. (例如本地文件路径、文件字节流、网络URL)。
  5. **依赖库:**
  6. - google-genai: Google官方最新的GenAI Python SDK。
  7. 安装命令: pip install google-genai
  8. - filetype: 用于从文件内容识别MIME类型。
  9. 安装命令: pip install filetype
  10. """
  11. import os
  12. import sys
  13. import logging
  14. import urllib.request
  15. from typing import List, Union, Optional, Dict, Any
  16. from concurrent.futures import ThreadPoolExecutor, as_completed
  17. from google import genai
  18. from google.genai import types
  19. from google.genai.errors import APIError
  20. import filetype
  21. current_dir = os.path.dirname(os.path.abspath(__file__))
  22. root_dir = os.path.dirname(current_dir)
  23. sys.path.insert(0, root_dir)
  24. from utils import llm_account_helper
  25. logging.basicConfig(
  26. level=logging.INFO,
  27. format='%(asctime)s - %(levelname)s - %(message)s'
  28. )
  29. def generate_text(
  30. prompt: str,
  31. model: str = 'gemini-2.5-flash',
  32. files_input: Optional[List[Union[str, bytes]]] = None
  33. ) -> str:
  34. """
  35. 向Gemini API发送一个包含提示和可选文件(本地路径、URL或字节流)的请求。
  36. Args:
  37. prompt (str): 必需的,给模型的文本提示。
  38. model (str): 可选的,要使用的模型名称。默认为 'gemini-2.5-flash'。
  39. files_input (Optional[List[Union[str, bytes]]]): 可选的,请求中要包含的文件列表。
  40. 列表中的每一项可以是:
  41. - 字符串(str): 本地文件路径,或者以 http://, https:// 开头的网络URL。
  42. - 字节(bytes): 内存中的文件内容。
  43. Returns:
  44. str: 从Gemini模型返回的文本响应。
  45. """
  46. api_key = llm_account_helper.get_api_key('gemini')
  47. client = genai.Client(api_key=api_key)
  48. uploaded_files_to_delete = []
  49. try:
  50. contents = []
  51. if files_input:
  52. logging.info(f"正在处理 {len(files_input)} 个文件输入...")
  53. for file_item in files_input:
  54. if isinstance(file_item, str):
  55. # --- 情况 A: 处理网络 URL ---
  56. if file_item.lower().startswith(('http://', 'https://')):
  57. logging.info(f"检测到 URL,正在下载: {file_item}")
  58. try:
  59. # 使用标准库 urllib 下载。添加 User-Agent 以避免部分服务器拒绝请求。
  60. req = urllib.request.Request(
  61. file_item,
  62. headers={'User-Agent': 'Mozilla/5.0 (Compatible; GeminiClient/1.0)'}
  63. )
  64. with urllib.request.urlopen(req) as response:
  65. file_data = response.read()
  66. # 下载成功后,复用下方的字节流处理逻辑
  67. mime_type = filetype.guess_mime(file_data)
  68. logging.info(f"URL 下载完成 (MIME: {mime_type}),已添加到请求。")
  69. contents.append(types.Part.from_bytes(data=file_data, mime_type=mime_type))
  70. except Exception as e:
  71. logging.error(f"下载 URL 失败: {e}")
  72. raise ValueError(f"无法处理 URL '{file_item}': {e}")
  73. continue # 处理完 URL 后跳过当前循环,避免进入本地文件判断
  74. # --- 情况 B: 处理本地文件路径 ---
  75. if not os.path.exists(file_item):
  76. raise FileNotFoundError(f"本地文件 '{file_item}' 不存在。")
  77. logging.info(f"正在上传本地文件: {file_item}")
  78. # 使用 File API 上传本地文件 (适合大文件)
  79. uploaded_file = client.files.upload(file=file_item)
  80. contents.append(uploaded_file)
  81. uploaded_files_to_delete.append(uploaded_file)
  82. elif isinstance(file_item, bytes):
  83. # --- 情况 C: 处理内存字节流 ---
  84. mime_type = filetype.guess_mime(file_item)
  85. logging.info(f"正在处理内存字节流 (MIME: {mime_type})")
  86. # 直接将小文件数据内嵌到请求中
  87. contents.append(types.Part.from_bytes(data=file_item, mime_type=mime_type))
  88. else:
  89. raise ValueError(
  90. f"不支持的输入类型: {type(file_item)}。仅支持本地路径(str)、URL(str)或字节流(bytes)。"
  91. )
  92. contents.append(prompt)
  93. logging.info(f"正在向模型 '{model}' 发送请求...")
  94. response = client.models.generate_content(
  95. model=model,
  96. contents=contents
  97. )
  98. return response.text
  99. except APIError as e:
  100. logging.error(f"Gemini API 调用错误: {e}")
  101. raise
  102. except Exception as e:
  103. logging.error(f"执行过程中发生未知错误: {e}")
  104. raise
  105. finally:
  106. # 清理通过 File API 上传的文件
  107. if uploaded_files_to_delete:
  108. logging.info(f"正在清理 {len(uploaded_files_to_delete)} 个已上传的临时文件...")
  109. for f in uploaded_files_to_delete:
  110. try:
  111. client.files.delete(name=f.name)
  112. except Exception as e:
  113. logging.warning(f"清理文件 {f.name} 失败: {e}")
  114. def concurrent_generate_text(
  115. prompt_file_pairs: List[Dict[str, Union[str, Optional[List[Union[str, bytes]]]]]],
  116. model: str = 'gemini-2.5-flash',
  117. max_workers: int = 10
  118. ) -> List[Dict[str, Any]]:
  119. """
  120. 并发执行多个Gemini请求,每个请求对应一个prompt和文件数组。
  121. Args:
  122. prompt_file_pairs (List[Dict[str, Union[str, Optional[List[Union[str, bytes]]]]]]):
  123. 包含多个prompt和对应文件列表的字典对象列表,每个字典应包含'prompt'和'files'键
  124. model (str): 要使用的模型名称。默认为 'gemini-2.5-flash'
  125. max_workers (int): 最大并发线程数。默认为 10
  126. Returns:
  127. List[Dict[str, Any]]: 每个元素是包含成功失败状态、失败原因、返回数据的字典对象
  128. 结果数组中元素的位置与输入的prompt_file_pairs位置一一对应
  129. """
  130. results = [None] * len(prompt_file_pairs) # 预先分配结果列表,保持与输入顺序一致
  131. def process_single_request(pair_idx: int, prompt: str, files: Optional[List[Union[str, bytes]]]) -> Dict[str, Any]:
  132. try:
  133. response_text = generate_text(prompt=prompt, model=model, files_input=files)
  134. return {
  135. "success": True,
  136. "data": response_text,
  137. "error_message": None
  138. }
  139. except Exception as e:
  140. return {
  141. "success": False,
  142. "data": None,
  143. "error_message": str(e)
  144. }
  145. with ThreadPoolExecutor(max_workers=max_workers) as executor:
  146. # 提交所有任务
  147. future_to_index = {
  148. executor.submit(process_single_request, i, prompt_file_pairs[i]['prompt'], prompt_file_pairs[i].get('files')): i
  149. for i in range(len(prompt_file_pairs))
  150. }
  151. # 收集结果并按原始顺序放置
  152. for future in as_completed(future_to_index):
  153. idx = future_to_index[future]
  154. try:
  155. result = future.result()
  156. results[idx] = result
  157. except Exception as e:
  158. # 即使future执行出现异常,也确保在正确位置设置结果
  159. results[idx] = {
  160. "success": False,
  161. "data": None,
  162. "error_message": str(e)
  163. }
  164. return results
  165. if __name__ == '__main__':
  166. print("--- Gemini 请求客户端演示 (已支持 URL) ---")
  167. # try:
  168. # # 示例: 使用网络图片 URL
  169. # print("\n--- 运行示例: 多模态提示 (使用网络 URL) ---")
  170. # image_url = "https://www.google.com/images/branding/googlelogo/2x/googlelogo_color_272x92dp.png"
  171. # prompt = "这张图片里的 logo 是什么公司的?"
  172. #
  173. # print(f"[提示]: {prompt}")
  174. # print(f"[输入 URL]: {image_url}")
  175. #
  176. # response = generate_text(
  177. # prompt=prompt,
  178. # files_input=[image_url]
  179. # )
  180. # print(f"[Gemini 回复]:\n{response}")
  181. #
  182. # except Exception as e:
  183. # print(f"\n演示运行失败: {e}")
  184. # 示例: 使用并发请求
  185. print("\n--- 运行示例: 并发请求 ---")
  186. prompt_file_pairs = [
  187. {
  188. "prompt": "图片上是什么?",
  189. "files": ["https://www.google.com/images/branding/googlelogo/2x/googlelogo_color_272x92dp.png"] # 第一个请求不带文件
  190. },
  191. {
  192. "prompt": "图片中是谁?",
  193. "files": ["https://inews.gtimg.com/om_bt/O6k8mse8MT9ki8fba5c7RK1j1xLFqT-FFZ9RirryqjENkAA/641"] # 第二个请求不带文件
  194. }
  195. ]
  196. try:
  197. concurrent_results = concurrent_generate_text(
  198. prompt_file_pairs=prompt_file_pairs,
  199. model='gemini-2.5-flash'
  200. )
  201. for i, result in enumerate(concurrent_results):
  202. print(f"请求 {i+1}:")
  203. if result['success']:
  204. print(f" 成功: {result['data'][:50]}...") # 只显示前50个字符
  205. else:
  206. print(f" 失败: {result['error_message']}")
  207. except Exception as e:
  208. print(f"\n并发请求演示失败: {e}")