Kaynağa Gözat

fix: fix embedding

CaIon 1 yıl önce
ebeveyn
işleme
c4b3d3a975

+ 5 - 1
dto/text_response.go

@@ -1,12 +1,16 @@
 package dto
 
 type TextResponseWithError struct {
-	Choices []OpenAITextResponseChoice `json:"choices"`
+	Choices []OpenAITextResponseChoice    `json:"choices"`
+	Object  string                        `json:"object"`
+	Data    []OpenAIEmbeddingResponseItem `json:"data"`
+	Model   string                        `json:"model"`
 	Usage   `json:"usage"`
 	Error   OpenAIError `json:"error"`
 }
 
 type TextResponse struct {
+	Model   string                     `json:"model"`
 	Choices []OpenAITextResponseChoice `json:"choices"`
 	Usage   `json:"usage"`
 }

+ 1 - 1
relay/channel/ollama/adaptor.go

@@ -45,7 +45,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
 		err, responseText = openai.OpenaiStreamHandler(c, resp, info.RelayMode)
 		usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
 	} else {
-		err, usage, sensitiveResp = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
+		err, usage, sensitiveResp = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName, info.RelayMode)
 	}
 	return
 }

+ 1 - 1
relay/channel/openai/adaptor.go

@@ -77,7 +77,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
 		err, responseText = OpenaiStreamHandler(c, resp, info.RelayMode)
 		usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
 	} else {
-		err, usage, sensitiveResp = OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
+		err, usage, sensitiveResp = OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName, info.RelayMode)
 	}
 	return
 }

+ 51 - 32
relay/channel/openai/relay-openai.go

@@ -124,8 +124,8 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*d
 	return nil, responseTextBuilder.String()
 }
 
-func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage, *dto.SensitiveResponse) {
-	var textResponseWithError dto.TextResponseWithError
+func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string, relayMode int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage, *dto.SensitiveResponse) {
+	var responseWithError dto.TextResponseWithError
 	responseBody, err := io.ReadAll(resp.Body)
 	if err != nil {
 		return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil, nil
@@ -134,62 +134,81 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model
 	if err != nil {
 		return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil, nil
 	}
-	err = json.Unmarshal(responseBody, &textResponseWithError)
+	err = json.Unmarshal(responseBody, &responseWithError)
 	if err != nil {
 		log.Printf("unmarshal_response_body_failed: body: %s, err: %v", string(responseBody), err)
 		return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil, nil
 	}
-	if textResponseWithError.Error.Type != "" {
+	if responseWithError.Error.Type != "" {
 		return &dto.OpenAIErrorWithStatusCode{
-			Error:      textResponseWithError.Error,
+			Error:      responseWithError.Error,
 			StatusCode: resp.StatusCode,
 		}, nil, nil
 	}
 
-	textResponse := &dto.TextResponse{
-		Choices: textResponseWithError.Choices,
-		Usage:   textResponseWithError.Usage,
-	}
-
 	checkSensitive := constant.ShouldCheckCompletionSensitive()
 	sensitiveWords := make([]string, 0)
 	triggerSensitive := false
 
-	if textResponse.Usage.TotalTokens == 0 || checkSensitive {
-		completionTokens := 0
-		for i, choice := range textResponse.Choices {
-			stringContent := string(choice.Message.Content)
-			ctkm, _, _ := service.CountTokenText(stringContent, model, false)
-			completionTokens += ctkm
-			if checkSensitive {
-				sensitive, words, stringContent := service.SensitiveWordReplace(stringContent, false)
-				if sensitive {
-					triggerSensitive = true
-					msg := choice.Message
-					msg.Content = common.StringToByteSlice(stringContent)
-					textResponse.Choices[i].Message = msg
-					sensitiveWords = append(sensitiveWords, words...)
+	usage := &responseWithError.Usage
+
+	//textResponse := &dto.TextResponse{
+	//	Choices: responseWithError.Choices,
+	//	Usage:   responseWithError.Usage,
+	//}
+	var doResponseBody []byte
+
+	switch relayMode {
+	case relayconstant.RelayModeEmbeddings:
+		embeddingResponse := &dto.OpenAIEmbeddingResponse{
+			Object: responseWithError.Object,
+			Data:   responseWithError.Data,
+			Model:  responseWithError.Model,
+			Usage:  *usage,
+		}
+		doResponseBody, err = json.Marshal(embeddingResponse)
+	default:
+		if responseWithError.Usage.TotalTokens == 0 || checkSensitive {
+			completionTokens := 0
+			for i, choice := range responseWithError.Choices {
+				stringContent := string(choice.Message.Content)
+				ctkm, _, _ := service.CountTokenText(stringContent, model, false)
+				completionTokens += ctkm
+				if checkSensitive {
+					sensitive, words, stringContent := service.SensitiveWordReplace(stringContent, false)
+					if sensitive {
+						triggerSensitive = true
+						msg := choice.Message
+						msg.Content = common.StringToByteSlice(stringContent)
+						responseWithError.Choices[i].Message = msg
+						sensitiveWords = append(sensitiveWords, words...)
+					}
 				}
 			}
+			responseWithError.Usage = dto.Usage{
+				PromptTokens:     promptTokens,
+				CompletionTokens: completionTokens,
+				TotalTokens:      promptTokens + completionTokens,
+			}
 		}
-		textResponse.Usage = dto.Usage{
-			PromptTokens:     promptTokens,
-			CompletionTokens: completionTokens,
-			TotalTokens:      promptTokens + completionTokens,
+		textResponse := &dto.TextResponse{
+			Choices: responseWithError.Choices,
+			Model:   responseWithError.Model,
+			Usage:   *usage,
 		}
+		doResponseBody, err = json.Marshal(textResponse)
 	}
 
 	if checkSensitive && triggerSensitive && constant.StopOnSensitiveEnabled {
 		sensitiveWords = common.RemoveDuplicate(sensitiveWords)
 		return service.OpenAIErrorWrapper(errors.New(fmt.Sprintf("sensitive words detected on response: %s",
 				strings.Join(sensitiveWords, ", "))), "sensitive_words_detected", http.StatusBadRequest),
-			&textResponse.Usage, &dto.SensitiveResponse{
+			usage, &dto.SensitiveResponse{
 				SensitiveWords: sensitiveWords,
 			}
 	} else {
-		responseBody, err = json.Marshal(textResponse)
 		// Reset response body
-		resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
+		resp.Body = io.NopCloser(bytes.NewBuffer(doResponseBody))
 		// We shouldn't set the header before we parse the response body, because the parse part may fail.
 		// And then we will have to send an error response, but in this case, the header has already been set.
 		// So the httpClient will be confused by the response.
@@ -207,5 +226,5 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model
 			return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil, nil
 		}
 	}
-	return nil, &textResponse.Usage, nil
+	return nil, usage, nil
 }

+ 1 - 1
relay/channel/perplexity/adaptor.go

@@ -49,7 +49,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
 		err, responseText = openai.OpenaiStreamHandler(c, resp, info.RelayMode)
 		usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
 	} else {
-		err, usage, sensitiveResp = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
+		err, usage, sensitiveResp = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName, info.RelayMode)
 	}
 	return
 }

+ 1 - 1
relay/channel/zhipu_4v/adaptor.go

@@ -50,7 +50,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
 		err, responseText = openai.OpenaiStreamHandler(c, resp, info.RelayMode)
 		usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
 	} else {
-		err, usage, sensitiveResp = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
+		err, usage, sensitiveResp = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName, info.RelayMode)
 	}
 	return
 }