gemini_image_client.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149
  1. """Gemini 原生图模 — REST `generateContent`(与官方文档一致,无 SDK)。
  2. 参考: https://ai.google.dev/gemini-api/docs/image-generation?hl=zh-cn#rest
  3. """
  4. from __future__ import annotations
  5. import os
  6. import re
  7. from typing import Any
  8. import httpx
  9. from dotenv import load_dotenv
  10. _ = load_dotenv()
  11. DEFAULT_MODEL = "gemini-2.5-flash-image"
  12. GEMINI_API_BASE = os.environ.get(
  13. "GEMINI_API_BASE", "https://airouter.piaoquantv.com/v1beta"
  14. )
  15. _DATA_URL_RE = re.compile(r"^data:[^;]+;base64,(.+)$", re.I | re.S)
  16. def _strip_data_url(b64_or_data_url: str) -> str:
  17. s = b64_or_data_url.strip()
  18. m = _DATA_URL_RE.match(s)
  19. return m.group(1) if m else s
  20. def _build_parts(
  21. prompt: str,
  22. images: list[dict[str, str]] | None,
  23. ) -> list[dict[str, Any]]:
  24. parts: list[dict[str, Any]] = [{"text": prompt}]
  25. if not images:
  26. return parts
  27. for img in images:
  28. mime = (img.get("mime_type") or img.get("mimeType") or "image/png").strip()
  29. raw = _strip_data_url(img.get("data") or "")
  30. if not raw:
  31. raise ValueError("images[].data 不能为空(Base64 或 data URL)")
  32. parts.append({"inline_data": {"mime_type": mime, "data": raw}})
  33. return parts
  34. def _merge_generation_config(
  35. *,
  36. aspect_ratio: str | None,
  37. image_size: str | None,
  38. response_modalities: list[str] | None,
  39. ) -> dict[str, Any] | None:
  40. cfg: dict[str, Any] = {}
  41. if response_modalities:
  42. cfg["responseModalities"] = response_modalities
  43. img_cfg: dict[str, str] = {}
  44. if aspect_ratio:
  45. img_cfg["aspectRatio"] = aspect_ratio.strip()
  46. if image_size:
  47. img_cfg["imageSize"] = image_size.strip()
  48. if img_cfg:
  49. cfg["imageConfig"] = img_cfg
  50. return cfg or None
  51. def generate_content(
  52. *,
  53. prompt: str,
  54. model: str | None,
  55. aspect_ratio: str | None = None,
  56. image_size: str | None = None,
  57. response_modalities: list[str] | None = None,
  58. images: list[dict[str, str]] | None = None,
  59. ) -> dict[str, Any]:
  60. api_key = os.environ.get("GEMINI_API_KEY", "").strip()
  61. if not api_key:
  62. raise ValueError("缺少环境变量 GEMINI_API_KEY")
  63. resolved = (model or os.environ.get("GEMINI_IMAGE_MODEL") or DEFAULT_MODEL).strip()
  64. url = f"{GEMINI_API_BASE.rstrip('/')}/models/{resolved}:generateContent"
  65. body: dict[str, Any] = {
  66. "contents": [
  67. {
  68. "role": "user",
  69. "parts": _build_parts(prompt, images),
  70. }
  71. ],
  72. }
  73. gen_cfg = _merge_generation_config(
  74. aspect_ratio=aspect_ratio,
  75. image_size=image_size,
  76. response_modalities=response_modalities,
  77. )
  78. if gen_cfg:
  79. body["generationConfig"] = gen_cfg
  80. headers = {
  81. "x-goog-api-key": api_key,
  82. "Content-Type": "application/json",
  83. }
  84. with httpx.Client(timeout=300.0) as client:
  85. r = client.post(url, headers=headers, json=body)
  86. try:
  87. data = r.json()
  88. except Exception:
  89. r.raise_for_status()
  90. raise RuntimeError(r.text[:2000]) from None
  91. if r.status_code >= 400:
  92. err = data.get("error") if isinstance(data, dict) else None
  93. msg = err.get("message", str(data)) if isinstance(err, dict) else str(data)
  94. raise RuntimeError(f"Gemini HTTP {r.status_code}: {msg}")
  95. if not isinstance(data, dict):
  96. raise RuntimeError("响应不是 JSON 对象")
  97. if data.get("error"):
  98. raise RuntimeError(str(data["error"]))
  99. images_out: list[str] = []
  100. texts: list[str] = []
  101. for cand in data.get("candidates") or []:
  102. if not isinstance(cand, dict):
  103. continue
  104. for part in cand.get("content", {}).get("parts") or []:
  105. if not isinstance(part, dict):
  106. continue
  107. if part.get("text"):
  108. texts.append(str(part["text"]))
  109. inline = part.get("inlineData") or part.get("inline_data")
  110. if isinstance(inline, dict):
  111. b64 = inline.get("data")
  112. if b64:
  113. mime = (
  114. inline.get("mimeType")
  115. or inline.get("mime_type")
  116. or "image/png"
  117. )
  118. images_out.append(f"data:{mime};base64,{b64}")
  119. out: dict[str, Any] = {
  120. "images": images_out,
  121. "model": resolved,
  122. }
  123. if texts:
  124. out["text"] = "\n".join(texts)
  125. return out