| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149 |
- """Gemini 原生图模 — REST `generateContent`(与官方文档一致,无 SDK)。
- 参考: https://ai.google.dev/gemini-api/docs/image-generation?hl=zh-cn#rest
- """
- from __future__ import annotations
- import os
- import re
- from typing import Any
- import httpx
- from dotenv import load_dotenv
- _ = load_dotenv()
- DEFAULT_MODEL = "gemini-2.5-flash-image"
- GEMINI_API_BASE = os.environ.get(
- "GEMINI_API_BASE", "https://airouter.piaoquantv.com/v1beta"
- )
- _DATA_URL_RE = re.compile(r"^data:[^;]+;base64,(.+)$", re.I | re.S)
- def _strip_data_url(b64_or_data_url: str) -> str:
- s = b64_or_data_url.strip()
- m = _DATA_URL_RE.match(s)
- return m.group(1) if m else s
- def _build_parts(
- prompt: str,
- images: list[dict[str, str]] | None,
- ) -> list[dict[str, Any]]:
- parts: list[dict[str, Any]] = [{"text": prompt}]
- if not images:
- return parts
- for img in images:
- mime = (img.get("mime_type") or img.get("mimeType") or "image/png").strip()
- raw = _strip_data_url(img.get("data") or "")
- if not raw:
- raise ValueError("images[].data 不能为空(Base64 或 data URL)")
- parts.append({"inline_data": {"mime_type": mime, "data": raw}})
- return parts
- def _merge_generation_config(
- *,
- aspect_ratio: str | None,
- image_size: str | None,
- response_modalities: list[str] | None,
- ) -> dict[str, Any] | None:
- cfg: dict[str, Any] = {}
- if response_modalities:
- cfg["responseModalities"] = response_modalities
- img_cfg: dict[str, str] = {}
- if aspect_ratio:
- img_cfg["aspectRatio"] = aspect_ratio.strip()
- if image_size:
- img_cfg["imageSize"] = image_size.strip()
- if img_cfg:
- cfg["imageConfig"] = img_cfg
- return cfg or None
- def generate_content(
- *,
- prompt: str,
- model: str | None,
- aspect_ratio: str | None = None,
- image_size: str | None = None,
- response_modalities: list[str] | None = None,
- images: list[dict[str, str]] | None = None,
- ) -> dict[str, Any]:
- api_key = os.environ.get("GEMINI_API_KEY", "").strip()
- if not api_key:
- raise ValueError("缺少环境变量 GEMINI_API_KEY")
- resolved = (model or os.environ.get("GEMINI_IMAGE_MODEL") or DEFAULT_MODEL).strip()
- url = f"{GEMINI_API_BASE.rstrip('/')}/models/{resolved}:generateContent"
- body: dict[str, Any] = {
- "contents": [
- {
- "role": "user",
- "parts": _build_parts(prompt, images),
- }
- ],
- }
- gen_cfg = _merge_generation_config(
- aspect_ratio=aspect_ratio,
- image_size=image_size,
- response_modalities=response_modalities,
- )
- if gen_cfg:
- body["generationConfig"] = gen_cfg
- headers = {
- "x-goog-api-key": api_key,
- "Content-Type": "application/json",
- }
- with httpx.Client(timeout=300.0) as client:
- r = client.post(url, headers=headers, json=body)
- try:
- data = r.json()
- except Exception:
- r.raise_for_status()
- raise RuntimeError(r.text[:2000]) from None
- if r.status_code >= 400:
- err = data.get("error") if isinstance(data, dict) else None
- msg = err.get("message", str(data)) if isinstance(err, dict) else str(data)
- raise RuntimeError(f"Gemini HTTP {r.status_code}: {msg}")
- if not isinstance(data, dict):
- raise RuntimeError("响应不是 JSON 对象")
- if data.get("error"):
- raise RuntimeError(str(data["error"]))
- images_out: list[str] = []
- texts: list[str] = []
- for cand in data.get("candidates") or []:
- if not isinstance(cand, dict):
- continue
- for part in cand.get("content", {}).get("parts") or []:
- if not isinstance(part, dict):
- continue
- if part.get("text"):
- texts.append(str(part["text"]))
- inline = part.get("inlineData") or part.get("inline_data")
- if isinstance(inline, dict):
- b64 = inline.get("data")
- if b64:
- mime = (
- inline.get("mimeType")
- or inline.get("mime_type")
- or "image/png"
- )
- images_out.append(f"data:{mime};base64,{b64}")
- out: dict[str, Any] = {
- "images": images_out,
- "model": resolved,
- }
- if texts:
- out["text"] = "\n".join(texts)
- return out
|