Просмотр исходного кода

refactor: refactor openai related code

JustSong 2 лет назад
Родитель
Сommit
12a0e7105e
2 измененных файлов с 139 добавлено и 108 удалено
  1. 133 0
      controller/relay-openai.go
  2. 6 108
      controller/relay-text.go

+ 133 - 0
controller/relay-openai.go

@@ -0,0 +1,133 @@
+package controller
+
+import (
+	"bufio"
+	"bytes"
+	"encoding/json"
+	"github.com/gin-gonic/gin"
+	"io"
+	"net/http"
+	"one-api/common"
+	"strings"
+)
+
+func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*OpenAIErrorWithStatusCode, string) {
+	responseText := ""
+	scanner := bufio.NewScanner(resp.Body)
+	scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
+		if atEOF && len(data) == 0 {
+			return 0, nil, nil
+		}
+		if i := strings.Index(string(data), "\n"); i >= 0 {
+			return i + 1, data[0:i], nil
+		}
+		if atEOF {
+			return len(data), data, nil
+		}
+		return 0, nil, nil
+	})
+	dataChan := make(chan string)
+	stopChan := make(chan bool)
+	go func() {
+		for scanner.Scan() {
+			data := scanner.Text()
+			if len(data) < 6 { // ignore blank line or wrong format
+				continue
+			}
+			dataChan <- data
+			data = data[6:]
+			if !strings.HasPrefix(data, "[DONE]") {
+				switch relayMode {
+				case RelayModeChatCompletions:
+					var streamResponse ChatCompletionsStreamResponse
+					err := json.Unmarshal([]byte(data), &streamResponse)
+					if err != nil {
+						common.SysError("error unmarshalling stream response: " + err.Error())
+						return
+					}
+					for _, choice := range streamResponse.Choices {
+						responseText += choice.Delta.Content
+					}
+				case RelayModeCompletions:
+					var streamResponse CompletionsStreamResponse
+					err := json.Unmarshal([]byte(data), &streamResponse)
+					if err != nil {
+						common.SysError("error unmarshalling stream response: " + err.Error())
+						return
+					}
+					for _, choice := range streamResponse.Choices {
+						responseText += choice.Text
+					}
+				}
+			}
+		}
+		stopChan <- true
+	}()
+	c.Writer.Header().Set("Content-Type", "text/event-stream")
+	c.Writer.Header().Set("Cache-Control", "no-cache")
+	c.Writer.Header().Set("Connection", "keep-alive")
+	c.Writer.Header().Set("Transfer-Encoding", "chunked")
+	c.Writer.Header().Set("X-Accel-Buffering", "no")
+	c.Stream(func(w io.Writer) bool {
+		select {
+		case data := <-dataChan:
+			if strings.HasPrefix(data, "data: [DONE]") {
+				data = data[:12]
+			}
+			// some implementations may add \r at the end of data
+			data = strings.TrimSuffix(data, "\r")
+			c.Render(-1, common.CustomEvent{Data: data})
+			return true
+		case <-stopChan:
+			return false
+		}
+	})
+	err := resp.Body.Close()
+	if err != nil {
+		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
+	}
+	return nil, responseText
+}
+
+func openaiHandler(c *gin.Context, resp *http.Response, consumeQuota bool) (*OpenAIErrorWithStatusCode, *Usage) {
+	var textResponse TextResponse
+	if consumeQuota {
+		responseBody, err := io.ReadAll(resp.Body)
+		if err != nil {
+			return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
+		}
+		err = resp.Body.Close()
+		if err != nil {
+			return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
+		}
+		err = json.Unmarshal(responseBody, &textResponse)
+		if err != nil {
+			return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
+		}
+		if textResponse.Error.Type != "" {
+			return &OpenAIErrorWithStatusCode{
+				OpenAIError: textResponse.Error,
+				StatusCode:  resp.StatusCode,
+			}, nil
+		}
+		// Reset response body
+		resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
+	}
+	// We shouldn't set the header before we parse the response body, because the parse part may fail.
+	// And then we will have to send an error response, but in this case, the header has already been set.
+	// So the client will be confused by the response.
+	// For example, Postman will report error, and we cannot check the response at all.
+	for k, v := range resp.Header {
+		c.Writer.Header().Set(k, v[0])
+	}
+	c.Writer.WriteHeader(resp.StatusCode)
+	_, err := io.Copy(c.Writer, resp.Body)
+	if err != nil {
+		return errorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
+	}
+	err = resp.Body.Close()
+	if err != nil {
+		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
+	}
+	return nil, &textResponse.Usage
+}

