| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196 |
- """Gemini 原生图模 — REST `generateContent`(与官方文档一致,无 SDK)。
- 参考: https://ai.google.dev/gemini-api/docs/image-generation?hl=zh-cn#rest
- """
- from __future__ import annotations
- import os
- import re
- from typing import Any
- import httpx
- from dotenv import load_dotenv
- _ = load_dotenv()
- DEFAULT_MODEL = "gemini-3.1-flash-image-preview"
- GEMINI_API_BASE = os.environ.get(
- "GEMINI_API_BASE", "https://airouter.piaoquantv.com/v1beta"
- )
- _DATA_URL_RE = re.compile(r"^data:[^;]+;base64,(.+)$", re.I | re.S)
- def _strip_data_url(b64_or_data_url: str) -> str:
- s = b64_or_data_url.strip()
- m = _DATA_URL_RE.match(s)
- return m.group(1) if m else s
- import base64
- _BASE64_CHARS = set("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/=\n\r\t ")
- def _looks_like_path_not_base64(s: str) -> bool:
- """判断字符串看起来像文件路径而不是 base64 数据。
- Base64 的字符集只有 [A-Za-z0-9+/=]。如果字符串包含 '.'、'-'(后者是 URL-safe
- base64 才有)或有明显的路径形态(以 / 开头、含 .png/.jpg 等扩展名),基本可以
- 判定为路径。
- """
- if not s:
- return False
- lower = s.lower()
- # 有图片扩展名 → 肯定是路径
- if any(lower.endswith(ext) for ext in (".png", ".jpg", ".jpeg", ".webp", ".gif", ".bmp")):
- return True
- # 含有非 base64 字符 → 肯定是路径或别的东西
- if any(c not in _BASE64_CHARS for c in s[:200]):
- return True
- return False
- def _build_parts(
- prompt: str,
- images: list[dict[str, str]] | None,
- ) -> list[dict[str, Any]]:
- parts: list[dict[str, Any]] = [{"text": prompt}]
- if not images:
- return parts
- for img in images:
- mime = (img.get("mime_type") or img.get("mimeType") or "image/png").strip()
- raw = _strip_data_url(img.get("data") or "")
- # 支持自动下载 HTTP 外链并转换为 Base64
- if raw.startswith("http://") or raw.startswith("https://"):
- try:
- resp = httpx.get(raw, timeout=30.0)
- resp.raise_for_status()
- raw = base64.b64encode(resp.content).decode("utf-8")
- except Exception as e:
- raise ValueError(f"拉取网络图片失败 [{raw}]: {e}")
- elif os.path.isfile(raw):
- try:
- with open(raw, "rb") as f:
- raw = base64.b64encode(f.read()).decode("utf-8")
- except Exception as e:
- raise ValueError(f"读取本地图片文件失败 [{raw}]: {e}")
- elif _looks_like_path_not_base64(raw):
- # 防御:显然是文件路径但文件不存在,早报错比让 Gemini 抛神秘的 base64 解码错误强
- raise ValueError(
- f"images[].data 看起来是文件路径但文件不存在: {raw!r} "
- f"(cwd={os.getcwd()})。请传 HTTP URL 或绝对路径,"
- f"调用方应先把本地文件上传到 OSS 再传 CDN URL。"
- )
- if not raw:
- raise ValueError("images[].data 不能为空(Base64、本地路径、data URL 或 HTTP 链接)")
- parts.append({"inline_data": {"mime_type": mime, "data": raw}})
- return parts
- def _merge_generation_config(
- *,
- aspect_ratio: str | None,
- image_size: str | None,
- response_modalities: list[str] | None,
- ) -> dict[str, Any] | None:
- cfg: dict[str, Any] = {}
- if response_modalities:
- cfg["responseModalities"] = response_modalities
- img_cfg: dict[str, str] = {}
- if aspect_ratio:
- img_cfg["aspectRatio"] = aspect_ratio.strip()
- if image_size:
- img_cfg["imageSize"] = image_size.strip()
- if img_cfg:
- cfg["imageConfig"] = img_cfg
- return cfg or None
- def generate_content(
- *,
- prompt: str,
- model: str | None,
- aspect_ratio: str | None = None,
- image_size: str | None = None,
- response_modalities: list[str] | None = None,
- images: list[dict[str, str]] | None = None,
- ) -> dict[str, Any]:
- api_key = os.environ.get("GEMINI_API_KEY", "").strip()
- if not api_key:
- raise ValueError("缺少环境变量 GEMINI_API_KEY")
- resolved = (model or os.environ.get("GEMINI_IMAGE_MODEL") or DEFAULT_MODEL).strip()
- url = f"{GEMINI_API_BASE.rstrip('/')}/models/{resolved}:generateContent"
- body: dict[str, Any] = {
- "contents": [
- {
- "role": "user",
- "parts": _build_parts(prompt, images),
- }
- ],
- }
- gen_cfg = _merge_generation_config(
- aspect_ratio=aspect_ratio,
- image_size=image_size,
- response_modalities=response_modalities,
- )
- if gen_cfg:
- body["generationConfig"] = gen_cfg
- headers = {
- "x-goog-api-key": api_key,
- "Content-Type": "application/json",
- }
- with httpx.Client(timeout=300.0) as client:
- r = client.post(url, headers=headers, json=body)
- try:
- data = r.json()
- except Exception:
- r.raise_for_status()
- raise RuntimeError(r.text[:2000]) from None
- if r.status_code >= 400:
- err = data.get("error") if isinstance(data, dict) else None
- msg = err.get("message", str(data)) if isinstance(err, dict) else str(data)
- raise RuntimeError(f"Gemini HTTP {r.status_code}: {msg}")
- if not isinstance(data, dict):
- raise RuntimeError("响应不是 JSON 对象")
- if data.get("error"):
- raise RuntimeError(str(data["error"]))
- images_out: list[str] = []
- texts: list[str] = []
- for cand in data.get("candidates") or []:
- if not isinstance(cand, dict):
- continue
- for part in cand.get("content", {}).get("parts") or []:
- if not isinstance(part, dict):
- continue
- if part.get("text"):
- texts.append(str(part["text"]))
- inline = part.get("inlineData") or part.get("inline_data")
- if isinstance(inline, dict):
- b64 = inline.get("data")
- if b64:
- mime = (
- inline.get("mimeType")
- or inline.get("mime_type")
- or "image/png"
- )
- images_out.append(f"data:{mime};base64,{b64}")
- out: dict[str, Any] = {
- "images": images_out,
- "model": resolved,
- }
- if texts:
- out["text"] = "\n".join(texts)
- return out
|