|
|
@@ -7,6 +7,7 @@ import (
|
|
|
"fmt"
|
|
|
"github.com/bytedance/gopkg/util/gopool"
|
|
|
"github.com/gin-gonic/gin"
|
|
|
+ "github.com/gorilla/websocket"
|
|
|
"io"
|
|
|
"net/http"
|
|
|
"one-api/common"
|
|
|
@@ -231,7 +232,7 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model
|
|
|
if simpleResponse.Usage.TotalTokens == 0 || (simpleResponse.Usage.PromptTokens == 0 && simpleResponse.Usage.CompletionTokens == 0) {
|
|
|
completionTokens := 0
|
|
|
for _, choice := range simpleResponse.Choices {
|
|
|
- ctkm, _ := service.CountTokenText(string(choice.Message.Content), model)
|
|
|
+ ctkm, _ := service.CountTextToken(string(choice.Message.Content), model)
|
|
|
completionTokens += ctkm
|
|
|
}
|
|
|
simpleResponse.Usage = dto.Usage{
|
|
|
@@ -324,7 +325,7 @@ func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
|
|
|
|
|
|
usage := &dto.Usage{}
|
|
|
usage.PromptTokens = info.PromptTokens
|
|
|
- usage.CompletionTokens, _ = service.CountTokenText(text, info.UpstreamModelName)
|
|
|
+ usage.CompletionTokens, _ = service.CountTextToken(text, info.UpstreamModelName)
|
|
|
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
|
|
return nil, usage
|
|
|
}
|
|
|
@@ -373,3 +374,210 @@ func getTextFromJSON(body []byte) (string, error) {
|
|
|
}
|
|
|
return whisperResponse.Text, nil
|
|
|
}
|
|
|
+
|
|
|
+func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.RealtimeUsage) {
|
|
|
+ info.IsStream = true
|
|
|
+ clientConn := info.ClientWs
|
|
|
+ targetConn := info.TargetWs
|
|
|
+
|
|
|
+ clientClosed := make(chan struct{})
|
|
|
+ targetClosed := make(chan struct{})
|
|
|
+ sendChan := make(chan []byte, 100)
|
|
|
+ receiveChan := make(chan []byte, 100)
|
|
|
+ errChan := make(chan error, 2)
|
|
|
+
|
|
|
+ usage := &dto.RealtimeUsage{}
|
|
|
+ localUsage := &dto.RealtimeUsage{}
|
|
|
+ sumUsage := &dto.RealtimeUsage{}
|
|
|
+
|
|
|
+ gopool.Go(func() {
|
|
|
+ for {
|
|
|
+ select {
|
|
|
+ case <-c.Done():
|
|
|
+ return
|
|
|
+ default:
|
|
|
+ _, message, err := clientConn.ReadMessage()
|
|
|
+ if err != nil {
|
|
|
+ if !websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) {
|
|
|
+ errChan <- fmt.Errorf("error reading from client: %v", err)
|
|
|
+ }
|
|
|
+ close(clientClosed)
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ realtimeEvent := &dto.RealtimeEvent{}
|
|
|
+ err = json.Unmarshal(message, realtimeEvent)
|
|
|
+ if err != nil {
|
|
|
+ errChan <- fmt.Errorf("error unmarshalling message: %v", err)
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ if realtimeEvent.Type == dto.RealtimeEventTypeSessionUpdate {
|
|
|
+ if realtimeEvent.Session != nil {
|
|
|
+ if realtimeEvent.Session.Tools != nil {
|
|
|
+ info.RealtimeTools = realtimeEvent.Session.Tools
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ textToken, audioToken, err := service.CountTokenRealtime(info, *realtimeEvent, info.UpstreamModelName)
|
|
|
+ if err != nil {
|
|
|
+ errChan <- fmt.Errorf("error counting text token: %v", err)
|
|
|
+ return
|
|
|
+ }
|
|
|
+ common.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken))
|
|
|
+ localUsage.TotalTokens += textToken + audioToken
|
|
|
+ localUsage.InputTokens += textToken + audioToken
|
|
|
+ localUsage.InputTokenDetails.TextTokens += textToken
|
|
|
+ localUsage.InputTokenDetails.AudioTokens += audioToken
|
|
|
+
|
|
|
+ err = service.WssString(c, targetConn, string(message))
|
|
|
+ if err != nil {
|
|
|
+ errChan <- fmt.Errorf("error writing to target: %v", err)
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ select {
|
|
|
+ case sendChan <- message:
|
|
|
+ default:
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ })
|
|
|
+
|
|
|
+ gopool.Go(func() {
|
|
|
+ for {
|
|
|
+ select {
|
|
|
+ case <-c.Done():
|
|
|
+ return
|
|
|
+ default:
|
|
|
+ _, message, err := targetConn.ReadMessage()
|
|
|
+ if err != nil {
|
|
|
+ if !websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) {
|
|
|
+ errChan <- fmt.Errorf("error reading from target: %v", err)
|
|
|
+ }
|
|
|
+ close(targetClosed)
|
|
|
+ return
|
|
|
+ }
|
|
|
+ info.SetFirstResponseTime()
|
|
|
+ realtimeEvent := &dto.RealtimeEvent{}
|
|
|
+ err = json.Unmarshal(message, realtimeEvent)
|
|
|
+ if err != nil {
|
|
|
+ errChan <- fmt.Errorf("error unmarshalling message: %v", err)
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ if realtimeEvent.Type == dto.RealtimeEventTypeResponseDone {
|
|
|
+ realtimeUsage := realtimeEvent.Response.Usage
|
|
|
+ if realtimeUsage != nil {
|
|
|
+ usage.TotalTokens += realtimeUsage.TotalTokens
|
|
|
+ usage.InputTokens += realtimeUsage.InputTokens
|
|
|
+ usage.OutputTokens += realtimeUsage.OutputTokens
|
|
|
+ usage.InputTokenDetails.AudioTokens += realtimeUsage.InputTokenDetails.AudioTokens
|
|
|
+ usage.InputTokenDetails.CachedTokens += realtimeUsage.InputTokenDetails.CachedTokens
|
|
|
+ usage.InputTokenDetails.TextTokens += realtimeUsage.InputTokenDetails.TextTokens
|
|
|
+ usage.OutputTokenDetails.AudioTokens += realtimeUsage.OutputTokenDetails.AudioTokens
|
|
|
+ usage.OutputTokenDetails.TextTokens += realtimeUsage.OutputTokenDetails.TextTokens
|
|
|
+ err := preConsumeUsage(c, info, usage, sumUsage)
|
|
|
+ if err != nil {
|
|
|
+ errChan <- fmt.Errorf("error consume usage: %v", err)
|
|
|
+ return
|
|
|
+ }
|
|
|
+ // 本次计费完成,清除
|
|
|
+ usage = &dto.RealtimeUsage{}
|
|
|
+
|
|
|
+ localUsage = &dto.RealtimeUsage{}
|
|
|
+ } else {
|
|
|
+ textToken, audioToken, err := service.CountTokenRealtime(info, *realtimeEvent, info.UpstreamModelName)
|
|
|
+ if err != nil {
|
|
|
+ errChan <- fmt.Errorf("error counting text token: %v", err)
|
|
|
+ return
|
|
|
+ }
|
|
|
+ common.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken))
|
|
|
+ localUsage.TotalTokens += textToken + audioToken
|
|
|
+ info.IsFirstRequest = false
|
|
|
+ localUsage.InputTokens += textToken + audioToken
|
|
|
+ localUsage.InputTokenDetails.TextTokens += textToken
|
|
|
+ localUsage.InputTokenDetails.AudioTokens += audioToken
|
|
|
+ err = preConsumeUsage(c, info, localUsage, sumUsage)
|
|
|
+ if err != nil {
|
|
|
+ errChan <- fmt.Errorf("error consume usage: %v", err)
|
|
|
+ return
|
|
|
+ }
|
|
|
+ // 本次计费完成,清除
|
|
|
+ localUsage = &dto.RealtimeUsage{}
|
|
|
+ // print now usage
|
|
|
+ }
|
|
|
+ //common.LogInfo(c, fmt.Sprintf("realtime streaming sumUsage: %v", sumUsage))
|
|
|
+ //common.LogInfo(c, fmt.Sprintf("realtime streaming localUsage: %v", localUsage))
|
|
|
+ //common.LogInfo(c, fmt.Sprintf("realtime streaming localUsage: %v", localUsage))
|
|
|
+
|
|
|
+ } else if realtimeEvent.Type == dto.RealtimeEventTypeSessionUpdated || realtimeEvent.Type == dto.RealtimeEventTypeSessionCreated {
|
|
|
+ realtimeSession := realtimeEvent.Session
|
|
|
+ if realtimeSession != nil {
|
|
|
+ // update audio format
|
|
|
+ info.InputAudioFormat = common.GetStringIfEmpty(realtimeSession.InputAudioFormat, info.InputAudioFormat)
|
|
|
+ info.OutputAudioFormat = common.GetStringIfEmpty(realtimeSession.OutputAudioFormat, info.OutputAudioFormat)
|
|
|
+ }
|
|
|
+ } else {
|
|
|
+ textToken, audioToken, err := service.CountTokenRealtime(info, *realtimeEvent, info.UpstreamModelName)
|
|
|
+ if err != nil {
|
|
|
+ errChan <- fmt.Errorf("error counting text token: %v", err)
|
|
|
+ return
|
|
|
+ }
|
|
|
+ common.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken))
|
|
|
+ localUsage.TotalTokens += textToken + audioToken
|
|
|
+ localUsage.OutputTokens += textToken + audioToken
|
|
|
+ localUsage.OutputTokenDetails.TextTokens += textToken
|
|
|
+ localUsage.OutputTokenDetails.AudioTokens += audioToken
|
|
|
+ }
|
|
|
+
|
|
|
+ err = service.WssString(c, clientConn, string(message))
|
|
|
+ if err != nil {
|
|
|
+ errChan <- fmt.Errorf("error writing to client: %v", err)
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ select {
|
|
|
+ case receiveChan <- message:
|
|
|
+ default:
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ })
|
|
|
+
|
|
|
+ select {
|
|
|
+ case <-clientClosed:
|
|
|
+ case <-targetClosed:
|
|
|
+ case err := <-errChan:
|
|
|
+ //return service.OpenAIErrorWrapper(err, "realtime_error", http.StatusInternalServerError), nil
|
|
|
+ common.LogError(c, "realtime error: "+err.Error())
|
|
|
+ case <-c.Done():
|
|
|
+ }
|
|
|
+
|
|
|
+ if usage.TotalTokens != 0 {
|
|
|
+ _ = preConsumeUsage(c, info, usage, sumUsage)
|
|
|
+ }
|
|
|
+
|
|
|
+ if localUsage.TotalTokens != 0 {
|
|
|
+ _ = preConsumeUsage(c, info, localUsage, sumUsage)
|
|
|
+ }
|
|
|
+
|
|
|
+ // check usage total tokens, if 0, use local usage
|
|
|
+
|
|
|
+ return nil, sumUsage
|
|
|
+}
|
|
|
+
|
|
|
+func preConsumeUsage(ctx *gin.Context, info *relaycommon.RelayInfo, usage *dto.RealtimeUsage, totalUsage *dto.RealtimeUsage) error {
|
|
|
+ totalUsage.TotalTokens += usage.TotalTokens
|
|
|
+ totalUsage.InputTokens += usage.InputTokens
|
|
|
+ totalUsage.OutputTokens += usage.OutputTokens
|
|
|
+ totalUsage.InputTokenDetails.CachedTokens += usage.InputTokenDetails.CachedTokens
|
|
|
+ totalUsage.InputTokenDetails.TextTokens += usage.InputTokenDetails.TextTokens
|
|
|
+ totalUsage.InputTokenDetails.AudioTokens += usage.InputTokenDetails.AudioTokens
|
|
|
+ totalUsage.OutputTokenDetails.TextTokens += usage.OutputTokenDetails.TextTokens
|
|
|
+ totalUsage.OutputTokenDetails.AudioTokens += usage.OutputTokenDetails.AudioTokens
|
|
|
+ // clear usage
|
|
|
+ err := service.PreWssConsumeQuota(ctx, info, usage)
|
|
|
+ return err
|
|
|
+}
|