relay-openai.go 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583
  1. package openai
  2. import (
  3. "bufio"
  4. "bytes"
  5. "encoding/json"
  6. "fmt"
  7. "github.com/bytedance/gopkg/util/gopool"
  8. "github.com/gin-gonic/gin"
  9. "github.com/gorilla/websocket"
  10. "io"
  11. "log"
  12. "net/http"
  13. "one-api/common"
  14. "one-api/constant"
  15. "one-api/dto"
  16. relaycommon "one-api/relay/common"
  17. relayconstant "one-api/relay/constant"
  18. "one-api/service"
  19. "strings"
  20. "sync"
  21. "time"
  22. )
  23. func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
  24. containStreamUsage := false
  25. var responseId string
  26. var createAt int64 = 0
  27. var systemFingerprint string
  28. model := info.UpstreamModelName
  29. var responseTextBuilder strings.Builder
  30. var usage = &dto.Usage{}
  31. var streamItems []string // store stream items
  32. toolCount := 0
  33. scanner := bufio.NewScanner(resp.Body)
  34. scanner.Split(bufio.ScanLines)
  35. service.SetEventStreamHeaders(c)
  36. ticker := time.NewTicker(time.Duration(constant.StreamingTimeout) * time.Second)
  37. defer ticker.Stop()
  38. stopChan := make(chan bool)
  39. defer close(stopChan)
  40. var (
  41. lastStreamData string
  42. mu sync.Mutex
  43. )
  44. gopool.Go(func() {
  45. for scanner.Scan() {
  46. info.SetFirstResponseTime()
  47. ticker.Reset(time.Duration(constant.StreamingTimeout) * time.Second)
  48. data := scanner.Text()
  49. if len(data) < 6 { // ignore blank line or wrong format
  50. continue
  51. }
  52. if data[:6] != "data: " && data[:6] != "[DONE]" {
  53. continue
  54. }
  55. mu.Lock()
  56. data = data[6:]
  57. if !strings.HasPrefix(data, "[DONE]") {
  58. if lastStreamData != "" {
  59. err := service.StringData(c, lastStreamData)
  60. if err != nil {
  61. common.LogError(c, "streaming error: "+err.Error())
  62. }
  63. }
  64. lastStreamData = data
  65. streamItems = append(streamItems, data)
  66. }
  67. mu.Unlock()
  68. }
  69. common.SafeSendBool(stopChan, true)
  70. })
  71. select {
  72. case <-ticker.C:
  73. // 超时处理逻辑
  74. common.LogError(c, "streaming timeout")
  75. case <-stopChan:
  76. // 正常结束
  77. }
  78. shouldSendLastResp := true
  79. var lastStreamResponse dto.ChatCompletionsStreamResponse
  80. err := json.Unmarshal(common.StringToByteSlice(lastStreamData), &lastStreamResponse)
  81. if err == nil {
  82. responseId = lastStreamResponse.Id
  83. createAt = lastStreamResponse.Created
  84. systemFingerprint = lastStreamResponse.GetSystemFingerprint()
  85. model = lastStreamResponse.Model
  86. if service.ValidUsage(lastStreamResponse.Usage) {
  87. containStreamUsage = true
  88. usage = lastStreamResponse.Usage
  89. if !info.ShouldIncludeUsage {
  90. shouldSendLastResp = false
  91. }
  92. }
  93. }
  94. if shouldSendLastResp {
  95. service.StringData(c, lastStreamData)
  96. }
  97. // 计算token
  98. streamResp := "[" + strings.Join(streamItems, ",") + "]"
  99. switch info.RelayMode {
  100. case relayconstant.RelayModeChatCompletions:
  101. var streamResponses []dto.ChatCompletionsStreamResponse
  102. err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses)
  103. if err != nil {
  104. // 一次性解析失败,逐个解析
  105. common.SysError("error unmarshalling stream response: " + err.Error())
  106. for _, item := range streamItems {
  107. var streamResponse dto.ChatCompletionsStreamResponse
  108. err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse)
  109. if err == nil {
  110. //if service.ValidUsage(streamResponse.Usage) {
  111. // usage = streamResponse.Usage
  112. //}
  113. for _, choice := range streamResponse.Choices {
  114. responseTextBuilder.WriteString(choice.Delta.GetContentString())
  115. if choice.Delta.ToolCalls != nil {
  116. if len(choice.Delta.ToolCalls) > toolCount {
  117. toolCount = len(choice.Delta.ToolCalls)
  118. }
  119. for _, tool := range choice.Delta.ToolCalls {
  120. responseTextBuilder.WriteString(tool.Function.Name)
  121. responseTextBuilder.WriteString(tool.Function.Arguments)
  122. }
  123. }
  124. }
  125. }
  126. }
  127. } else {
  128. for _, streamResponse := range streamResponses {
  129. //if service.ValidUsage(streamResponse.Usage) {
  130. // usage = streamResponse.Usage
  131. // containStreamUsage = true
  132. //}
  133. for _, choice := range streamResponse.Choices {
  134. responseTextBuilder.WriteString(choice.Delta.GetContentString())
  135. if choice.Delta.ToolCalls != nil {
  136. if len(choice.Delta.ToolCalls) > toolCount {
  137. toolCount = len(choice.Delta.ToolCalls)
  138. }
  139. for _, tool := range choice.Delta.ToolCalls {
  140. responseTextBuilder.WriteString(tool.Function.Name)
  141. responseTextBuilder.WriteString(tool.Function.Arguments)
  142. }
  143. }
  144. }
  145. }
  146. }
  147. case relayconstant.RelayModeCompletions:
  148. var streamResponses []dto.CompletionsStreamResponse
  149. err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses)
  150. if err != nil {
  151. // 一次性解析失败,逐个解析
  152. common.SysError("error unmarshalling stream response: " + err.Error())
  153. for _, item := range streamItems {
  154. var streamResponse dto.CompletionsStreamResponse
  155. err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse)
  156. if err == nil {
  157. for _, choice := range streamResponse.Choices {
  158. responseTextBuilder.WriteString(choice.Text)
  159. }
  160. }
  161. }
  162. } else {
  163. for _, streamResponse := range streamResponses {
  164. for _, choice := range streamResponse.Choices {
  165. responseTextBuilder.WriteString(choice.Text)
  166. }
  167. }
  168. }
  169. }
  170. if !containStreamUsage {
  171. usage, _ = service.ResponseText2Usage(responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens)
  172. usage.CompletionTokens += toolCount * 7
  173. }
  174. if info.ShouldIncludeUsage && !containStreamUsage {
  175. response := service.GenerateFinalUsageResponse(responseId, createAt, model, *usage)
  176. response.SetSystemFingerprint(systemFingerprint)
  177. service.ObjectData(c, response)
  178. }
  179. service.Done(c)
  180. resp.Body.Close()
  181. return nil, usage
  182. }
  183. func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
  184. var simpleResponse dto.SimpleResponse
  185. responseBody, err := io.ReadAll(resp.Body)
  186. if err != nil {
  187. return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
  188. }
  189. err = resp.Body.Close()
  190. if err != nil {
  191. return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
  192. }
  193. err = json.Unmarshal(responseBody, &simpleResponse)
  194. if err != nil {
  195. return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
  196. }
  197. if simpleResponse.Error.Type != "" {
  198. return &dto.OpenAIErrorWithStatusCode{
  199. Error: simpleResponse.Error,
  200. StatusCode: resp.StatusCode,
  201. }, nil
  202. }
  203. // Reset response body
  204. resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
  205. // We shouldn't set the header before we parse the response body, because the parse part may fail.
  206. // And then we will have to send an error response, but in this case, the header has already been set.
  207. // So the httpClient will be confused by the response.
  208. // For example, Postman will report error, and we cannot check the response at all.
  209. for k, v := range resp.Header {
  210. c.Writer.Header().Set(k, v[0])
  211. }
  212. c.Writer.WriteHeader(resp.StatusCode)
  213. _, err = io.Copy(c.Writer, resp.Body)
  214. if err != nil {
  215. return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
  216. }
  217. resp.Body.Close()
  218. if simpleResponse.Usage.TotalTokens == 0 || (simpleResponse.Usage.PromptTokens == 0 && simpleResponse.Usage.CompletionTokens == 0) {
  219. completionTokens := 0
  220. for _, choice := range simpleResponse.Choices {
  221. ctkm, _ := service.CountTextToken(string(choice.Message.Content), model)
  222. completionTokens += ctkm
  223. }
  224. simpleResponse.Usage = dto.Usage{
  225. PromptTokens: promptTokens,
  226. CompletionTokens: completionTokens,
  227. TotalTokens: promptTokens + completionTokens,
  228. }
  229. }
  230. return nil, &simpleResponse.Usage
  231. }
  232. func OpenaiTTSHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
  233. responseBody, err := io.ReadAll(resp.Body)
  234. if err != nil {
  235. return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
  236. }
  237. err = resp.Body.Close()
  238. if err != nil {
  239. return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
  240. }
  241. // Reset response body
  242. resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
  243. // We shouldn't set the header before we parse the response body, because the parse part may fail.
  244. // And then we will have to send an error response, but in this case, the header has already been set.
  245. // So the httpClient will be confused by the response.
  246. // For example, Postman will report error, and we cannot check the response at all.
  247. for k, v := range resp.Header {
  248. c.Writer.Header().Set(k, v[0])
  249. }
  250. c.Writer.WriteHeader(resp.StatusCode)
  251. _, err = io.Copy(c.Writer, resp.Body)
  252. if err != nil {
  253. return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
  254. }
  255. err = resp.Body.Close()
  256. if err != nil {
  257. return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
  258. }
  259. usage := &dto.Usage{}
  260. usage.PromptTokens = info.PromptTokens
  261. usage.TotalTokens = info.PromptTokens
  262. return nil, usage
  263. }
  264. func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, responseFormat string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
  265. var audioResp dto.AudioResponse
  266. responseBody, err := io.ReadAll(resp.Body)
  267. if err != nil {
  268. return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
  269. }
  270. err = resp.Body.Close()
  271. if err != nil {
  272. return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
  273. }
  274. err = json.Unmarshal(responseBody, &audioResp)
  275. if err != nil {
  276. return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
  277. }
  278. // Reset response body
  279. resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
  280. // We shouldn't set the header before we parse the response body, because the parse part may fail.
  281. // And then we will have to send an error response, but in this case, the header has already been set.
  282. // So the httpClient will be confused by the response.
  283. // For example, Postman will report error, and we cannot check the response at all.
  284. for k, v := range resp.Header {
  285. c.Writer.Header().Set(k, v[0])
  286. }
  287. c.Writer.WriteHeader(resp.StatusCode)
  288. _, err = io.Copy(c.Writer, resp.Body)
  289. if err != nil {
  290. return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
  291. }
  292. resp.Body.Close()
  293. var text string
  294. switch responseFormat {
  295. case "json":
  296. text, err = getTextFromJSON(responseBody)
  297. case "text":
  298. text, err = getTextFromText(responseBody)
  299. case "srt":
  300. text, err = getTextFromSRT(responseBody)
  301. case "verbose_json":
  302. text, err = getTextFromVerboseJSON(responseBody)
  303. case "vtt":
  304. text, err = getTextFromVTT(responseBody)
  305. }
  306. usage := &dto.Usage{}
  307. usage.PromptTokens = info.PromptTokens
  308. usage.CompletionTokens, _ = service.CountTextToken(text, info.UpstreamModelName)
  309. usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
  310. return nil, usage
  311. }
  312. func getTextFromVTT(body []byte) (string, error) {
  313. return getTextFromSRT(body)
  314. }
  315. func getTextFromVerboseJSON(body []byte) (string, error) {
  316. var whisperResponse dto.WhisperVerboseJSONResponse
  317. if err := json.Unmarshal(body, &whisperResponse); err != nil {
  318. return "", fmt.Errorf("unmarshal_response_body_failed err :%w", err)
  319. }
  320. return whisperResponse.Text, nil
  321. }
  322. func getTextFromSRT(body []byte) (string, error) {
  323. scanner := bufio.NewScanner(strings.NewReader(string(body)))
  324. var builder strings.Builder
  325. var textLine bool
  326. for scanner.Scan() {
  327. line := scanner.Text()
  328. if textLine {
  329. builder.WriteString(line)
  330. textLine = false
  331. continue
  332. } else if strings.Contains(line, "-->") {
  333. textLine = true
  334. continue
  335. }
  336. }
  337. if err := scanner.Err(); err != nil {
  338. return "", err
  339. }
  340. return builder.String(), nil
  341. }
  342. func getTextFromText(body []byte) (string, error) {
  343. return strings.TrimSuffix(string(body), "\n"), nil
  344. }
  345. func getTextFromJSON(body []byte) (string, error) {
  346. var whisperResponse dto.AudioResponse
  347. if err := json.Unmarshal(body, &whisperResponse); err != nil {
  348. return "", fmt.Errorf("unmarshal_response_body_failed err :%w", err)
  349. }
  350. return whisperResponse.Text, nil
  351. }
  352. func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.RealtimeUsage) {
  353. info.IsStream = true
  354. clientConn := info.ClientWs
  355. targetConn := info.TargetWs
  356. clientClosed := make(chan struct{})
  357. targetClosed := make(chan struct{})
  358. sendChan := make(chan []byte, 100)
  359. receiveChan := make(chan []byte, 100)
  360. errChan := make(chan error, 2)
  361. usage := &dto.RealtimeUsage{}
  362. localUsage := &dto.RealtimeUsage{}
  363. sumUsage := &dto.RealtimeUsage{}
  364. go func() {
  365. for {
  366. select {
  367. case <-c.Done():
  368. return
  369. default:
  370. _, message, err := clientConn.ReadMessage()
  371. if err != nil {
  372. if !websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) {
  373. errChan <- fmt.Errorf("error reading from client: %v", err)
  374. }
  375. close(clientClosed)
  376. return
  377. }
  378. realtimeEvent := &dto.RealtimeEvent{}
  379. err = json.Unmarshal(message, realtimeEvent)
  380. if err != nil {
  381. errChan <- fmt.Errorf("error unmarshalling message: %v", err)
  382. return
  383. }
  384. if realtimeEvent.Type == dto.RealtimeEventTypeSessionUpdate {
  385. if realtimeEvent.Session != nil {
  386. if realtimeEvent.Session.Tools != nil {
  387. info.RealtimeTools = realtimeEvent.Session.Tools
  388. }
  389. }
  390. }
  391. textToken, audioToken, err := service.CountTokenRealtime(info, *realtimeEvent, info.UpstreamModelName)
  392. if err != nil {
  393. errChan <- fmt.Errorf("error counting text token: %v", err)
  394. return
  395. }
  396. log.Printf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken)
  397. localUsage.TotalTokens += textToken + audioToken
  398. localUsage.InputTokens += textToken
  399. localUsage.InputTokenDetails.TextTokens += textToken
  400. localUsage.InputTokenDetails.AudioTokens += audioToken
  401. err = service.WssString(c, targetConn, string(message))
  402. if err != nil {
  403. errChan <- fmt.Errorf("error writing to target: %v", err)
  404. return
  405. }
  406. select {
  407. case sendChan <- message:
  408. default:
  409. }
  410. }
  411. }
  412. }()
  413. go func() {
  414. for {
  415. select {
  416. case <-c.Done():
  417. return
  418. default:
  419. _, message, err := targetConn.ReadMessage()
  420. if err != nil {
  421. if !websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) {
  422. errChan <- fmt.Errorf("error reading from target: %v", err)
  423. }
  424. close(targetClosed)
  425. return
  426. }
  427. info.SetFirstResponseTime()
  428. realtimeEvent := &dto.RealtimeEvent{}
  429. err = json.Unmarshal(message, realtimeEvent)
  430. if err != nil {
  431. errChan <- fmt.Errorf("error unmarshalling message: %v", err)
  432. return
  433. }
  434. if realtimeEvent.Type == dto.RealtimeEventTypeResponseDone {
  435. realtimeUsage := realtimeEvent.Response.Usage
  436. if realtimeUsage != nil {
  437. usage.TotalTokens += realtimeUsage.TotalTokens
  438. usage.InputTokens += realtimeUsage.InputTokens
  439. usage.OutputTokens += realtimeUsage.OutputTokens
  440. usage.InputTokenDetails.AudioTokens += realtimeUsage.InputTokenDetails.AudioTokens
  441. usage.InputTokenDetails.CachedTokens += realtimeUsage.InputTokenDetails.CachedTokens
  442. usage.InputTokenDetails.TextTokens += realtimeUsage.InputTokenDetails.TextTokens
  443. usage.OutputTokenDetails.AudioTokens += realtimeUsage.OutputTokenDetails.AudioTokens
  444. usage.OutputTokenDetails.TextTokens += realtimeUsage.OutputTokenDetails.TextTokens
  445. err := preConsumeUsage(c, info, usage, sumUsage)
  446. if err != nil {
  447. errChan <- fmt.Errorf("error consume usage: %v", err)
  448. return
  449. }
  450. usage = &dto.RealtimeUsage{}
  451. } else {
  452. textToken, audioToken, err := service.CountTokenRealtime(info, *realtimeEvent, info.UpstreamModelName)
  453. if err != nil {
  454. errChan <- fmt.Errorf("error counting text token: %v", err)
  455. return
  456. }
  457. log.Printf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken)
  458. localUsage.TotalTokens += textToken + audioToken
  459. info.IsFirstRequest = false
  460. localUsage.InputTokens += textToken + audioToken
  461. localUsage.InputTokenDetails.TextTokens += textToken
  462. localUsage.InputTokenDetails.AudioTokens += audioToken
  463. err = preConsumeUsage(c, info, localUsage, sumUsage)
  464. if err != nil {
  465. errChan <- fmt.Errorf("error consume usage: %v", err)
  466. return
  467. }
  468. localUsage = &dto.RealtimeUsage{}
  469. // print now usage
  470. }
  471. common.LogInfo(c, fmt.Sprintf("realtime streaming sumUsage: %v", sumUsage))
  472. common.LogInfo(c, fmt.Sprintf("realtime streaming localUsage: %v", localUsage))
  473. common.LogInfo(c, fmt.Sprintf("realtime streaming localUsage: %v", localUsage))
  474. } else if realtimeEvent.Type == dto.RealtimeEventTypeSessionUpdated || realtimeEvent.Type == dto.RealtimeEventTypeSessionCreated {
  475. realtimeSession := realtimeEvent.Session
  476. if realtimeSession != nil {
  477. // update audio format
  478. info.InputAudioFormat = common.GetStringIfEmpty(realtimeSession.InputAudioFormat, info.InputAudioFormat)
  479. info.OutputAudioFormat = common.GetStringIfEmpty(realtimeSession.OutputAudioFormat, info.OutputAudioFormat)
  480. }
  481. } else {
  482. textToken, audioToken, err := service.CountTokenRealtime(info, *realtimeEvent, info.UpstreamModelName)
  483. if err != nil {
  484. errChan <- fmt.Errorf("error counting text token: %v", err)
  485. return
  486. }
  487. log.Printf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken)
  488. localUsage.TotalTokens += textToken + audioToken
  489. localUsage.OutputTokens += textToken + audioToken
  490. localUsage.OutputTokenDetails.TextTokens += textToken
  491. localUsage.OutputTokenDetails.AudioTokens += audioToken
  492. }
  493. err = service.WssString(c, clientConn, string(message))
  494. if err != nil {
  495. errChan <- fmt.Errorf("error writing to client: %v", err)
  496. return
  497. }
  498. select {
  499. case receiveChan <- message:
  500. default:
  501. }
  502. }
  503. }
  504. }()
  505. select {
  506. case <-clientClosed:
  507. case <-targetClosed:
  508. case err := <-errChan:
  509. //return service.OpenAIErrorWrapper(err, "realtime_error", http.StatusInternalServerError), nil
  510. common.LogError(c, "realtime error: "+err.Error())
  511. case <-c.Done():
  512. }
  513. if usage.TotalTokens != 0 {
  514. _ = preConsumeUsage(c, info, usage, sumUsage)
  515. }
  516. if localUsage.TotalTokens != 0 {
  517. _ = preConsumeUsage(c, info, localUsage, sumUsage)
  518. }
  519. // check usage total tokens, if 0, use local usage
  520. return nil, sumUsage
  521. }
  522. func preConsumeUsage(ctx *gin.Context, info *relaycommon.RelayInfo, usage *dto.RealtimeUsage, totalUsage *dto.RealtimeUsage) error {
  523. totalUsage.TotalTokens += usage.TotalTokens
  524. totalUsage.InputTokens += usage.InputTokens
  525. totalUsage.OutputTokens += usage.OutputTokens
  526. totalUsage.InputTokenDetails.CachedTokens += usage.InputTokenDetails.CachedTokens
  527. totalUsage.InputTokenDetails.TextTokens += usage.InputTokenDetails.TextTokens
  528. totalUsage.InputTokenDetails.AudioTokens += usage.InputTokenDetails.AudioTokens
  529. totalUsage.OutputTokenDetails.TextTokens += usage.OutputTokenDetails.TextTokens
  530. totalUsage.OutputTokenDetails.AudioTokens += usage.OutputTokenDetails.AudioTokens
  531. // clear usage
  532. err := service.PreWssConsumeQuota(ctx, info, usage)
  533. if err == nil {
  534. common.LogInfo(ctx, "realtime streaming consume usage success")
  535. }
  536. return err
  537. }