Jelajahi Sumber

feat: gemini-3-pro-image-preview add extra param

feitianbubu 3 bulan lalu
induk
melakukan
d859872e0d
3 mengubah file dengan 163 tambahan dan 14 penghapusan
  1. 3 3
      dto/openai_image.go
  2. 93 11
      relay/channel/gemini/adaptor.go
  3. 67 0
      relay/channel/gemini/relay-gemini.go

+ 3 - 3
dto/openai_image.go

@@ -169,7 +169,7 @@ type ImageResponse struct {
 	Extra   any         `json:"extra,omitempty"`
 }
 type ImageData struct {
-	Url           string `json:"url"`
-	B64Json       string `json:"b64_json"`
-	RevisedPrompt string `json:"revised_prompt"`
+	Url           string `json:"url,omitempty"`
+	B64Json       string `json:"b64_json,omitempty"`
+	RevisedPrompt string `json:"revised_prompt,omitempty"`
 }

+ 93 - 11
relay/channel/gemini/adaptor.go

@@ -1,6 +1,7 @@
 package gemini
 
 import (
+	"encoding/json"
 	"errors"
 	"fmt"
 	"io"
@@ -55,6 +56,78 @@ func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInf
 	return nil, errors.New("not implemented")
 }
 
+type ImageConfig struct {
+	AspectRatio string `json:"aspectRatio,omitempty"`
+	ImageSize   string `json:"imageSize,omitempty"`
+}
+
+type SizeMapping struct {
+	AspectRatio string
+	ImageSize   string
+}
+
+type QualityMapping struct {
+	Standard string
+	HD       string
+	High     string
+	FourK    string
+	Auto     string
+}
+
+func getImageSizeMapping() QualityMapping {
+	return QualityMapping{
+		Standard: "1K",
+		HD:       "2K",
+		High:     "2K",
+		FourK:    "4K",
+		Auto:     "1K",
+	}
+}
+
+func getSizeMappings() map[string]SizeMapping {
+	return map[string]SizeMapping{
+		"1536x1024": {AspectRatio: "3:2", ImageSize: ""},
+		"1024x1536": {AspectRatio: "2:3", ImageSize: ""},
+		"1024x1792": {AspectRatio: "9:16", ImageSize: ""},
+		"1792x1024": {AspectRatio: "16:9", ImageSize: ""},
+		"2048x2048": {AspectRatio: "", ImageSize: "2K"},
+		"4096x4096": {AspectRatio: "", ImageSize: "4K"},
+	}
+}
+
+func processSizeParameters(size, quality string) ImageConfig {
+	config := ImageConfig{} // 默认为空值
+
+	if size != "" {
+		if strings.Contains(size, ":") {
+			config.AspectRatio = size // 直接设置,不与默认值比较
+		} else {
+			if mapping, exists := getSizeMappings()[size]; exists {
+				if mapping.AspectRatio != "" {
+					config.AspectRatio = mapping.AspectRatio
+				}
+				if mapping.ImageSize != "" {
+					config.ImageSize = mapping.ImageSize
+				}
+			}
+		}
+	}
+
+	if quality != "" {
+		qualityMapping := getImageSizeMapping()
+		switch strings.ToLower(strings.TrimSpace(quality)) {
+		case "hd", "high":
+			config.ImageSize = qualityMapping.HD
+		case "4k":
+			config.ImageSize = qualityMapping.FourK
+		case "standard", "medium", "low", "auto", "1k":
+			config.ImageSize = qualityMapping.Standard
+		}
+	}
+
+	return config
+}
+
 func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
 	if strings.HasPrefix(info.UpstreamModelName, "gemini-3-pro-image") {
 		chatRequest := dto.GeneralOpenAIRequest{
@@ -64,6 +137,20 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf
 			},
 			N: int(request.N),
 		}
+
+		config := processSizeParameters(strings.TrimSpace(request.Size), request.Quality)
+		googleGenerationConfig := map[string]interface{}{
+			"response_modalities": []string{"TEXT", "IMAGE"},
+			"image_config":        config,
+		}
+
+		extraBody := map[string]interface{}{
+			"google": map[string]interface{}{
+				"generation_config": googleGenerationConfig,
+			},
+		}
+		chatRequest.ExtraBody, _ = json.Marshal(extraBody)
+
 		return a.ConvertOpenAIRequest(c, info, &chatRequest)
 	}
 
