relay-openai.go 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650
  1. package openai
  2. import (
  3. "bytes"
  4. "encoding/json"
  5. "fmt"
  6. "io"
  7. "math"
  8. "mime/multipart"
  9. "net/http"
  10. "one-api/common"
  11. "one-api/constant"
  12. "one-api/dto"
  13. relaycommon "one-api/relay/common"
  14. "one-api/relay/helper"
  15. "one-api/service"
  16. "os"
  17. "path/filepath"
  18. "strings"
  19. "github.com/bytedance/gopkg/util/gopool"
  20. "github.com/gin-gonic/gin"
  21. "github.com/gorilla/websocket"
  22. "github.com/pkg/errors"
  23. )
  24. func sendStreamData(c *gin.Context, info *relaycommon.RelayInfo, data string, forceFormat bool, thinkToContent bool) error {
  25. if data == "" {
  26. return nil
  27. }
  28. if !forceFormat && !thinkToContent {
  29. return helper.StringData(c, data)
  30. }
  31. var lastStreamResponse dto.ChatCompletionsStreamResponse
  32. if err := common.DecodeJsonStr(data, &lastStreamResponse); err != nil {
  33. return err
  34. }
  35. if !thinkToContent {
  36. return helper.ObjectData(c, lastStreamResponse)
  37. }
  38. hasThinkingContent := false
  39. hasContent := false
  40. var thinkingContent strings.Builder
  41. for _, choice := range lastStreamResponse.Choices {
  42. if len(choice.Delta.GetReasoningContent()) > 0 {
  43. hasThinkingContent = true
  44. thinkingContent.WriteString(choice.Delta.GetReasoningContent())
  45. }
  46. if len(choice.Delta.GetContentString()) > 0 {
  47. hasContent = true
  48. }
  49. }
  50. // Handle think to content conversion
  51. if info.ThinkingContentInfo.IsFirstThinkingContent {
  52. if hasThinkingContent {
  53. response := lastStreamResponse.Copy()
  54. for i := range response.Choices {
  55. // send `think` tag with thinking content
  56. response.Choices[i].Delta.SetContentString("<think>\n" + thinkingContent.String())
  57. response.Choices[i].Delta.ReasoningContent = nil
  58. response.Choices[i].Delta.Reasoning = nil
  59. }
  60. info.ThinkingContentInfo.IsFirstThinkingContent = false
  61. info.ThinkingContentInfo.HasSentThinkingContent = true
  62. return helper.ObjectData(c, response)
  63. }
  64. }
  65. if lastStreamResponse.Choices == nil || len(lastStreamResponse.Choices) == 0 {
  66. return helper.ObjectData(c, lastStreamResponse)
  67. }
  68. // Process each choice
  69. for i, choice := range lastStreamResponse.Choices {
  70. // Handle transition from thinking to content
  71. // only send `</think>` tag when previous thinking content has been sent
  72. if hasContent && !info.ThinkingContentInfo.SendLastThinkingContent && info.ThinkingContentInfo.HasSentThinkingContent {
  73. response := lastStreamResponse.Copy()
  74. for j := range response.Choices {
  75. response.Choices[j].Delta.SetContentString("\n</think>\n")
  76. response.Choices[j].Delta.ReasoningContent = nil
  77. response.Choices[j].Delta.Reasoning = nil
  78. }
  79. info.ThinkingContentInfo.SendLastThinkingContent = true
  80. helper.ObjectData(c, response)
  81. }
  82. // Convert reasoning content to regular content if any
  83. if len(choice.Delta.GetReasoningContent()) > 0 {
  84. lastStreamResponse.Choices[i].Delta.SetContentString(choice.Delta.GetReasoningContent())
  85. lastStreamResponse.Choices[i].Delta.ReasoningContent = nil
  86. lastStreamResponse.Choices[i].Delta.Reasoning = nil
  87. } else if !hasThinkingContent && !hasContent {
  88. // flush thinking content
  89. lastStreamResponse.Choices[i].Delta.ReasoningContent = nil
  90. lastStreamResponse.Choices[i].Delta.Reasoning = nil
  91. }
  92. }
  93. return helper.ObjectData(c, lastStreamResponse)
  94. }
  95. func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
  96. if resp == nil || resp.Body == nil {
  97. common.LogError(c, "invalid response or response body")
  98. return service.OpenAIErrorWrapper(fmt.Errorf("invalid response"), "invalid_response", http.StatusInternalServerError), nil
  99. }
  100. containStreamUsage := false
  101. var responseId string
  102. var createAt int64 = 0
  103. var systemFingerprint string
  104. model := info.UpstreamModelName
  105. var responseTextBuilder strings.Builder
  106. var toolCount int
  107. var usage = &dto.Usage{}
  108. var streamItems []string // store stream items
  109. var forceFormat bool
  110. var thinkToContent bool
  111. if forceFmt, ok := info.ChannelSetting[constant.ForceFormat].(bool); ok {
  112. forceFormat = forceFmt
  113. }
  114. if think2Content, ok := info.ChannelSetting[constant.ChannelSettingThinkingToContent].(bool); ok {
  115. thinkToContent = think2Content
  116. }
  117. var (
  118. lastStreamData string
  119. )
  120. helper.StreamScannerHandler(c, resp, info, func(data string) bool {
  121. if lastStreamData != "" {
  122. err := handleStreamFormat(c, info, lastStreamData, forceFormat, thinkToContent)
  123. if err != nil {
  124. common.SysError("error handling stream format: " + err.Error())
  125. }
  126. }
  127. lastStreamData = data
  128. streamItems = append(streamItems, data)
  129. return true
  130. })
  131. shouldSendLastResp := true
  132. var lastStreamResponse dto.ChatCompletionsStreamResponse
  133. err := common.DecodeJsonStr(lastStreamData, &lastStreamResponse)
  134. if err == nil {
  135. responseId = lastStreamResponse.Id
  136. createAt = lastStreamResponse.Created
  137. systemFingerprint = lastStreamResponse.GetSystemFingerprint()
  138. model = lastStreamResponse.Model
  139. if service.ValidUsage(lastStreamResponse.Usage) {
  140. containStreamUsage = true
  141. usage = lastStreamResponse.Usage
  142. if !info.ShouldIncludeUsage {
  143. shouldSendLastResp = false
  144. }
  145. }
  146. for _, choice := range lastStreamResponse.Choices {
  147. if choice.FinishReason != nil {
  148. shouldSendLastResp = true
  149. }
  150. }
  151. }
  152. if shouldSendLastResp {
  153. sendStreamData(c, info, lastStreamData, forceFormat, thinkToContent)
  154. //err = handleStreamFormat(c, info, lastStreamData, forceFormat, thinkToContent)
  155. }
  156. // 处理token计算
  157. if err := processTokens(info.RelayMode, streamItems, &responseTextBuilder, &toolCount); err != nil {
  158. common.SysError("error processing tokens: " + err.Error())
  159. }
  160. if !containStreamUsage {
  161. usage = service.ResponseText2Usage(responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens)
  162. usage.CompletionTokens += toolCount * 7
  163. } else {
  164. if info.ChannelType == common.ChannelTypeDeepSeek {
  165. if usage.PromptCacheHitTokens != 0 {
  166. usage.PromptTokensDetails.CachedTokens = usage.PromptCacheHitTokens
  167. }
  168. }
  169. }
  170. handleFinalResponse(c, info, lastStreamData, responseId, createAt, model, systemFingerprint, usage, containStreamUsage)
  171. return nil, usage
  172. }
  173. func OpenaiHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
  174. var simpleResponse dto.OpenAITextResponse
  175. responseBody, err := io.ReadAll(resp.Body)
  176. if err != nil {
  177. return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
  178. }
  179. err = resp.Body.Close()
  180. if err != nil {
  181. return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
  182. }
  183. err = common.DecodeJson(responseBody, &simpleResponse)
  184. if err != nil {
  185. return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
  186. }
  187. if simpleResponse.Error != nil && simpleResponse.Error.Type != "" {
  188. return &dto.OpenAIErrorWithStatusCode{
  189. Error: *simpleResponse.Error,
  190. StatusCode: resp.StatusCode,
  191. }, nil
  192. }
  193. forceFormat := false
  194. if forceFmt, ok := info.ChannelSetting[constant.ForceFormat].(bool); ok {
  195. forceFormat = forceFmt
  196. }
  197. if simpleResponse.Usage.TotalTokens == 0 || (simpleResponse.Usage.PromptTokens == 0 && simpleResponse.Usage.CompletionTokens == 0) {
  198. completionTokens := 0
  199. for _, choice := range simpleResponse.Choices {
  200. ctkm := service.CountTextToken(choice.Message.StringContent()+choice.Message.ReasoningContent+choice.Message.Reasoning, info.UpstreamModelName)
  201. completionTokens += ctkm
  202. }
  203. simpleResponse.Usage = dto.Usage{
  204. PromptTokens: info.PromptTokens,
  205. CompletionTokens: completionTokens,
  206. TotalTokens: info.PromptTokens + completionTokens,
  207. }
  208. }
  209. switch info.RelayFormat {
  210. case relaycommon.RelayFormatOpenAI:
  211. if forceFormat {
  212. responseBody, err = json.Marshal(simpleResponse)
  213. if err != nil {
  214. return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
  215. }
  216. } else {
  217. break
  218. }
  219. case relaycommon.RelayFormatClaude:
  220. claudeResp := service.ResponseOpenAI2Claude(&simpleResponse, info)
  221. claudeRespStr, err := json.Marshal(claudeResp)
  222. if err != nil {
  223. return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
  224. }
  225. responseBody = claudeRespStr
  226. }
  227. // Reset response body
  228. resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
  229. // We shouldn't set the header before we parse the response body, because the parse part may fail.
  230. // And then we will have to send an error response, but in this case, the header has already been set.
  231. // So the httpClient will be confused by the response.
  232. // For example, Postman will report error, and we cannot check the response at all.
  233. for k, v := range resp.Header {
  234. c.Writer.Header().Set(k, v[0])
  235. }
  236. c.Writer.WriteHeader(resp.StatusCode)
  237. _, err = io.Copy(c.Writer, resp.Body)
  238. if err != nil {
  239. //return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
  240. common.SysError("error copying response body: " + err.Error())
  241. }
  242. resp.Body.Close()
  243. return nil, &simpleResponse.Usage
  244. }
  245. func OpenaiTTSHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
  246. // the status code has been judged before, if there is a body reading failure,
  247. // it should be regarded as a non-recoverable error, so it should not return err for external retry.
  248. // Analogous to nginx's load balancing, it will only retry if it can't be requested or
  249. // if the upstream returns a specific status code, once the upstream has already written the header,
  250. // the subsequent failure of the response body should be regarded as a non-recoverable error,
  251. // and can be terminated directly.
  252. defer resp.Body.Close()
  253. usage := &dto.Usage{}
  254. usage.PromptTokens = info.PromptTokens
  255. usage.TotalTokens = info.PromptTokens
  256. for k, v := range resp.Header {
  257. c.Writer.Header().Set(k, v[0])
  258. }
  259. c.Writer.WriteHeader(resp.StatusCode)
  260. c.Writer.WriteHeaderNow()
  261. _, err := io.Copy(c.Writer, resp.Body)
  262. if err != nil {
  263. common.LogError(c, err.Error())
  264. }
  265. return nil, usage
  266. }
  267. func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, responseFormat string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
  268. // count tokens by audio file duration
  269. audioTokens, err := countAudioTokens(c)
  270. if err != nil {
  271. return service.OpenAIErrorWrapper(err, "count_audio_tokens_failed", http.StatusInternalServerError), nil
  272. }
  273. responseBody, err := io.ReadAll(resp.Body)
  274. if err != nil {
  275. return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
  276. }
  277. err = resp.Body.Close()
  278. if err != nil {
  279. return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
  280. }
  281. // Reset response body
  282. resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
  283. // We shouldn't set the header before we parse the response body, because the parse part may fail.
  284. // And then we will have to send an error response, but in this case, the header has already been set.
  285. // So the httpClient will be confused by the response.
  286. // For example, Postman will report error, and we cannot check the response at all.
  287. for k, v := range resp.Header {
  288. c.Writer.Header().Set(k, v[0])
  289. }
  290. c.Writer.WriteHeader(resp.StatusCode)
  291. _, err = io.Copy(c.Writer, resp.Body)
  292. if err != nil {
  293. return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
  294. }
  295. resp.Body.Close()
  296. usage := &dto.Usage{}
  297. usage.PromptTokens = audioTokens
  298. usage.CompletionTokens = 0
  299. usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
  300. return nil, usage
  301. }
  302. func countAudioTokens(c *gin.Context) (int, error) {
  303. body, err := common.GetRequestBody(c)
  304. if err != nil {
  305. return 0, errors.WithStack(err)
  306. }
  307. var reqBody struct {
  308. File *multipart.FileHeader `form:"file" binding:"required"`
  309. }
  310. c.Request.Body = io.NopCloser(bytes.NewReader(body))
  311. if err = c.ShouldBind(&reqBody); err != nil {
  312. return 0, errors.WithStack(err)
  313. }
  314. ext := filepath.Ext(reqBody.File.Filename) // 获取文件扩展名
  315. reqFp, err := reqBody.File.Open()
  316. if err != nil {
  317. return 0, errors.WithStack(err)
  318. }
  319. defer reqFp.Close()
  320. tmpFp, err := os.CreateTemp("", "audio-*"+ext)
  321. if err != nil {
  322. return 0, errors.WithStack(err)
  323. }
  324. defer os.Remove(tmpFp.Name())
  325. _, err = io.Copy(tmpFp, reqFp)
  326. if err != nil {
  327. return 0, errors.WithStack(err)
  328. }
  329. if err = tmpFp.Close(); err != nil {
  330. return 0, errors.WithStack(err)
  331. }
  332. duration, err := common.GetAudioDuration(c.Request.Context(), tmpFp.Name(), ext)
  333. if err != nil {
  334. return 0, errors.WithStack(err)
  335. }
  336. return int(math.Round(math.Ceil(duration) / 60.0 * 1000)), nil // 1 minute 相当于 1k tokens
  337. }
  338. func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.RealtimeUsage) {
  339. if info == nil || info.ClientWs == nil || info.TargetWs == nil {
  340. return service.OpenAIErrorWrapper(fmt.Errorf("invalid websocket connection"), "invalid_connection", http.StatusBadRequest), nil
  341. }
  342. info.IsStream = true
  343. clientConn := info.ClientWs
  344. targetConn := info.TargetWs
  345. clientClosed := make(chan struct{})
  346. targetClosed := make(chan struct{})
  347. sendChan := make(chan []byte, 100)
  348. receiveChan := make(chan []byte, 100)
  349. errChan := make(chan error, 2)
  350. usage := &dto.RealtimeUsage{}
  351. localUsage := &dto.RealtimeUsage{}
  352. sumUsage := &dto.RealtimeUsage{}
  353. gopool.Go(func() {
  354. defer func() {
  355. if r := recover(); r != nil {
  356. errChan <- fmt.Errorf("panic in client reader: %v", r)
  357. }
  358. }()
  359. for {
  360. select {
  361. case <-c.Done():
  362. return
  363. default:
  364. _, message, err := clientConn.ReadMessage()
  365. if err != nil {
  366. if !websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) {
  367. errChan <- fmt.Errorf("error reading from client: %v", err)
  368. }
  369. close(clientClosed)
  370. return
  371. }
  372. realtimeEvent := &dto.RealtimeEvent{}
  373. err = json.Unmarshal(message, realtimeEvent)
  374. if err != nil {
  375. errChan <- fmt.Errorf("error unmarshalling message: %v", err)
  376. return
  377. }
  378. if realtimeEvent.Type == dto.RealtimeEventTypeSessionUpdate {
  379. if realtimeEvent.Session != nil {
  380. if realtimeEvent.Session.Tools != nil {
  381. info.RealtimeTools = realtimeEvent.Session.Tools
  382. }
  383. }
  384. }
  385. textToken, audioToken, err := service.CountTokenRealtime(info, *realtimeEvent, info.UpstreamModelName)
  386. if err != nil {
  387. errChan <- fmt.Errorf("error counting text token: %v", err)
  388. return
  389. }
  390. common.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken))
  391. localUsage.TotalTokens += textToken + audioToken
  392. localUsage.InputTokens += textToken + audioToken
  393. localUsage.InputTokenDetails.TextTokens += textToken
  394. localUsage.InputTokenDetails.AudioTokens += audioToken
  395. err = helper.WssString(c, targetConn, string(message))
  396. if err != nil {
  397. errChan <- fmt.Errorf("error writing to target: %v", err)
  398. return
  399. }
  400. select {
  401. case sendChan <- message:
  402. default:
  403. }
  404. }
  405. }
  406. })
  407. gopool.Go(func() {
  408. defer func() {
  409. if r := recover(); r != nil {
  410. errChan <- fmt.Errorf("panic in target reader: %v", r)
  411. }
  412. }()
  413. for {
  414. select {
  415. case <-c.Done():
  416. return
  417. default:
  418. _, message, err := targetConn.ReadMessage()
  419. if err != nil {
  420. if !websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) {
  421. errChan <- fmt.Errorf("error reading from target: %v", err)
  422. }
  423. close(targetClosed)
  424. return
  425. }
  426. info.SetFirstResponseTime()
  427. realtimeEvent := &dto.RealtimeEvent{}
  428. err = json.Unmarshal(message, realtimeEvent)
  429. if err != nil {
  430. errChan <- fmt.Errorf("error unmarshalling message: %v", err)
  431. return
  432. }
  433. if realtimeEvent.Type == dto.RealtimeEventTypeResponseDone {
  434. realtimeUsage := realtimeEvent.Response.Usage
  435. if realtimeUsage != nil {
  436. usage.TotalTokens += realtimeUsage.TotalTokens
  437. usage.InputTokens += realtimeUsage.InputTokens
  438. usage.OutputTokens += realtimeUsage.OutputTokens
  439. usage.InputTokenDetails.AudioTokens += realtimeUsage.InputTokenDetails.AudioTokens
  440. usage.InputTokenDetails.CachedTokens += realtimeUsage.InputTokenDetails.CachedTokens
  441. usage.InputTokenDetails.TextTokens += realtimeUsage.InputTokenDetails.TextTokens
  442. usage.OutputTokenDetails.AudioTokens += realtimeUsage.OutputTokenDetails.AudioTokens
  443. usage.OutputTokenDetails.TextTokens += realtimeUsage.OutputTokenDetails.TextTokens
  444. err := preConsumeUsage(c, info, usage, sumUsage)
  445. if err != nil {
  446. errChan <- fmt.Errorf("error consume usage: %v", err)
  447. return
  448. }
  449. // 本次计费完成,清除
  450. usage = &dto.RealtimeUsage{}
  451. localUsage = &dto.RealtimeUsage{}
  452. } else {
  453. textToken, audioToken, err := service.CountTokenRealtime(info, *realtimeEvent, info.UpstreamModelName)
  454. if err != nil {
  455. errChan <- fmt.Errorf("error counting text token: %v", err)
  456. return
  457. }
  458. common.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken))
  459. localUsage.TotalTokens += textToken + audioToken
  460. info.IsFirstRequest = false
  461. localUsage.InputTokens += textToken + audioToken
  462. localUsage.InputTokenDetails.TextTokens += textToken
  463. localUsage.InputTokenDetails.AudioTokens += audioToken
  464. err = preConsumeUsage(c, info, localUsage, sumUsage)
  465. if err != nil {
  466. errChan <- fmt.Errorf("error consume usage: %v", err)
  467. return
  468. }
  469. // 本次计费完成,清除
  470. localUsage = &dto.RealtimeUsage{}
  471. // print now usage
  472. }
  473. //common.LogInfo(c, fmt.Sprintf("realtime streaming sumUsage: %v", sumUsage))
  474. //common.LogInfo(c, fmt.Sprintf("realtime streaming localUsage: %v", localUsage))
  475. //common.LogInfo(c, fmt.Sprintf("realtime streaming localUsage: %v", localUsage))
  476. } else if realtimeEvent.Type == dto.RealtimeEventTypeSessionUpdated || realtimeEvent.Type == dto.RealtimeEventTypeSessionCreated {
  477. realtimeSession := realtimeEvent.Session
  478. if realtimeSession != nil {
  479. // update audio format
  480. info.InputAudioFormat = common.GetStringIfEmpty(realtimeSession.InputAudioFormat, info.InputAudioFormat)
  481. info.OutputAudioFormat = common.GetStringIfEmpty(realtimeSession.OutputAudioFormat, info.OutputAudioFormat)
  482. }
  483. } else {
  484. textToken, audioToken, err := service.CountTokenRealtime(info, *realtimeEvent, info.UpstreamModelName)
  485. if err != nil {
  486. errChan <- fmt.Errorf("error counting text token: %v", err)
  487. return
  488. }
  489. common.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken))
  490. localUsage.TotalTokens += textToken + audioToken
  491. localUsage.OutputTokens += textToken + audioToken
  492. localUsage.OutputTokenDetails.TextTokens += textToken
  493. localUsage.OutputTokenDetails.AudioTokens += audioToken
  494. }
  495. err = helper.WssString(c, clientConn, string(message))
  496. if err != nil {
  497. errChan <- fmt.Errorf("error writing to client: %v", err)
  498. return
  499. }
  500. select {
  501. case receiveChan <- message:
  502. default:
  503. }
  504. }
  505. }
  506. })
  507. select {
  508. case <-clientClosed:
  509. case <-targetClosed:
  510. case err := <-errChan:
  511. //return service.OpenAIErrorWrapper(err, "realtime_error", http.StatusInternalServerError), nil
  512. common.LogError(c, "realtime error: "+err.Error())
  513. case <-c.Done():
  514. }
  515. if usage.TotalTokens != 0 {
  516. _ = preConsumeUsage(c, info, usage, sumUsage)
  517. }
  518. if localUsage.TotalTokens != 0 {
  519. _ = preConsumeUsage(c, info, localUsage, sumUsage)
  520. }
  521. // check usage total tokens, if 0, use local usage
  522. return nil, sumUsage
  523. }
  524. func preConsumeUsage(ctx *gin.Context, info *relaycommon.RelayInfo, usage *dto.RealtimeUsage, totalUsage *dto.RealtimeUsage) error {
  525. if usage == nil || totalUsage == nil {
  526. return fmt.Errorf("invalid usage pointer")
  527. }
  528. totalUsage.TotalTokens += usage.TotalTokens
  529. totalUsage.InputTokens += usage.InputTokens
  530. totalUsage.OutputTokens += usage.OutputTokens
  531. totalUsage.InputTokenDetails.CachedTokens += usage.InputTokenDetails.CachedTokens
  532. totalUsage.InputTokenDetails.TextTokens += usage.InputTokenDetails.TextTokens
  533. totalUsage.InputTokenDetails.AudioTokens += usage.InputTokenDetails.AudioTokens
  534. totalUsage.OutputTokenDetails.TextTokens += usage.OutputTokenDetails.TextTokens
  535. totalUsage.OutputTokenDetails.AudioTokens += usage.OutputTokenDetails.AudioTokens
  536. // clear usage
  537. err := service.PreWssConsumeQuota(ctx, info, usage)
  538. return err
  539. }
  540. func OpenaiHandlerWithUsage(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
  541. responseBody, err := io.ReadAll(resp.Body)
  542. if err != nil {
  543. return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
  544. }
  545. err = resp.Body.Close()
  546. if err != nil {
  547. return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
  548. }
  549. // Reset response body
  550. resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
  551. // We shouldn't set the header before we parse the response body, because the parse part may fail.
  552. // And then we will have to send an error response, but in this case, the header has already been set.
  553. // So the httpClient will be confused by the response.
  554. // For example, Postman will report error, and we cannot check the response at all.
  555. for k, v := range resp.Header {
  556. c.Writer.Header().Set(k, v[0])
  557. }
  558. // reset content length
  559. c.Writer.Header().Set("Content-Length", fmt.Sprintf("%d", len(responseBody)))
  560. c.Writer.WriteHeader(resp.StatusCode)
  561. _, err = io.Copy(c.Writer, resp.Body)
  562. if err != nil {
  563. return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
  564. }
  565. err = resp.Body.Close()
  566. if err != nil {
  567. return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
  568. }
  569. var usageResp dto.SimpleResponse
  570. err = json.Unmarshal(responseBody, &usageResp)
  571. if err != nil {
  572. return service.OpenAIErrorWrapper(err, "parse_response_body_failed", http.StatusInternalServerError), nil
  573. }
  574. // format
  575. if usageResp.InputTokens > 0 {
  576. usageResp.PromptTokens += usageResp.InputTokens
  577. }
  578. if usageResp.OutputTokens > 0 {
  579. usageResp.CompletionTokens += usageResp.OutputTokens
  580. }
  581. if usageResp.InputTokensDetails != nil {
  582. usageResp.PromptTokensDetails.ImageTokens += usageResp.InputTokensDetails.ImageTokens
  583. usageResp.PromptTokensDetails.TextTokens += usageResp.InputTokensDetails.TextTokens
  584. }
  585. return nil, &usageResp.Usage
  586. }