|
|
@@ -4,14 +4,10 @@ import (
|
|
|
"bufio"
|
|
|
"bytes"
|
|
|
"encoding/json"
|
|
|
- "errors"
|
|
|
- "fmt"
|
|
|
"github.com/gin-gonic/gin"
|
|
|
"io"
|
|
|
- "log"
|
|
|
"net/http"
|
|
|
"one-api/common"
|
|
|
- "one-api/constant"
|
|
|
"one-api/dto"
|
|
|
relayconstant "one-api/relay/constant"
|
|
|
"one-api/service"
|
|
|
@@ -21,7 +17,7 @@ import (
|
|
|
)
|
|
|
|
|
|
func OpenaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*dto.OpenAIErrorWithStatusCode, string) {
|
|
|
- checkSensitive := constant.ShouldCheckCompletionSensitive()
|
|
|
+ //checkSensitive := constant.ShouldCheckCompletionSensitive()
|
|
|
var responseTextBuilder strings.Builder
|
|
|
scanner := bufio.NewScanner(resp.Body)
|
|
|
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
|
|
@@ -53,20 +49,11 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*d
|
|
|
if data[:6] != "data: " && data[:6] != "[DONE]" {
|
|
|
continue
|
|
|
}
|
|
|
- sensitive := false
|
|
|
- if checkSensitive {
|
|
|
- // check sensitive
|
|
|
- sensitive, _, data = service.SensitiveWordReplace(data, false)
|
|
|
- }
|
|
|
dataChan <- data
|
|
|
data = data[6:]
|
|
|
if !strings.HasPrefix(data, "[DONE]") {
|
|
|
streamItems = append(streamItems, data)
|
|
|
}
|
|
|
- if sensitive && constant.StopOnSensitiveEnabled {
|
|
|
- dataChan <- "data: [DONE]"
|
|
|
- break
|
|
|
- }
|
|
|
}
|
|
|
streamResp := "[" + strings.Join(streamItems, ",") + "]"
|
|
|
switch relayMode {
|
|
|
@@ -142,118 +129,56 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*d
|
|
|
return nil, responseTextBuilder.String()
|
|
|
}
|
|
|
|
|
|
-func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string, relayMode int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage, *dto.SensitiveResponse) {
|
|
|
- var responseWithError dto.TextResponseWithError
|
|
|
+func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
|
|
+ var simpleResponse dto.SimpleResponse
|
|
|
responseBody, err := io.ReadAll(resp.Body)
|
|
|
if err != nil {
|
|
|
- return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil, nil
|
|
|
+ return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
|
|
}
|
|
|
err = resp.Body.Close()
|
|
|
if err != nil {
|
|
|
- return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil, nil
|
|
|
+ return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
|
|
}
|
|
|
- err = json.Unmarshal(responseBody, &responseWithError)
|
|
|
+ err = json.Unmarshal(responseBody, &simpleResponse)
|
|
|
if err != nil {
|
|
|
- log.Printf("unmarshal_response_body_failed: body: %s, err: %v", string(responseBody), err)
|
|
|
- return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil, nil
|
|
|
+ return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
|
|
}
|
|
|
- if responseWithError.Error.Type != "" {
|
|
|
+ if simpleResponse.Error.Type != "" {
|
|
|
return &dto.OpenAIErrorWithStatusCode{
|
|
|
- Error: responseWithError.Error,
|
|
|
+ Error: simpleResponse.Error,
|
|
|
StatusCode: resp.StatusCode,
|
|
|
- }, nil, nil
|
|
|
+ }, nil
|
|
|
}
|
|
|
-
|
|
|
- checkSensitive := constant.ShouldCheckCompletionSensitive()
|
|
|
- sensitiveWords := make([]string, 0)
|
|
|
- triggerSensitive := false
|
|
|
-
|
|
|
- usage := &responseWithError.Usage
|
|
|
-
|
|
|
- //textResponse := &dto.TextResponse{
|
|
|
- // Choices: responseWithError.Choices,
|
|
|
- // Usage: responseWithError.Usage,
|
|
|
- //}
|
|
|
- var doResponseBody []byte
|
|
|
-
|
|
|
- switch relayMode {
|
|
|
- case relayconstant.RelayModeEmbeddings:
|
|
|
- embeddingResponse := &dto.OpenAIEmbeddingResponse{
|
|
|
- Object: responseWithError.Object,
|
|
|
- Data: responseWithError.Data,
|
|
|
- Model: responseWithError.Model,
|
|
|
- Usage: *usage,
|
|
|
- }
|
|
|
- doResponseBody, err = json.Marshal(embeddingResponse)
|
|
|
- default:
|
|
|
- if responseWithError.Usage.TotalTokens == 0 || checkSensitive {
|
|
|
- completionTokens := 0
|
|
|
- for i, choice := range responseWithError.Choices {
|
|
|
- stringContent := string(choice.Message.Content)
|
|
|
- ctkm, _, _ := service.CountTokenText(stringContent, model, false)
|
|
|
- completionTokens += ctkm
|
|
|
- if checkSensitive {
|
|
|
- sensitive, words, stringContent := service.SensitiveWordReplace(stringContent, false)
|
|
|
- if sensitive {
|
|
|
- triggerSensitive = true
|
|
|
- msg := choice.Message
|
|
|
- msg.Content = common.StringToByteSlice(stringContent)
|
|
|
- responseWithError.Choices[i].Message = msg
|
|
|
- sensitiveWords = append(sensitiveWords, words...)
|
|
|
- }
|
|
|
- }
|
|
|
- }
|
|
|
- responseWithError.Usage = dto.Usage{
|
|
|
- PromptTokens: promptTokens,
|
|
|
- CompletionTokens: completionTokens,
|
|
|
- TotalTokens: promptTokens + completionTokens,
|
|
|
- }
|
|
|
- }
|
|
|
- textResponse := &dto.TextResponse{
|
|
|
- Id: responseWithError.Id,
|
|
|
- Created: responseWithError.Created,
|
|
|
- Object: responseWithError.Object,
|
|
|
- Choices: responseWithError.Choices,
|
|
|
- Model: responseWithError.Model,
|
|
|
- Usage: *usage,
|
|
|
- }
|
|
|
- doResponseBody, err = json.Marshal(textResponse)
|
|
|
+ // Reset response body
|
|
|
+ resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
|
|
|
+ // We shouldn't set the header before we parse the response body, because the parse part may fail.
|
|
|
+ // And then we will have to send an error response, but in this case, the header has already been set.
|
|
|
+ // So the httpClient will be confused by the response.
|
|
|
+ // For example, Postman will report error, and we cannot check the response at all.
|
|
|
+ for k, v := range resp.Header {
|
|
|
+ c.Writer.Header().Set(k, v[0])
|
|
|
+ }
|
|
|
+ c.Writer.WriteHeader(resp.StatusCode)
|
|
|
+ _, err = io.Copy(c.Writer, resp.Body)
|
|
|
+ if err != nil {
|
|
|
+ return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
|
|
|
+ }
|
|
|
+ err = resp.Body.Close()
|
|
|
+ if err != nil {
|
|
|
+ return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
|
|
}
|
|
|
|
|
|
- if checkSensitive && triggerSensitive && constant.StopOnSensitiveEnabled {
|
|
|
- sensitiveWords = common.RemoveDuplicate(sensitiveWords)
|
|
|
- return service.OpenAIErrorWrapper(errors.New(fmt.Sprintf("sensitive words detected on response: %s",
|
|
|
- strings.Join(sensitiveWords, ", "))), "sensitive_words_detected", http.StatusBadRequest),
|
|
|
- usage, &dto.SensitiveResponse{
|
|
|
- SensitiveWords: sensitiveWords,
|
|
|
- }
|
|
|
- } else {
|
|
|
- // Reset response body
|
|
|
- resp.Body = io.NopCloser(bytes.NewBuffer(doResponseBody))
|
|
|
- // We shouldn't set the header before we parse the response body, because the parse part may fail.
|
|
|
- // And then we will have to send an error response, but in this case, the header has already been set.
|
|
|
- // So the httpClient will be confused by the response.
|
|
|
- // For example, Postman will report error, and we cannot check the response at all.
|
|
|
- // Copy headers
|
|
|
- for k, v := range resp.Header {
|
|
|
- // 删除任何现有的相同头部,以防止重复添加头部
|
|
|
- c.Writer.Header().Del(k)
|
|
|
- for _, vv := range v {
|
|
|
- c.Writer.Header().Add(k, vv)
|
|
|
- }
|
|
|
- }
|
|
|
- // reset content length
|
|
|
- c.Writer.Header().Del("Content-Length")
|
|
|
- c.Writer.Header().Set("Content-Length", fmt.Sprintf("%d", len(doResponseBody)))
|
|
|
- c.Writer.WriteHeader(resp.StatusCode)
|
|
|
- _, err = io.Copy(c.Writer, resp.Body)
|
|
|
- if err != nil {
|
|
|
- return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil, nil
|
|
|
+ if simpleResponse.Usage.TotalTokens == 0 {
|
|
|
+ completionTokens := 0
|
|
|
+ for _, choice := range simpleResponse.Choices {
|
|
|
+ ctkm, _, _ := service.CountTokenText(string(choice.Message.Content), model, false)
|
|
|
+ completionTokens += ctkm
|
|
|
}
|
|
|
- err = resp.Body.Close()
|
|
|
- if err != nil {
|
|
|
- return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil, nil
|
|
|
+ simpleResponse.Usage = dto.Usage{
|
|
|
+ PromptTokens: promptTokens,
|
|
|
+ CompletionTokens: completionTokens,
|
|
|
+ TotalTokens: promptTokens + completionTokens,
|
|
|
}
|
|
|
}
|
|
|
- return nil, usage, nil
|
|
|
+ return nil, &simpleResponse.Usage
|
|
|
}
|