relay-openai.go 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601
  1. package openai
  2. import (
  3. "bytes"
  4. "fmt"
  5. "io"
  6. "math"
  7. "mime/multipart"
  8. "net/http"
  9. "one-api/common"
  10. "one-api/constant"
  11. "one-api/dto"
  12. "one-api/logger"
  13. relaycommon "one-api/relay/common"
  14. "one-api/relay/helper"
  15. "one-api/service"
  16. "os"
  17. "path/filepath"
  18. "strings"
  19. "one-api/types"
  20. "github.com/bytedance/gopkg/util/gopool"
  21. "github.com/gin-gonic/gin"
  22. "github.com/gorilla/websocket"
  23. "github.com/pkg/errors"
  24. )
  25. func sendStreamData(c *gin.Context, info *relaycommon.RelayInfo, data string, forceFormat bool, thinkToContent bool) error {
  26. if data == "" {
  27. return nil
  28. }
  29. if !forceFormat && !thinkToContent {
  30. return helper.StringData(c, data)
  31. }
  32. var lastStreamResponse dto.ChatCompletionsStreamResponse
  33. if err := common.UnmarshalJsonStr(data, &lastStreamResponse); err != nil {
  34. return err
  35. }
  36. if !thinkToContent {
  37. return helper.ObjectData(c, lastStreamResponse)
  38. }
  39. hasThinkingContent := false
  40. hasContent := false
  41. var thinkingContent strings.Builder
  42. for _, choice := range lastStreamResponse.Choices {
  43. if len(choice.Delta.GetReasoningContent()) > 0 {
  44. hasThinkingContent = true
  45. thinkingContent.WriteString(choice.Delta.GetReasoningContent())
  46. }
  47. if len(choice.Delta.GetContentString()) > 0 {
  48. hasContent = true
  49. }
  50. }
  51. // Handle think to content conversion
  52. if info.ThinkingContentInfo.IsFirstThinkingContent {
  53. if hasThinkingContent {
  54. response := lastStreamResponse.Copy()
  55. for i := range response.Choices {
  56. // send `think` tag with thinking content
  57. response.Choices[i].Delta.SetContentString("<think>\n" + thinkingContent.String())
  58. response.Choices[i].Delta.ReasoningContent = nil
  59. response.Choices[i].Delta.Reasoning = nil
  60. }
  61. info.ThinkingContentInfo.IsFirstThinkingContent = false
  62. info.ThinkingContentInfo.HasSentThinkingContent = true
  63. return helper.ObjectData(c, response)
  64. }
  65. }
  66. if lastStreamResponse.Choices == nil || len(lastStreamResponse.Choices) == 0 {
  67. return helper.ObjectData(c, lastStreamResponse)
  68. }
  69. // Process each choice
  70. for i, choice := range lastStreamResponse.Choices {
  71. // Handle transition from thinking to content
  72. // only send `</think>` tag when previous thinking content has been sent
  73. if hasContent && !info.ThinkingContentInfo.SendLastThinkingContent && info.ThinkingContentInfo.HasSentThinkingContent {
  74. response := lastStreamResponse.Copy()
  75. for j := range response.Choices {
  76. response.Choices[j].Delta.SetContentString("\n</think>\n")
  77. response.Choices[j].Delta.ReasoningContent = nil
  78. response.Choices[j].Delta.Reasoning = nil
  79. }
  80. info.ThinkingContentInfo.SendLastThinkingContent = true
  81. helper.ObjectData(c, response)
  82. }
  83. // Convert reasoning content to regular content if any
  84. if len(choice.Delta.GetReasoningContent()) > 0 {
  85. lastStreamResponse.Choices[i].Delta.SetContentString(choice.Delta.GetReasoningContent())
  86. lastStreamResponse.Choices[i].Delta.ReasoningContent = nil
  87. lastStreamResponse.Choices[i].Delta.Reasoning = nil
  88. } else if !hasThinkingContent && !hasContent {
  89. // flush thinking content
  90. lastStreamResponse.Choices[i].Delta.ReasoningContent = nil
  91. lastStreamResponse.Choices[i].Delta.Reasoning = nil
  92. }
  93. }
  94. return helper.ObjectData(c, lastStreamResponse)
  95. }
  96. func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
  97. if resp == nil || resp.Body == nil {
  98. logger.LogError(c, "invalid response or response body")
  99. return nil, types.NewOpenAIError(fmt.Errorf("invalid response"), types.ErrorCodeBadResponse, http.StatusInternalServerError)
  100. }
  101. defer service.CloseResponseBodyGracefully(resp)
  102. model := info.UpstreamModelName
  103. var responseId string
  104. var createAt int64 = 0
  105. var systemFingerprint string
  106. var containStreamUsage bool
  107. var responseTextBuilder strings.Builder
  108. var toolCount int
  109. var usage = &dto.Usage{}
  110. var streamItems []string // store stream items
  111. var lastStreamData string
  112. helper.StreamScannerHandler(c, resp, info, func(data string) bool {
  113. if lastStreamData != "" {
  114. err := HandleStreamFormat(c, info, lastStreamData, info.ChannelSetting.ForceFormat, info.ChannelSetting.ThinkingToContent)
  115. if err != nil {
  116. common.SysLog("error handling stream format: " + err.Error())
  117. }
  118. }
  119. if len(data) > 0 {
  120. lastStreamData = data
  121. streamItems = append(streamItems, data)
  122. }
  123. return true
  124. })
  125. // 处理最后的响应
  126. shouldSendLastResp := true
  127. if err := handleLastResponse(lastStreamData, &responseId, &createAt, &systemFingerprint, &model, &usage,
  128. &containStreamUsage, info, &shouldSendLastResp); err != nil {
  129. logger.LogError(c, fmt.Sprintf("error handling last response: %s, lastStreamData: [%s]", err.Error(), lastStreamData))
  130. }
  131. if info.RelayFormat == types.RelayFormatOpenAI {
  132. if shouldSendLastResp {
  133. _ = sendStreamData(c, info, lastStreamData, info.ChannelSetting.ForceFormat, info.ChannelSetting.ThinkingToContent)
  134. }
  135. }
  136. // 处理token计算
  137. if err := processTokens(info.RelayMode, streamItems, &responseTextBuilder, &toolCount); err != nil {
  138. logger.LogError(c, "error processing tokens: "+err.Error())
  139. }
  140. if !containStreamUsage {
  141. usage = service.ResponseText2Usage(responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens)
  142. usage.CompletionTokens += toolCount * 7
  143. } else {
  144. if info.ChannelType == constant.ChannelTypeDeepSeek {
  145. if usage.PromptCacheHitTokens != 0 {
  146. usage.PromptTokensDetails.CachedTokens = usage.PromptCacheHitTokens
  147. }
  148. }
  149. }
  150. HandleFinalResponse(c, info, lastStreamData, responseId, createAt, model, systemFingerprint, usage, containStreamUsage)
  151. return usage, nil
  152. }
  153. func OpenaiHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
  154. defer service.CloseResponseBodyGracefully(resp)
  155. var simpleResponse dto.OpenAITextResponse
  156. responseBody, err := io.ReadAll(resp.Body)
  157. if err != nil {
  158. return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
  159. }
  160. if common.DebugEnabled {
  161. println("upstream response body:", string(responseBody))
  162. }
  163. err = common.Unmarshal(responseBody, &simpleResponse)
  164. if err != nil {
  165. return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
  166. }
  167. if oaiError := simpleResponse.GetOpenAIError(); oaiError != nil && oaiError.Type != "" {
  168. return nil, types.WithOpenAIError(*oaiError, resp.StatusCode)
  169. }
  170. forceFormat := false
  171. if info.ChannelSetting.ForceFormat {
  172. forceFormat = true
  173. }
  174. usageModified := false
  175. if simpleResponse.Usage.PromptTokens == 0 {
  176. completionTokens := simpleResponse.Usage.CompletionTokens
  177. if completionTokens == 0 {
  178. for _, choice := range simpleResponse.Choices {
  179. ctkm := service.CountTextToken(choice.Message.StringContent()+choice.Message.ReasoningContent+choice.Message.Reasoning, info.UpstreamModelName)
  180. completionTokens += ctkm
  181. }
  182. }
  183. simpleResponse.Usage = dto.Usage{
  184. PromptTokens: info.PromptTokens,
  185. CompletionTokens: completionTokens,
  186. TotalTokens: info.PromptTokens + completionTokens,
  187. }
  188. usageModified = true
  189. }
  190. switch info.RelayFormat {
  191. case types.RelayFormatOpenAI:
  192. if usageModified {
  193. var bodyMap map[string]interface{}
  194. err = common.Unmarshal(responseBody, &bodyMap)
  195. if err != nil {
  196. return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
  197. }
  198. bodyMap["usage"] = simpleResponse.Usage
  199. responseBody, _ = common.Marshal(bodyMap)
  200. }
  201. if forceFormat {
  202. responseBody, err = common.Marshal(simpleResponse)
  203. if err != nil {
  204. return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
  205. }
  206. } else {
  207. break
  208. }
  209. case types.RelayFormatClaude:
  210. claudeResp := service.ResponseOpenAI2Claude(&simpleResponse, info)
  211. claudeRespStr, err := common.Marshal(claudeResp)
  212. if err != nil {
  213. return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
  214. }
  215. responseBody = claudeRespStr
  216. case types.RelayFormatGemini:
  217. geminiResp := service.ResponseOpenAI2Gemini(&simpleResponse, info)
  218. geminiRespStr, err := common.Marshal(geminiResp)
  219. if err != nil {
  220. return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
  221. }
  222. responseBody = geminiRespStr
  223. }
  224. service.IOCopyBytesGracefully(c, resp, responseBody)
  225. return &simpleResponse.Usage, nil
  226. }
  227. func OpenaiTTSHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) *dto.Usage {
  228. // the status code has been judged before, if there is a body reading failure,
  229. // it should be regarded as a non-recoverable error, so it should not return err for external retry.
  230. // Analogous to nginx's load balancing, it will only retry if it can't be requested or
  231. // if the upstream returns a specific status code, once the upstream has already written the header,
  232. // the subsequent failure of the response body should be regarded as a non-recoverable error,
  233. // and can be terminated directly.
  234. defer service.CloseResponseBodyGracefully(resp)
  235. usage := &dto.Usage{}
  236. usage.PromptTokens = info.PromptTokens
  237. usage.TotalTokens = info.PromptTokens
  238. for k, v := range resp.Header {
  239. c.Writer.Header().Set(k, v[0])
  240. }
  241. c.Writer.WriteHeader(resp.StatusCode)
  242. c.Writer.WriteHeaderNow()
  243. _, err := io.Copy(c.Writer, resp.Body)
  244. if err != nil {
  245. logger.LogError(c, err.Error())
  246. }
  247. return usage
  248. }
  249. func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, responseFormat string) (*types.NewAPIError, *dto.Usage) {
  250. defer service.CloseResponseBodyGracefully(resp)
  251. // count tokens by audio file duration
  252. audioTokens, err := countAudioTokens(c)
  253. if err != nil {
  254. return types.NewError(err, types.ErrorCodeCountTokenFailed), nil
  255. }
  256. responseBody, err := io.ReadAll(resp.Body)
  257. if err != nil {
  258. return types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), nil
  259. }
  260. // 写入新的 response body
  261. service.IOCopyBytesGracefully(c, resp, responseBody)
  262. usage := &dto.Usage{}
  263. usage.PromptTokens = audioTokens
  264. usage.CompletionTokens = 0
  265. usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
  266. return nil, usage
  267. }
  268. func countAudioTokens(c *gin.Context) (int, error) {
  269. body, err := common.GetRequestBody(c)
  270. if err != nil {
  271. return 0, errors.WithStack(err)
  272. }
  273. var reqBody struct {
  274. File *multipart.FileHeader `form:"file" binding:"required"`
  275. }
  276. c.Request.Body = io.NopCloser(bytes.NewReader(body))
  277. if err = c.ShouldBind(&reqBody); err != nil {
  278. return 0, errors.WithStack(err)
  279. }
  280. ext := filepath.Ext(reqBody.File.Filename) // 获取文件扩展名
  281. reqFp, err := reqBody.File.Open()
  282. if err != nil {
  283. return 0, errors.WithStack(err)
  284. }
  285. defer reqFp.Close()
  286. tmpFp, err := os.CreateTemp("", "audio-*"+ext)
  287. if err != nil {
  288. return 0, errors.WithStack(err)
  289. }
  290. defer os.Remove(tmpFp.Name())
  291. _, err = io.Copy(tmpFp, reqFp)
  292. if err != nil {
  293. return 0, errors.WithStack(err)
  294. }
  295. if err = tmpFp.Close(); err != nil {
  296. return 0, errors.WithStack(err)
  297. }
  298. duration, err := common.GetAudioDuration(c.Request.Context(), tmpFp.Name(), ext)
  299. if err != nil {
  300. return 0, errors.WithStack(err)
  301. }
  302. return int(math.Round(math.Ceil(duration) / 60.0 * 1000)), nil // 1 minute 相当于 1k tokens
  303. }
  304. func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*types.NewAPIError, *dto.RealtimeUsage) {
  305. if info == nil || info.ClientWs == nil || info.TargetWs == nil {
  306. return types.NewError(fmt.Errorf("invalid websocket connection"), types.ErrorCodeBadResponse), nil
  307. }
  308. info.IsStream = true
  309. clientConn := info.ClientWs
  310. targetConn := info.TargetWs
  311. clientClosed := make(chan struct{})
  312. targetClosed := make(chan struct{})
  313. sendChan := make(chan []byte, 100)
  314. receiveChan := make(chan []byte, 100)
  315. errChan := make(chan error, 2)
  316. usage := &dto.RealtimeUsage{}
  317. localUsage := &dto.RealtimeUsage{}
  318. sumUsage := &dto.RealtimeUsage{}
  319. gopool.Go(func() {
  320. defer func() {
  321. if r := recover(); r != nil {
  322. errChan <- fmt.Errorf("panic in client reader: %v", r)
  323. }
  324. }()
  325. for {
  326. select {
  327. case <-c.Done():
  328. return
  329. default:
  330. _, message, err := clientConn.ReadMessage()
  331. if err != nil {
  332. if !websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) {
  333. errChan <- fmt.Errorf("error reading from client: %v", err)
  334. }
  335. close(clientClosed)
  336. return
  337. }
  338. realtimeEvent := &dto.RealtimeEvent{}
  339. err = common.Unmarshal(message, realtimeEvent)
  340. if err != nil {
  341. errChan <- fmt.Errorf("error unmarshalling message: %v", err)
  342. return
  343. }
  344. if realtimeEvent.Type == dto.RealtimeEventTypeSessionUpdate {
  345. if realtimeEvent.Session != nil {
  346. if realtimeEvent.Session.Tools != nil {
  347. info.RealtimeTools = realtimeEvent.Session.Tools
  348. }
  349. }
  350. }
  351. textToken, audioToken, err := service.CountTokenRealtime(info, *realtimeEvent, info.UpstreamModelName)
  352. if err != nil {
  353. errChan <- fmt.Errorf("error counting text token: %v", err)
  354. return
  355. }
  356. logger.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken))
  357. localUsage.TotalTokens += textToken + audioToken
  358. localUsage.InputTokens += textToken + audioToken
  359. localUsage.InputTokenDetails.TextTokens += textToken
  360. localUsage.InputTokenDetails.AudioTokens += audioToken
  361. err = helper.WssString(c, targetConn, string(message))
  362. if err != nil {
  363. errChan <- fmt.Errorf("error writing to target: %v", err)
  364. return
  365. }
  366. select {
  367. case sendChan <- message:
  368. default:
  369. }
  370. }
  371. }
  372. })
  373. gopool.Go(func() {
  374. defer func() {
  375. if r := recover(); r != nil {
  376. errChan <- fmt.Errorf("panic in target reader: %v", r)
  377. }
  378. }()
  379. for {
  380. select {
  381. case <-c.Done():
  382. return
  383. default:
  384. _, message, err := targetConn.ReadMessage()
  385. if err != nil {
  386. if !websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) {
  387. errChan <- fmt.Errorf("error reading from target: %v", err)
  388. }
  389. close(targetClosed)
  390. return
  391. }
  392. info.SetFirstResponseTime()
  393. realtimeEvent := &dto.RealtimeEvent{}
  394. err = common.Unmarshal(message, realtimeEvent)
  395. if err != nil {
  396. errChan <- fmt.Errorf("error unmarshalling message: %v", err)
  397. return
  398. }
  399. if realtimeEvent.Type == dto.RealtimeEventTypeResponseDone {
  400. realtimeUsage := realtimeEvent.Response.Usage
  401. if realtimeUsage != nil {
  402. usage.TotalTokens += realtimeUsage.TotalTokens
  403. usage.InputTokens += realtimeUsage.InputTokens
  404. usage.OutputTokens += realtimeUsage.OutputTokens
  405. usage.InputTokenDetails.AudioTokens += realtimeUsage.InputTokenDetails.AudioTokens
  406. usage.InputTokenDetails.CachedTokens += realtimeUsage.InputTokenDetails.CachedTokens
  407. usage.InputTokenDetails.TextTokens += realtimeUsage.InputTokenDetails.TextTokens
  408. usage.OutputTokenDetails.AudioTokens += realtimeUsage.OutputTokenDetails.AudioTokens
  409. usage.OutputTokenDetails.TextTokens += realtimeUsage.OutputTokenDetails.TextTokens
  410. err := preConsumeUsage(c, info, usage, sumUsage)
  411. if err != nil {
  412. errChan <- fmt.Errorf("error consume usage: %v", err)
  413. return
  414. }
  415. // 本次计费完成,清除
  416. usage = &dto.RealtimeUsage{}
  417. localUsage = &dto.RealtimeUsage{}
  418. } else {
  419. textToken, audioToken, err := service.CountTokenRealtime(info, *realtimeEvent, info.UpstreamModelName)
  420. if err != nil {
  421. errChan <- fmt.Errorf("error counting text token: %v", err)
  422. return
  423. }
  424. logger.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken))
  425. localUsage.TotalTokens += textToken + audioToken
  426. info.IsFirstRequest = false
  427. localUsage.InputTokens += textToken + audioToken
  428. localUsage.InputTokenDetails.TextTokens += textToken
  429. localUsage.InputTokenDetails.AudioTokens += audioToken
  430. err = preConsumeUsage(c, info, localUsage, sumUsage)
  431. if err != nil {
  432. errChan <- fmt.Errorf("error consume usage: %v", err)
  433. return
  434. }
  435. // 本次计费完成,清除
  436. localUsage = &dto.RealtimeUsage{}
  437. // print now usage
  438. }
  439. logger.LogInfo(c, fmt.Sprintf("realtime streaming sumUsage: %v", sumUsage))
  440. logger.LogInfo(c, fmt.Sprintf("realtime streaming localUsage: %v", localUsage))
  441. logger.LogInfo(c, fmt.Sprintf("realtime streaming localUsage: %v", localUsage))
  442. } else if realtimeEvent.Type == dto.RealtimeEventTypeSessionUpdated || realtimeEvent.Type == dto.RealtimeEventTypeSessionCreated {
  443. realtimeSession := realtimeEvent.Session
  444. if realtimeSession != nil {
  445. // update audio format
  446. info.InputAudioFormat = common.GetStringIfEmpty(realtimeSession.InputAudioFormat, info.InputAudioFormat)
  447. info.OutputAudioFormat = common.GetStringIfEmpty(realtimeSession.OutputAudioFormat, info.OutputAudioFormat)
  448. }
  449. } else {
  450. textToken, audioToken, err := service.CountTokenRealtime(info, *realtimeEvent, info.UpstreamModelName)
  451. if err != nil {
  452. errChan <- fmt.Errorf("error counting text token: %v", err)
  453. return
  454. }
  455. logger.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken))
  456. localUsage.TotalTokens += textToken + audioToken
  457. localUsage.OutputTokens += textToken + audioToken
  458. localUsage.OutputTokenDetails.TextTokens += textToken
  459. localUsage.OutputTokenDetails.AudioTokens += audioToken
  460. }
  461. err = helper.WssString(c, clientConn, string(message))
  462. if err != nil {
  463. errChan <- fmt.Errorf("error writing to client: %v", err)
  464. return
  465. }
  466. select {
  467. case receiveChan <- message:
  468. default:
  469. }
  470. }
  471. }
  472. })
  473. select {
  474. case <-clientClosed:
  475. case <-targetClosed:
  476. case err := <-errChan:
  477. //return service.OpenAIErrorWrapper(err, "realtime_error", http.StatusInternalServerError), nil
  478. logger.LogError(c, "realtime error: "+err.Error())
  479. case <-c.Done():
  480. }
  481. if usage.TotalTokens != 0 {
  482. _ = preConsumeUsage(c, info, usage, sumUsage)
  483. }
  484. if localUsage.TotalTokens != 0 {
  485. _ = preConsumeUsage(c, info, localUsage, sumUsage)
  486. }
  487. // check usage total tokens, if 0, use local usage
  488. return nil, sumUsage
  489. }
  490. func preConsumeUsage(ctx *gin.Context, info *relaycommon.RelayInfo, usage *dto.RealtimeUsage, totalUsage *dto.RealtimeUsage) error {
  491. if usage == nil || totalUsage == nil {
  492. return fmt.Errorf("invalid usage pointer")
  493. }
  494. totalUsage.TotalTokens += usage.TotalTokens
  495. totalUsage.InputTokens += usage.InputTokens
  496. totalUsage.OutputTokens += usage.OutputTokens
  497. totalUsage.InputTokenDetails.CachedTokens += usage.InputTokenDetails.CachedTokens
  498. totalUsage.InputTokenDetails.TextTokens += usage.InputTokenDetails.TextTokens
  499. totalUsage.InputTokenDetails.AudioTokens += usage.InputTokenDetails.AudioTokens
  500. totalUsage.OutputTokenDetails.TextTokens += usage.OutputTokenDetails.TextTokens
  501. totalUsage.OutputTokenDetails.AudioTokens += usage.OutputTokenDetails.AudioTokens
  502. // clear usage
  503. err := service.PreWssConsumeQuota(ctx, info, usage)
  504. return err
  505. }
  506. func OpenaiHandlerWithUsage(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
  507. defer service.CloseResponseBodyGracefully(resp)
  508. responseBody, err := io.ReadAll(resp.Body)
  509. if err != nil {
  510. return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
  511. }
  512. var usageResp dto.SimpleResponse
  513. err = common.Unmarshal(responseBody, &usageResp)
  514. if err != nil {
  515. return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
  516. }
  517. // 写入新的 response body
  518. service.IOCopyBytesGracefully(c, resp, responseBody)
  519. // Once we've written to the client, we should not return errors anymore
  520. // because the upstream has already consumed resources and returned content
  521. // We should still perform billing even if parsing fails
  522. // format
  523. if usageResp.InputTokens > 0 {
  524. usageResp.PromptTokens += usageResp.InputTokens
  525. }
  526. if usageResp.OutputTokens > 0 {
  527. usageResp.CompletionTokens += usageResp.OutputTokens
  528. }
  529. if usageResp.InputTokensDetails != nil {
  530. usageResp.PromptTokensDetails.ImageTokens += usageResp.InputTokensDetails.ImageTokens
  531. usageResp.PromptTokensDetails.TextTokens += usageResp.InputTokensDetails.TextTokens
  532. }
  533. return &usageResp.Usage, nil
  534. }