relay-openai.go 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628
  1. package openai
  2. import (
  3. "bufio"
  4. "bytes"
  5. "encoding/json"
  6. "fmt"
  7. "io"
  8. "net/http"
  9. "one-api/common"
  10. "one-api/constant"
  11. "one-api/dto"
  12. relaycommon "one-api/relay/common"
  13. relayconstant "one-api/relay/constant"
  14. "one-api/service"
  15. "strings"
  16. "sync"
  17. "time"
  18. "github.com/bytedance/gopkg/util/gopool"
  19. "github.com/gin-gonic/gin"
  20. "github.com/gorilla/websocket"
  21. )
  22. func sendStreamData(c *gin.Context, data string, forceFormat bool) error {
  23. if data == "" {
  24. return nil
  25. }
  26. if forceFormat {
  27. var lastStreamResponse dto.ChatCompletionsStreamResponse
  28. if err := json.Unmarshal(common.StringToByteSlice(data), &lastStreamResponse); err != nil {
  29. return err
  30. }
  31. return service.ObjectData(c, lastStreamResponse)
  32. }
  33. return service.StringData(c, data)
  34. }
  35. func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
  36. if resp == nil || resp.Body == nil {
  37. common.LogError(c, "invalid response or response body")
  38. return service.OpenAIErrorWrapper(fmt.Errorf("invalid response"), "invalid_response", http.StatusInternalServerError), nil
  39. }
  40. containStreamUsage := false
  41. var responseId string
  42. var createAt int64 = 0
  43. var systemFingerprint string
  44. model := info.UpstreamModelName
  45. var responseTextBuilder strings.Builder
  46. var usage = &dto.Usage{}
  47. var streamItems []string // store stream items
  48. var forceFormat bool
  49. if info.ChannelType == common.ChannelTypeCustom {
  50. if forceFmt, ok := info.ChannelSetting["force_format"].(bool); ok {
  51. forceFormat = forceFmt
  52. }
  53. }
  54. toolCount := 0
  55. scanner := bufio.NewScanner(resp.Body)
  56. scanner.Split(bufio.ScanLines)
  57. service.SetEventStreamHeaders(c)
  58. ticker := time.NewTicker(time.Duration(constant.StreamingTimeout) * time.Second)
  59. defer ticker.Stop()
  60. stopChan := make(chan bool)
  61. defer close(stopChan)
  62. var (
  63. lastStreamData string
  64. mu sync.Mutex
  65. )
  66. gopool.Go(func() {
  67. for scanner.Scan() {
  68. info.SetFirstResponseTime()
  69. ticker.Reset(time.Duration(constant.StreamingTimeout) * time.Second)
  70. data := scanner.Text()
  71. if len(data) < 6 { // ignore blank line or wrong format
  72. continue
  73. }
  74. if data[:6] != "data: " && data[:6] != "[DONE]" {
  75. continue
  76. }
  77. mu.Lock()
  78. data = data[6:]
  79. if !strings.HasPrefix(data, "[DONE]") {
  80. if lastStreamData != "" {
  81. err := sendStreamData(c, lastStreamData, forceFormat)
  82. if err != nil {
  83. common.LogError(c, "streaming error: "+err.Error())
  84. }
  85. }
  86. lastStreamData = data
  87. streamItems = append(streamItems, data)
  88. }
  89. mu.Unlock()
  90. }
  91. common.SafeSendBool(stopChan, true)
  92. })
  93. select {
  94. case <-ticker.C:
  95. // 超时处理逻辑
  96. common.LogError(c, "streaming timeout")
  97. case <-stopChan:
  98. // 正常结束
  99. }
  100. shouldSendLastResp := true
  101. var lastStreamResponse dto.ChatCompletionsStreamResponse
  102. err := json.Unmarshal(common.StringToByteSlice(lastStreamData), &lastStreamResponse)
  103. if err == nil {
  104. responseId = lastStreamResponse.Id
  105. createAt = lastStreamResponse.Created
  106. systemFingerprint = lastStreamResponse.GetSystemFingerprint()
  107. model = lastStreamResponse.Model
  108. if service.ValidUsage(lastStreamResponse.Usage) {
  109. containStreamUsage = true
  110. usage = lastStreamResponse.Usage
  111. if !info.ShouldIncludeUsage {
  112. shouldSendLastResp = false
  113. }
  114. }
  115. for _, choice := range lastStreamResponse.Choices {
  116. if choice.FinishReason != nil {
  117. shouldSendLastResp = true
  118. }
  119. }
  120. }
  121. if shouldSendLastResp {
  122. sendStreamData(c, lastStreamData, forceFormat)
  123. }
  124. // 计算token
  125. streamResp := "[" + strings.Join(streamItems, ",") + "]"
  126. switch info.RelayMode {
  127. case relayconstant.RelayModeChatCompletions:
  128. var streamResponses []dto.ChatCompletionsStreamResponse
  129. err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses)
  130. if err != nil {
  131. // 一次性解析失败,逐个解析
  132. common.SysError("error unmarshalling stream response: " + err.Error())
  133. for _, item := range streamItems {
  134. var streamResponse dto.ChatCompletionsStreamResponse
  135. err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse)
  136. if err == nil {
  137. //if service.ValidUsage(streamResponse.Usage) {
  138. // usage = streamResponse.Usage
  139. //}
  140. for _, choice := range streamResponse.Choices {
  141. responseTextBuilder.WriteString(choice.Delta.GetContentString())
  142. if choice.Delta.ToolCalls != nil {
  143. if len(choice.Delta.ToolCalls) > toolCount {
  144. toolCount = len(choice.Delta.ToolCalls)
  145. }
  146. for _, tool := range choice.Delta.ToolCalls {
  147. responseTextBuilder.WriteString(tool.Function.Name)
  148. responseTextBuilder.WriteString(tool.Function.Arguments)
  149. }
  150. }
  151. }
  152. }
  153. }
  154. } else {
  155. for _, streamResponse := range streamResponses {
  156. //if service.ValidUsage(streamResponse.Usage) {
  157. // usage = streamResponse.Usage
  158. // containStreamUsage = true
  159. //}
  160. for _, choice := range streamResponse.Choices {
  161. responseTextBuilder.WriteString(choice.Delta.GetContentString())
  162. if choice.Delta.ToolCalls != nil {
  163. if len(choice.Delta.ToolCalls) > toolCount {
  164. toolCount = len(choice.Delta.ToolCalls)
  165. }
  166. for _, tool := range choice.Delta.ToolCalls {
  167. responseTextBuilder.WriteString(tool.Function.Name)
  168. responseTextBuilder.WriteString(tool.Function.Arguments)
  169. }
  170. }
  171. }
  172. }
  173. }
  174. case relayconstant.RelayModeCompletions:
  175. var streamResponses []dto.CompletionsStreamResponse
  176. err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses)
  177. if err != nil {
  178. // 一次性解析失败,逐个解析
  179. common.SysError("error unmarshalling stream response: " + err.Error())
  180. for _, item := range streamItems {
  181. var streamResponse dto.CompletionsStreamResponse
  182. err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse)
  183. if err == nil {
  184. for _, choice := range streamResponse.Choices {
  185. responseTextBuilder.WriteString(choice.Text)
  186. }
  187. }
  188. }
  189. } else {
  190. for _, streamResponse := range streamResponses {
  191. for _, choice := range streamResponse.Choices {
  192. responseTextBuilder.WriteString(choice.Text)
  193. }
  194. }
  195. }
  196. }
  197. if !containStreamUsage {
  198. usage, _ = service.ResponseText2Usage(responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens)
  199. usage.CompletionTokens += toolCount * 7
  200. }
  201. if info.ShouldIncludeUsage && !containStreamUsage {
  202. response := service.GenerateFinalUsageResponse(responseId, createAt, model, *usage)
  203. response.SetSystemFingerprint(systemFingerprint)
  204. service.ObjectData(c, response)
  205. }
  206. service.Done(c)
  207. resp.Body.Close()
  208. return nil, usage
  209. }
  210. func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
  211. var simpleResponse dto.SimpleResponse
  212. responseBody, err := io.ReadAll(resp.Body)
  213. if err != nil {
  214. return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
  215. }
  216. err = resp.Body.Close()
  217. if err != nil {
  218. return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
  219. }
  220. err = json.Unmarshal(responseBody, &simpleResponse)
  221. if err != nil {
  222. return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
  223. }
  224. if simpleResponse.Error.Type != "" {
  225. return &dto.OpenAIErrorWithStatusCode{
  226. Error: simpleResponse.Error,
  227. StatusCode: resp.StatusCode,
  228. }, nil
  229. }
  230. // Reset response body
  231. resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
  232. // We shouldn't set the header before we parse the response body, because the parse part may fail.
  233. // And then we will have to send an error response, but in this case, the header has already been set.
  234. // So the httpClient will be confused by the response.
  235. // For example, Postman will report error, and we cannot check the response at all.
  236. for k, v := range resp.Header {
  237. c.Writer.Header().Set(k, v[0])
  238. }
  239. c.Writer.WriteHeader(resp.StatusCode)
  240. _, err = io.Copy(c.Writer, resp.Body)
  241. if err != nil {
  242. return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
  243. }
  244. resp.Body.Close()
  245. if simpleResponse.Usage.TotalTokens == 0 || (simpleResponse.Usage.PromptTokens == 0 && simpleResponse.Usage.CompletionTokens == 0) {
  246. completionTokens := 0
  247. for _, choice := range simpleResponse.Choices {
  248. ctkm, _ := service.CountTextToken(string(choice.Message.Content), model)
  249. completionTokens += ctkm
  250. }
  251. simpleResponse.Usage = dto.Usage{
  252. PromptTokens: promptTokens,
  253. CompletionTokens: completionTokens,
  254. TotalTokens: promptTokens + completionTokens,
  255. }
  256. }
  257. return nil, &simpleResponse.Usage
  258. }
  259. func OpenaiTTSHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
  260. responseBody, err := io.ReadAll(resp.Body)
  261. if err != nil {
  262. return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
  263. }
  264. err = resp.Body.Close()
  265. if err != nil {
  266. return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
  267. }
  268. // Reset response body
  269. resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
  270. // We shouldn't set the header before we parse the response body, because the parse part may fail.
  271. // And then we will have to send an error response, but in this case, the header has already been set.
  272. // So the httpClient will be confused by the response.
  273. // For example, Postman will report error, and we cannot check the response at all.
  274. for k, v := range resp.Header {
  275. c.Writer.Header().Set(k, v[0])
  276. }
  277. c.Writer.WriteHeader(resp.StatusCode)
  278. _, err = io.Copy(c.Writer, resp.Body)
  279. if err != nil {
  280. return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
  281. }
  282. err = resp.Body.Close()
  283. if err != nil {
  284. return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
  285. }
  286. usage := &dto.Usage{}
  287. usage.PromptTokens = info.PromptTokens
  288. usage.TotalTokens = info.PromptTokens
  289. return nil, usage
  290. }
  291. func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, responseFormat string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
  292. responseBody, err := io.ReadAll(resp.Body)
  293. if err != nil {
  294. return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
  295. }
  296. err = resp.Body.Close()
  297. if err != nil {
  298. return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
  299. }
  300. // Reset response body
  301. resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
  302. // We shouldn't set the header before we parse the response body, because the parse part may fail.
  303. // And then we will have to send an error response, but in this case, the header has already been set.
  304. // So the httpClient will be confused by the response.
  305. // For example, Postman will report error, and we cannot check the response at all.
  306. for k, v := range resp.Header {
  307. c.Writer.Header().Set(k, v[0])
  308. }
  309. c.Writer.WriteHeader(resp.StatusCode)
  310. _, err = io.Copy(c.Writer, resp.Body)
  311. if err != nil {
  312. return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
  313. }
  314. resp.Body.Close()
  315. var text string
  316. switch responseFormat {
  317. case "json":
  318. text, err = getTextFromJSON(responseBody)
  319. case "text":
  320. text, err = getTextFromText(responseBody)
  321. case "srt":
  322. text, err = getTextFromSRT(responseBody)
  323. case "verbose_json":
  324. text, err = getTextFromVerboseJSON(responseBody)
  325. case "vtt":
  326. text, err = getTextFromVTT(responseBody)
  327. }
  328. usage := &dto.Usage{}
  329. usage.PromptTokens = info.PromptTokens
  330. usage.CompletionTokens, _ = service.CountTextToken(text, info.UpstreamModelName)
  331. usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
  332. return nil, usage
  333. }
  334. func getTextFromVTT(body []byte) (string, error) {
  335. return getTextFromSRT(body)
  336. }
  337. func getTextFromVerboseJSON(body []byte) (string, error) {
  338. var whisperResponse dto.WhisperVerboseJSONResponse
  339. if err := json.Unmarshal(body, &whisperResponse); err != nil {
  340. return "", fmt.Errorf("unmarshal_response_body_failed err :%w", err)
  341. }
  342. return whisperResponse.Text, nil
  343. }
  344. func getTextFromSRT(body []byte) (string, error) {
  345. scanner := bufio.NewScanner(strings.NewReader(string(body)))
  346. var builder strings.Builder
  347. var textLine bool
  348. for scanner.Scan() {
  349. line := scanner.Text()
  350. if textLine {
  351. builder.WriteString(line)
  352. textLine = false
  353. continue
  354. } else if strings.Contains(line, "-->") {
  355. textLine = true
  356. continue
  357. }
  358. }
  359. if err := scanner.Err(); err != nil {
  360. return "", err
  361. }
  362. return builder.String(), nil
  363. }
  364. func getTextFromText(body []byte) (string, error) {
  365. return strings.TrimSuffix(string(body), "\n"), nil
  366. }
  367. func getTextFromJSON(body []byte) (string, error) {
  368. var whisperResponse dto.AudioResponse
  369. if err := json.Unmarshal(body, &whisperResponse); err != nil {
  370. return "", fmt.Errorf("unmarshal_response_body_failed err :%w", err)
  371. }
  372. return whisperResponse.Text, nil
  373. }
  374. func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.RealtimeUsage) {
  375. if info == nil || info.ClientWs == nil || info.TargetWs == nil {
  376. return service.OpenAIErrorWrapper(fmt.Errorf("invalid websocket connection"), "invalid_connection", http.StatusBadRequest), nil
  377. }
  378. info.IsStream = true
  379. clientConn := info.ClientWs
  380. targetConn := info.TargetWs
  381. clientClosed := make(chan struct{})
  382. targetClosed := make(chan struct{})
  383. sendChan := make(chan []byte, 100)
  384. receiveChan := make(chan []byte, 100)
  385. errChan := make(chan error, 2)
  386. usage := &dto.RealtimeUsage{}
  387. localUsage := &dto.RealtimeUsage{}
  388. sumUsage := &dto.RealtimeUsage{}
  389. gopool.Go(func() {
  390. defer func() {
  391. if r := recover(); r != nil {
  392. errChan <- fmt.Errorf("panic in client reader: %v", r)
  393. }
  394. }()
  395. for {
  396. select {
  397. case <-c.Done():
  398. return
  399. default:
  400. _, message, err := clientConn.ReadMessage()
  401. if err != nil {
  402. if !websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) {
  403. errChan <- fmt.Errorf("error reading from client: %v", err)
  404. }
  405. close(clientClosed)
  406. return
  407. }
  408. realtimeEvent := &dto.RealtimeEvent{}
  409. err = json.Unmarshal(message, realtimeEvent)
  410. if err != nil {
  411. errChan <- fmt.Errorf("error unmarshalling message: %v", err)
  412. return
  413. }
  414. if realtimeEvent.Type == dto.RealtimeEventTypeSessionUpdate {
  415. if realtimeEvent.Session != nil {
  416. if realtimeEvent.Session.Tools != nil {
  417. info.RealtimeTools = realtimeEvent.Session.Tools
  418. }
  419. }
  420. }
  421. textToken, audioToken, err := service.CountTokenRealtime(info, *realtimeEvent, info.UpstreamModelName)
  422. if err != nil {
  423. errChan <- fmt.Errorf("error counting text token: %v", err)
  424. return
  425. }
  426. common.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken))
  427. localUsage.TotalTokens += textToken + audioToken
  428. localUsage.InputTokens += textToken + audioToken
  429. localUsage.InputTokenDetails.TextTokens += textToken
  430. localUsage.InputTokenDetails.AudioTokens += audioToken
  431. err = service.WssString(c, targetConn, string(message))
  432. if err != nil {
  433. errChan <- fmt.Errorf("error writing to target: %v", err)
  434. return
  435. }
  436. select {
  437. case sendChan <- message:
  438. default:
  439. }
  440. }
  441. }
  442. })
  443. gopool.Go(func() {
  444. defer func() {
  445. if r := recover(); r != nil {
  446. errChan <- fmt.Errorf("panic in target reader: %v", r)
  447. }
  448. }()
  449. for {
  450. select {
  451. case <-c.Done():
  452. return
  453. default:
  454. _, message, err := targetConn.ReadMessage()
  455. if err != nil {
  456. if !websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) {
  457. errChan <- fmt.Errorf("error reading from target: %v", err)
  458. }
  459. close(targetClosed)
  460. return
  461. }
  462. info.SetFirstResponseTime()
  463. realtimeEvent := &dto.RealtimeEvent{}
  464. err = json.Unmarshal(message, realtimeEvent)
  465. if err != nil {
  466. errChan <- fmt.Errorf("error unmarshalling message: %v", err)
  467. return
  468. }
  469. if realtimeEvent.Type == dto.RealtimeEventTypeResponseDone {
  470. realtimeUsage := realtimeEvent.Response.Usage
  471. if realtimeUsage != nil {
  472. usage.TotalTokens += realtimeUsage.TotalTokens
  473. usage.InputTokens += realtimeUsage.InputTokens
  474. usage.OutputTokens += realtimeUsage.OutputTokens
  475. usage.InputTokenDetails.AudioTokens += realtimeUsage.InputTokenDetails.AudioTokens
  476. usage.InputTokenDetails.CachedTokens += realtimeUsage.InputTokenDetails.CachedTokens
  477. usage.InputTokenDetails.TextTokens += realtimeUsage.InputTokenDetails.TextTokens
  478. usage.OutputTokenDetails.AudioTokens += realtimeUsage.OutputTokenDetails.AudioTokens
  479. usage.OutputTokenDetails.TextTokens += realtimeUsage.OutputTokenDetails.TextTokens
  480. err := preConsumeUsage(c, info, usage, sumUsage)
  481. if err != nil {
  482. errChan <- fmt.Errorf("error consume usage: %v", err)
  483. return
  484. }
  485. // 本次计费完成,清除
  486. usage = &dto.RealtimeUsage{}
  487. localUsage = &dto.RealtimeUsage{}
  488. } else {
  489. textToken, audioToken, err := service.CountTokenRealtime(info, *realtimeEvent, info.UpstreamModelName)
  490. if err != nil {
  491. errChan <- fmt.Errorf("error counting text token: %v", err)
  492. return
  493. }
  494. common.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken))
  495. localUsage.TotalTokens += textToken + audioToken
  496. info.IsFirstRequest = false
  497. localUsage.InputTokens += textToken + audioToken
  498. localUsage.InputTokenDetails.TextTokens += textToken
  499. localUsage.InputTokenDetails.AudioTokens += audioToken
  500. err = preConsumeUsage(c, info, localUsage, sumUsage)
  501. if err != nil {
  502. errChan <- fmt.Errorf("error consume usage: %v", err)
  503. return
  504. }
  505. // 本次计费完成,清除
  506. localUsage = &dto.RealtimeUsage{}
  507. // print now usage
  508. }
  509. //common.LogInfo(c, fmt.Sprintf("realtime streaming sumUsage: %v", sumUsage))
  510. //common.LogInfo(c, fmt.Sprintf("realtime streaming localUsage: %v", localUsage))
  511. //common.LogInfo(c, fmt.Sprintf("realtime streaming localUsage: %v", localUsage))
  512. } else if realtimeEvent.Type == dto.RealtimeEventTypeSessionUpdated || realtimeEvent.Type == dto.RealtimeEventTypeSessionCreated {
  513. realtimeSession := realtimeEvent.Session
  514. if realtimeSession != nil {
  515. // update audio format
  516. info.InputAudioFormat = common.GetStringIfEmpty(realtimeSession.InputAudioFormat, info.InputAudioFormat)
  517. info.OutputAudioFormat = common.GetStringIfEmpty(realtimeSession.OutputAudioFormat, info.OutputAudioFormat)
  518. }
  519. } else {
  520. textToken, audioToken, err := service.CountTokenRealtime(info, *realtimeEvent, info.UpstreamModelName)
  521. if err != nil {
  522. errChan <- fmt.Errorf("error counting text token: %v", err)
  523. return
  524. }
  525. common.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken))
  526. localUsage.TotalTokens += textToken + audioToken
  527. localUsage.OutputTokens += textToken + audioToken
  528. localUsage.OutputTokenDetails.TextTokens += textToken
  529. localUsage.OutputTokenDetails.AudioTokens += audioToken
  530. }
  531. err = service.WssString(c, clientConn, string(message))
  532. if err != nil {
  533. errChan <- fmt.Errorf("error writing to client: %v", err)
  534. return
  535. }
  536. select {
  537. case receiveChan <- message:
  538. default:
  539. }
  540. }
  541. }
  542. })
  543. select {
  544. case <-clientClosed:
  545. case <-targetClosed:
  546. case err := <-errChan:
  547. //return service.OpenAIErrorWrapper(err, "realtime_error", http.StatusInternalServerError), nil
  548. common.LogError(c, "realtime error: "+err.Error())
  549. case <-c.Done():
  550. }
  551. if usage.TotalTokens != 0 {
  552. _ = preConsumeUsage(c, info, usage, sumUsage)
  553. }
  554. if localUsage.TotalTokens != 0 {
  555. _ = preConsumeUsage(c, info, localUsage, sumUsage)
  556. }
  557. // check usage total tokens, if 0, use local usage
  558. return nil, sumUsage
  559. }
  560. func preConsumeUsage(ctx *gin.Context, info *relaycommon.RelayInfo, usage *dto.RealtimeUsage, totalUsage *dto.RealtimeUsage) error {
  561. if usage == nil || totalUsage == nil {
  562. return fmt.Errorf("invalid usage pointer")
  563. }
  564. totalUsage.TotalTokens += usage.TotalTokens
  565. totalUsage.InputTokens += usage.InputTokens
  566. totalUsage.OutputTokens += usage.OutputTokens
  567. totalUsage.InputTokenDetails.CachedTokens += usage.InputTokenDetails.CachedTokens
  568. totalUsage.InputTokenDetails.TextTokens += usage.InputTokenDetails.TextTokens
  569. totalUsage.InputTokenDetails.AudioTokens += usage.InputTokenDetails.AudioTokens
  570. totalUsage.OutputTokenDetails.TextTokens += usage.OutputTokenDetails.TextTokens
  571. totalUsage.OutputTokenDetails.AudioTokens += usage.OutputTokenDetails.AudioTokens
  572. // clear usage
  573. err := service.PreWssConsumeQuota(ctx, info, usage)
  574. return err
  575. }