Explorar o código

feat: Enhance sensitive word detection with detailed logging

1808837298@qq.com hai 1 ano
pai
achega
9cc6385b0c
Modificáronse 4 ficheiros con 28 adicións e 32 borrados
  1. 3 1
      relay/relay-audio.go
  2. 2 1
      relay/relay-image.go
  3. 9 7
      relay/relay-text.go
  4. 14 23
      service/sensitive.go

+ 3 - 1
relay/relay-audio.go

@@ -13,6 +13,7 @@ import (
 	"one-api/relay/helper"
 	"one-api/service"
 	"one-api/setting"
+	"strings"
 )
 
 func getAndValidAudioRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.AudioRequest, error) {
@@ -27,8 +28,9 @@ func getAndValidAudioRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.
 			return nil, errors.New("model is required")
 		}
 		if setting.ShouldCheckPromptSensitive() {
-			err := service.CheckSensitiveInput(audioRequest.Input)
+			words, err := service.CheckSensitiveInput(audioRequest.Input)
 			if err != nil {
+				common.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(words, ",")))
 				return nil, err
 			}
 		}

+ 2 - 1
relay/relay-image.go

@@ -61,8 +61,9 @@ func getAndValidImageRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.
 	//	return service.OpenAIErrorWrapper(errors.New("n must be between 1 and 10"), "invalid_field_value", http.StatusBadRequest)
 	//}
 	if setting.ShouldCheckPromptSensitive() {
-		err := service.CheckSensitiveInput(imageRequest.Prompt)
+		words, err := service.CheckSensitiveInput(imageRequest.Prompt)
 		if err != nil {
+			common.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(words, ",")))
 			return nil, err
 		}
 	}

+ 9 - 7
relay/relay-text.go

@@ -78,8 +78,9 @@ func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
 	}
 
 	if setting.ShouldCheckPromptSensitive() {
-		err = checkRequestSensitive(textRequest, relayInfo)
+		words, err := checkRequestSensitive(textRequest, relayInfo)
 		if err != nil {
+			common.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(words, ", ")))
 			return service.OpenAIErrorWrapperLocal(err, "sensitive_words_detected", http.StatusBadRequest)
 		}
 	}
@@ -219,19 +220,20 @@ func getPromptTokens(textRequest *dto.GeneralOpenAIRequest, info *relaycommon.Re
 	return promptTokens, err
 }
 
-func checkRequestSensitive(textRequest *dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) error {
+func checkRequestSensitive(textRequest *dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) ([]string, error) {
 	var err error
+	var words []string
 	switch info.RelayMode {
 	case relayconstant.RelayModeChatCompletions:
-		err = service.CheckSensitiveMessages(textRequest.Messages)
+		words, err = service.CheckSensitiveMessages(textRequest.Messages)
 	case relayconstant.RelayModeCompletions:
-		err = service.CheckSensitiveInput(textRequest.Prompt)
+		words, err = service.CheckSensitiveInput(textRequest.Prompt)
 	case relayconstant.RelayModeModerations:
-		err = service.CheckSensitiveInput(textRequest.Input)
+		words, err = service.CheckSensitiveInput(textRequest.Input)
 	case relayconstant.RelayModeEmbeddings:
-		err = service.CheckSensitiveInput(textRequest.Input)
+		words, err = service.CheckSensitiveInput(textRequest.Input)
 	}
-	return err
+	return words, err
 }
 
 // 预扣费并返回用户剩余配额

+ 14 - 23
service/sensitive.go

@@ -8,39 +8,30 @@ import (
 	"strings"
 )
 
-func CheckSensitiveMessages(messages []dto.Message) error {
+func CheckSensitiveMessages(messages []dto.Message) ([]string, error) {
 	for _, message := range messages {
-		if len(message.Content) > 0 {
-			if message.IsStringContent() {
-				stringContent := message.StringContent()
-				if ok, words := SensitiveWordContains(stringContent); ok {
-					return errors.New("sensitive words: " + strings.Join(words, ","))
-				}
-			}
-		} else {
-			arrayContent := message.ParseContent()
-			for _, m := range arrayContent {
-				if m.Type == "image_url" {
-					// TODO: check image url
-				} else {
-					if ok, words := SensitiveWordContains(m.Text); ok {
-						return errors.New("sensitive words: " + strings.Join(words, ","))
-					}
+		arrayContent := message.ParseContent()
+		for _, m := range arrayContent {
+			if m.Type == "image_url" {
+				// TODO: check image url
+			} else {
+				if ok, words := SensitiveWordContains(m.Text); ok {
+					return words, errors.New("sensitive words detected")
 				}
 			}
 		}
 	}
-	return nil
+	return nil, nil
 }
 
-func CheckSensitiveText(text string) error {
+func CheckSensitiveText(text string) ([]string, error) {
 	if ok, words := SensitiveWordContains(text); ok {
-		return errors.New("sensitive words: " + strings.Join(words, ","))
+		return words, errors.New("sensitive words detected")
 	}
-	return nil
+	return nil, nil
 }
 
-func CheckSensitiveInput(input any) error {
+func CheckSensitiveInput(input any) ([]string, error) {
 	switch v := input.(type) {
 	case string:
 		return CheckSensitiveText(v)
@@ -60,7 +51,7 @@ func SensitiveWordContains(text string) (bool, []string) {
 		return false, nil
 	}
 	checkText := strings.ToLower(text)
-	return AcSearch(checkText, setting.SensitiveWords, false)
+	return AcSearch(checkText, setting.SensitiveWords, true)
 }
 
 // SensitiveWordReplace 敏感词替换,返回是否包含敏感词和替换后的文本