瀏覽代碼

feat: add Gemini Imagen image generation support

Sh1n3zZ 1 年之前
父節點
當前提交
61d2a2f92d
共有 3 個文件被更改,包括 128 次插入4 次删除
  1. 99 4
      relay/channel/gemini/adaptor.go
  2. 2 0
      relay/channel/gemini/constant.go
  3. 27 0
      relay/channel/gemini/dto.go

+ 99 - 4
relay/channel/gemini/adaptor.go

@@ -1,15 +1,21 @@
 package gemini
 
 import (
+	"encoding/json"
 	"errors"
 	"fmt"
-	"github.com/gin-gonic/gin"
 	"io"
 	"net/http"
+	"one-api/common"
 	"one-api/constant"
 	"one-api/dto"
 	"one-api/relay/channel"
 	relaycommon "one-api/relay/common"
+	"one-api/service"
+
+	"strings"
+
+	"github.com/gin-gonic/gin"
 )
 
 type Adaptor struct {
@@ -21,8 +27,36 @@ func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInf
 }
 
 func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
-	//TODO implement me
-	return nil, errors.New("not implemented")
+	if !strings.HasPrefix(info.UpstreamModelName, "imagen") {
+		return nil, errors.New("not supported model for image generation")
+	}
+
+	// convert size to aspect ratio
+	aspectRatio := "1:1" // default aspect ratio
+	switch request.Size {
+	case "1024x1024":
+		aspectRatio = "1:1"
+	case "1024x1792":
+		aspectRatio = "9:16"
+	case "1792x1024":
+		aspectRatio = "16:9"
+	}
+
+	// build gemini imagen request
+	geminiRequest := GeminiImageRequest{
+		Instances: []GeminiImageInstance{
+			{
+				Prompt: request.Prompt,
+			},
+		},
+		Parameters: GeminiImageParameters{
+			SampleCount:      request.N,
+			AspectRatio:      aspectRatio,
+			PersonGeneration: "allow_adult", // default allow adult
+		},
+	}
+
+	return geminiRequest, nil
 }
 
 func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
@@ -40,6 +74,10 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
 		}
 	}
 
+	if strings.HasPrefix(info.UpstreamModelName, "imagen") {
+		return fmt.Sprintf("%s/%s/models/%s:predict", info.BaseUrl, version, info.UpstreamModelName), nil
+	}
+
 	action := "generateContent"
 	if info.IsStream {
 		action = "streamGenerateContent?alt=sse"
@@ -73,12 +111,15 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela
 	return nil, errors.New("not implemented")
 }
 
-
 func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
 	return channel.DoApiRequest(a, c, info, requestBody)
 }
 
 func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
+	if strings.HasPrefix(info.UpstreamModelName, "imagen") {
+		return GeminiImageHandler(c, resp, info)
+	}
+
 	if info.IsStream {
 		err, usage = GeminiChatStreamHandler(c, resp, info)
 	} else {
@@ -87,6 +128,60 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
 	return
 }
 
+func GeminiImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
+	responseBody, readErr := io.ReadAll(resp.Body)
+	if readErr != nil {
+		return nil, service.OpenAIErrorWrapper(readErr, "read_response_body_failed", http.StatusInternalServerError)
+	}
+	_ = resp.Body.Close()
+
+	var geminiResponse GeminiImageResponse
+	if jsonErr := json.Unmarshal(responseBody, &geminiResponse); jsonErr != nil {
+		return nil, service.OpenAIErrorWrapper(jsonErr, "unmarshal_response_body_failed", http.StatusInternalServerError)
+	}
+
+	if len(geminiResponse.Predictions) == 0 {
+		return nil, service.OpenAIErrorWrapper(errors.New("no images generated"), "no_images", http.StatusBadRequest)
+	}
+
+	// convert to openai format response
+	openAIResponse := dto.ImageResponse{
+		Created: common.GetTimestamp(),
+		Data:    make([]dto.ImageData, 0, len(geminiResponse.Predictions)),
+	}
+
+	for _, prediction := range geminiResponse.Predictions {
+		if prediction.RaiFilteredReason != "" {
+			continue // skip filtered image
+		}
+		openAIResponse.Data = append(openAIResponse.Data, dto.ImageData{
+			B64Json: prediction.BytesBase64Encoded,
+		})
+	}
+
+	jsonResponse, jsonErr := json.Marshal(openAIResponse)
+	if jsonErr != nil {
+		return nil, service.OpenAIErrorWrapper(jsonErr, "marshal_response_failed", http.StatusInternalServerError)
+	}
+
+	c.Writer.Header().Set("Content-Type", "application/json")
+	c.Writer.WriteHeader(resp.StatusCode)
+	_, _ = c.Writer.Write(jsonResponse)
+
+	// https://github.com/google-gemini/cookbook/blob/719a27d752aac33f39de18a8d3cb42a70874917e/quickstarts/Counting_Tokens.ipynb
+	// each image has fixed 258 tokens
+	const imageTokens = 258
+	generatedImages := len(openAIResponse.Data)
+
+	usage = &dto.Usage{
+		PromptTokens:     imageTokens * generatedImages, // each generated image has fixed 258 tokens
+		CompletionTokens: 0,                             // image generation does not calculate completion tokens
+		TotalTokens:      imageTokens * generatedImages,
+	}
+
+	return usage, nil
+}
+
 func (a *Adaptor) GetModelList() []string {
 	return ModelList
 }

+ 2 - 0
relay/channel/gemini/constant.go

@@ -16,6 +16,8 @@ var ModelList = []string{
 	"gemini-2.0-pro-exp",
 	// thinking exp
 	"gemini-2.0-flash-thinking-exp",
+	// imagen models
+	"imagen-3.0-generate-002",
 }
 
 var ChannelName = "google gemini"

+ 27 - 0
relay/channel/gemini/dto.go

@@ -109,3 +109,30 @@ type GeminiUsageMetadata struct {
 	CandidatesTokenCount int `json:"candidatesTokenCount"`
 	TotalTokenCount      int `json:"totalTokenCount"`
 }
+
+// Imagen related structs
+type GeminiImageRequest struct {
+	Instances  []GeminiImageInstance `json:"instances"`
+	Parameters GeminiImageParameters `json:"parameters"`
+}
+
+type GeminiImageInstance struct {
+	Prompt string `json:"prompt"`
+}
+
+type GeminiImageParameters struct {
+	SampleCount      int    `json:"sampleCount,omitempty"`
+	AspectRatio      string `json:"aspectRatio,omitempty"`
+	PersonGeneration string `json:"personGeneration,omitempty"`
+}
+
+type GeminiImageResponse struct {
+	Predictions []GeminiImagePrediction `json:"predictions"`
+}
+
+type GeminiImagePrediction struct {
+	MimeType           string `json:"mimeType"`
+	BytesBase64Encoded string `json:"bytesBase64Encoded"`
+	RaiFilteredReason  string `json:"raiFilteredReason,omitempty"`
+	SafetyAttributes   any    `json:"safetyAttributes,omitempty"`
+}