Просмотр исходного кода

fix typo; add ParamOverride for Gemini Embedding

RedwindA 7 месяцев назад
Родитель
Сommit
7a31e481a6

+ 3 - 2
relay/channel/gemini/adaptor.go

@@ -115,7 +115,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
 		strings.HasPrefix(info.UpstreamModelName, "embedding") ||
 		strings.HasPrefix(info.UpstreamModelName, "gemini-embedding") {
 		action := "embedContent"
-		if info.IsGeminiBatchEmbdding {
+		if info.IsGeminiBatchEmbedding {
 			action = "batchEmbedContents"
 		}
 		return fmt.Sprintf("%s/%s/models/%s:%s", info.BaseUrl, version, info.UpstreamModelName, action), nil
@@ -199,7 +199,8 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
 
 func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
 	if info.RelayMode == constant.RelayModeGemini {
-		if strings.Contains(info.RequestURLPath, "embed") {
+		if strings.HasSuffix(info.RequestURLPath, ":embedContent") ||
+			strings.HasSuffix(info.RequestURLPath, ":batchEmbedContents") {
 			return NativeGeminiEmbeddingHandler(c, resp, info)
 		}
 		if info.IsStream {

+ 1 - 1
relay/channel/gemini/relay-gemini-native.go

@@ -81,7 +81,7 @@ func NativeGeminiEmbeddingHandler(c *gin.Context, resp *http.Response, info *rel
 		TotalTokens:  info.PromptTokens,
 	}
 
-	if info.IsGeminiBatchEmbdding {
+	if info.IsGeminiBatchEmbedding {
 		var geminiResponse dto.GeminiBatchEmbeddingResponse
 		err = common.Unmarshal(responseBody, &geminiResponse)
 		if err != nil {

+ 8 - 8
relay/common/relay_info.go

@@ -74,14 +74,14 @@ type RelayInfo struct {
 	FirstResponseTime    time.Time
 	isFirstResponse      bool
 	//SendLastReasoningResponse bool
-	ApiType               int
-	IsStream              bool
-	IsGeminiBatchEmbdding bool
-	IsPlayground          bool
-	UsePrice              bool
-	RelayMode             int
-	UpstreamModelName     string
-	OriginModelName       string
+	ApiType                int
+	IsStream               bool
+	IsGeminiBatchEmbedding bool
+	IsPlayground           bool
+	UsePrice               bool
+	RelayMode              int
+	UpstreamModelName      string
+	OriginModelName        string
 	//RecodeModelName      string
 	RequestURLPath       string
 	ApiVersion           string

+ 14 - 1
relay/gemini_handler.go

@@ -269,7 +269,7 @@ func GeminiEmbeddingHandler(c *gin.Context) (newAPIError *types.NewAPIError) {
 	relayInfo := relaycommon.GenRelayInfoGemini(c)
 
 	isBatch := strings.HasSuffix(c.Request.URL.Path, "batchEmbedContents")
-	relayInfo.IsGeminiBatchEmbdding = isBatch
+	relayInfo.IsGeminiBatchEmbedding = isBatch
 
 	var promptTokens int
 	var req any
@@ -338,6 +338,19 @@ func GeminiEmbeddingHandler(c *gin.Context) (newAPIError *types.NewAPIError) {
 	if err != nil {
 		return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
 	}
+
+	// apply param override
+	if len(relayInfo.ParamOverride) > 0 {
+		reqMap := make(map[string]interface{})
+		_ = common.Unmarshal(jsonData, &reqMap)
+		for key, value := range relayInfo.ParamOverride {
+			reqMap[key] = value
+		}
+		jsonData, err = common.Marshal(reqMap)
+		if err != nil {
+			return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
+		}
+	}
 	requestBody = bytes.NewReader(jsonData)
 
 	resp, err := adaptor.DoRequest(c, relayInfo, requestBody)