Explorar el Código

fix: remove sensitive check on completion (close #157)

CaIon hace 1 año
padre
commit
44a8ade4ba

+ 5 - 4
constant/sensitive.go

@@ -4,7 +4,8 @@ import "strings"
 
 var CheckSensitiveEnabled = true
 var CheckSensitiveOnPromptEnabled = true
-var CheckSensitiveOnCompletionEnabled = true
+
+//var CheckSensitiveOnCompletionEnabled = true
 
 // StopOnSensitiveEnabled 如果检测到敏感词,是否立刻停止生成,否则替换敏感词
 var StopOnSensitiveEnabled = true
@@ -37,6 +38,6 @@ func ShouldCheckPromptSensitive() bool {
 	return CheckSensitiveEnabled && CheckSensitiveOnPromptEnabled
 }
 
-func ShouldCheckCompletionSensitive() bool {
-	return CheckSensitiveEnabled && CheckSensitiveOnCompletionEnabled
-}
+//func ShouldCheckCompletionSensitive() bool {
+//	return CheckSensitiveEnabled && CheckSensitiveOnCompletionEnabled
+//}

+ 1 - 1
controller/channel-test.go

@@ -87,7 +87,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openaiErr
 		err := relaycommon.RelayErrorHandler(resp)
 		return fmt.Errorf("status code %d: %s", resp.StatusCode, err.Error.Message), &err.Error
 	}
-	usage, respErr, _ := adaptor.DoResponse(c, resp, meta)
+	usage, respErr := adaptor.DoResponse(c, resp, meta)
 	if respErr != nil {
 		return fmt.Errorf("%s", respErr.Error.Message), &respErr.Error
 	}

+ 6 - 0
dto/text_response.go

@@ -11,6 +11,12 @@ type TextResponseWithError struct {
 	Error   OpenAIError `json:"error"`
 }
 
+type SimpleResponse struct {
+	Usage   `json:"usage"`
+	Error   OpenAIError                `json:"error"`
+	Choices []OpenAITextResponseChoice `json:"choices"`
+}
+
 type TextResponse struct {
 	Id      string                     `json:"id"`
 	Object  string                     `json:"object"`

+ 3 - 3
model/option.go

@@ -93,7 +93,7 @@ func InitOptionMap() {
 	common.OptionMap["MjNotifyEnabled"] = strconv.FormatBool(constant.MjNotifyEnabled)
 	common.OptionMap["CheckSensitiveEnabled"] = strconv.FormatBool(constant.CheckSensitiveEnabled)
 	common.OptionMap["CheckSensitiveOnPromptEnabled"] = strconv.FormatBool(constant.CheckSensitiveOnPromptEnabled)
-	common.OptionMap["CheckSensitiveOnCompletionEnabled"] = strconv.FormatBool(constant.CheckSensitiveOnCompletionEnabled)
+	//common.OptionMap["CheckSensitiveOnCompletionEnabled"] = strconv.FormatBool(constant.CheckSensitiveOnCompletionEnabled)
 	common.OptionMap["StopOnSensitiveEnabled"] = strconv.FormatBool(constant.StopOnSensitiveEnabled)
 	common.OptionMap["SensitiveWords"] = constant.SensitiveWordsToString()
 	common.OptionMap["StreamCacheQueueLength"] = strconv.Itoa(constant.StreamCacheQueueLength)
@@ -196,8 +196,8 @@ func updateOptionMap(key string, value string) (err error) {
 			constant.CheckSensitiveEnabled = boolValue
 		case "CheckSensitiveOnPromptEnabled":
 			constant.CheckSensitiveOnPromptEnabled = boolValue
-		case "CheckSensitiveOnCompletionEnabled":
-			constant.CheckSensitiveOnCompletionEnabled = boolValue
+		//case "CheckSensitiveOnCompletionEnabled":
+		//	constant.CheckSensitiveOnCompletionEnabled = boolValue
 		case "StopOnSensitiveEnabled":
 			constant.StopOnSensitiveEnabled = boolValue
 		case "SMTPSSLEnabled":

+ 1 - 1
relay/channel/adapter.go

@@ -15,7 +15,7 @@ type Adaptor interface {
 	SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error
 	ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error)
 	DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error)
-	DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode, sensitiveResp *dto.SensitiveResponse)
+	DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode)
 	GetModelList() []string
 	GetChannelName() string
 }

