|
|
@@ -9,11 +9,10 @@ import (
|
|
|
"net/http"
|
|
|
"one-api/common"
|
|
|
"strings"
|
|
|
- "sync"
|
|
|
)
|
|
|
|
|
|
func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*OpenAIErrorWithStatusCode, string) {
|
|
|
- var responseTextBuilder strings.Builder
|
|
|
+ responseText := ""
|
|
|
scanner := bufio.NewScanner(resp.Body)
|
|
|
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
|
|
if atEOF && len(data) == 0 {
|
|
|
@@ -29,10 +28,7 @@ func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*O
|
|
|
})
|
|
|
dataChan := make(chan string)
|
|
|
stopChan := make(chan bool)
|
|
|
- var wg sync.WaitGroup
|
|
|
go func() {
|
|
|
- wg.Add(1)
|
|
|
- var streamItems []string
|
|
|
for scanner.Scan() {
|
|
|
data := scanner.Text()
|
|
|
if len(data) < 6 { // ignore blank line or wrong format
|
|
|
@@ -44,39 +40,30 @@ func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*O
|
|
|
dataChan <- data
|
|
|
data = data[6:]
|
|
|
if !strings.HasPrefix(data, "[DONE]") {
|
|
|
- streamItems = append(streamItems, data)
|
|
|
- }
|
|
|
- }
|
|
|
- streamResp := "[" + strings.Join(streamItems, ",") + "]"
|
|
|
- switch relayMode {
|
|
|
- case RelayModeChatCompletions:
|
|
|
- var streamResponses []ChatCompletionsStreamResponseSimple
|
|
|
- err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses)
|
|
|
- if err != nil {
|
|
|
- common.SysError("error unmarshalling stream response: " + err.Error())
|
|
|
- wg.Done()
|
|
|
- return // just ignore the error
|
|
|
- }
|
|
|
- for _, streamResponse := range streamResponses {
|
|
|
- for _, choice := range streamResponse.Choices {
|
|
|
- responseTextBuilder.WriteString(choice.Delta.Content)
|
|
|
- }
|
|
|
- }
|
|
|
- case RelayModeCompletions:
|
|
|
- var streamResponses []CompletionsStreamResponse
|
|
|
- err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses)
|
|
|
- if err != nil {
|
|
|
- common.SysError("error unmarshalling stream response: " + err.Error())
|
|
|
- wg.Done()
|
|
|
- return // just ignore the error
|
|
|
- }
|
|
|
- for _, streamResponse := range streamResponses {
|
|
|
- for _, choice := range streamResponse.Choices {
|
|
|
- responseTextBuilder.WriteString(choice.Text)
|
|
|
+ switch relayMode {
|
|
|
+ case RelayModeChatCompletions:
|
|
|
+ var streamResponse ChatCompletionsStreamResponseSimple
|
|
|
+ err := json.Unmarshal(common.StringToByteSlice(data), &streamResponse)
|
|
|
+ if err != nil {
|
|
|
+ common.SysError("error unmarshalling stream response: " + err.Error())
|
|
|
+ continue // just ignore the error
|
|
|
+ }
|
|
|
+ for _, choice := range streamResponse.Choices {
|
|
|
+ responseText += choice.Delta.Content
|
|
|
+ }
|
|
|
+ case RelayModeCompletions:
|
|
|
+ var streamResponse CompletionsStreamResponse
|
|
|
+ err := json.Unmarshal(common.StringToByteSlice(data), &streamResponse)
|
|
|
+ if err != nil {
|
|
|
+ common.SysError("error unmarshalling stream response: " + err.Error())
|
|
|
+ continue
|
|
|
+ }
|
|
|
+ for _, choice := range streamResponse.Choices {
|
|
|
+ responseText += choice.Text
|
|
|
+ }
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
- wg.Done()
|
|
|
stopChan <- true
|
|
|
}()
|
|
|
setEventStreamHeaders(c)
|
|
|
@@ -98,8 +85,7 @@ func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*O
|
|
|
if err != nil {
|
|
|
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
|
|
|
}
|
|
|
- wg.Wait()
|
|
|
- return nil, responseTextBuilder.String()
|
|
|
+ return nil, responseText
|
|
|
}
|
|
|
|
|
|
func openaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) {
|