| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628 |
- package openai
- import (
- "bufio"
- "bytes"
- "encoding/json"
- "fmt"
- "io"
- "net/http"
- "one-api/common"
- "one-api/constant"
- "one-api/dto"
- relaycommon "one-api/relay/common"
- relayconstant "one-api/relay/constant"
- "one-api/service"
- "strings"
- "sync"
- "time"
- "github.com/bytedance/gopkg/util/gopool"
- "github.com/gin-gonic/gin"
- "github.com/gorilla/websocket"
- )
- func sendStreamData(c *gin.Context, data string, forceFormat bool) error {
- if data == "" {
- return nil
- }
- if forceFormat {
- var lastStreamResponse dto.ChatCompletionsStreamResponse
- if err := json.Unmarshal(common.StringToByteSlice(data), &lastStreamResponse); err != nil {
- return err
- }
- return service.ObjectData(c, lastStreamResponse)
- }
- return service.StringData(c, data)
- }
- func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
- if resp == nil || resp.Body == nil {
- common.LogError(c, "invalid response or response body")
- return service.OpenAIErrorWrapper(fmt.Errorf("invalid response"), "invalid_response", http.StatusInternalServerError), nil
- }
- containStreamUsage := false
- var responseId string
- var createAt int64 = 0
- var systemFingerprint string
- model := info.UpstreamModelName
- var responseTextBuilder strings.Builder
- var usage = &dto.Usage{}
- var streamItems []string // store stream items
- var forceFormat bool
- if info.ChannelType == common.ChannelTypeCustom {
- if forceFmt, ok := info.ChannelSetting["force_format"].(bool); ok {
- forceFormat = forceFmt
- }
- }
- toolCount := 0
- scanner := bufio.NewScanner(resp.Body)
- scanner.Split(bufio.ScanLines)
- service.SetEventStreamHeaders(c)
- ticker := time.NewTicker(time.Duration(constant.StreamingTimeout) * time.Second)
- defer ticker.Stop()
- stopChan := make(chan bool)
- defer close(stopChan)
- var (
- lastStreamData string
- mu sync.Mutex
- )
- gopool.Go(func() {
- for scanner.Scan() {
- info.SetFirstResponseTime()
- ticker.Reset(time.Duration(constant.StreamingTimeout) * time.Second)
- data := scanner.Text()
- if len(data) < 6 { // ignore blank line or wrong format
- continue
- }
- if data[:6] != "data: " && data[:6] != "[DONE]" {
- continue
- }
- mu.Lock()
- data = data[6:]
- if !strings.HasPrefix(data, "[DONE]") {
- if lastStreamData != "" {
- err := sendStreamData(c, lastStreamData, forceFormat)
- if err != nil {
- common.LogError(c, "streaming error: "+err.Error())
- }
- }
- lastStreamData = data
- streamItems = append(streamItems, data)
- }
- mu.Unlock()
- }
- common.SafeSendBool(stopChan, true)
- })
- select {
- case <-ticker.C:
- // 超时处理逻辑
- common.LogError(c, "streaming timeout")
- case <-stopChan:
- // 正常结束
- }
- shouldSendLastResp := true
- var lastStreamResponse dto.ChatCompletionsStreamResponse
- err := json.Unmarshal(common.StringToByteSlice(lastStreamData), &lastStreamResponse)
- if err == nil {
- responseId = lastStreamResponse.Id
- createAt = lastStreamResponse.Created
- systemFingerprint = lastStreamResponse.GetSystemFingerprint()
- model = lastStreamResponse.Model
- if service.ValidUsage(lastStreamResponse.Usage) {
- containStreamUsage = true
- usage = lastStreamResponse.Usage
- if !info.ShouldIncludeUsage {
- shouldSendLastResp = false
- }
- }
- for _, choice := range lastStreamResponse.Choices {
- if choice.FinishReason != nil {
- shouldSendLastResp = true
- }
- }
- }
- if shouldSendLastResp {
- sendStreamData(c, lastStreamData, forceFormat)
- }
- // 计算token
- streamResp := "[" + strings.Join(streamItems, ",") + "]"
- switch info.RelayMode {
- case relayconstant.RelayModeChatCompletions:
- var streamResponses []dto.ChatCompletionsStreamResponse
- err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses)
- if err != nil {
- // 一次性解析失败,逐个解析
- common.SysError("error unmarshalling stream response: " + err.Error())
- for _, item := range streamItems {
- var streamResponse dto.ChatCompletionsStreamResponse
- err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse)
- if err == nil {
- //if service.ValidUsage(streamResponse.Usage) {
- // usage = streamResponse.Usage
- //}
- for _, choice := range streamResponse.Choices {
- responseTextBuilder.WriteString(choice.Delta.GetContentString())
- if choice.Delta.ToolCalls != nil {
- if len(choice.Delta.ToolCalls) > toolCount {
- toolCount = len(choice.Delta.ToolCalls)
- }
- for _, tool := range choice.Delta.ToolCalls {
- responseTextBuilder.WriteString(tool.Function.Name)
- responseTextBuilder.WriteString(tool.Function.Arguments)
- }
- }
- }
- }
- }
- } else {
- for _, streamResponse := range streamResponses {
- //if service.ValidUsage(streamResponse.Usage) {
- // usage = streamResponse.Usage
- // containStreamUsage = true
- //}
- for _, choice := range streamResponse.Choices {
- responseTextBuilder.WriteString(choice.Delta.GetContentString())
- if choice.Delta.ToolCalls != nil {
- if len(choice.Delta.ToolCalls) > toolCount {
- toolCount = len(choice.Delta.ToolCalls)
- }
- for _, tool := range choice.Delta.ToolCalls {
- responseTextBuilder.WriteString(tool.Function.Name)
- responseTextBuilder.WriteString(tool.Function.Arguments)
- }
- }
- }
- }
- }
- case relayconstant.RelayModeCompletions:
- var streamResponses []dto.CompletionsStreamResponse
- err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses)
- if err != nil {
- // 一次性解析失败,逐个解析
- common.SysError("error unmarshalling stream response: " + err.Error())
- for _, item := range streamItems {
- var streamResponse dto.CompletionsStreamResponse
- err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse)
- if err == nil {
- for _, choice := range streamResponse.Choices {
- responseTextBuilder.WriteString(choice.Text)
- }
- }
- }
- } else {
- for _, streamResponse := range streamResponses {
- for _, choice := range streamResponse.Choices {
- responseTextBuilder.WriteString(choice.Text)
- }
- }
- }
- }
- if !containStreamUsage {
- usage, _ = service.ResponseText2Usage(responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens)
- usage.CompletionTokens += toolCount * 7
- }
- if info.ShouldIncludeUsage && !containStreamUsage {
- response := service.GenerateFinalUsageResponse(responseId, createAt, model, *usage)
- response.SetSystemFingerprint(systemFingerprint)
- service.ObjectData(c, response)
- }
- service.Done(c)
- resp.Body.Close()
- return nil, usage
- }
- func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
- var simpleResponse dto.SimpleResponse
- responseBody, err := io.ReadAll(resp.Body)
- if err != nil {
- return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
- }
- err = resp.Body.Close()
- if err != nil {
- return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
- }
- err = json.Unmarshal(responseBody, &simpleResponse)
- if err != nil {
- return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
- }
- if simpleResponse.Error.Type != "" {
- return &dto.OpenAIErrorWithStatusCode{
- Error: simpleResponse.Error,
- StatusCode: resp.StatusCode,
- }, nil
- }
- // Reset response body
- resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
- // We shouldn't set the header before we parse the response body, because the parse part may fail.
- // And then we will have to send an error response, but in this case, the header has already been set.
- // So the httpClient will be confused by the response.
- // For example, Postman will report error, and we cannot check the response at all.
- for k, v := range resp.Header {
- c.Writer.Header().Set(k, v[0])
- }
- c.Writer.WriteHeader(resp.StatusCode)
- _, err = io.Copy(c.Writer, resp.Body)
- if err != nil {
- return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
- }
- resp.Body.Close()
- if simpleResponse.Usage.TotalTokens == 0 || (simpleResponse.Usage.PromptTokens == 0 && simpleResponse.Usage.CompletionTokens == 0) {
- completionTokens := 0
- for _, choice := range simpleResponse.Choices {
- ctkm, _ := service.CountTextToken(string(choice.Message.Content), model)
- completionTokens += ctkm
- }
- simpleResponse.Usage = dto.Usage{
- PromptTokens: promptTokens,
- CompletionTokens: completionTokens,
- TotalTokens: promptTokens + completionTokens,
- }
- }
- return nil, &simpleResponse.Usage
- }
- func OpenaiTTSHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
- responseBody, err := io.ReadAll(resp.Body)
- if err != nil {
- return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
- }
- err = resp.Body.Close()
- if err != nil {
- return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
- }
- // Reset response body
- resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
- // We shouldn't set the header before we parse the response body, because the parse part may fail.
- // And then we will have to send an error response, but in this case, the header has already been set.
- // So the httpClient will be confused by the response.
- // For example, Postman will report error, and we cannot check the response at all.
- for k, v := range resp.Header {
- c.Writer.Header().Set(k, v[0])
- }
- c.Writer.WriteHeader(resp.StatusCode)
- _, err = io.Copy(c.Writer, resp.Body)
- if err != nil {
- return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
- }
- err = resp.Body.Close()
- if err != nil {
- return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
- }
- usage := &dto.Usage{}
- usage.PromptTokens = info.PromptTokens
- usage.TotalTokens = info.PromptTokens
- return nil, usage
- }
- func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, responseFormat string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
- responseBody, err := io.ReadAll(resp.Body)
- if err != nil {
- return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
- }
- err = resp.Body.Close()
- if err != nil {
- return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
- }
- // Reset response body
- resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
- // We shouldn't set the header before we parse the response body, because the parse part may fail.
- // And then we will have to send an error response, but in this case, the header has already been set.
- // So the httpClient will be confused by the response.
- // For example, Postman will report error, and we cannot check the response at all.
- for k, v := range resp.Header {
- c.Writer.Header().Set(k, v[0])
- }
- c.Writer.WriteHeader(resp.StatusCode)
- _, err = io.Copy(c.Writer, resp.Body)
- if err != nil {
- return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
- }
- resp.Body.Close()
- var text string
- switch responseFormat {
- case "json":
- text, err = getTextFromJSON(responseBody)
- case "text":
- text, err = getTextFromText(responseBody)
- case "srt":
- text, err = getTextFromSRT(responseBody)
- case "verbose_json":
- text, err = getTextFromVerboseJSON(responseBody)
- case "vtt":
- text, err = getTextFromVTT(responseBody)
- }
- usage := &dto.Usage{}
- usage.PromptTokens = info.PromptTokens
- usage.CompletionTokens, _ = service.CountTextToken(text, info.UpstreamModelName)
- usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
- return nil, usage
- }
- func getTextFromVTT(body []byte) (string, error) {
- return getTextFromSRT(body)
- }
- func getTextFromVerboseJSON(body []byte) (string, error) {
- var whisperResponse dto.WhisperVerboseJSONResponse
- if err := json.Unmarshal(body, &whisperResponse); err != nil {
- return "", fmt.Errorf("unmarshal_response_body_failed err :%w", err)
- }
- return whisperResponse.Text, nil
- }
- func getTextFromSRT(body []byte) (string, error) {
- scanner := bufio.NewScanner(strings.NewReader(string(body)))
- var builder strings.Builder
- var textLine bool
- for scanner.Scan() {
- line := scanner.Text()
- if textLine {
- builder.WriteString(line)
- textLine = false
- continue
- } else if strings.Contains(line, "-->") {
- textLine = true
- continue
- }
- }
- if err := scanner.Err(); err != nil {
- return "", err
- }
- return builder.String(), nil
- }
- func getTextFromText(body []byte) (string, error) {
- return strings.TrimSuffix(string(body), "\n"), nil
- }
- func getTextFromJSON(body []byte) (string, error) {
- var whisperResponse dto.AudioResponse
- if err := json.Unmarshal(body, &whisperResponse); err != nil {
- return "", fmt.Errorf("unmarshal_response_body_failed err :%w", err)
- }
- return whisperResponse.Text, nil
- }
- func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.RealtimeUsage) {
- if info == nil || info.ClientWs == nil || info.TargetWs == nil {
- return service.OpenAIErrorWrapper(fmt.Errorf("invalid websocket connection"), "invalid_connection", http.StatusBadRequest), nil
- }
- info.IsStream = true
- clientConn := info.ClientWs
- targetConn := info.TargetWs
- clientClosed := make(chan struct{})
- targetClosed := make(chan struct{})
- sendChan := make(chan []byte, 100)
- receiveChan := make(chan []byte, 100)
- errChan := make(chan error, 2)
- usage := &dto.RealtimeUsage{}
- localUsage := &dto.RealtimeUsage{}
- sumUsage := &dto.RealtimeUsage{}
- gopool.Go(func() {
- defer func() {
- if r := recover(); r != nil {
- errChan <- fmt.Errorf("panic in client reader: %v", r)
- }
- }()
- for {
- select {
- case <-c.Done():
- return
- default:
- _, message, err := clientConn.ReadMessage()
- if err != nil {
- if !websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) {
- errChan <- fmt.Errorf("error reading from client: %v", err)
- }
- close(clientClosed)
- return
- }
- realtimeEvent := &dto.RealtimeEvent{}
- err = json.Unmarshal(message, realtimeEvent)
- if err != nil {
- errChan <- fmt.Errorf("error unmarshalling message: %v", err)
- return
- }
- if realtimeEvent.Type == dto.RealtimeEventTypeSessionUpdate {
- if realtimeEvent.Session != nil {
- if realtimeEvent.Session.Tools != nil {
- info.RealtimeTools = realtimeEvent.Session.Tools
- }
- }
- }
- textToken, audioToken, err := service.CountTokenRealtime(info, *realtimeEvent, info.UpstreamModelName)
- if err != nil {
- errChan <- fmt.Errorf("error counting text token: %v", err)
- return
- }
- common.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken))
- localUsage.TotalTokens += textToken + audioToken
- localUsage.InputTokens += textToken + audioToken
- localUsage.InputTokenDetails.TextTokens += textToken
- localUsage.InputTokenDetails.AudioTokens += audioToken
- err = service.WssString(c, targetConn, string(message))
- if err != nil {
- errChan <- fmt.Errorf("error writing to target: %v", err)
- return
- }
- select {
- case sendChan <- message:
- default:
- }
- }
- }
- })
- gopool.Go(func() {
- defer func() {
- if r := recover(); r != nil {
- errChan <- fmt.Errorf("panic in target reader: %v", r)
- }
- }()
- for {
- select {
- case <-c.Done():
- return
- default:
- _, message, err := targetConn.ReadMessage()
- if err != nil {
- if !websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) {
- errChan <- fmt.Errorf("error reading from target: %v", err)
- }
- close(targetClosed)
- return
- }
- info.SetFirstResponseTime()
- realtimeEvent := &dto.RealtimeEvent{}
- err = json.Unmarshal(message, realtimeEvent)
- if err != nil {
- errChan <- fmt.Errorf("error unmarshalling message: %v", err)
- return
- }
- if realtimeEvent.Type == dto.RealtimeEventTypeResponseDone {
- realtimeUsage := realtimeEvent.Response.Usage
- if realtimeUsage != nil {
- usage.TotalTokens += realtimeUsage.TotalTokens
- usage.InputTokens += realtimeUsage.InputTokens
- usage.OutputTokens += realtimeUsage.OutputTokens
- usage.InputTokenDetails.AudioTokens += realtimeUsage.InputTokenDetails.AudioTokens
- usage.InputTokenDetails.CachedTokens += realtimeUsage.InputTokenDetails.CachedTokens
- usage.InputTokenDetails.TextTokens += realtimeUsage.InputTokenDetails.TextTokens
- usage.OutputTokenDetails.AudioTokens += realtimeUsage.OutputTokenDetails.AudioTokens
- usage.OutputTokenDetails.TextTokens += realtimeUsage.OutputTokenDetails.TextTokens
- err := preConsumeUsage(c, info, usage, sumUsage)
- if err != nil {
- errChan <- fmt.Errorf("error consume usage: %v", err)
- return
- }
- // 本次计费完成,清除
- usage = &dto.RealtimeUsage{}
- localUsage = &dto.RealtimeUsage{}
- } else {
- textToken, audioToken, err := service.CountTokenRealtime(info, *realtimeEvent, info.UpstreamModelName)
- if err != nil {
- errChan <- fmt.Errorf("error counting text token: %v", err)
- return
- }
- common.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken))
- localUsage.TotalTokens += textToken + audioToken
- info.IsFirstRequest = false
- localUsage.InputTokens += textToken + audioToken
- localUsage.InputTokenDetails.TextTokens += textToken
- localUsage.InputTokenDetails.AudioTokens += audioToken
- err = preConsumeUsage(c, info, localUsage, sumUsage)
- if err != nil {
- errChan <- fmt.Errorf("error consume usage: %v", err)
- return
- }
- // 本次计费完成,清除
- localUsage = &dto.RealtimeUsage{}
- // print now usage
- }
- //common.LogInfo(c, fmt.Sprintf("realtime streaming sumUsage: %v", sumUsage))
- //common.LogInfo(c, fmt.Sprintf("realtime streaming localUsage: %v", localUsage))
- //common.LogInfo(c, fmt.Sprintf("realtime streaming localUsage: %v", localUsage))
- } else if realtimeEvent.Type == dto.RealtimeEventTypeSessionUpdated || realtimeEvent.Type == dto.RealtimeEventTypeSessionCreated {
- realtimeSession := realtimeEvent.Session
- if realtimeSession != nil {
- // update audio format
- info.InputAudioFormat = common.GetStringIfEmpty(realtimeSession.InputAudioFormat, info.InputAudioFormat)
- info.OutputAudioFormat = common.GetStringIfEmpty(realtimeSession.OutputAudioFormat, info.OutputAudioFormat)
- }
- } else {
- textToken, audioToken, err := service.CountTokenRealtime(info, *realtimeEvent, info.UpstreamModelName)
- if err != nil {
- errChan <- fmt.Errorf("error counting text token: %v", err)
- return
- }
- common.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken))
- localUsage.TotalTokens += textToken + audioToken
- localUsage.OutputTokens += textToken + audioToken
- localUsage.OutputTokenDetails.TextTokens += textToken
- localUsage.OutputTokenDetails.AudioTokens += audioToken
- }
- err = service.WssString(c, clientConn, string(message))
- if err != nil {
- errChan <- fmt.Errorf("error writing to client: %v", err)
- return
- }
- select {
- case receiveChan <- message:
- default:
- }
- }
- }
- })
- select {
- case <-clientClosed:
- case <-targetClosed:
- case err := <-errChan:
- //return service.OpenAIErrorWrapper(err, "realtime_error", http.StatusInternalServerError), nil
- common.LogError(c, "realtime error: "+err.Error())
- case <-c.Done():
- }
- if usage.TotalTokens != 0 {
- _ = preConsumeUsage(c, info, usage, sumUsage)
- }
- if localUsage.TotalTokens != 0 {
- _ = preConsumeUsage(c, info, localUsage, sumUsage)
- }
- // check usage total tokens, if 0, use local usage
- return nil, sumUsage
- }
- func preConsumeUsage(ctx *gin.Context, info *relaycommon.RelayInfo, usage *dto.RealtimeUsage, totalUsage *dto.RealtimeUsage) error {
- if usage == nil || totalUsage == nil {
- return fmt.Errorf("invalid usage pointer")
- }
- totalUsage.TotalTokens += usage.TotalTokens
- totalUsage.InputTokens += usage.InputTokens
- totalUsage.OutputTokens += usage.OutputTokens
- totalUsage.InputTokenDetails.CachedTokens += usage.InputTokenDetails.CachedTokens
- totalUsage.InputTokenDetails.TextTokens += usage.InputTokenDetails.TextTokens
- totalUsage.InputTokenDetails.AudioTokens += usage.InputTokenDetails.AudioTokens
- totalUsage.OutputTokenDetails.TextTokens += usage.OutputTokenDetails.TextTokens
- totalUsage.OutputTokenDetails.AudioTokens += usage.OutputTokenDetails.AudioTokens
- // clear usage
- err := service.PreWssConsumeQuota(ctx, info, usage)
- return err
- }
|