+ 1 - 1
relay/channel/ali/adaptor.go

@@ -57,7 +57,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
 	return channel.DoApiRequest(a, c, info, requestBody)
 }
 
-func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode, sensitiveResp *dto.SensitiveResponse) {
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
 	if info.IsStream {
 		err, usage = aliStreamHandler(c, resp)
 	} else {

+ 1 - 1
relay/channel/baidu/adaptor.go

@@ -69,7 +69,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
 	return channel.DoApiRequest(a, c, info, requestBody)
 }
 
-func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode, sensitiveResp *dto.SensitiveResponse) {
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
 	if info.IsStream {
 		err, usage = baiduStreamHandler(c, resp)
 	} else {

+ 1 - 1
relay/channel/claude/adaptor.go

@@ -63,7 +63,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
 	return channel.DoApiRequest(a, c, info, requestBody)
 }
 
-func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode, sensitiveResp *dto.SensitiveResponse) {
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
 	if info.IsStream {
 		err, usage = claudeStreamHandler(a.RequestMode, info.UpstreamModelName, info.PromptTokens, c, resp)
 	} else {

+ 1 - 2
relay/channel/claude/relay-claude.go

@@ -8,7 +8,6 @@ import (
 	"io"
 	"net/http"
 	"one-api/common"
-	"one-api/constant"
 	"one-api/dto"
 	"one-api/service"
 	"strings"
@@ -317,7 +316,7 @@ func claudeHandler(requestMode int, c *gin.Context, resp *http.Response, promptT
 		}, nil
 	}
 	fullTextResponse := responseClaude2OpenAI(requestMode, &claudeResponse)
-	completionTokens, err, _ := service.CountTokenText(claudeResponse.Completion, model, constant.ShouldCheckCompletionSensitive())
+	completionTokens, err, _ := service.CountTokenText(claudeResponse.Completion, model, false)
 	if err != nil {
 		return service.OpenAIErrorWrapper(err, "count_token_text_failed", http.StatusInternalServerError), nil
 	}

+ 1 - 1
relay/channel/gemini/adaptor.go

@@ -47,7 +47,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
 	return channel.DoApiRequest(a, c, info, requestBody)
 }
 
-func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode, sensitiveResp *dto.SensitiveResponse) {
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
 	if info.IsStream {
 		var responseText string
 		err, responseText = geminiChatStreamHandler(c, resp)

+ 1 - 2
relay/channel/gemini/relay-gemini.go

@@ -7,7 +7,6 @@ import (
 	"io"
 	"net/http"
 	"one-api/common"
-	"one-api/constant"
 	"one-api/dto"
 	relaycommon "one-api/relay/common"
 	"one-api/service"
@@ -257,7 +256,7 @@ func geminiChatHandler(c *gin.Context, resp *http.Response, promptTokens int, mo
 		}, nil
 	}
 	fullTextResponse := responseGeminiChat2OpenAI(&geminiResponse)
-	completionTokens, _, _ := service.CountTokenText(geminiResponse.GetResponseText(), model, constant.ShouldCheckCompletionSensitive())
+	completionTokens, _, _ := service.CountTokenText(geminiResponse.GetResponseText(), model, false)
 	usage := dto.Usage{
 		PromptTokens:     promptTokens,
 		CompletionTokens: completionTokens,

+ 3 - 3
relay/channel/ollama/adaptor.go

@@ -49,16 +49,16 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
 	return channel.DoApiRequest(a, c, info, requestBody)
 }
 
-func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode, sensitiveResp *dto.SensitiveResponse) {
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
 	if info.IsStream {
 		var responseText string
 		err, responseText = openai.OpenaiStreamHandler(c, resp, info.RelayMode)
 		usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
 	} else {
 		if info.RelayMode == relayconstant.RelayModeEmbeddings {
-			err, usage, sensitiveResp = ollamaEmbeddingHandler(c, resp, info.PromptTokens, info.UpstreamModelName, info.RelayMode)
+			err, usage = ollamaEmbeddingHandler(c, resp, info.PromptTokens, info.UpstreamModelName, info.RelayMode)
 		} else {
-			err, usage, sensitiveResp = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName, info.RelayMode)
+			err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
 		}
 	}
 	return

+ 8 - 8
relay/channel/ollama/relay-ollama.go

@@ -45,19 +45,19 @@ func requestOpenAI2Embeddings(request dto.GeneralOpenAIRequest) *OllamaEmbedding
 	}
 }
 
