relay-openai.go 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635
  1. package openai
  2. import (
  3. "bytes"
  4. "encoding/json"
  5. "fmt"
  6. "io"
  7. "math"
  8. "mime/multipart"
  9. "net/http"
  10. "one-api/common"
  11. "one-api/constant"
  12. "one-api/dto"
  13. "one-api/logger"
  14. "one-api/relay/channel/openrouter"
  15. relaycommon "one-api/relay/common"
  16. "one-api/relay/helper"
  17. "one-api/service"
  18. "os"
  19. "path/filepath"
  20. "strings"
  21. "one-api/types"
  22. "github.com/bytedance/gopkg/util/gopool"
  23. "github.com/gin-gonic/gin"
  24. "github.com/gorilla/websocket"
  25. "github.com/pkg/errors"
  26. )
  27. func sendStreamData(c *gin.Context, info *relaycommon.RelayInfo, data string, forceFormat bool, thinkToContent bool) error {
  28. if data == "" {
  29. return nil
  30. }
  31. if !forceFormat && !thinkToContent {
  32. return helper.StringData(c, data)
  33. }
  34. var lastStreamResponse dto.ChatCompletionsStreamResponse
  35. if err := common.UnmarshalJsonStr(data, &lastStreamResponse); err != nil {
  36. return err
  37. }
  38. if !thinkToContent {
  39. return helper.ObjectData(c, lastStreamResponse)
  40. }
  41. hasThinkingContent := false
  42. hasContent := false
  43. var thinkingContent strings.Builder
  44. for _, choice := range lastStreamResponse.Choices {
  45. if len(choice.Delta.GetReasoningContent()) > 0 {
  46. hasThinkingContent = true
  47. thinkingContent.WriteString(choice.Delta.GetReasoningContent())
  48. }
  49. if len(choice.Delta.GetContentString()) > 0 {
  50. hasContent = true
  51. }
  52. }
  53. // Handle think to content conversion
  54. if info.ThinkingContentInfo.IsFirstThinkingContent {
  55. if hasThinkingContent {
  56. response := lastStreamResponse.Copy()
  57. for i := range response.Choices {
  58. // send `think` tag with thinking content
  59. response.Choices[i].Delta.SetContentString("<think>\n" + thinkingContent.String())
  60. response.Choices[i].Delta.ReasoningContent = nil
  61. response.Choices[i].Delta.Reasoning = nil
  62. }
  63. info.ThinkingContentInfo.IsFirstThinkingContent = false
  64. info.ThinkingContentInfo.HasSentThinkingContent = true
  65. return helper.ObjectData(c, response)
  66. }
  67. }
  68. if lastStreamResponse.Choices == nil || len(lastStreamResponse.Choices) == 0 {
  69. return helper.ObjectData(c, lastStreamResponse)
  70. }
  71. // Process each choice
  72. for i, choice := range lastStreamResponse.Choices {
  73. // Handle transition from thinking to content
  74. // only send `</think>` tag when previous thinking content has been sent
  75. if hasContent && !info.ThinkingContentInfo.SendLastThinkingContent && info.ThinkingContentInfo.HasSentThinkingContent {
  76. response := lastStreamResponse.Copy()
  77. for j := range response.Choices {
  78. response.Choices[j].Delta.SetContentString("\n</think>\n")
  79. response.Choices[j].Delta.ReasoningContent = nil
  80. response.Choices[j].Delta.Reasoning = nil
  81. }
  82. info.ThinkingContentInfo.SendLastThinkingContent = true
  83. helper.ObjectData(c, response)
  84. }
  85. // Convert reasoning content to regular content if any
  86. if len(choice.Delta.GetReasoningContent()) > 0 {
  87. lastStreamResponse.Choices[i].Delta.SetContentString(choice.Delta.GetReasoningContent())
  88. lastStreamResponse.Choices[i].Delta.ReasoningContent = nil
  89. lastStreamResponse.Choices[i].Delta.Reasoning = nil
  90. } else if !hasThinkingContent && !hasContent {
  91. // flush thinking content
  92. lastStreamResponse.Choices[i].Delta.ReasoningContent = nil
  93. lastStreamResponse.Choices[i].Delta.Reasoning = nil
  94. }
  95. }
  96. return helper.ObjectData(c, lastStreamResponse)
  97. }
  98. func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
  99. if resp == nil || resp.Body == nil {
  100. logger.LogError(c, "invalid response or response body")
  101. return nil, types.NewOpenAIError(fmt.Errorf("invalid response"), types.ErrorCodeBadResponse, http.StatusInternalServerError)
  102. }
  103. defer service.CloseResponseBodyGracefully(resp)
  104. model := info.UpstreamModelName
  105. var responseId string
  106. var createAt int64 = 0
  107. var systemFingerprint string
  108. var containStreamUsage bool
  109. var responseTextBuilder strings.Builder
  110. var toolCount int
  111. var usage = &dto.Usage{}
  112. var streamItems []string // store stream items
  113. var lastStreamData string
  114. helper.StreamScannerHandler(c, resp, info, func(data string) bool {
  115. if lastStreamData != "" {
  116. err := HandleStreamFormat(c, info, lastStreamData, info.ChannelSetting.ForceFormat, info.ChannelSetting.ThinkingToContent)
  117. if err != nil {
  118. common.SysLog("error handling stream format: " + err.Error())
  119. }
  120. }
  121. if len(data) > 0 {
  122. lastStreamData = data
  123. streamItems = append(streamItems, data)
  124. }
  125. return true
  126. })
  127. // 处理最后的响应
  128. shouldSendLastResp := true
  129. if err := handleLastResponse(lastStreamData, &responseId, &createAt, &systemFingerprint, &model, &usage,
  130. &containStreamUsage, info, &shouldSendLastResp); err != nil {
  131. logger.LogError(c, fmt.Sprintf("error handling last response: %s, lastStreamData: [%s]", err.Error(), lastStreamData))
  132. }
  133. if info.RelayFormat == types.RelayFormatOpenAI {
  134. if shouldSendLastResp {
  135. _ = sendStreamData(c, info, lastStreamData, info.ChannelSetting.ForceFormat, info.ChannelSetting.ThinkingToContent)
  136. }
  137. }
  138. // 处理token计算
  139. if err := processTokens(info.RelayMode, streamItems, &responseTextBuilder, &toolCount); err != nil {
  140. logger.LogError(c, "error processing tokens: "+err.Error())
  141. }
  142. if !containStreamUsage {
  143. usage = service.ResponseText2Usage(responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens)
  144. usage.CompletionTokens += toolCount * 7
  145. } else {
  146. if info.ChannelType == constant.ChannelTypeDeepSeek {
  147. if usage.PromptCacheHitTokens != 0 {
  148. usage.PromptTokensDetails.CachedTokens = usage.PromptCacheHitTokens
  149. }
  150. }
  151. }
  152. HandleFinalResponse(c, info, lastStreamData, responseId, createAt, model, systemFingerprint, usage, containStreamUsage)
  153. return usage, nil
  154. }
  155. func OpenaiHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
  156. defer service.CloseResponseBodyGracefully(resp)
  157. var simpleResponse dto.OpenAITextResponse
  158. responseBody, err := io.ReadAll(resp.Body)
  159. if err != nil {
  160. return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
  161. }
  162. if common.DebugEnabled {
  163. println("upstream response body:", string(responseBody))
  164. }
  165. // Unmarshal to simpleResponse
  166. if info.ChannelType == constant.ChannelTypeOpenRouter && info.ChannelOtherSettings.IsOpenRouterEnterprise() {
  167. // 尝试解析为 openrouter enterprise
  168. var enterpriseResponse openrouter.OpenRouterEnterpriseResponse
  169. err = common.Unmarshal(responseBody, &enterpriseResponse)
  170. if err != nil {
  171. return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
  172. }
  173. if enterpriseResponse.Success {
  174. responseBody = enterpriseResponse.Data
  175. } else {
  176. logger.LogError(c, fmt.Sprintf("openrouter enterprise response success=false, data: %s", enterpriseResponse.Data))
  177. return nil, types.NewOpenAIError(fmt.Errorf("openrouter response success=false"), types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
  178. }
  179. }
  180. err = common.Unmarshal(responseBody, &simpleResponse)
  181. if err != nil {
  182. return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
  183. }
  184. if oaiError := simpleResponse.GetOpenAIError(); oaiError != nil && oaiError.Type != "" {
  185. return nil, types.WithOpenAIError(*oaiError, resp.StatusCode)
  186. }
  187. forceFormat := false
  188. if info.ChannelSetting.ForceFormat {
  189. forceFormat = true
  190. }
  191. usageModified := false
  192. if simpleResponse.Usage.PromptTokens == 0 {
  193. completionTokens := simpleResponse.Usage.CompletionTokens
  194. if completionTokens == 0 {
  195. for _, choice := range simpleResponse.Choices {
  196. ctkm := service.CountTextToken(choice.Message.StringContent()+choice.Message.ReasoningContent+choice.Message.Reasoning, info.UpstreamModelName)
  197. completionTokens += ctkm
  198. }
  199. }
  200. simpleResponse.Usage = dto.Usage{
  201. PromptTokens: info.PromptTokens,
  202. CompletionTokens: completionTokens,
  203. TotalTokens: info.PromptTokens + completionTokens,
  204. }
  205. usageModified = true
  206. }
  207. switch info.RelayFormat {
  208. case types.RelayFormatOpenAI:
  209. if usageModified {
  210. var bodyMap map[string]interface{}
  211. err = common.Unmarshal(responseBody, &bodyMap)
  212. if err != nil {
  213. return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
  214. }
  215. bodyMap["usage"] = simpleResponse.Usage
  216. responseBody, _ = common.Marshal(bodyMap)
  217. }
  218. if forceFormat {
  219. responseBody, err = common.Marshal(simpleResponse)
  220. if err != nil {
  221. return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
  222. }
  223. } else {
  224. break
  225. }
  226. case types.RelayFormatClaude:
  227. claudeResp := service.ResponseOpenAI2Claude(&simpleResponse, info)
  228. claudeRespStr, err := common.Marshal(claudeResp)
  229. if err != nil {
  230. return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
  231. }
  232. responseBody = claudeRespStr
  233. case types.RelayFormatGemini:
  234. geminiResp := service.ResponseOpenAI2Gemini(&simpleResponse, info)
  235. geminiRespStr, err := common.Marshal(geminiResp)
  236. if err != nil {
  237. return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
  238. }
  239. responseBody = geminiRespStr
  240. }
  241. service.IOCopyBytesGracefully(c, resp, responseBody)
  242. return &simpleResponse.Usage, nil
  243. }
  244. func OpenaiTTSHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) *dto.Usage {
  245. // the status code has been judged before, if there is a body reading failure,
  246. // it should be regarded as a non-recoverable error, so it should not return err for external retry.
  247. // Analogous to nginx's load balancing, it will only retry if it can't be requested or
  248. // if the upstream returns a specific status code, once the upstream has already written the header,
  249. // the subsequent failure of the response body should be regarded as a non-recoverable error,
  250. // and can be terminated directly.
  251. defer service.CloseResponseBodyGracefully(resp)
  252. usage := &dto.Usage{}
  253. usage.PromptTokens = info.PromptTokens
  254. usage.TotalTokens = info.PromptTokens
  255. for k, v := range resp.Header {
  256. c.Writer.Header().Set(k, v[0])
  257. }
  258. c.Writer.WriteHeader(resp.StatusCode)
  259. c.Writer.WriteHeaderNow()
  260. _, err := io.Copy(c.Writer, resp.Body)
  261. if err != nil {
  262. logger.LogError(c, err.Error())
  263. }
  264. return usage
  265. }
  266. func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, responseFormat string) (*types.NewAPIError, *dto.Usage) {
  267. defer service.CloseResponseBodyGracefully(resp)
  268. responseBody, err := io.ReadAll(resp.Body)
  269. if err != nil {
  270. return types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), nil
  271. }
  272. // 写入新的 response body
  273. service.IOCopyBytesGracefully(c, resp, responseBody)
  274. var responseData struct {
  275. Usage *dto.Usage `json:"usage"`
  276. }
  277. if err := json.Unmarshal(responseBody, &responseData); err == nil && responseData.Usage != nil {
  278. if responseData.Usage.TotalTokens > 0 {
  279. usage := responseData.Usage
  280. if usage.PromptTokens == 0 {
  281. usage.PromptTokens = usage.InputTokens
  282. }
  283. if usage.CompletionTokens == 0 {
  284. usage.CompletionTokens = usage.OutputTokens
  285. }
  286. return nil, usage
  287. }
  288. }
  289. audioTokens, err := countAudioTokens(c)
  290. if err != nil {
  291. return types.NewError(err, types.ErrorCodeCountTokenFailed), nil
  292. }
  293. usage := &dto.Usage{}
  294. usage.PromptTokens = audioTokens
  295. usage.CompletionTokens = 0
  296. usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
  297. return nil, usage
  298. }
  299. func countAudioTokens(c *gin.Context) (int, error) {
  300. body, err := common.GetRequestBody(c)
  301. if err != nil {
  302. return 0, errors.WithStack(err)
  303. }
  304. var reqBody struct {
  305. File *multipart.FileHeader `form:"file" binding:"required"`
  306. }
  307. c.Request.Body = io.NopCloser(bytes.NewReader(body))
  308. if err = c.ShouldBind(&reqBody); err != nil {
  309. return 0, errors.WithStack(err)
  310. }
  311. ext := filepath.Ext(reqBody.File.Filename) // 获取文件扩展名
  312. reqFp, err := reqBody.File.Open()
  313. if err != nil {
  314. return 0, errors.WithStack(err)
  315. }
  316. defer reqFp.Close()
  317. tmpFp, err := os.CreateTemp("", "audio-*"+ext)
  318. if err != nil {
  319. return 0, errors.WithStack(err)
  320. }
  321. defer os.Remove(tmpFp.Name())
  322. _, err = io.Copy(tmpFp, reqFp)
  323. if err != nil {
  324. return 0, errors.WithStack(err)
  325. }
  326. if err = tmpFp.Close(); err != nil {
  327. return 0, errors.WithStack(err)
  328. }
  329. duration, err := common.GetAudioDuration(c.Request.Context(), tmpFp.Name(), ext)
  330. if err != nil {
  331. return 0, errors.WithStack(err)
  332. }
  333. return int(math.Round(math.Ceil(duration) / 60.0 * 1000)), nil // 1 minute 相当于 1k tokens
  334. }
  335. func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*types.NewAPIError, *dto.RealtimeUsage) {
  336. if info == nil || info.ClientWs == nil || info.TargetWs == nil {
  337. return types.NewError(fmt.Errorf("invalid websocket connection"), types.ErrorCodeBadResponse), nil
  338. }
  339. info.IsStream = true
  340. clientConn := info.ClientWs
  341. targetConn := info.TargetWs
  342. clientClosed := make(chan struct{})
  343. targetClosed := make(chan struct{})
  344. sendChan := make(chan []byte, 100)
  345. receiveChan := make(chan []byte, 100)
  346. errChan := make(chan error, 2)
  347. usage := &dto.RealtimeUsage{}
  348. localUsage := &dto.RealtimeUsage{}
  349. sumUsage := &dto.RealtimeUsage{}
  350. gopool.Go(func() {
  351. defer func() {
  352. if r := recover(); r != nil {
  353. errChan <- fmt.Errorf("panic in client reader: %v", r)
  354. }
  355. }()
  356. for {
  357. select {
  358. case <-c.Done():
  359. return
  360. default:
  361. _, message, err := clientConn.ReadMessage()
  362. if err != nil {
  363. if !websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) {
  364. errChan <- fmt.Errorf("error reading from client: %v", err)
  365. }
  366. close(clientClosed)
  367. return
  368. }
  369. realtimeEvent := &dto.RealtimeEvent{}
  370. err = common.Unmarshal(message, realtimeEvent)
  371. if err != nil {
  372. errChan <- fmt.Errorf("error unmarshalling message: %v", err)
  373. return
  374. }
  375. if realtimeEvent.Type == dto.RealtimeEventTypeSessionUpdate {
  376. if realtimeEvent.Session != nil {
  377. if realtimeEvent.Session.Tools != nil {
  378. info.RealtimeTools = realtimeEvent.Session.Tools
  379. }
  380. }
  381. }
  382. textToken, audioToken, err := service.CountTokenRealtime(info, *realtimeEvent, info.UpstreamModelName)
  383. if err != nil {
  384. errChan <- fmt.Errorf("error counting text token: %v", err)
  385. return
  386. }
  387. logger.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken))
  388. localUsage.TotalTokens += textToken + audioToken
  389. localUsage.InputTokens += textToken + audioToken
  390. localUsage.InputTokenDetails.TextTokens += textToken
  391. localUsage.InputTokenDetails.AudioTokens += audioToken
  392. err = helper.WssString(c, targetConn, string(message))
  393. if err != nil {
  394. errChan <- fmt.Errorf("error writing to target: %v", err)
  395. return
  396. }
  397. select {
  398. case sendChan <- message:
  399. default:
  400. }
  401. }
  402. }
  403. })
  404. gopool.Go(func() {
  405. defer func() {
  406. if r := recover(); r != nil {
  407. errChan <- fmt.Errorf("panic in target reader: %v", r)
  408. }
  409. }()
  410. for {
  411. select {
  412. case <-c.Done():
  413. return
  414. default:
  415. _, message, err := targetConn.ReadMessage()
  416. if err != nil {
  417. if !websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) {
  418. errChan <- fmt.Errorf("error reading from target: %v", err)
  419. }
  420. close(targetClosed)
  421. return
  422. }
  423. info.SetFirstResponseTime()
  424. realtimeEvent := &dto.RealtimeEvent{}
  425. err = common.Unmarshal(message, realtimeEvent)
  426. if err != nil {
  427. errChan <- fmt.Errorf("error unmarshalling message: %v", err)
  428. return
  429. }
  430. if realtimeEvent.Type == dto.RealtimeEventTypeResponseDone {
  431. realtimeUsage := realtimeEvent.Response.Usage
  432. if realtimeUsage != nil {
  433. usage.TotalTokens += realtimeUsage.TotalTokens
  434. usage.InputTokens += realtimeUsage.InputTokens
  435. usage.OutputTokens += realtimeUsage.OutputTokens
  436. usage.InputTokenDetails.AudioTokens += realtimeUsage.InputTokenDetails.AudioTokens
  437. usage.InputTokenDetails.CachedTokens += realtimeUsage.InputTokenDetails.CachedTokens
  438. usage.InputTokenDetails.TextTokens += realtimeUsage.InputTokenDetails.TextTokens
  439. usage.OutputTokenDetails.AudioTokens += realtimeUsage.OutputTokenDetails.AudioTokens
  440. usage.OutputTokenDetails.TextTokens += realtimeUsage.OutputTokenDetails.TextTokens
  441. err := preConsumeUsage(c, info, usage, sumUsage)
  442. if err != nil {
  443. errChan <- fmt.Errorf("error consume usage: %v", err)
  444. return
  445. }
  446. // 本次计费完成,清除
  447. usage = &dto.RealtimeUsage{}
  448. localUsage = &dto.RealtimeUsage{}
  449. } else {
  450. textToken, audioToken, err := service.CountTokenRealtime(info, *realtimeEvent, info.UpstreamModelName)
  451. if err != nil {
  452. errChan <- fmt.Errorf("error counting text token: %v", err)
  453. return
  454. }
  455. logger.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken))
  456. localUsage.TotalTokens += textToken + audioToken
  457. info.IsFirstRequest = false
  458. localUsage.InputTokens += textToken + audioToken
  459. localUsage.InputTokenDetails.TextTokens += textToken
  460. localUsage.InputTokenDetails.AudioTokens += audioToken
  461. err = preConsumeUsage(c, info, localUsage, sumUsage)
  462. if err != nil {
  463. errChan <- fmt.Errorf("error consume usage: %v", err)
  464. return
  465. }
  466. // 本次计费完成,清除
  467. localUsage = &dto.RealtimeUsage{}
  468. // print now usage
  469. }
  470. logger.LogInfo(c, fmt.Sprintf("realtime streaming sumUsage: %v", sumUsage))
  471. logger.LogInfo(c, fmt.Sprintf("realtime streaming localUsage: %v", localUsage))
  472. logger.LogInfo(c, fmt.Sprintf("realtime streaming localUsage: %v", localUsage))
  473. } else if realtimeEvent.Type == dto.RealtimeEventTypeSessionUpdated || realtimeEvent.Type == dto.RealtimeEventTypeSessionCreated {
  474. realtimeSession := realtimeEvent.Session
  475. if realtimeSession != nil {
  476. // update audio format
  477. info.InputAudioFormat = common.GetStringIfEmpty(realtimeSession.InputAudioFormat, info.InputAudioFormat)
  478. info.OutputAudioFormat = common.GetStringIfEmpty(realtimeSession.OutputAudioFormat, info.OutputAudioFormat)
  479. }
  480. } else {
  481. textToken, audioToken, err := service.CountTokenRealtime(info, *realtimeEvent, info.UpstreamModelName)
  482. if err != nil {
  483. errChan <- fmt.Errorf("error counting text token: %v", err)
  484. return
  485. }
  486. logger.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken))
  487. localUsage.TotalTokens += textToken + audioToken
  488. localUsage.OutputTokens += textToken + audioToken
  489. localUsage.OutputTokenDetails.TextTokens += textToken
  490. localUsage.OutputTokenDetails.AudioTokens += audioToken
  491. }
  492. err = helper.WssString(c, clientConn, string(message))
  493. if err != nil {
  494. errChan <- fmt.Errorf("error writing to client: %v", err)
  495. return
  496. }
  497. select {
  498. case receiveChan <- message:
  499. default:
  500. }
  501. }
  502. }
  503. })
  504. select {
  505. case <-clientClosed:
  506. case <-targetClosed:
  507. case err := <-errChan:
  508. //return service.OpenAIErrorWrapper(err, "realtime_error", http.StatusInternalServerError), nil
  509. logger.LogError(c, "realtime error: "+err.Error())
  510. case <-c.Done():
  511. }
  512. if usage.TotalTokens != 0 {
  513. _ = preConsumeUsage(c, info, usage, sumUsage)
  514. }
  515. if localUsage.TotalTokens != 0 {
  516. _ = preConsumeUsage(c, info, localUsage, sumUsage)
  517. }
  518. // check usage total tokens, if 0, use local usage
  519. return nil, sumUsage
  520. }
  521. func preConsumeUsage(ctx *gin.Context, info *relaycommon.RelayInfo, usage *dto.RealtimeUsage, totalUsage *dto.RealtimeUsage) error {
  522. if usage == nil || totalUsage == nil {
  523. return fmt.Errorf("invalid usage pointer")
  524. }
  525. totalUsage.TotalTokens += usage.TotalTokens
  526. totalUsage.InputTokens += usage.InputTokens
  527. totalUsage.OutputTokens += usage.OutputTokens
  528. totalUsage.InputTokenDetails.CachedTokens += usage.InputTokenDetails.CachedTokens
  529. totalUsage.InputTokenDetails.TextTokens += usage.InputTokenDetails.TextTokens
  530. totalUsage.InputTokenDetails.AudioTokens += usage.InputTokenDetails.AudioTokens
  531. totalUsage.OutputTokenDetails.TextTokens += usage.OutputTokenDetails.TextTokens
  532. totalUsage.OutputTokenDetails.AudioTokens += usage.OutputTokenDetails.AudioTokens
  533. // clear usage
  534. err := service.PreWssConsumeQuota(ctx, info, usage)
  535. return err
  536. }
  537. func OpenaiHandlerWithUsage(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
  538. defer service.CloseResponseBodyGracefully(resp)
  539. responseBody, err := io.ReadAll(resp.Body)
  540. if err != nil {
  541. return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
  542. }
  543. var usageResp dto.SimpleResponse
  544. err = common.Unmarshal(responseBody, &usageResp)
  545. if err != nil {
  546. return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
  547. }
  548. // 写入新的 response body
  549. service.IOCopyBytesGracefully(c, resp, responseBody)
  550. // Once we've written to the client, we should not return errors anymore
  551. // because the upstream has already consumed resources and returned content
  552. // We should still perform billing even if parsing fails
  553. // format
  554. if usageResp.InputTokens > 0 {
  555. usageResp.PromptTokens += usageResp.InputTokens
  556. }
  557. if usageResp.OutputTokens > 0 {
  558. usageResp.CompletionTokens += usageResp.OutputTokens
  559. }
  560. if usageResp.InputTokensDetails != nil {
  561. usageResp.PromptTokensDetails.ImageTokens += usageResp.InputTokensDetails.ImageTokens
  562. usageResp.PromptTokensDetails.TextTokens += usageResp.InputTokensDetails.TextTokens
  563. }
  564. return &usageResp.Usage, nil
  565. }