Browse Source

fix: fix SensitiveWords error

CaIon 1 year ago
parent
commit
222a55387d
4 changed files with 22 additions and 12 deletions
  1. 7 2
      dto/text_response.go
  2. 12 7
      relay/channel/openai/relay-openai.go
  3. 2 2
      relay/common/relay_utils.go
  4. 1 1
      service/sensitive.go

+ 7 - 2
dto/text_response.go

@@ -1,9 +1,14 @@
 package dto
 
+type TextResponseWithError struct {
+	Choices []OpenAITextResponseChoice `json:"choices"`
+	Usage   `json:"usage"`
+	Error   OpenAIError `json:"error"`
+}
+
 type TextResponse struct {
-	Choices []*OpenAITextResponseChoice `json:"choices"`
+	Choices []OpenAITextResponseChoice `json:"choices"`
 	Usage   `json:"usage"`
-	Error   *OpenAIError `json:"error,omitempty"`
 }
 
 type OpenAITextResponseChoice struct {

+ 12 - 7
relay/channel/openai/relay-openai.go

@@ -125,7 +125,7 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*d
 }
 
 func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage, *dto.SensitiveResponse) {
-	var textResponse dto.TextResponse
+	var textResponseWithError dto.TextResponseWithError
 	responseBody, err := io.ReadAll(resp.Body)
 	if err != nil {
 		return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil, nil
@@ -134,18 +134,23 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model
 	if err != nil {
 		return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil, nil
 	}
-	err = json.Unmarshal(responseBody, &textResponse)
+	err = json.Unmarshal(responseBody, &textResponseWithError)
 	if err != nil {
+		log.Printf("unmarshal_response_body_failed: body: %s, err: %v", string(responseBody), err)
 		return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil, nil
 	}
-	log.Printf("textResponse: %+v", textResponse)
-	if textResponse.Error != nil {
+	if textResponseWithError.Error.Type != "" {
 		return &dto.OpenAIErrorWithStatusCode{
-			Error:      *textResponse.Error,
+			Error:      textResponseWithError.Error,
 			StatusCode: resp.StatusCode,
 		}, nil, nil
 	}
 
+	textResponse := &dto.TextResponse{
+		Choices: textResponseWithError.Choices,
+		Usage:   textResponseWithError.Usage,
+	}
+
 	checkSensitive := constant.ShouldCheckCompletionSensitive()
 	sensitiveWords := make([]string, 0)
 	triggerSensitive := false
@@ -174,7 +179,7 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model
 		}
 	}
 
-	if constant.StopOnSensitiveEnabled {
+	if checkSensitive && constant.StopOnSensitiveEnabled && triggerSensitive {
 
 	} else {
 		responseBody, err = json.Marshal(textResponse)
@@ -200,7 +205,7 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model
 
 	if checkSensitive && triggerSensitive {
 		sensitiveWords = common.RemoveDuplicate(sensitiveWords)
-		return service.OpenAIErrorWrapper(errors.New(fmt.Sprintf("sensitive words detected: %s", strings.Join(sensitiveWords, ", "))), "sensitive_words_detected", http.StatusBadRequest), &textResponse.Usage, &dto.SensitiveResponse{
+		return service.OpenAIErrorWrapper(errors.New(fmt.Sprintf("sensitive words detected on response: %s", strings.Join(sensitiveWords, ", "))), "sensitive_words_detected", http.StatusBadRequest), &textResponse.Usage, &dto.SensitiveResponse{
 			SensitiveWords: sensitiveWords,
 		}
 	}

+ 2 - 2
relay/common/relay_utils.go

@@ -35,12 +35,12 @@ func RelayErrorHandler(resp *http.Response) (OpenAIErrorWithStatusCode *dto.Open
 	if err != nil {
 		return
 	}
-	var textResponse dto.TextResponse
+	var textResponse dto.TextResponseWithError
 	err = json.Unmarshal(responseBody, &textResponse)
 	if err != nil {
 		return
 	}
-	OpenAIErrorWithStatusCode.Error = *textResponse.Error
+	OpenAIErrorWithStatusCode.Error = textResponse.Error
 	return
 }
 

+ 1 - 1
service/sensitive.go

@@ -40,7 +40,7 @@ func SensitiveWordReplace(text string, returnImmediately bool) (bool, []string,
 		for _, hit := range hits {
 			pos := hit.Pos
 			word := string(hit.Word)
-			text = text[:pos] + "*###*" + text[pos+len(word):]
+			text = text[:pos] + "**###**" + text[pos+len(word):]
 			words = append(words, word)
 		}
 		return true, words, text