-func ollamaEmbeddingHandler(c *gin.Context, resp *http.Response, promptTokens int, model string, relayMode int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage, *dto.SensitiveResponse) {
+func ollamaEmbeddingHandler(c *gin.Context, resp *http.Response, promptTokens int, model string, relayMode int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
 	var ollamaEmbeddingResponse OllamaEmbeddingResponse
 	responseBody, err := io.ReadAll(resp.Body)
 	if err != nil {
-		return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil, nil
+		return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
 	}
 	err = resp.Body.Close()
 	if err != nil {
-		return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil, nil
+		return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
 	}
 	err = json.Unmarshal(responseBody, &ollamaEmbeddingResponse)
 	if err != nil {
-		return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil, nil
+		return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
 	}
 	data := make([]dto.OpenAIEmbeddingResponseItem, 0, 1)
 	data = append(data, dto.OpenAIEmbeddingResponseItem{
@@ -77,7 +77,7 @@ func ollamaEmbeddingHandler(c *gin.Context, resp *http.Response, promptTokens in
 	}
 	doResponseBody, err := json.Marshal(embeddingResponse)
 	if err != nil {
-		return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil, nil
+		return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
 	}
 	resp.Body = io.NopCloser(bytes.NewBuffer(doResponseBody))
 	// We shouldn't set the header before we parse the response body, because the parse part may fail.
@@ -98,11 +98,11 @@ func ollamaEmbeddingHandler(c *gin.Context, resp *http.Response, promptTokens in
 	c.Writer.WriteHeader(resp.StatusCode)
 	_, err = io.Copy(c.Writer, resp.Body)
 	if err != nil {
-		return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil, nil
+		return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
 	}
 	err = resp.Body.Close()
 	if err != nil {
-		return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil, nil
+		return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
 	}
-	return nil, usage, nil
+	return nil, usage
 }

+ 2 - 2
relay/channel/openai/adaptor.go

@@ -69,13 +69,13 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
 	return channel.DoApiRequest(a, c, info, requestBody)
 }
 
-func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode, sensitiveResp *dto.SensitiveResponse) {
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
 	if info.IsStream {
 		var responseText string
 		err, responseText = OpenaiStreamHandler(c, resp, info.RelayMode)
 		usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
 	} else {
-		err, usage, sensitiveResp = OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName, info.RelayMode)
+		err, usage = OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
 	}
 	return
 }

+ 37 - 112
relay/channel/openai/relay-openai.go

@@ -4,14 +4,10 @@ import (
 	"bufio"
 	"bytes"
 	"encoding/json"
-	"errors"
-	"fmt"
 	"github.com/gin-gonic/gin"
 	"io"
-	"log"
 	"net/http"
 	"one-api/common"
-	"one-api/constant"
 	"one-api/dto"
 	relayconstant "one-api/relay/constant"
 	"one-api/service"
@@ -21,7 +17,7 @@ import (
 )
 
 func OpenaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*dto.OpenAIErrorWithStatusCode, string) {
-	checkSensitive := constant.ShouldCheckCompletionSensitive()
+	//checkSensitive := constant.ShouldCheckCompletionSensitive()
 	var responseTextBuilder strings.Builder
 	scanner := bufio.NewScanner(resp.Body)
 	scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
@@ -53,20 +49,11 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*d
 			if data[:6] != "data: " && data[:6] != "[DONE]" {
 				continue
 			}
-			sensitive := false
-			if checkSensitive {
-				// check sensitive
-				sensitive, _, data = service.SensitiveWordReplace(data, false)
-			}
 			dataChan <- data
 			data = data[6:]
 			if !strings.HasPrefix(data, "[DONE]") {
 				streamItems = append(streamItems, data)
 			}
-			if sensitive && constant.StopOnSensitiveEnabled {
-				dataChan <- "data: [DONE]"
-				break
-			}
 		}
 		streamResp := "[" + strings.Join(streamItems, ",") + "]"
 		switch relayMode {
@@ -142,118 +129,56 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*d
 	return nil, responseTextBuilder.String()
 }
 
-func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string, relayMode int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage, *dto.SensitiveResponse) {
-	var responseWithError dto.TextResponseWithError
+func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
+	var simpleResponse dto.SimpleResponse
 	responseBody, err := io.ReadAll(resp.Body)
 	if err != nil {
-		return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil, nil
+		return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
 	}
 	err = resp.Body.Close()
 	if err != nil {
-		return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil, nil
+		return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
 	}
-	err = json.Unmarshal(responseBody, &responseWithError)
+	err = json.Unmarshal(responseBody, &simpleResponse)
 	if err != nil {
-		log.Printf("unmarshal_response_body_failed: body: %s, err: %v", string(responseBody), err)
-		return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil, nil
+		return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
 	}
-	if responseWithError.Error.Type != "" {
+	if simpleResponse.Error.Type != "" {
 		return &dto.OpenAIErrorWithStatusCode{
-			Error:      responseWithError.Error,
+			Error:      simpleResponse.Error,
 			StatusCode: resp.StatusCode,
-		}, nil, nil
+		}, nil
 	}
-
-	checkSensitive := constant.ShouldCheckCompletionSensitive()
-	sensitiveWords := make([]string, 0)
-	triggerSensitive := false
-
-	usage := &responseWithError.Usage
-
-	//textResponse := &dto.TextResponse{
-	//	Choices: responseWithError.Choices,
-	//	Usage:   responseWithError.Usage,
-	//}
-	var doResponseBody []byte
-
-	switch relayMode {
-	case relayconstant.RelayModeEmbeddings:
-		embeddingResponse := &dto.OpenAIEmbeddingResponse{
-			Object: responseWithError.Object,
-			Data:   responseWithError.Data,
-			Model:  responseWithError.Model,
-			Usage:  *usage,
-		}
-		doResponseBody, err = json.Marshal(embeddingResponse)
-	default:
-		if responseWithError.Usage.TotalTokens == 0 || checkSensitive {
-			completionTokens := 0
-			for i, choice := range responseWithError.Choices {
-				stringContent := string(choice.Message.Content)
-				ctkm, _, _ := service.CountTokenText(stringContent, model, false)
-				completionTokens += ctkm
-				if checkSensitive {
-					sensitive, words, stringContent := service.SensitiveWordReplace(stringContent, false)
-					if sensitive {
-						triggerSensitive = true
-						msg := choice.Message
-						msg.Content = common.StringToByteSlice(stringContent)
-						responseWithError.Choices[i].Message = msg
-						sensitiveWords = append(sensitiveWords, words...)
-					}
-				}
-			}
-			responseWithError.Usage = dto.Usage{
-				PromptTokens:     promptTokens,
-				CompletionTokens: completionTokens,
-				TotalTokens:      promptTokens + completionTokens,
-			}
-		}
-		textResponse := &dto.TextResponse{
-			Id:      responseWithError.Id,
-			Created: responseWithError.Created,
-			Object:  responseWithError.Object,
-			Choices: responseWithError.Choices,
-			Model:   responseWithError.Model,
-			Usage:   *usage,
-		}
-		doResponseBody, err = json.Marshal(textResponse)
+	// Reset response body
+	resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
+	// We shouldn't set the header before we parse the response body, because the parse part may fail.
+	// And then we will have to send an error response, but in this case, the header has already been set.
+	// So the httpClient will be confused by the response.
+	// For example, Postman will report error, and we cannot check the response at all.
+	for k, v := range resp.Header {
+		c.Writer.Header().Set(k, v[0])
+	}
+	c.Writer.WriteHeader(resp.StatusCode)
+	_, err = io.Copy(c.Writer, resp.Body)
+	if err != nil {
+		return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
+	}
+	err = resp.Body.Close()
+	if err != nil {
+		return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
 	}
 
-	if checkSensitive && triggerSensitive && constant.StopOnSensitiveEnabled {
-		sensitiveWords = common.RemoveDuplicate(sensitiveWords)
-		return service.OpenAIErrorWrapper(errors.New(fmt.Sprintf("sensitive words detected on response: %s",
-				strings.Join(sensitiveWords, ", "))), "sensitive_words_detected", http.StatusBadRequest),
-			usage, &dto.SensitiveResponse{
-				SensitiveWords: sensitiveWords,
-			}
-	} else {
-		// Reset response body
-		resp.Body = io.NopCloser(bytes.NewBuffer(doResponseBody))
-		// We shouldn't set the header before we parse the response body, because the parse part may fail.
-		// And then we will have to send an error response, but in this case, the header has already been set.
-		// So the httpClient will be confused by the response.
-		// For example, Postman will report error, and we cannot check the response at all.
-		// Copy headers
-		for k, v := range resp.Header {
-			// 删除任何现有的相同头部,以防止重复添加头部
-			c.Writer.Header().Del(k)
-			for _, vv := range v {
-				c.Writer.Header().Add(k, vv)
-			}
-		}
-		// reset content length
-		c.Writer.Header().Del("Content-Length")
-		c.Writer.Header().Set("Content-Length", fmt.Sprintf("%d", len(doResponseBody)))
-		c.Writer.WriteHeader(resp.StatusCode)
-		_, err = io.Copy(c.Writer, resp.Body)
-		if err != nil {
-			return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil, nil
+	if simpleResponse.Usage.TotalTokens == 0 {
+		completionTokens := 0
+		for _, choice := range simpleResponse.Choices {
+			ctkm, _, _ := service.CountTokenText(string(choice.Message.Content), model, false)
+			completionTokens += ctkm
 		}
-		err = resp.Body.Close()
-		if err != nil {
-			return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil, nil
+		simpleResponse.Usage = dto.Usage{
+			PromptTokens:     promptTokens,
+			CompletionTokens: completionTokens,
+			TotalTokens:      promptTokens + completionTokens,
 		}
 	}
-	return nil, usage, nil
+	return nil, &simpleResponse.Usage
 }

+ 1 - 1
relay/channel/palm/adaptor.go

@@ -39,7 +39,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
 	return channel.DoApiRequest(a, c, info, requestBody)
 }
 
-func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode, sensitiveResp *dto.SensitiveResponse) {
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
 	if info.IsStream {
 		var responseText string
 		err, responseText = palmStreamHandler(c, resp)

+ 1 - 2
relay/channel/palm/relay-palm.go

@@ -7,7 +7,6 @@ import (
 	"io"
 	"net/http"
 	"one-api/common"
-	"one-api/constant"
 	"one-api/dto"
 	relaycommon "one-api/relay/common"
 	"one-api/service"
@@ -157,7 +156,7 @@ func palmHandler(c *gin.Context, resp *http.Response, promptTokens int, model st
 		}, nil
 	}
 	fullTextResponse := responsePaLM2OpenAI(&palmResponse)
-	completionTokens, _, _ := service.CountTokenText(palmResponse.Candidates[0].Content, model, constant.ShouldCheckCompletionSensitive())
+	completionTokens, _, _ := service.CountTokenText(palmResponse.Candidates[0].Content, model, false)
 	usage := dto.Usage{
 		PromptTokens:     promptTokens,
 		CompletionTokens: completionTokens,

+ 2 - 2
relay/channel/perplexity/adaptor.go

@@ -43,13 +43,13 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
 	return channel.DoApiRequest(a, c, info, requestBody)
 }
 
-func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode, sensitiveResp *dto.SensitiveResponse) {
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
 	if info.IsStream {
 		var responseText string
 		err, responseText = openai.OpenaiStreamHandler(c, resp, info.RelayMode)
 		usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
 	} else {
-		err, usage, sensitiveResp = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName, info.RelayMode)
+		err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
 	}
 	return
 }

+ 1 - 1
relay/channel/tencent/adaptor.go

@@ -53,7 +53,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
 	return channel.DoApiRequest(a, c, info, requestBody)
 }
 
-func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode, sensitiveResp *dto.SensitiveResponse) {
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
 	if info.IsStream {
 		var responseText string
 		err, responseText = tencentStreamHandler(c, resp)

+ 3 - 3
relay/channel/xunfei/adaptor.go

@@ -43,13 +43,13 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
 	return dummyResp, nil
 }
 
-func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode, sensitiveResp *dto.SensitiveResponse) {
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
 	splits := strings.Split(info.ApiKey, "|")
 	if len(splits) != 3 {
-		return nil, service.OpenAIErrorWrapper(errors.New("invalid auth"), "invalid_auth", http.StatusBadRequest), nil
+		return nil, service.OpenAIErrorWrapper(errors.New("invalid auth"), "invalid_auth", http.StatusBadRequest)
 	}
 	if a.request == nil {
-		return nil, service.OpenAIErrorWrapper(errors.New("request is nil"), "request_is_nil", http.StatusBadRequest), nil
+		return nil, service.OpenAIErrorWrapper(errors.New("request is nil"), "request_is_nil", http.StatusBadRequest)
 	}
 	if info.IsStream {
 		err, usage = xunfeiStreamHandler(c, *a.request, splits[0], splits[1], splits[2])

+ 1 - 1
relay/channel/zhipu/adaptor.go

@@ -46,7 +46,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
 	return channel.DoApiRequest(a, c, info, requestBody)
 }
 
-func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode, sensitiveResp *dto.SensitiveResponse) {
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
 	if info.IsStream {
 		err, usage = zhipuStreamHandler(c, resp)
 	} else {

+ 2 - 2
relay/channel/zhipu_4v/adaptor.go

@@ -44,13 +44,13 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
 	return channel.DoApiRequest(a, c, info, requestBody)
 }
 
-func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode, sensitiveResp *dto.SensitiveResponse) {
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
 	if info.IsStream {
 		var responseText string
 		err, responseText = openai.OpenaiStreamHandler(c, resp, info.RelayMode)
 		usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
 	} else {
-		err, usage, sensitiveResp = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName, info.RelayMode)
+		err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
 	}
 	return
 }

+ 1 - 1
relay/relay-audio.go

@@ -173,7 +173,7 @@ func AudioHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
 			if strings.HasPrefix(audioRequest.Model, "tts-1") {
 				quota = promptTokens
 			} else {
-				quota, err, _ = service.CountAudioToken(audioResponse.Text, audioRequest.Model, constant.ShouldCheckCompletionSensitive())
+				quota, err, _ = service.CountAudioToken(audioResponse.Text, audioRequest.Model, false)
 			}
 			quota = int(float64(quota) * ratio)
 			if ratio != 0 && quota <= 0 {

+ 8 - 17
relay/relay-text.go

@@ -165,21 +165,12 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
 		return service.OpenAIErrorWrapper(fmt.Errorf("bad response status code: %d", resp.StatusCode), "bad_response_status_code", resp.StatusCode)
 	}
 
-	usage, openaiErr, sensitiveResp := adaptor.DoResponse(c, resp, relayInfo)
+	usage, openaiErr := adaptor.DoResponse(c, resp, relayInfo)
 	if openaiErr != nil {
-		if sensitiveResp == nil { // 如果没有敏感词检查结果
-			returnPreConsumedQuota(c, relayInfo.TokenId, userQuota, preConsumedQuota)
-			return openaiErr
-		} else {
-			// 如果有敏感词检查结果,不返回预消耗配额,继续消耗配额
-			postConsumeQuota(c, relayInfo, *textRequest, usage, ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, sensitiveResp)
-			if constant.StopOnSensitiveEnabled { // 是否直接返回错误
-				return openaiErr
-			}
-			return nil
-		}
+		returnPreConsumedQuota(c, relayInfo.TokenId, userQuota, preConsumedQuota)
+		return openaiErr
 	}