@@ -74,17 +161,8 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf
 		if strings.Contains(size, ":") {
 			aspectRatio = size
 		} else {
-			switch size {
-			case "256x256", "512x512", "1024x1024":
-				aspectRatio = "1:1"
-			case "1536x1024":
-				aspectRatio = "3:2"
-			case "1024x1536":
-				aspectRatio = "2:3"
-			case "1024x1792":
-				aspectRatio = "9:16"
-			case "1792x1024":
-				aspectRatio = "16:9"
+			if mapping, exists := getSizeMappings()[size]; exists && mapping.AspectRatio != "" {
+				aspectRatio = mapping.AspectRatio
 			}
 		}
 	}
@@ -265,6 +343,10 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
 		return GeminiImageHandler(c, info, resp)
 	}
 
+	if model_setting.IsGeminiModelSupportImagine(info.UpstreamModelName) {
+		return ChatImageHandler(c, info, resp)
+	}
+
 	// check if the model is an embedding model
 	if strings.HasPrefix(info.UpstreamModelName, "text-embedding") ||
 		strings.HasPrefix(info.UpstreamModelName, "embedding") ||

+ 67 - 0
relay/channel/gemini/relay-gemini.go

@@ -1264,3 +1264,70 @@ func GeminiImageHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.
 
 	return usage, nil
 }
+
+func convertToOaiImageResponse(geminiResponse *dto.GeminiChatResponse) (*dto.ImageResponse, error) {
+	openAIResponse := &dto.ImageResponse{
+		Created: common.GetTimestamp(),
+		Data:    make([]dto.ImageData, 0),
+	}
+
+	// extract images from candidates' inlineData
+	for _, candidate := range geminiResponse.Candidates {
+		for _, part := range candidate.Content.Parts {
+			if part.InlineData != nil && strings.HasPrefix(part.InlineData.MimeType, "image") {
+				openAIResponse.Data = append(openAIResponse.Data, dto.ImageData{
+					B64Json: part.InlineData.Data,
+				})
+			}
+		}
+	}
+
+	if len(openAIResponse.Data) == 0 {
+		return nil, errors.New("no images found in response")
+	}
+
+	return openAIResponse, nil
+}
+
+func ChatImageHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
+	responseBody, readErr := io.ReadAll(resp.Body)
+	if readErr != nil {
+		return nil, types.NewOpenAIError(readErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
+	}
+	service.CloseResponseBodyGracefully(resp)
+
+	if common.DebugEnabled {
+		println("ChatImageHandler response:", string(responseBody))
+	}
+
+	var geminiResponse dto.GeminiChatResponse
+	if jsonErr := common.Unmarshal(responseBody, &geminiResponse); jsonErr != nil {
+		return nil, types.NewOpenAIError(jsonErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
+	}
+
+	if len(geminiResponse.Candidates) == 0 {
+		return nil, types.NewOpenAIError(errors.New("no images generated"), types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
+	}
+
+	openAIResponse, err := convertToOaiImageResponse(&geminiResponse)
+	if err != nil {
+		return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
+	}
+
+	jsonResponse, jsonErr := json.Marshal(openAIResponse)
+	if jsonErr != nil {
+		return nil, types.NewError(jsonErr, types.ErrorCodeBadResponseBody)
+	}
+
+	c.Writer.Header().Set("Content-Type", "application/json")
+	c.Writer.WriteHeader(resp.StatusCode)
+	_, _ = c.Writer.Write(jsonResponse)
+
+	usage := &dto.Usage{
+		PromptTokens:     geminiResponse.UsageMetadata.PromptTokenCount,
+		CompletionTokens: geminiResponse.UsageMetadata.CandidatesTokenCount,
+		TotalTokens:      geminiResponse.UsageMetadata.TotalTokenCount,
+	}
+
+	return usage, nil
+}