main.py 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. """nano_banana — 本地 HTTP 封装 Gemini 原生图模(REST generateContent)。
  2. 环境变量:
  3. GEMINI_API_KEY 必填,对应文档中的 x-goog-api-key
  4. GEMINI_IMAGE_MODEL 可选,未在请求体指定 model 时使用,默认 gemini-2.5-flash-image
  5. GEMINI_API_BASE 可选,默认 https://generativelanguage.googleapis.com/v1beta
  6. 接口:
  7. GET /health
  8. POST /generate 文生图 / 图+文生图,字段与 registry input_schema 对齐
  9. 文档: https://ai.google.dev/gemini-api/docs/image-generation?hl=zh-cn#rest
  10. """
  11. from __future__ import annotations
  12. import argparse
  13. import uvicorn
  14. from fastapi import FastAPI, HTTPException
  15. from pydantic import BaseModel, Field
  16. from gemini_image_client import generate_content
  17. app = FastAPI(title="Nano Banana — Gemini Image (REST)")
  18. class ImageInput(BaseModel):
  19. """参考图:Base64 或 data URL;字段名与 REST inline_data 对应。"""
  20. mime_type: str = Field(default="image/png", description="如 image/png、image/jpeg")
  21. data: str = Field(..., description="图片 Base64,或 data:image/...;base64,...")
  22. class GenerateRequest(BaseModel):
  23. prompt: str = Field(..., description="主提示词")
  24. model: str | None = Field(
  25. default=None,
  26. description=(
  27. "模型 ID,如 gemini-2.5-flash-image、gemini-3.1-flash-image-preview;"
  28. "省略则使用 GEMINI_IMAGE_MODEL / 内置默认"
  29. ),
  30. )
  31. aspect_ratio: str | None = Field(
  32. default=None,
  33. description='宽高比,如 "1:1"、"16:9"(见官方文档 imageConfig.aspectRatio)',
  34. )
  35. image_size: str | None = Field(
  36. default=None,
  37. description='Gemini 3.x 输出分辨率:512、1K、2K、4K(generationConfig.imageConfig.imageSize)',
  38. )
  39. response_modalities: list[str] | None = Field(
  40. default=None,
  41. description='如 ["TEXT","IMAGE"] 或 ["IMAGE"];省略则由 API 默认',
  42. )
  43. images: list[ImageInput] | None = Field(
  44. default=None,
  45. description="可选参考图列表(图生图 / 编辑),对应 REST parts 中的 inline_data",
  46. )
  47. @app.get("/health")
  48. def health() -> dict[str, str]:
  49. return {"status": "ok"}
  50. @app.post("/generate")
  51. def generate(req: GenerateRequest) -> dict:
  52. try:
  53. imgs = None
  54. if req.images:
  55. imgs = [{"mime_type": i.mime_type, "data": i.data} for i in req.images]
  56. return generate_content(
  57. prompt=req.prompt,
  58. model=req.model,
  59. aspect_ratio=req.aspect_ratio,
  60. image_size=req.image_size,
  61. response_modalities=req.response_modalities,
  62. images=imgs,
  63. )
  64. except ValueError as e:
  65. raise HTTPException(status_code=503, detail=str(e)) from e
  66. except RuntimeError as e:
  67. raise HTTPException(status_code=502, detail=str(e)) from e
  68. except Exception as e:
  69. raise HTTPException(status_code=502, detail=str(e)) from e
  70. if __name__ == "__main__":
  71. parser = argparse.ArgumentParser()
  72. parser.add_argument("--port", type=int, default=8001)
  73. args = parser.parse_args()
  74. uvicorn.run(app, host="0.0.0.0", port=args.port)