-	postConsumeQuota(c, relayInfo, *textRequest, usage, ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, nil)
+	postConsumeQuota(c, relayInfo, *textRequest, usage, ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice)
 	return nil
 }
 
@@ -258,7 +249,7 @@ func returnPreConsumedQuota(c *gin.Context, tokenId int, userQuota int, preConsu
 
 func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, textRequest dto.GeneralOpenAIRequest,
 	usage *dto.Usage, ratio float64, preConsumedQuota int, userQuota int, modelRatio float64, groupRatio float64,
-	modelPrice float64, sensitiveResp *dto.SensitiveResponse) {
+	modelPrice float64) {
 
 	useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
 	promptTokens := usage.PromptTokens
@@ -293,9 +284,9 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, textRe
 		logContent += fmt.Sprintf("(可能是上游超时)")
 		common.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, textRequest.Model, preConsumedQuota))
 	} else {
-		if sensitiveResp != nil {
-			logContent += fmt.Sprintf(",敏感词:%s", strings.Join(sensitiveResp.SensitiveWords, ", "))
-		}
+		//if sensitiveResp != nil {
+		//	logContent += fmt.Sprintf(",敏感词:%s", strings.Join(sensitiveResp.SensitiveWords, ", "))
+		//}
 		quotaDelta := quota - preConsumedQuota
 		err := model.PostConsumeTokenQuota(relayInfo.TokenId, userQuota, quotaDelta, preConsumedQuota, true)
 		if err != nil {

+ 14 - 14
web/src/components/OperationSetting.js

@@ -330,21 +330,21 @@ const OperationSetting = () => {
               name='CheckSensitiveOnPromptEnabled'
               onChange={handleInputChange}
             />
-            <Form.Checkbox
-              checked={inputs.CheckSensitiveOnCompletionEnabled === 'true'}
-              label='启用生成内容检查'
-              name='CheckSensitiveOnCompletionEnabled'
-              onChange={handleInputChange}
-            />
-          </Form.Group>
-          <Form.Group inline>
-            <Form.Checkbox
-              checked={inputs.StopOnSensitiveEnabled === 'true'}
-              label='在检测到屏蔽词时,立刻停止生成,否则替换屏蔽词'
-              name='StopOnSensitiveEnabled'
-              onChange={handleInputChange}
-            />
+            {/*<Form.Checkbox*/}
+            {/*  checked={inputs.CheckSensitiveOnCompletionEnabled === 'true'}*/}
+            {/*  label='启用生成内容检查'*/}
+            {/*  name='CheckSensitiveOnCompletionEnabled'*/}
+            {/*  onChange={handleInputChange}*/}
+            {/*/>*/}
           </Form.Group>
+          {/*<Form.Group inline>*/}
+          {/*  <Form.Checkbox*/}
+          {/*    checked={inputs.StopOnSensitiveEnabled === 'true'}*/}
+          {/*    label='在检测到屏蔽词时,立刻停止生成,否则替换屏蔽词'*/}
+          {/*    name='StopOnSensitiveEnabled'*/}
+          {/*    onChange={handleInputChange}*/}
+          {/*  />*/}
+          {/*</Form.Group>*/}
           {/*<Form.Group>*/}
           {/*  <Form.Input*/}
           {/*    label="流模式下缓存队列,默认不缓存,设置越大检测越准确,但是回复会有卡顿感"*/}