+ 6 - 108
controller/relay-text.go

@@ -1,7 +1,6 @@
 package controller
 
 import (
-	"bufio"
 	"bytes"
 	"encoding/json"
 	"errors"
@@ -256,119 +255,18 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 	switch apiType {
 	case APITypeOpenAI:
 		if isStream {
-			scanner := bufio.NewScanner(resp.Body)
-			scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
-				if atEOF && len(data) == 0 {
-					return 0, nil, nil
-				}
-				if i := strings.Index(string(data), "\n"); i >= 0 {
-					return i + 1, data[0:i], nil
-				}
-				if atEOF {
-					return len(data), data, nil
-				}
-				return 0, nil, nil
-			})
-			dataChan := make(chan string)
-			stopChan := make(chan bool)
-			go func() {
-				for scanner.Scan() {
-					data := scanner.Text()
-					if len(data) < 6 { // ignore blank line or wrong format
-						continue
-					}
-					dataChan <- data
-					data = data[6:]
-					if !strings.HasPrefix(data, "[DONE]") {
-						switch relayMode {
-						case RelayModeChatCompletions:
-							var streamResponse ChatCompletionsStreamResponse
-							err = json.Unmarshal([]byte(data), &streamResponse)
-							if err != nil {
-								common.SysError("error unmarshalling stream response: " + err.Error())
-								return
-							}
-							for _, choice := range streamResponse.Choices {
-								streamResponseText += choice.Delta.Content
-							}
-						case RelayModeCompletions:
-							var streamResponse CompletionsStreamResponse
-							err = json.Unmarshal([]byte(data), &streamResponse)
-							if err != nil {
-								common.SysError("error unmarshalling stream response: " + err.Error())
-								return
-							}
-							for _, choice := range streamResponse.Choices {
-								streamResponseText += choice.Text
-							}
-						}
-					}
-				}
-				stopChan <- true
-			}()
-			c.Writer.Header().Set("Content-Type", "text/event-stream")
-			c.Writer.Header().Set("Cache-Control", "no-cache")
-			c.Writer.Header().Set("Connection", "keep-alive")
-			c.Writer.Header().Set("Transfer-Encoding", "chunked")
-			c.Writer.Header().Set("X-Accel-Buffering", "no")
-			c.Stream(func(w io.Writer) bool {
-				select {
-				case data := <-dataChan:
-					if strings.HasPrefix(data, "data: [DONE]") {
-						data = data[:12]
-					}
-					// some implementations may add \r at the end of data
-					data = strings.TrimSuffix(data, "\r")
-					c.Render(-1, common.CustomEvent{Data: data})
-					return true
-				case <-stopChan:
-					return false
-				}
-			})
-			err = resp.Body.Close()
+			err, responseText := openaiStreamHandler(c, resp, relayMode)
 			if err != nil {
-				return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
+				return err
 			}
+			streamResponseText = responseText
 			return nil
 		} else {
-			if consumeQuota {
-				responseBody, err := io.ReadAll(resp.Body)
-				if err != nil {
-					return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
-				}
-				err = resp.Body.Close()
-				if err != nil {
-					return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
-				}
-				err = json.Unmarshal(responseBody, &textResponse)
-				if err != nil {
-					return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
-				}
-				if textResponse.Error.Type != "" {
-					return &OpenAIErrorWithStatusCode{
-						OpenAIError: textResponse.Error,
-						StatusCode:  resp.StatusCode,
-					}
-				}
-				// Reset response body
-				resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
-			}
-			// We shouldn't set the header before we parse the response body, because the parse part may fail.
-			// And then we will have to send an error response, but in this case, the header has already been set.
-			// So the client will be confused by the response.
-			// For example, Postman will report error, and we cannot check the response at all.
-			for k, v := range resp.Header {
-				c.Writer.Header().Set(k, v[0])
-			}
-			c.Writer.WriteHeader(resp.StatusCode)
-			_, err = io.Copy(c.Writer, resp.Body)
+			err, usage := openaiHandler(c, resp, consumeQuota)
 			if err != nil {
-				return errorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError)
-			}
-			err = resp.Body.Close()
-			if err != nil {
-				return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
+				return err
 			}
+			textResponse.Usage = *usage
 			return nil
 		}
 	case APITypeClaude: