relay-openai.go 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588
  1. package openai
  2. import (
  3. "bytes"
  4. "fmt"
  5. "io"
  6. "math"
  7. "mime/multipart"
  8. "net/http"
  9. "one-api/common"
  10. "one-api/constant"
  11. "one-api/dto"
  12. relaycommon "one-api/relay/common"
  13. "one-api/relay/helper"
  14. "one-api/service"
  15. "os"
  16. "path/filepath"
  17. "strings"
  18. "github.com/bytedance/gopkg/util/gopool"
  19. "github.com/gin-gonic/gin"
  20. "github.com/gorilla/websocket"
  21. "github.com/pkg/errors"
  22. )
  23. func sendStreamData(c *gin.Context, info *relaycommon.RelayInfo, data string, forceFormat bool, thinkToContent bool) error {
  24. if data == "" {
  25. return nil
  26. }
  27. if !forceFormat && !thinkToContent {
  28. return helper.StringData(c, data)
  29. }
  30. var lastStreamResponse dto.ChatCompletionsStreamResponse
  31. if err := common.UnmarshalJsonStr(data, &lastStreamResponse); err != nil {
  32. return err
  33. }
  34. if !thinkToContent {
  35. return helper.ObjectData(c, lastStreamResponse)
  36. }
  37. hasThinkingContent := false
  38. hasContent := false
  39. var thinkingContent strings.Builder
  40. for _, choice := range lastStreamResponse.Choices {
  41. if len(choice.Delta.GetReasoningContent()) > 0 {
  42. hasThinkingContent = true
  43. thinkingContent.WriteString(choice.Delta.GetReasoningContent())
  44. }
  45. if len(choice.Delta.GetContentString()) > 0 {
  46. hasContent = true
  47. }
  48. }
  49. // Handle think to content conversion
  50. if info.ThinkingContentInfo.IsFirstThinkingContent {
  51. if hasThinkingContent {
  52. response := lastStreamResponse.Copy()
  53. for i := range response.Choices {
  54. // send `think` tag with thinking content
  55. response.Choices[i].Delta.SetContentString("<think>\n" + thinkingContent.String())
  56. response.Choices[i].Delta.ReasoningContent = nil
  57. response.Choices[i].Delta.Reasoning = nil
  58. }
  59. info.ThinkingContentInfo.IsFirstThinkingContent = false
  60. info.ThinkingContentInfo.HasSentThinkingContent = true
  61. return helper.ObjectData(c, response)
  62. }
  63. }
  64. if lastStreamResponse.Choices == nil || len(lastStreamResponse.Choices) == 0 {
  65. return helper.ObjectData(c, lastStreamResponse)
  66. }
  67. // Process each choice
  68. for i, choice := range lastStreamResponse.Choices {
  69. // Handle transition from thinking to content
  70. // only send `</think>` tag when previous thinking content has been sent
  71. if hasContent && !info.ThinkingContentInfo.SendLastThinkingContent && info.ThinkingContentInfo.HasSentThinkingContent {
  72. response := lastStreamResponse.Copy()
  73. for j := range response.Choices {
  74. response.Choices[j].Delta.SetContentString("\n</think>\n")
  75. response.Choices[j].Delta.ReasoningContent = nil
  76. response.Choices[j].Delta.Reasoning = nil
  77. }
  78. info.ThinkingContentInfo.SendLastThinkingContent = true
  79. helper.ObjectData(c, response)
  80. }
  81. // Convert reasoning content to regular content if any
  82. if len(choice.Delta.GetReasoningContent()) > 0 {
  83. lastStreamResponse.Choices[i].Delta.SetContentString(choice.Delta.GetReasoningContent())
  84. lastStreamResponse.Choices[i].Delta.ReasoningContent = nil
  85. lastStreamResponse.Choices[i].Delta.Reasoning = nil
  86. } else if !hasThinkingContent && !hasContent {
  87. // flush thinking content
  88. lastStreamResponse.Choices[i].Delta.ReasoningContent = nil
  89. lastStreamResponse.Choices[i].Delta.Reasoning = nil
  90. }
  91. }
  92. return helper.ObjectData(c, lastStreamResponse)
  93. }
  94. func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
  95. if resp == nil || resp.Body == nil {
  96. common.LogError(c, "invalid response or response body")
  97. return service.OpenAIErrorWrapper(fmt.Errorf("invalid response"), "invalid_response", http.StatusInternalServerError), nil
  98. }
  99. defer common.CloseResponseBodyGracefully(resp)
  100. model := info.UpstreamModelName
  101. var responseId string
  102. var createAt int64 = 0
  103. var systemFingerprint string
  104. var containStreamUsage bool
  105. var responseTextBuilder strings.Builder
  106. var toolCount int
  107. var usage = &dto.Usage{}
  108. var streamItems []string // store stream items
  109. var forceFormat bool
  110. var thinkToContent bool
  111. if forceFmt, ok := info.ChannelSetting[constant.ForceFormat].(bool); ok {
  112. forceFormat = forceFmt
  113. }
  114. if think2Content, ok := info.ChannelSetting[constant.ChannelSettingThinkingToContent].(bool); ok {
  115. thinkToContent = think2Content
  116. }
  117. var (
  118. lastStreamData string
  119. )
  120. helper.StreamScannerHandler(c, resp, info, func(data string) bool {
  121. if lastStreamData != "" {
  122. err := handleStreamFormat(c, info, lastStreamData, forceFormat, thinkToContent)
  123. if err != nil {
  124. common.SysError("error handling stream format: " + err.Error())
  125. }
  126. }
  127. lastStreamData = data
  128. streamItems = append(streamItems, data)
  129. return true
  130. })
  131. // 处理最后的响应
  132. shouldSendLastResp := true
  133. if err := handleLastResponse(lastStreamData, &responseId, &createAt, &systemFingerprint, &model, &usage,
  134. &containStreamUsage, info, &shouldSendLastResp); err != nil {
  135. common.SysError("error handling last response: " + err.Error())
  136. }
  137. if shouldSendLastResp && info.RelayFormat == relaycommon.RelayFormatOpenAI {
  138. _ = sendStreamData(c, info, lastStreamData, forceFormat, thinkToContent)
  139. }
  140. // 处理token计算
  141. if err := processTokens(info.RelayMode, streamItems, &responseTextBuilder, &toolCount); err != nil {
  142. common.SysError("error processing tokens: " + err.Error())
  143. }
  144. if !containStreamUsage {
  145. usage = service.ResponseText2Usage(responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens)
  146. usage.CompletionTokens += toolCount * 7
  147. } else {
  148. if info.ChannelType == constant.ChannelTypeDeepSeek {
  149. if usage.PromptCacheHitTokens != 0 {
  150. usage.PromptTokensDetails.CachedTokens = usage.PromptCacheHitTokens
  151. }
  152. }
  153. }
  154. handleFinalResponse(c, info, lastStreamData, responseId, createAt, model, systemFingerprint, usage, containStreamUsage)
  155. return nil, usage
  156. }
  157. func OpenaiHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
  158. defer common.CloseResponseBodyGracefully(resp)
  159. var simpleResponse dto.OpenAITextResponse
  160. responseBody, err := io.ReadAll(resp.Body)
  161. if err != nil {
  162. return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
  163. }
  164. err = common.UnmarshalJson(responseBody, &simpleResponse)
  165. if err != nil {
  166. return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
  167. }
  168. if simpleResponse.Error != nil && simpleResponse.Error.Type != "" {
  169. return &dto.OpenAIErrorWithStatusCode{
  170. Error: *simpleResponse.Error,
  171. StatusCode: resp.StatusCode,
  172. }, nil
  173. }
  174. forceFormat := false
  175. if forceFmt, ok := info.ChannelSetting[constant.ForceFormat].(bool); ok {
  176. forceFormat = forceFmt
  177. }
  178. if simpleResponse.Usage.TotalTokens == 0 || (simpleResponse.Usage.PromptTokens == 0 && simpleResponse.Usage.CompletionTokens == 0) {
  179. completionTokens := 0
  180. for _, choice := range simpleResponse.Choices {
  181. ctkm := service.CountTextToken(choice.Message.StringContent()+choice.Message.ReasoningContent+choice.Message.Reasoning, info.UpstreamModelName)
  182. completionTokens += ctkm
  183. }
  184. simpleResponse.Usage = dto.Usage{
  185. PromptTokens: info.PromptTokens,
  186. CompletionTokens: completionTokens,
  187. TotalTokens: info.PromptTokens + completionTokens,
  188. }
  189. }
  190. switch info.RelayFormat {
  191. case relaycommon.RelayFormatOpenAI:
  192. if forceFormat {
  193. responseBody, err = common.EncodeJson(simpleResponse)
  194. if err != nil {
  195. return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
  196. }
  197. } else {
  198. break
  199. }
  200. case relaycommon.RelayFormatClaude:
  201. claudeResp := service.ResponseOpenAI2Claude(&simpleResponse, info)
  202. claudeRespStr, err := common.EncodeJson(claudeResp)
  203. if err != nil {
  204. return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
  205. }
  206. responseBody = claudeRespStr
  207. }
  208. common.IOCopyBytesGracefully(c, resp, responseBody)
  209. return nil, &simpleResponse.Usage
  210. }
  211. func OpenaiTTSHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
  212. // the status code has been judged before, if there is a body reading failure,
  213. // it should be regarded as a non-recoverable error, so it should not return err for external retry.
  214. // Analogous to nginx's load balancing, it will only retry if it can't be requested or
  215. // if the upstream returns a specific status code, once the upstream has already written the header,
  216. // the subsequent failure of the response body should be regarded as a non-recoverable error,
  217. // and can be terminated directly.
  218. defer common.CloseResponseBodyGracefully(resp)
  219. usage := &dto.Usage{}
  220. usage.PromptTokens = info.PromptTokens
  221. usage.TotalTokens = info.PromptTokens
  222. for k, v := range resp.Header {
  223. c.Writer.Header().Set(k, v[0])
  224. }
  225. c.Writer.WriteHeader(resp.StatusCode)
  226. c.Writer.WriteHeaderNow()
  227. _, err := io.Copy(c.Writer, resp.Body)
  228. if err != nil {
  229. common.LogError(c, err.Error())
  230. }
  231. return nil, usage
  232. }
  233. func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, responseFormat string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
  234. defer common.CloseResponseBodyGracefully(resp)
  235. // count tokens by audio file duration
  236. audioTokens, err := countAudioTokens(c)
  237. if err != nil {
  238. return service.OpenAIErrorWrapper(err, "count_audio_tokens_failed", http.StatusInternalServerError), nil
  239. }
  240. responseBody, err := io.ReadAll(resp.Body)
  241. if err != nil {
  242. return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
  243. }
  244. // 写入新的 response body
  245. common.IOCopyBytesGracefully(c, resp, responseBody)
  246. usage := &dto.Usage{}
  247. usage.PromptTokens = audioTokens
  248. usage.CompletionTokens = 0
  249. usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
  250. return nil, usage
  251. }
  252. func countAudioTokens(c *gin.Context) (int, error) {
  253. body, err := common.GetRequestBody(c)
  254. if err != nil {
  255. return 0, errors.WithStack(err)
  256. }
  257. var reqBody struct {
  258. File *multipart.FileHeader `form:"file" binding:"required"`
  259. }
  260. c.Request.Body = io.NopCloser(bytes.NewReader(body))
  261. if err = c.ShouldBind(&reqBody); err != nil {
  262. return 0, errors.WithStack(err)
  263. }
  264. ext := filepath.Ext(reqBody.File.Filename) // 获取文件扩展名
  265. reqFp, err := reqBody.File.Open()
  266. if err != nil {
  267. return 0, errors.WithStack(err)
  268. }
  269. defer reqFp.Close()
  270. tmpFp, err := os.CreateTemp("", "audio-*"+ext)
  271. if err != nil {
  272. return 0, errors.WithStack(err)
  273. }
  274. defer os.Remove(tmpFp.Name())
  275. _, err = io.Copy(tmpFp, reqFp)
  276. if err != nil {
  277. return 0, errors.WithStack(err)
  278. }
  279. if err = tmpFp.Close(); err != nil {
  280. return 0, errors.WithStack(err)
  281. }
  282. duration, err := common.GetAudioDuration(c.Request.Context(), tmpFp.Name(), ext)
  283. if err != nil {
  284. return 0, errors.WithStack(err)
  285. }
  286. return int(math.Round(math.Ceil(duration) / 60.0 * 1000)), nil // 1 minute 相当于 1k tokens
  287. }
  288. func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.RealtimeUsage) {
  289. if info == nil || info.ClientWs == nil || info.TargetWs == nil {
  290. return service.OpenAIErrorWrapper(fmt.Errorf("invalid websocket connection"), "invalid_connection", http.StatusBadRequest), nil
  291. }
  292. info.IsStream = true
  293. clientConn := info.ClientWs
  294. targetConn := info.TargetWs
  295. clientClosed := make(chan struct{})
  296. targetClosed := make(chan struct{})
  297. sendChan := make(chan []byte, 100)
  298. receiveChan := make(chan []byte, 100)
  299. errChan := make(chan error, 2)
  300. usage := &dto.RealtimeUsage{}
  301. localUsage := &dto.RealtimeUsage{}
  302. sumUsage := &dto.RealtimeUsage{}
  303. gopool.Go(func() {
  304. defer func() {
  305. if r := recover(); r != nil {
  306. errChan <- fmt.Errorf("panic in client reader: %v", r)
  307. }
  308. }()
  309. for {
  310. select {
  311. case <-c.Done():
  312. return
  313. default:
  314. _, message, err := clientConn.ReadMessage()
  315. if err != nil {
  316. if !websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) {
  317. errChan <- fmt.Errorf("error reading from client: %v", err)
  318. }
  319. close(clientClosed)
  320. return
  321. }
  322. realtimeEvent := &dto.RealtimeEvent{}
  323. err = common.UnmarshalJson(message, realtimeEvent)
  324. if err != nil {
  325. errChan <- fmt.Errorf("error unmarshalling message: %v", err)
  326. return
  327. }
  328. if realtimeEvent.Type == dto.RealtimeEventTypeSessionUpdate {
  329. if realtimeEvent.Session != nil {
  330. if realtimeEvent.Session.Tools != nil {
  331. info.RealtimeTools = realtimeEvent.Session.Tools
  332. }
  333. }
  334. }
  335. textToken, audioToken, err := service.CountTokenRealtime(info, *realtimeEvent, info.UpstreamModelName)
  336. if err != nil {
  337. errChan <- fmt.Errorf("error counting text token: %v", err)
  338. return
  339. }
  340. common.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken))
  341. localUsage.TotalTokens += textToken + audioToken
  342. localUsage.InputTokens += textToken + audioToken
  343. localUsage.InputTokenDetails.TextTokens += textToken
  344. localUsage.InputTokenDetails.AudioTokens += audioToken
  345. err = helper.WssString(c, targetConn, string(message))
  346. if err != nil {
  347. errChan <- fmt.Errorf("error writing to target: %v", err)
  348. return
  349. }
  350. select {
  351. case sendChan <- message:
  352. default:
  353. }
  354. }
  355. }
  356. })
  357. gopool.Go(func() {
  358. defer func() {
  359. if r := recover(); r != nil {
  360. errChan <- fmt.Errorf("panic in target reader: %v", r)
  361. }
  362. }()
  363. for {
  364. select {
  365. case <-c.Done():
  366. return
  367. default:
  368. _, message, err := targetConn.ReadMessage()
  369. if err != nil {
  370. if !websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) {
  371. errChan <- fmt.Errorf("error reading from target: %v", err)
  372. }
  373. close(targetClosed)
  374. return
  375. }
  376. info.SetFirstResponseTime()
  377. realtimeEvent := &dto.RealtimeEvent{}
  378. err = common.UnmarshalJson(message, realtimeEvent)
  379. if err != nil {
  380. errChan <- fmt.Errorf("error unmarshalling message: %v", err)
  381. return
  382. }
  383. if realtimeEvent.Type == dto.RealtimeEventTypeResponseDone {
  384. realtimeUsage := realtimeEvent.Response.Usage
  385. if realtimeUsage != nil {
  386. usage.TotalTokens += realtimeUsage.TotalTokens
  387. usage.InputTokens += realtimeUsage.InputTokens
  388. usage.OutputTokens += realtimeUsage.OutputTokens
  389. usage.InputTokenDetails.AudioTokens += realtimeUsage.InputTokenDetails.AudioTokens
  390. usage.InputTokenDetails.CachedTokens += realtimeUsage.InputTokenDetails.CachedTokens
  391. usage.InputTokenDetails.TextTokens += realtimeUsage.InputTokenDetails.TextTokens
  392. usage.OutputTokenDetails.AudioTokens += realtimeUsage.OutputTokenDetails.AudioTokens
  393. usage.OutputTokenDetails.TextTokens += realtimeUsage.OutputTokenDetails.TextTokens
  394. err := preConsumeUsage(c, info, usage, sumUsage)
  395. if err != nil {
  396. errChan <- fmt.Errorf("error consume usage: %v", err)
  397. return
  398. }
  399. // 本次计费完成,清除
  400. usage = &dto.RealtimeUsage{}
  401. localUsage = &dto.RealtimeUsage{}
  402. } else {
  403. textToken, audioToken, err := service.CountTokenRealtime(info, *realtimeEvent, info.UpstreamModelName)
  404. if err != nil {
  405. errChan <- fmt.Errorf("error counting text token: %v", err)
  406. return
  407. }
  408. common.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken))
  409. localUsage.TotalTokens += textToken + audioToken
  410. info.IsFirstRequest = false
  411. localUsage.InputTokens += textToken + audioToken
  412. localUsage.InputTokenDetails.TextTokens += textToken
  413. localUsage.InputTokenDetails.AudioTokens += audioToken
  414. err = preConsumeUsage(c, info, localUsage, sumUsage)
  415. if err != nil {
  416. errChan <- fmt.Errorf("error consume usage: %v", err)
  417. return
  418. }
  419. // 本次计费完成,清除
  420. localUsage = &dto.RealtimeUsage{}
  421. // print now usage
  422. }
  423. common.LogInfo(c, fmt.Sprintf("realtime streaming sumUsage: %v", sumUsage))
  424. common.LogInfo(c, fmt.Sprintf("realtime streaming localUsage: %v", localUsage))
  425. common.LogInfo(c, fmt.Sprintf("realtime streaming localUsage: %v", localUsage))
  426. } else if realtimeEvent.Type == dto.RealtimeEventTypeSessionUpdated || realtimeEvent.Type == dto.RealtimeEventTypeSessionCreated {
  427. realtimeSession := realtimeEvent.Session
  428. if realtimeSession != nil {
  429. // update audio format
  430. info.InputAudioFormat = common.GetStringIfEmpty(realtimeSession.InputAudioFormat, info.InputAudioFormat)
  431. info.OutputAudioFormat = common.GetStringIfEmpty(realtimeSession.OutputAudioFormat, info.OutputAudioFormat)
  432. }
  433. } else {
  434. textToken, audioToken, err := service.CountTokenRealtime(info, *realtimeEvent, info.UpstreamModelName)
  435. if err != nil {
  436. errChan <- fmt.Errorf("error counting text token: %v", err)
  437. return
  438. }
  439. common.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken))
  440. localUsage.TotalTokens += textToken + audioToken
  441. localUsage.OutputTokens += textToken + audioToken
  442. localUsage.OutputTokenDetails.TextTokens += textToken
  443. localUsage.OutputTokenDetails.AudioTokens += audioToken
  444. }
  445. err = helper.WssString(c, clientConn, string(message))
  446. if err != nil {
  447. errChan <- fmt.Errorf("error writing to client: %v", err)
  448. return
  449. }
  450. select {
  451. case receiveChan <- message:
  452. default:
  453. }
  454. }
  455. }
  456. })
  457. select {
  458. case <-clientClosed:
  459. case <-targetClosed:
  460. case err := <-errChan:
  461. //return service.OpenAIErrorWrapper(err, "realtime_error", http.StatusInternalServerError), nil
  462. common.LogError(c, "realtime error: "+err.Error())
  463. case <-c.Done():
  464. }
  465. if usage.TotalTokens != 0 {
  466. _ = preConsumeUsage(c, info, usage, sumUsage)
  467. }
  468. if localUsage.TotalTokens != 0 {
  469. _ = preConsumeUsage(c, info, localUsage, sumUsage)
  470. }
  471. // check usage total tokens, if 0, use local usage
  472. return nil, sumUsage
  473. }
  474. func preConsumeUsage(ctx *gin.Context, info *relaycommon.RelayInfo, usage *dto.RealtimeUsage, totalUsage *dto.RealtimeUsage) error {
  475. if usage == nil || totalUsage == nil {
  476. return fmt.Errorf("invalid usage pointer")
  477. }
  478. totalUsage.TotalTokens += usage.TotalTokens
  479. totalUsage.InputTokens += usage.InputTokens
  480. totalUsage.OutputTokens += usage.OutputTokens
  481. totalUsage.InputTokenDetails.CachedTokens += usage.InputTokenDetails.CachedTokens
  482. totalUsage.InputTokenDetails.TextTokens += usage.InputTokenDetails.TextTokens
  483. totalUsage.InputTokenDetails.AudioTokens += usage.InputTokenDetails.AudioTokens
  484. totalUsage.OutputTokenDetails.TextTokens += usage.OutputTokenDetails.TextTokens
  485. totalUsage.OutputTokenDetails.AudioTokens += usage.OutputTokenDetails.AudioTokens
  486. // clear usage
  487. err := service.PreWssConsumeQuota(ctx, info, usage)
  488. return err
  489. }
  490. func OpenaiHandlerWithUsage(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
  491. defer common.CloseResponseBodyGracefully(resp)
  492. responseBody, err := io.ReadAll(resp.Body)
  493. if err != nil {
  494. return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
  495. }
  496. var usageResp dto.SimpleResponse
  497. err = common.UnmarshalJson(responseBody, &usageResp)
  498. if err != nil {
  499. return service.OpenAIErrorWrapper(err, "parse_response_body_failed", http.StatusInternalServerError), nil
  500. }
  501. // 写入新的 response body
  502. common.IOCopyBytesGracefully(c, resp, responseBody)
  503. // Once we've written to the client, we should not return errors anymore
  504. // because the upstream has already consumed resources and returned content
  505. // We should still perform billing even if parsing fails
  506. // format
  507. if usageResp.InputTokens > 0 {
  508. usageResp.PromptTokens += usageResp.InputTokens
  509. }
  510. if usageResp.OutputTokens > 0 {
  511. usageResp.CompletionTokens += usageResp.OutputTokens
  512. }
  513. if usageResp.InputTokensDetails != nil {
  514. usageResp.PromptTokensDetails.ImageTokens += usageResp.InputTokensDetails.ImageTokens
  515. usageResp.PromptTokensDetails.TextTokens += usageResp.InputTokensDetails.TextTokens
  516. }
  517. return nil, &usageResp.Usage
  518. }