Просмотр исходного кода

feat: gemini Embeddings support

Sh1n3zZ 1 год назад
Родитель
Сommit
e1b9f164f9

+ 44 - 2
relay/channel/gemini/adaptor.go

@@ -70,6 +70,12 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
 		return fmt.Sprintf("%s/%s/models/%s:predict", info.BaseUrl, version, info.UpstreamModelName), nil
 	}
 
+	if strings.HasPrefix(info.UpstreamModelName, "text-embedding") ||
+		strings.HasPrefix(info.UpstreamModelName, "embedding") ||
+		strings.HasPrefix(info.UpstreamModelName, "gemini-embedding") {
+		return fmt.Sprintf("%s/%s/models/%s:embedContent", info.BaseUrl, version, info.UpstreamModelName), nil
+	}
+
 	action := "generateContent"
 	if info.IsStream {
 		action = "streamGenerateContent?alt=sse"
@@ -99,8 +105,37 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
 }
 
 func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
-	//TODO implement me
-	return nil, errors.New("not implemented")
+	if request.Input == nil {
+		return nil, errors.New("input is required")
+	}
+
+	inputs := request.ParseInput()
+	if len(inputs) == 0 {
+		return nil, errors.New("input is empty")
+	}
+
+	// only process the first input
+	geminiRequest := GeminiEmbeddingRequest{
+		Content: GeminiChatContent{
+			Parts: []GeminiPart{
+				{
+					Text: inputs[0],
+				},
+			},
+		},
+	}
+
+	// set specific parameters for different models
+	// https://ai.google.dev/api/embeddings?hl=zh-cn#method:-models.embedcontent
+	switch info.UpstreamModelName {
+	case "text-embedding-004":
+		// except embedding-001 supports setting `OutputDimensionality`
+		if request.Dimensions > 0 {
+			geminiRequest.OutputDimensionality = request.Dimensions
+		}
+	}
+
+	return geminiRequest, nil
 }
 
 func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
@@ -112,6 +147,13 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
 		return GeminiImageHandler(c, resp, info)
 	}
 
+	// check if the model is an embedding model
+	if strings.HasPrefix(info.UpstreamModelName, "text-embedding") ||
+		strings.HasPrefix(info.UpstreamModelName, "embedding") ||
+		strings.HasPrefix(info.UpstreamModelName, "gemini-embedding") {
+		return GeminiEmbeddingHandler(c, resp, info)
+	}
+
 	if info.IsStream {
 		err, usage = GeminiChatStreamHandler(c, resp, info)
 	} else {

+ 4 - 0
relay/channel/gemini/constant.go

@@ -18,6 +18,10 @@ var ModelList = []string{
 	"gemini-2.0-flash-thinking-exp",
 	// imagen models
 	"imagen-3.0-generate-002",
+	// embedding models
+	"gemini-embedding-exp-03-07",
+	"text-embedding-004",
+	"embedding-001",
 }
 
 var SafetySettingList = []string{

+ 16 - 0
relay/channel/gemini/dto.go

@@ -136,3 +136,19 @@ type GeminiImagePrediction struct {
 	RaiFilteredReason  string `json:"raiFilteredReason,omitempty"`
 	SafetyAttributes   any    `json:"safetyAttributes,omitempty"`
 }
+
+// Embedding related structs
+type GeminiEmbeddingRequest struct {
+	Content              GeminiChatContent `json:"content"`
+	TaskType             string            `json:"taskType,omitempty"`
+	Title                string            `json:"title,omitempty"`
+	OutputDimensionality int               `json:"outputDimensionality,omitempty"`
+}
+
+type GeminiEmbeddingResponse struct {
+	Embedding ContentEmbedding `json:"embedding"`
+}
+
+type ContentEmbedding struct {
+	Values []float64 `json:"values"`
+}

+ 49 - 0
relay/channel/gemini/relay-gemini.go

@@ -580,3 +580,52 @@ func GeminiChatHandler(c *gin.Context, resp *http.Response, info *relaycommon.Re
 	_, err = c.Writer.Write(jsonResponse)
 	return nil, &usage
 }
+
+func GeminiEmbeddingHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
+	responseBody, readErr := io.ReadAll(resp.Body)
+	if readErr != nil {
+		return nil, service.OpenAIErrorWrapper(readErr, "read_response_body_failed", http.StatusInternalServerError)
+	}
+	_ = resp.Body.Close()
+
+	var geminiResponse GeminiEmbeddingResponse
+	if jsonErr := json.Unmarshal(responseBody, &geminiResponse); jsonErr != nil {
+		return nil, service.OpenAIErrorWrapper(jsonErr, "unmarshal_response_body_failed", http.StatusInternalServerError)
+	}
+
+	// convert to openai format response
+	openAIResponse := dto.OpenAIEmbeddingResponse{
+		Object: "list",
+		Data: []dto.OpenAIEmbeddingResponseItem{
+			{
+				Object:    "embedding",
+				Embedding: geminiResponse.Embedding.Values,
+				Index:     0,
+			},
+		},
+		Model: info.UpstreamModelName,
+	}
+
+	// calculate usage
+	// https://ai.google.dev/gemini-api/docs/pricing?hl=zh-cn#text-embedding-004
+	// Google has not yet clarified how embedding models will be billed
+	// refer to openai billing method to use input tokens billing
+	// https://platform.openai.com/docs/guides/embeddings#what-are-embeddings
+	usage = &dto.Usage{
+		PromptTokens:     info.PromptTokens,
+		CompletionTokens: 0,
+		TotalTokens:      info.PromptTokens,
+	}
+	openAIResponse.Usage = *usage.(*dto.Usage)
+
+	jsonResponse, jsonErr := json.Marshal(openAIResponse)
+	if jsonErr != nil {
+		return nil, service.OpenAIErrorWrapper(jsonErr, "marshal_response_failed", http.StatusInternalServerError)
+	}
+
+	c.Writer.Header().Set("Content-Type", "application/json")
+	c.Writer.WriteHeader(resp.StatusCode)
+	_, _ = c.Writer.Write(jsonResponse)
+
+	return usage, nil
+}