Переглянути джерело

fix : Gemini embedding model only embeds the first text in a batch

antecanis8 7 місяців тому
батько
коміт
43263a3bc8
3 змінених файлів з 42 додано та 30 видалено
  1. 7 3
      dto/gemini.go
  2. 24 18
      relay/channel/gemini/adaptor.go
  3. 11 9
      relay/channel/gemini/relay-gemini.go

+ 7 - 3
dto/gemini.go

@@ -216,10 +216,14 @@ type GeminiEmbeddingRequest struct {
 	OutputDimensionality int               `json:"outputDimensionality,omitempty"`
 }
 
-type GeminiEmbeddingResponse struct {
-	Embedding ContentEmbedding `json:"embedding"`
+type GeminiBatchEmbeddingRequest struct {
+	Requests []*GeminiEmbeddingRequest `json:"requests"`
 }
 
-type ContentEmbedding struct {
+type GeminiEmbedding struct {
 	Values []float64 `json:"values"`
 }
+
+type GeminiBatchEmbeddingResponse struct {
+	Embeddings []*GeminiEmbedding `json:"embeddings"`
+}

+ 24 - 18
relay/channel/gemini/adaptor.go

@@ -114,7 +114,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
 	if strings.HasPrefix(info.UpstreamModelName, "text-embedding") ||
 		strings.HasPrefix(info.UpstreamModelName, "embedding") ||
 		strings.HasPrefix(info.UpstreamModelName, "gemini-embedding") {
-		return fmt.Sprintf("%s/%s/models/%s:embedContent", info.BaseUrl, version, info.UpstreamModelName), nil
+		return fmt.Sprintf("%s/%s/models/%s:batchEmbedContents", info.BaseUrl, version, info.UpstreamModelName), nil
 	}
 
 	action := "generateContent"
@@ -156,29 +156,35 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela
 	if len(inputs) == 0 {
 		return nil, errors.New("input is empty")
 	}
-
-	// only process the first input
-	geminiRequest := dto.GeminiEmbeddingRequest{
-		Content: dto.GeminiChatContent{
-			Parts: []dto.GeminiPart{
-				{
-					Text: inputs[0],
+	// process all inputs
+	geminiRequests := make([]map[string]interface{}, 0, len(inputs))
+	for _, input := range inputs {
+		geminiRequest := map[string]interface{}{
+			"model": fmt.Sprintf("models/%s", info.UpstreamModelName),
+			"content": dto.GeminiChatContent{
+				Parts: []dto.GeminiPart{
+					{
+						Text: input,
+					},
 				},
 			},
-		},
-	}
+		}
 
-	// set specific parameters for different models
-	// https://ai.google.dev/api/embeddings?hl=zh-cn#method:-models.embedcontent
-	switch info.UpstreamModelName {
-	case "text-embedding-004":
-		// except embedding-001 supports setting `OutputDimensionality`
-		if request.Dimensions > 0 {
-			geminiRequest.OutputDimensionality = request.Dimensions
+		// set specific parameters for different models
+		// https://ai.google.dev/api/embeddings?hl=zh-cn#method:-models.embedcontent
+		switch info.UpstreamModelName {
+		case "text-embedding-004":
+			// except embedding-001 supports setting `OutputDimensionality`
+			if request.Dimensions > 0 {
+				geminiRequest["outputDimensionality"] = request.Dimensions
+			}
 		}
+		geminiRequests = append(geminiRequests, geminiRequest)
 	}
 
-	return geminiRequest, nil
+	return map[string]interface{}{
+		"requests": geminiRequests,
+	}, nil
 }
 
 func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {

+ 11 - 9
relay/channel/gemini/relay-gemini.go

@@ -974,7 +974,7 @@ func GeminiEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *h
 		return nil, types.NewOpenAIError(readErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
 	}
 
-	var geminiResponse dto.GeminiEmbeddingResponse
+	var geminiResponse dto.GeminiBatchEmbeddingResponse
 	if jsonErr := common.Unmarshal(responseBody, &geminiResponse); jsonErr != nil {
 		return nil, types.NewOpenAIError(jsonErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
 	}
@@ -982,14 +982,16 @@ func GeminiEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *h
 	// convert to openai format response
 	openAIResponse := dto.OpenAIEmbeddingResponse{
 		Object: "list",
-		Data: []dto.OpenAIEmbeddingResponseItem{
-			{
-				Object:    "embedding",
-				Embedding: geminiResponse.Embedding.Values,
-				Index:     0,
-			},
-		},
-		Model: info.UpstreamModelName,
+		Data:   make([]dto.OpenAIEmbeddingResponseItem, 0, len(geminiResponse.Embeddings)),
+		Model:  info.UpstreamModelName,
+	}
+
+	for i, embedding := range geminiResponse.Embeddings {
+		openAIResponse.Data = append(openAIResponse.Data, dto.OpenAIEmbeddingResponseItem{
+			Object:    "embedding",
+			Embedding: embedding.Values,
+			Index:     i,
+		})
 	}
 
 	// calculate usage