gemini_image_client.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196
  1. """Gemini 原生图模 — REST `generateContent`(与官方文档一致,无 SDK)。
  2. 参考: https://ai.google.dev/gemini-api/docs/image-generation?hl=zh-cn#rest
  3. """
  4. from __future__ import annotations
  5. import os
  6. import re
  7. from typing import Any
  8. import httpx
  9. from dotenv import load_dotenv
  10. _ = load_dotenv()
  11. DEFAULT_MODEL = "gemini-3.1-flash-image-preview"
  12. GEMINI_API_BASE = os.environ.get(
  13. "GEMINI_API_BASE", "https://airouter.piaoquantv.com/v1beta"
  14. )
  15. _DATA_URL_RE = re.compile(r"^data:[^;]+;base64,(.+)$", re.I | re.S)
  16. def _strip_data_url(b64_or_data_url: str) -> str:
  17. s = b64_or_data_url.strip()
  18. m = _DATA_URL_RE.match(s)
  19. return m.group(1) if m else s
  20. import base64
  21. _BASE64_CHARS = set("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/=\n\r\t ")
  22. def _looks_like_path_not_base64(s: str) -> bool:
  23. """判断字符串看起来像文件路径而不是 base64 数据。
  24. Base64 的字符集只有 [A-Za-z0-9+/=]。如果字符串包含 '.'、'-'(后者是 URL-safe
  25. base64 才有)或有明显的路径形态(以 / 开头、含 .png/.jpg 等扩展名),基本可以
  26. 判定为路径。
  27. """
  28. if not s:
  29. return False
  30. lower = s.lower()
  31. # 有图片扩展名 → 肯定是路径
  32. if any(lower.endswith(ext) for ext in (".png", ".jpg", ".jpeg", ".webp", ".gif", ".bmp")):
  33. return True
  34. # 含有非 base64 字符 → 肯定是路径或别的东西
  35. if any(c not in _BASE64_CHARS for c in s[:200]):
  36. return True
  37. return False
  38. def _build_parts(
  39. prompt: str,
  40. images: list[dict[str, str]] | None,
  41. ) -> list[dict[str, Any]]:
  42. parts: list[dict[str, Any]] = [{"text": prompt}]
  43. if not images:
  44. return parts
  45. for img in images:
  46. mime = (img.get("mime_type") or img.get("mimeType") or "image/png").strip()
  47. raw = _strip_data_url(img.get("data") or "")
  48. # 支持自动下载 HTTP 外链并转换为 Base64
  49. if raw.startswith("http://") or raw.startswith("https://"):
  50. try:
  51. resp = httpx.get(raw, timeout=30.0)
  52. resp.raise_for_status()
  53. raw = base64.b64encode(resp.content).decode("utf-8")
  54. except Exception as e:
  55. raise ValueError(f"拉取网络图片失败 [{raw}]: {e}")
  56. elif os.path.isfile(raw):
  57. try:
  58. with open(raw, "rb") as f:
  59. raw = base64.b64encode(f.read()).decode("utf-8")
  60. except Exception as e:
  61. raise ValueError(f"读取本地图片文件失败 [{raw}]: {e}")
  62. elif _looks_like_path_not_base64(raw):
  63. # 防御:显然是文件路径但文件不存在,早报错比让 Gemini 抛神秘的 base64 解码错误强
  64. raise ValueError(
  65. f"images[].data 看起来是文件路径但文件不存在: {raw!r} "
  66. f"(cwd={os.getcwd()})。请传 HTTP URL 或绝对路径,"
  67. f"调用方应先把本地文件上传到 OSS 再传 CDN URL。"
  68. )
  69. if not raw:
  70. raise ValueError("images[].data 不能为空(Base64、本地路径、data URL 或 HTTP 链接)")
  71. parts.append({"inline_data": {"mime_type": mime, "data": raw}})
  72. return parts
  73. def _merge_generation_config(
  74. *,
  75. aspect_ratio: str | None,
  76. image_size: str | None,
  77. response_modalities: list[str] | None,
  78. ) -> dict[str, Any] | None:
  79. cfg: dict[str, Any] = {}
  80. if response_modalities:
  81. cfg["responseModalities"] = response_modalities
  82. img_cfg: dict[str, str] = {}
  83. if aspect_ratio:
  84. img_cfg["aspectRatio"] = aspect_ratio.strip()
  85. if image_size:
  86. img_cfg["imageSize"] = image_size.strip()
  87. if img_cfg:
  88. cfg["imageConfig"] = img_cfg
  89. return cfg or None
  90. def generate_content(
  91. *,
  92. prompt: str,
  93. model: str | None,
  94. aspect_ratio: str | None = None,
  95. image_size: str | None = None,
  96. response_modalities: list[str] | None = None,
  97. images: list[dict[str, str]] | None = None,
  98. ) -> dict[str, Any]:
  99. api_key = os.environ.get("GEMINI_API_KEY", "").strip()
  100. if not api_key:
  101. raise ValueError("缺少环境变量 GEMINI_API_KEY")
  102. resolved = (model or os.environ.get("GEMINI_IMAGE_MODEL") or DEFAULT_MODEL).strip()
  103. url = f"{GEMINI_API_BASE.rstrip('/')}/models/{resolved}:generateContent"
  104. body: dict[str, Any] = {
  105. "contents": [
  106. {
  107. "role": "user",
  108. "parts": _build_parts(prompt, images),
  109. }
  110. ],
  111. }
  112. gen_cfg = _merge_generation_config(
  113. aspect_ratio=aspect_ratio,
  114. image_size=image_size,
  115. response_modalities=response_modalities,
  116. )
  117. if gen_cfg:
  118. body["generationConfig"] = gen_cfg
  119. headers = {
  120. "x-goog-api-key": api_key,
  121. "Content-Type": "application/json",
  122. }
  123. with httpx.Client(timeout=300.0) as client:
  124. r = client.post(url, headers=headers, json=body)
  125. try:
  126. data = r.json()
  127. except Exception:
  128. r.raise_for_status()
  129. raise RuntimeError(r.text[:2000]) from None
  130. if r.status_code >= 400:
  131. err = data.get("error") if isinstance(data, dict) else None
  132. msg = err.get("message", str(data)) if isinstance(err, dict) else str(data)
  133. raise RuntimeError(f"Gemini HTTP {r.status_code}: {msg}")
  134. if not isinstance(data, dict):
  135. raise RuntimeError("响应不是 JSON 对象")
  136. if data.get("error"):
  137. raise RuntimeError(str(data["error"]))
  138. images_out: list[str] = []
  139. texts: list[str] = []
  140. for cand in data.get("candidates") or []:
  141. if not isinstance(cand, dict):
  142. continue
  143. for part in cand.get("content", {}).get("parts") or []:
  144. if not isinstance(part, dict):
  145. continue
  146. if part.get("text"):
  147. texts.append(str(part["text"]))
  148. inline = part.get("inlineData") or part.get("inline_data")
  149. if isinstance(inline, dict):
  150. b64 = inline.get("data")
  151. if b64:
  152. mime = (
  153. inline.get("mimeType")
  154. or inline.get("mime_type")
  155. or "image/png"
  156. )
  157. images_out.append(f"data:{mime};base64,{b64}")
  158. out: dict[str, Any] = {
  159. "images": images_out,
  160. "model": resolved,
  161. }
  162. if texts:
  163. out["text"] = "\n".join(texts)
  164. return out