|
@@ -124,8 +124,8 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*d
|
|
|
return nil, responseTextBuilder.String()
|
|
return nil, responseTextBuilder.String()
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
-func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage, *dto.SensitiveResponse) {
|
|
|
|
|
- var textResponseWithError dto.TextResponseWithError
|
|
|
|
|
|
|
+func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string, relayMode int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage, *dto.SensitiveResponse) {
|
|
|
|
|
+ var responseWithError dto.TextResponseWithError
|
|
|
responseBody, err := io.ReadAll(resp.Body)
|
|
responseBody, err := io.ReadAll(resp.Body)
|
|
|
if err != nil {
|
|
if err != nil {
|
|
|
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil, nil
|
|
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil, nil
|
|
@@ -134,62 +134,81 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model
|
|
|
if err != nil {
|
|
if err != nil {
|
|
|
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil, nil
|
|
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil, nil
|
|
|
}
|
|
}
|
|
|
- err = json.Unmarshal(responseBody, &textResponseWithError)
|
|
|
|
|
|
|
+ err = json.Unmarshal(responseBody, &responseWithError)
|
|
|
if err != nil {
|
|
if err != nil {
|
|
|
log.Printf("unmarshal_response_body_failed: body: %s, err: %v", string(responseBody), err)
|
|
log.Printf("unmarshal_response_body_failed: body: %s, err: %v", string(responseBody), err)
|
|
|
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil, nil
|
|
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil, nil
|
|
|
}
|
|
}
|
|
|
- if textResponseWithError.Error.Type != "" {
|
|
|
|
|
|
|
+ if responseWithError.Error.Type != "" {
|
|
|
return &dto.OpenAIErrorWithStatusCode{
|
|
return &dto.OpenAIErrorWithStatusCode{
|
|
|
- Error: textResponseWithError.Error,
|
|
|
|
|
|
|
+ Error: responseWithError.Error,
|
|
|
StatusCode: resp.StatusCode,
|
|
StatusCode: resp.StatusCode,
|
|
|
}, nil, nil
|
|
}, nil, nil
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- textResponse := &dto.TextResponse{
|
|
|
|
|
- Choices: textResponseWithError.Choices,
|
|
|
|
|
- Usage: textResponseWithError.Usage,
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
checkSensitive := constant.ShouldCheckCompletionSensitive()
|
|
checkSensitive := constant.ShouldCheckCompletionSensitive()
|
|
|
sensitiveWords := make([]string, 0)
|
|
sensitiveWords := make([]string, 0)
|
|
|
triggerSensitive := false
|
|
triggerSensitive := false
|
|
|
|
|
|
|
|
- if textResponse.Usage.TotalTokens == 0 || checkSensitive {
|
|
|
|
|
- completionTokens := 0
|
|
|
|
|
- for i, choice := range textResponse.Choices {
|
|
|
|
|
- stringContent := string(choice.Message.Content)
|
|
|
|
|
- ctkm, _, _ := service.CountTokenText(stringContent, model, false)
|
|
|
|
|
- completionTokens += ctkm
|
|
|
|
|
- if checkSensitive {
|
|
|
|
|
- sensitive, words, stringContent := service.SensitiveWordReplace(stringContent, false)
|
|
|
|
|
- if sensitive {
|
|
|
|
|
- triggerSensitive = true
|
|
|
|
|
- msg := choice.Message
|
|
|
|
|
- msg.Content = common.StringToByteSlice(stringContent)
|
|
|
|
|
- textResponse.Choices[i].Message = msg
|
|
|
|
|
- sensitiveWords = append(sensitiveWords, words...)
|
|
|
|
|
|
|
+ usage := &responseWithError.Usage
|
|
|
|
|
+
|
|
|
|
|
+ //textResponse := &dto.TextResponse{
|
|
|
|
|
+ // Choices: responseWithError.Choices,
|
|
|
|
|
+ // Usage: responseWithError.Usage,
|
|
|
|
|
+ //}
|
|
|
|
|
+ var doResponseBody []byte
|
|
|
|
|
+
|
|
|
|
|
+ switch relayMode {
|
|
|
|
|
+ case relayconstant.RelayModeEmbeddings:
|
|
|
|
|
+ embeddingResponse := &dto.OpenAIEmbeddingResponse{
|
|
|
|
|
+ Object: responseWithError.Object,
|
|
|
|
|
+ Data: responseWithError.Data,
|
|
|
|
|
+ Model: responseWithError.Model,
|
|
|
|
|
+ Usage: *usage,
|
|
|
|
|
+ }
|
|
|
|
|
+ doResponseBody, err = json.Marshal(embeddingResponse)
|
|
|
|
|
+ default:
|
|
|
|
|
+ if responseWithError.Usage.TotalTokens == 0 || checkSensitive {
|
|
|
|
|
+ completionTokens := 0
|
|
|
|
|
+ for i, choice := range responseWithError.Choices {
|
|
|
|
|
+ stringContent := string(choice.Message.Content)
|
|
|
|
|
+ ctkm, _, _ := service.CountTokenText(stringContent, model, false)
|
|
|
|
|
+ completionTokens += ctkm
|
|
|
|
|
+ if checkSensitive {
|
|
|
|
|
+ sensitive, words, stringContent := service.SensitiveWordReplace(stringContent, false)
|
|
|
|
|
+ if sensitive {
|
|
|
|
|
+ triggerSensitive = true
|
|
|
|
|
+ msg := choice.Message
|
|
|
|
|
+ msg.Content = common.StringToByteSlice(stringContent)
|
|
|
|
|
+ responseWithError.Choices[i].Message = msg
|
|
|
|
|
+ sensitiveWords = append(sensitiveWords, words...)
|
|
|
|
|
+ }
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
|
|
+ responseWithError.Usage = dto.Usage{
|
|
|
|
|
+ PromptTokens: promptTokens,
|
|
|
|
|
+ CompletionTokens: completionTokens,
|
|
|
|
|
+ TotalTokens: promptTokens + completionTokens,
|
|
|
|
|
+ }
|
|
|
}
|
|
}
|
|
|
- textResponse.Usage = dto.Usage{
|
|
|
|
|
- PromptTokens: promptTokens,
|
|
|
|
|
- CompletionTokens: completionTokens,
|
|
|
|
|
- TotalTokens: promptTokens + completionTokens,
|
|
|
|
|
|
|
+ textResponse := &dto.TextResponse{
|
|
|
|
|
+ Choices: responseWithError.Choices,
|
|
|
|
|
+ Model: responseWithError.Model,
|
|
|
|
|
+ Usage: *usage,
|
|
|
}
|
|
}
|
|
|
|
|
+ doResponseBody, err = json.Marshal(textResponse)
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
if checkSensitive && triggerSensitive && constant.StopOnSensitiveEnabled {
|
|
if checkSensitive && triggerSensitive && constant.StopOnSensitiveEnabled {
|
|
|
sensitiveWords = common.RemoveDuplicate(sensitiveWords)
|
|
sensitiveWords = common.RemoveDuplicate(sensitiveWords)
|
|
|
return service.OpenAIErrorWrapper(errors.New(fmt.Sprintf("sensitive words detected on response: %s",
|
|
return service.OpenAIErrorWrapper(errors.New(fmt.Sprintf("sensitive words detected on response: %s",
|
|
|
strings.Join(sensitiveWords, ", "))), "sensitive_words_detected", http.StatusBadRequest),
|
|
strings.Join(sensitiveWords, ", "))), "sensitive_words_detected", http.StatusBadRequest),
|
|
|
- &textResponse.Usage, &dto.SensitiveResponse{
|
|
|
|
|
|
|
+ usage, &dto.SensitiveResponse{
|
|
|
SensitiveWords: sensitiveWords,
|
|
SensitiveWords: sensitiveWords,
|
|
|
}
|
|
}
|
|
|
} else {
|
|
} else {
|
|
|
- responseBody, err = json.Marshal(textResponse)
|
|
|
|
|
// Reset response body
|
|
// Reset response body
|
|
|
- resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
|
|
|
|
|
|
|
+ resp.Body = io.NopCloser(bytes.NewBuffer(doResponseBody))
|
|
|
// We shouldn't set the header before we parse the response body, because the parse part may fail.
|
|
// We shouldn't set the header before we parse the response body, because the parse part may fail.
|
|
|
// And then we will have to send an error response, but in this case, the header has already been set.
|
|
// And then we will have to send an error response, but in this case, the header has already been set.
|
|
|
// So the httpClient will be confused by the response.
|
|
// So the httpClient will be confused by the response.
|
|
@@ -207,5 +226,5 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model
|
|
|
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil, nil
|
|
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil, nil
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
- return nil, &textResponse.Usage, nil
|
|
|
|
|
|
|
+ return nil, usage, nil
|
|
|
}
|
|
}
|