Преглед изворни кода

Merge remote-tracking branch 'origin/alpha' into alpha

t0ng7u пре 6 месеци
родитељ
комит
13aee98d4a
2 измењених фајлова са 71 додато и 10 уклоњено
  1. 15 8
      relay/channel/gemini/adaptor.go
  2. 56 2
      relay/channel/vertex/adaptor.go

+ 15 - 8
relay/channel/gemini/adaptor.go

@@ -59,15 +59,22 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf
 		return nil, errors.New("not supported model for image generation")
 	}
 
-	// convert size to aspect ratio
+	// convert size to aspect ratio but allow user to specify aspect ratio
 	aspectRatio := "1:1" // default aspect ratio
-	switch request.Size {
-	case "1024x1024":
-		aspectRatio = "1:1"
-	case "1024x1792":
-		aspectRatio = "9:16"
-	case "1792x1024":
-		aspectRatio = "16:9"
+	size := strings.TrimSpace(request.Size)
+	if size != "" {
+		if strings.Contains(size, ":") {
+			aspectRatio = size
+		} else {
+			switch size {
+			case "1024x1024":
+				aspectRatio = "1:1"
+			case "1024x1792":
+				aspectRatio = "9:16"
+			case "1792x1024":
+				aspectRatio = "16:9"
+			}
+		}
 	}
 
 	// build gemini imagen request

+ 56 - 2
relay/channel/vertex/adaptor.go

@@ -66,8 +66,8 @@ func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInf
 }
 
 func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
-	//TODO implement me
-	return nil, errors.New("not implemented")
+	geminiAdaptor := gemini.Adaptor{}
+	return geminiAdaptor.ConvertImageRequest(c, info, request)
 }
 
 func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
@@ -181,6 +181,60 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
 	if request == nil {
 		return nil, errors.New("request is nil")
 	}
+	if a.RequestMode == RequestModeGemini && strings.HasPrefix(info.UpstreamModelName, "imagen") {
+		prompt := ""
+		for _, m := range request.Messages {
+			if m.Role == "user" {
+				prompt = m.StringContent()
+				if prompt != "" {
+					break
+				}
+			}
+		}
+		if prompt == "" {
+			if p, ok := request.Prompt.(string); ok {
+				prompt = p
+			}
+		}
+		if prompt == "" {
+			return nil, errors.New("prompt is required for image generation")
+		}
+
+		imgReq := dto.ImageRequest{
+			Model:  request.Model,
+			Prompt: prompt,
+			N:      1,
+			Size:   "1024x1024",
+		}
+		if request.N > 0 {
+			imgReq.N = uint(request.N)
+		}
+		if request.Size != "" {
+			imgReq.Size = request.Size
+		}
+		if len(request.ExtraBody) > 0 {
+			var extra map[string]any
+			if err := json.Unmarshal(request.ExtraBody, &extra); err == nil {
+				if n, ok := extra["n"].(float64); ok && n > 0 {
+					imgReq.N = uint(n)
+				}
+				if size, ok := extra["size"].(string); ok {
+					imgReq.Size = size
+				}
+				// accept aspectRatio in extra body (top-level or under parameters)
+				if ar, ok := extra["aspectRatio"].(string); ok && ar != "" {
+					imgReq.Size = ar
+				}
+				if params, ok := extra["parameters"].(map[string]any); ok {
+					if ar, ok := params["aspectRatio"].(string); ok && ar != "" {
+						imgReq.Size = ar
+					}
+				}
+			}
+		}
+		c.Set("request_model", request.Model)
+		return a.ConvertImageRequest(c, info, imgReq)
+	}
 	if a.RequestMode == RequestModeClaude {
 		claudeReq, err := claude.RequestOpenAI2ClaudeMessage(c, *request)
 		if err != nil {