relay-openai.go 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584
  1. package openai
  2. import (
  3. "bytes"
  4. "fmt"
  5. "io"
  6. "math"
  7. "mime/multipart"
  8. "net/http"
  9. "one-api/common"
  10. "one-api/constant"
  11. "one-api/dto"
  12. relaycommon "one-api/relay/common"
  13. "one-api/relay/helper"
  14. "one-api/service"
  15. "os"
  16. "path/filepath"
  17. "strings"
  18. "one-api/types"
  19. "github.com/bytedance/gopkg/util/gopool"
  20. "github.com/gin-gonic/gin"
  21. "github.com/gorilla/websocket"
  22. "github.com/pkg/errors"
  23. )
  24. func sendStreamData(c *gin.Context, info *relaycommon.RelayInfo, data string, forceFormat bool, thinkToContent bool) error {
  25. if data == "" {
  26. return nil
  27. }
  28. if !forceFormat && !thinkToContent {
  29. return helper.StringData(c, data)
  30. }
  31. var lastStreamResponse dto.ChatCompletionsStreamResponse
  32. if err := common.UnmarshalJsonStr(data, &lastStreamResponse); err != nil {
  33. return err
  34. }
  35. if !thinkToContent {
  36. return helper.ObjectData(c, lastStreamResponse)
  37. }
  38. hasThinkingContent := false
  39. hasContent := false
  40. var thinkingContent strings.Builder
  41. for _, choice := range lastStreamResponse.Choices {
  42. if len(choice.Delta.GetReasoningContent()) > 0 {
  43. hasThinkingContent = true
  44. thinkingContent.WriteString(choice.Delta.GetReasoningContent())
  45. }
  46. if len(choice.Delta.GetContentString()) > 0 {
  47. hasContent = true
  48. }
  49. }
  50. // Handle think to content conversion
  51. if info.ThinkingContentInfo.IsFirstThinkingContent {
  52. if hasThinkingContent {
  53. response := lastStreamResponse.Copy()
  54. for i := range response.Choices {
  55. // send `think` tag with thinking content
  56. response.Choices[i].Delta.SetContentString("<think>\n" + thinkingContent.String())
  57. response.Choices[i].Delta.ReasoningContent = nil
  58. response.Choices[i].Delta.Reasoning = nil
  59. }
  60. info.ThinkingContentInfo.IsFirstThinkingContent = false
  61. info.ThinkingContentInfo.HasSentThinkingContent = true
  62. return helper.ObjectData(c, response)
  63. }
  64. }
  65. if lastStreamResponse.Choices == nil || len(lastStreamResponse.Choices) == 0 {
  66. return helper.ObjectData(c, lastStreamResponse)
  67. }
  68. // Process each choice
  69. for i, choice := range lastStreamResponse.Choices {
  70. // Handle transition from thinking to content
  71. // only send `</think>` tag when previous thinking content has been sent
  72. if hasContent && !info.ThinkingContentInfo.SendLastThinkingContent && info.ThinkingContentInfo.HasSentThinkingContent {
  73. response := lastStreamResponse.Copy()
  74. for j := range response.Choices {
  75. response.Choices[j].Delta.SetContentString("\n</think>\n")
  76. response.Choices[j].Delta.ReasoningContent = nil
  77. response.Choices[j].Delta.Reasoning = nil
  78. }
  79. info.ThinkingContentInfo.SendLastThinkingContent = true
  80. helper.ObjectData(c, response)
  81. }
  82. // Convert reasoning content to regular content if any
  83. if len(choice.Delta.GetReasoningContent()) > 0 {
  84. lastStreamResponse.Choices[i].Delta.SetContentString(choice.Delta.GetReasoningContent())
  85. lastStreamResponse.Choices[i].Delta.ReasoningContent = nil
  86. lastStreamResponse.Choices[i].Delta.Reasoning = nil
  87. } else if !hasThinkingContent && !hasContent {
  88. // flush thinking content
  89. lastStreamResponse.Choices[i].Delta.ReasoningContent = nil
  90. lastStreamResponse.Choices[i].Delta.Reasoning = nil
  91. }
  92. }
  93. return helper.ObjectData(c, lastStreamResponse)
  94. }
  95. func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
  96. if resp == nil || resp.Body == nil {
  97. common.LogError(c, "invalid response or response body")
  98. return nil, types.NewOpenAIError(fmt.Errorf("invalid response"), types.ErrorCodeBadResponse, http.StatusInternalServerError)
  99. }
  100. defer common.CloseResponseBodyGracefully(resp)
  101. model := info.UpstreamModelName
  102. var responseId string
  103. var createAt int64 = 0
  104. var systemFingerprint string
  105. var containStreamUsage bool
  106. var responseTextBuilder strings.Builder
  107. var toolCount int
  108. var usage = &dto.Usage{}
  109. var streamItems []string // store stream items
  110. var lastStreamData string
  111. helper.StreamScannerHandler(c, resp, info, func(data string) bool {
  112. if lastStreamData != "" {
  113. err := HandleStreamFormat(c, info, lastStreamData, info.ChannelSetting.ForceFormat, info.ChannelSetting.ThinkingToContent)
  114. if err != nil {
  115. common.SysError("error handling stream format: " + err.Error())
  116. }
  117. }
  118. if len(data) > 0 {
  119. lastStreamData = data
  120. streamItems = append(streamItems, data)
  121. }
  122. return true
  123. })
  124. // 处理最后的响应
  125. shouldSendLastResp := true
  126. if err := handleLastResponse(lastStreamData, &responseId, &createAt, &systemFingerprint, &model, &usage,
  127. &containStreamUsage, info, &shouldSendLastResp); err != nil {
  128. common.LogError(c, fmt.Sprintf("error handling last response: %s, lastStreamData: [%s]", err.Error(), lastStreamData))
  129. }
  130. if info.RelayFormat == relaycommon.RelayFormatOpenAI {
  131. if shouldSendLastResp {
  132. _ = sendStreamData(c, info, lastStreamData, info.ChannelSetting.ForceFormat, info.ChannelSetting.ThinkingToContent)
  133. }
  134. }
  135. // 处理token计算
  136. if err := processTokens(info.RelayMode, streamItems, &responseTextBuilder, &toolCount); err != nil {
  137. common.LogError(c, "error processing tokens: "+err.Error())
  138. }
  139. if !containStreamUsage {
  140. usage = service.ResponseText2Usage(responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens)
  141. usage.CompletionTokens += toolCount * 7
  142. } else {
  143. if info.ChannelType == constant.ChannelTypeDeepSeek {
  144. if usage.PromptCacheHitTokens != 0 {
  145. usage.PromptTokensDetails.CachedTokens = usage.PromptCacheHitTokens
  146. }
  147. }
  148. }
  149. HandleFinalResponse(c, info, lastStreamData, responseId, createAt, model, systemFingerprint, usage, containStreamUsage)
  150. return usage, nil
  151. }
  152. func OpenaiHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
  153. defer common.CloseResponseBodyGracefully(resp)
  154. var simpleResponse dto.OpenAITextResponse
  155. responseBody, err := io.ReadAll(resp.Body)
  156. if err != nil {
  157. return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
  158. }
  159. err = common.Unmarshal(responseBody, &simpleResponse)
  160. if err != nil {
  161. return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
  162. }
  163. if oaiError := simpleResponse.GetOpenAIError(); oaiError != nil && oaiError.Type != "" {
  164. return nil, types.WithOpenAIError(*oaiError, resp.StatusCode)
  165. }
  166. forceFormat := false
  167. if info.ChannelSetting.ForceFormat {
  168. forceFormat = true
  169. }
  170. if simpleResponse.Usage.TotalTokens == 0 || (simpleResponse.Usage.PromptTokens == 0 && simpleResponse.Usage.CompletionTokens == 0) {
  171. completionTokens := 0
  172. for _, choice := range simpleResponse.Choices {
  173. ctkm := service.CountTextToken(choice.Message.StringContent()+choice.Message.ReasoningContent+choice.Message.Reasoning, info.UpstreamModelName)
  174. completionTokens += ctkm
  175. }
  176. simpleResponse.Usage = dto.Usage{
  177. PromptTokens: info.PromptTokens,
  178. CompletionTokens: completionTokens,
  179. TotalTokens: info.PromptTokens + completionTokens,
  180. }
  181. }
  182. switch info.RelayFormat {
  183. case relaycommon.RelayFormatOpenAI:
  184. if forceFormat {
  185. responseBody, err = common.Marshal(simpleResponse)
  186. if err != nil {
  187. return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
  188. }
  189. } else {
  190. break
  191. }
  192. case relaycommon.RelayFormatClaude:
  193. claudeResp := service.ResponseOpenAI2Claude(&simpleResponse, info)
  194. claudeRespStr, err := common.Marshal(claudeResp)
  195. if err != nil {
  196. return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
  197. }
  198. responseBody = claudeRespStr
  199. case relaycommon.RelayFormatGemini:
  200. geminiResp := service.ResponseOpenAI2Gemini(&simpleResponse, info)
  201. geminiRespStr, err := common.Marshal(geminiResp)
  202. if err != nil {
  203. return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
  204. }
  205. responseBody = geminiRespStr
  206. }
  207. common.IOCopyBytesGracefully(c, resp, responseBody)
  208. return &simpleResponse.Usage, nil
  209. }
  210. func OpenaiTTSHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) *dto.Usage {
  211. // the status code has been judged before, if there is a body reading failure,
  212. // it should be regarded as a non-recoverable error, so it should not return err for external retry.
  213. // Analogous to nginx's load balancing, it will only retry if it can't be requested or
  214. // if the upstream returns a specific status code, once the upstream has already written the header,
  215. // the subsequent failure of the response body should be regarded as a non-recoverable error,
  216. // and can be terminated directly.
  217. defer common.CloseResponseBodyGracefully(resp)
  218. usage := &dto.Usage{}
  219. usage.PromptTokens = info.PromptTokens
  220. usage.TotalTokens = info.PromptTokens
  221. for k, v := range resp.Header {
  222. c.Writer.Header().Set(k, v[0])
  223. }
  224. c.Writer.WriteHeader(resp.StatusCode)
  225. c.Writer.WriteHeaderNow()
  226. _, err := io.Copy(c.Writer, resp.Body)
  227. if err != nil {
  228. common.LogError(c, err.Error())
  229. }
  230. return usage
  231. }
  232. func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, responseFormat string) (*types.NewAPIError, *dto.Usage) {
  233. defer common.CloseResponseBodyGracefully(resp)
  234. // count tokens by audio file duration
  235. audioTokens, err := countAudioTokens(c)
  236. if err != nil {
  237. return types.NewError(err, types.ErrorCodeCountTokenFailed), nil
  238. }
  239. responseBody, err := io.ReadAll(resp.Body)
  240. if err != nil {
  241. return types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), nil
  242. }
  243. // 写入新的 response body
  244. common.IOCopyBytesGracefully(c, resp, responseBody)
  245. usage := &dto.Usage{}
  246. usage.PromptTokens = audioTokens
  247. usage.CompletionTokens = 0
  248. usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
  249. return nil, usage
  250. }
  251. func countAudioTokens(c *gin.Context) (int, error) {
  252. body, err := common.GetRequestBody(c)
  253. if err != nil {
  254. return 0, errors.WithStack(err)
  255. }
  256. var reqBody struct {
  257. File *multipart.FileHeader `form:"file" binding:"required"`
  258. }
  259. c.Request.Body = io.NopCloser(bytes.NewReader(body))
  260. if err = c.ShouldBind(&reqBody); err != nil {
  261. return 0, errors.WithStack(err)
  262. }
  263. ext := filepath.Ext(reqBody.File.Filename) // 获取文件扩展名
  264. reqFp, err := reqBody.File.Open()
  265. if err != nil {
  266. return 0, errors.WithStack(err)
  267. }
  268. defer reqFp.Close()
  269. tmpFp, err := os.CreateTemp("", "audio-*"+ext)
  270. if err != nil {
  271. return 0, errors.WithStack(err)
  272. }
  273. defer os.Remove(tmpFp.Name())
  274. _, err = io.Copy(tmpFp, reqFp)
  275. if err != nil {
  276. return 0, errors.WithStack(err)
  277. }
  278. if err = tmpFp.Close(); err != nil {
  279. return 0, errors.WithStack(err)
  280. }
  281. duration, err := common.GetAudioDuration(c.Request.Context(), tmpFp.Name(), ext)
  282. if err != nil {
  283. return 0, errors.WithStack(err)
  284. }
  285. return int(math.Round(math.Ceil(duration) / 60.0 * 1000)), nil // 1 minute 相当于 1k tokens
  286. }
  287. func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*types.NewAPIError, *dto.RealtimeUsage) {
  288. if info == nil || info.ClientWs == nil || info.TargetWs == nil {
  289. return types.NewError(fmt.Errorf("invalid websocket connection"), types.ErrorCodeBadResponse), nil
  290. }
  291. info.IsStream = true
  292. clientConn := info.ClientWs
  293. targetConn := info.TargetWs
  294. clientClosed := make(chan struct{})
  295. targetClosed := make(chan struct{})
  296. sendChan := make(chan []byte, 100)
  297. receiveChan := make(chan []byte, 100)
  298. errChan := make(chan error, 2)
  299. usage := &dto.RealtimeUsage{}
  300. localUsage := &dto.RealtimeUsage{}
  301. sumUsage := &dto.RealtimeUsage{}
  302. gopool.Go(func() {
  303. defer func() {
  304. if r := recover(); r != nil {
  305. errChan <- fmt.Errorf("panic in client reader: %v", r)
  306. }
  307. }()
  308. for {
  309. select {
  310. case <-c.Done():
  311. return
  312. default:
  313. _, message, err := clientConn.ReadMessage()
  314. if err != nil {
  315. if !websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) {
  316. errChan <- fmt.Errorf("error reading from client: %v", err)
  317. }
  318. close(clientClosed)
  319. return
  320. }
  321. realtimeEvent := &dto.RealtimeEvent{}
  322. err = common.Unmarshal(message, realtimeEvent)
  323. if err != nil {
  324. errChan <- fmt.Errorf("error unmarshalling message: %v", err)
  325. return
  326. }
  327. if realtimeEvent.Type == dto.RealtimeEventTypeSessionUpdate {
  328. if realtimeEvent.Session != nil {
  329. if realtimeEvent.Session.Tools != nil {
  330. info.RealtimeTools = realtimeEvent.Session.Tools
  331. }
  332. }
  333. }
  334. textToken, audioToken, err := service.CountTokenRealtime(info, *realtimeEvent, info.UpstreamModelName)
  335. if err != nil {
  336. errChan <- fmt.Errorf("error counting text token: %v", err)
  337. return
  338. }
  339. common.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken))
  340. localUsage.TotalTokens += textToken + audioToken
  341. localUsage.InputTokens += textToken + audioToken
  342. localUsage.InputTokenDetails.TextTokens += textToken
  343. localUsage.InputTokenDetails.AudioTokens += audioToken
  344. err = helper.WssString(c, targetConn, string(message))
  345. if err != nil {
  346. errChan <- fmt.Errorf("error writing to target: %v", err)
  347. return
  348. }
  349. select {
  350. case sendChan <- message:
  351. default:
  352. }
  353. }
  354. }
  355. })
  356. gopool.Go(func() {
  357. defer func() {
  358. if r := recover(); r != nil {
  359. errChan <- fmt.Errorf("panic in target reader: %v", r)
  360. }
  361. }()
  362. for {
  363. select {
  364. case <-c.Done():
  365. return
  366. default:
  367. _, message, err := targetConn.ReadMessage()
  368. if err != nil {
  369. if !websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) {
  370. errChan <- fmt.Errorf("error reading from target: %v", err)
  371. }
  372. close(targetClosed)
  373. return
  374. }
  375. info.SetFirstResponseTime()
  376. realtimeEvent := &dto.RealtimeEvent{}
  377. err = common.Unmarshal(message, realtimeEvent)
  378. if err != nil {
  379. errChan <- fmt.Errorf("error unmarshalling message: %v", err)
  380. return
  381. }
  382. if realtimeEvent.Type == dto.RealtimeEventTypeResponseDone {
  383. realtimeUsage := realtimeEvent.Response.Usage
  384. if realtimeUsage != nil {
  385. usage.TotalTokens += realtimeUsage.TotalTokens
  386. usage.InputTokens += realtimeUsage.InputTokens
  387. usage.OutputTokens += realtimeUsage.OutputTokens
  388. usage.InputTokenDetails.AudioTokens += realtimeUsage.InputTokenDetails.AudioTokens
  389. usage.InputTokenDetails.CachedTokens += realtimeUsage.InputTokenDetails.CachedTokens
  390. usage.InputTokenDetails.TextTokens += realtimeUsage.InputTokenDetails.TextTokens
  391. usage.OutputTokenDetails.AudioTokens += realtimeUsage.OutputTokenDetails.AudioTokens
  392. usage.OutputTokenDetails.TextTokens += realtimeUsage.OutputTokenDetails.TextTokens
  393. err := preConsumeUsage(c, info, usage, sumUsage)
  394. if err != nil {
  395. errChan <- fmt.Errorf("error consume usage: %v", err)
  396. return
  397. }
  398. // 本次计费完成,清除
  399. usage = &dto.RealtimeUsage{}
  400. localUsage = &dto.RealtimeUsage{}
  401. } else {
  402. textToken, audioToken, err := service.CountTokenRealtime(info, *realtimeEvent, info.UpstreamModelName)
  403. if err != nil {
  404. errChan <- fmt.Errorf("error counting text token: %v", err)
  405. return
  406. }
  407. common.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken))
  408. localUsage.TotalTokens += textToken + audioToken
  409. info.IsFirstRequest = false
  410. localUsage.InputTokens += textToken + audioToken
  411. localUsage.InputTokenDetails.TextTokens += textToken
  412. localUsage.InputTokenDetails.AudioTokens += audioToken
  413. err = preConsumeUsage(c, info, localUsage, sumUsage)
  414. if err != nil {
  415. errChan <- fmt.Errorf("error consume usage: %v", err)
  416. return
  417. }
  418. // 本次计费完成,清除
  419. localUsage = &dto.RealtimeUsage{}
  420. // print now usage
  421. }
  422. common.LogInfo(c, fmt.Sprintf("realtime streaming sumUsage: %v", sumUsage))
  423. common.LogInfo(c, fmt.Sprintf("realtime streaming localUsage: %v", localUsage))
  424. common.LogInfo(c, fmt.Sprintf("realtime streaming localUsage: %v", localUsage))
  425. } else if realtimeEvent.Type == dto.RealtimeEventTypeSessionUpdated || realtimeEvent.Type == dto.RealtimeEventTypeSessionCreated {
  426. realtimeSession := realtimeEvent.Session
  427. if realtimeSession != nil {
  428. // update audio format
  429. info.InputAudioFormat = common.GetStringIfEmpty(realtimeSession.InputAudioFormat, info.InputAudioFormat)
  430. info.OutputAudioFormat = common.GetStringIfEmpty(realtimeSession.OutputAudioFormat, info.OutputAudioFormat)
  431. }
  432. } else {
  433. textToken, audioToken, err := service.CountTokenRealtime(info, *realtimeEvent, info.UpstreamModelName)
  434. if err != nil {
  435. errChan <- fmt.Errorf("error counting text token: %v", err)
  436. return
  437. }
  438. common.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken))
  439. localUsage.TotalTokens += textToken + audioToken
  440. localUsage.OutputTokens += textToken + audioToken
  441. localUsage.OutputTokenDetails.TextTokens += textToken
  442. localUsage.OutputTokenDetails.AudioTokens += audioToken
  443. }
  444. err = helper.WssString(c, clientConn, string(message))
  445. if err != nil {
  446. errChan <- fmt.Errorf("error writing to client: %v", err)
  447. return
  448. }
  449. select {
  450. case receiveChan <- message:
  451. default:
  452. }
  453. }
  454. }
  455. })
  456. select {
  457. case <-clientClosed:
  458. case <-targetClosed:
  459. case err := <-errChan:
  460. //return service.OpenAIErrorWrapper(err, "realtime_error", http.StatusInternalServerError), nil
  461. common.LogError(c, "realtime error: "+err.Error())
  462. case <-c.Done():
  463. }
  464. if usage.TotalTokens != 0 {
  465. _ = preConsumeUsage(c, info, usage, sumUsage)
  466. }
  467. if localUsage.TotalTokens != 0 {
  468. _ = preConsumeUsage(c, info, localUsage, sumUsage)
  469. }
  470. // check usage total tokens, if 0, use local usage
  471. return nil, sumUsage
  472. }
  473. func preConsumeUsage(ctx *gin.Context, info *relaycommon.RelayInfo, usage *dto.RealtimeUsage, totalUsage *dto.RealtimeUsage) error {
  474. if usage == nil || totalUsage == nil {
  475. return fmt.Errorf("invalid usage pointer")
  476. }
  477. totalUsage.TotalTokens += usage.TotalTokens
  478. totalUsage.InputTokens += usage.InputTokens
  479. totalUsage.OutputTokens += usage.OutputTokens
  480. totalUsage.InputTokenDetails.CachedTokens += usage.InputTokenDetails.CachedTokens
  481. totalUsage.InputTokenDetails.TextTokens += usage.InputTokenDetails.TextTokens
  482. totalUsage.InputTokenDetails.AudioTokens += usage.InputTokenDetails.AudioTokens
  483. totalUsage.OutputTokenDetails.TextTokens += usage.OutputTokenDetails.TextTokens
  484. totalUsage.OutputTokenDetails.AudioTokens += usage.OutputTokenDetails.AudioTokens
  485. // clear usage
  486. err := service.PreWssConsumeQuota(ctx, info, usage)
  487. return err
  488. }
  489. func OpenaiHandlerWithUsage(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
  490. defer common.CloseResponseBodyGracefully(resp)
  491. responseBody, err := io.ReadAll(resp.Body)
  492. if err != nil {
  493. return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
  494. }
  495. var usageResp dto.SimpleResponse
  496. err = common.Unmarshal(responseBody, &usageResp)
  497. if err != nil {
  498. return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
  499. }
  500. // 写入新的 response body
  501. common.IOCopyBytesGracefully(c, resp, responseBody)
  502. // Once we've written to the client, we should not return errors anymore
  503. // because the upstream has already consumed resources and returned content
  504. // We should still perform billing even if parsing fails
  505. // format
  506. if usageResp.InputTokens > 0 {
  507. usageResp.PromptTokens += usageResp.InputTokens
  508. }
  509. if usageResp.OutputTokens > 0 {
  510. usageResp.CompletionTokens += usageResp.OutputTokens
  511. }
  512. if usageResp.InputTokensDetails != nil {
  513. usageResp.PromptTokensDetails.ImageTokens += usageResp.InputTokensDetails.ImageTokens
  514. usageResp.PromptTokensDetails.TextTokens += usageResp.InputTokensDetails.TextTokens
  515. }
  516. return &usageResp.Usage, nil
  517. }