relay-openai.go 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718
  1. package openai
  2. import (
  3. "fmt"
  4. "io"
  5. "net/http"
  6. "strings"
  7. "github.com/QuantumNous/new-api/common"
  8. "github.com/QuantumNous/new-api/constant"
  9. "github.com/QuantumNous/new-api/dto"
  10. "github.com/QuantumNous/new-api/logger"
  11. "github.com/QuantumNous/new-api/relay/channel/openrouter"
  12. relaycommon "github.com/QuantumNous/new-api/relay/common"
  13. "github.com/QuantumNous/new-api/relay/helper"
  14. "github.com/QuantumNous/new-api/service"
  15. "github.com/QuantumNous/new-api/types"
  16. "github.com/bytedance/gopkg/util/gopool"
  17. "github.com/gin-gonic/gin"
  18. "github.com/gorilla/websocket"
  19. )
  20. func sendStreamData(c *gin.Context, info *relaycommon.RelayInfo, data string, forceFormat bool, thinkToContent bool) error {
  21. if data == "" {
  22. return nil
  23. }
  24. if !forceFormat && !thinkToContent {
  25. return helper.StringData(c, data)
  26. }
  27. var lastStreamResponse dto.ChatCompletionsStreamResponse
  28. if err := common.UnmarshalJsonStr(data, &lastStreamResponse); err != nil {
  29. return err
  30. }
  31. if !thinkToContent {
  32. return helper.ObjectData(c, lastStreamResponse)
  33. }
  34. hasThinkingContent := false
  35. hasContent := false
  36. var thinkingContent strings.Builder
  37. for _, choice := range lastStreamResponse.Choices {
  38. if len(choice.Delta.GetReasoningContent()) > 0 {
  39. hasThinkingContent = true
  40. thinkingContent.WriteString(choice.Delta.GetReasoningContent())
  41. }
  42. if len(choice.Delta.GetContentString()) > 0 {
  43. hasContent = true
  44. }
  45. }
  46. // Handle think to content conversion
  47. if info.ThinkingContentInfo.IsFirstThinkingContent {
  48. if hasThinkingContent {
  49. response := lastStreamResponse.Copy()
  50. for i := range response.Choices {
  51. // send `think` tag with thinking content
  52. response.Choices[i].Delta.SetContentString("<think>\n" + thinkingContent.String())
  53. response.Choices[i].Delta.ReasoningContent = nil
  54. response.Choices[i].Delta.Reasoning = nil
  55. }
  56. info.ThinkingContentInfo.IsFirstThinkingContent = false
  57. info.ThinkingContentInfo.HasSentThinkingContent = true
  58. return helper.ObjectData(c, response)
  59. }
  60. }
  61. if lastStreamResponse.Choices == nil || len(lastStreamResponse.Choices) == 0 {
  62. return helper.ObjectData(c, lastStreamResponse)
  63. }
  64. // Process each choice
  65. for i, choice := range lastStreamResponse.Choices {
  66. // Handle transition from thinking to content
  67. // only send `</think>` tag when previous thinking content has been sent
  68. if hasContent && !info.ThinkingContentInfo.SendLastThinkingContent && info.ThinkingContentInfo.HasSentThinkingContent {
  69. response := lastStreamResponse.Copy()
  70. for j := range response.Choices {
  71. response.Choices[j].Delta.SetContentString("\n</think>\n")
  72. response.Choices[j].Delta.ReasoningContent = nil
  73. response.Choices[j].Delta.Reasoning = nil
  74. }
  75. info.ThinkingContentInfo.SendLastThinkingContent = true
  76. helper.ObjectData(c, response)
  77. }
  78. // Convert reasoning content to regular content if any
  79. if len(choice.Delta.GetReasoningContent()) > 0 {
  80. lastStreamResponse.Choices[i].Delta.SetContentString(choice.Delta.GetReasoningContent())
  81. lastStreamResponse.Choices[i].Delta.ReasoningContent = nil
  82. lastStreamResponse.Choices[i].Delta.Reasoning = nil
  83. } else if !hasThinkingContent && !hasContent {
  84. // flush thinking content
  85. lastStreamResponse.Choices[i].Delta.ReasoningContent = nil
  86. lastStreamResponse.Choices[i].Delta.Reasoning = nil
  87. }
  88. }
  89. return helper.ObjectData(c, lastStreamResponse)
  90. }
  91. func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
  92. if resp == nil || resp.Body == nil {
  93. logger.LogError(c, "invalid response or response body")
  94. return nil, types.NewOpenAIError(fmt.Errorf("invalid response"), types.ErrorCodeBadResponse, http.StatusInternalServerError)
  95. }
  96. defer service.CloseResponseBodyGracefully(resp)
  97. model := info.UpstreamModelName
  98. var responseId string
  99. var createAt int64 = 0
  100. var systemFingerprint string
  101. var containStreamUsage bool
  102. var responseTextBuilder strings.Builder
  103. var toolCount int
  104. var usage = &dto.Usage{}
  105. var streamItems []string // store stream items
  106. var lastStreamData string
  107. var secondLastStreamData string // 存储倒数第二个stream data,用于音频模型
  108. // 检查是否为音频模型
  109. isAudioModel := strings.Contains(strings.ToLower(model), "audio")
  110. helper.StreamScannerHandler(c, resp, info, func(data string, sr *helper.StreamResult) {
  111. if lastStreamData != "" {
  112. if err := HandleStreamFormat(c, info, lastStreamData, info.ChannelSetting.ForceFormat, info.ChannelSetting.ThinkingToContent); err != nil {
  113. common.SysLog("error handling stream format: " + err.Error())
  114. sr.Error(err)
  115. }
  116. }
  117. if len(data) > 0 {
  118. // 对音频模型,保存倒数第二个stream data
  119. if isAudioModel && lastStreamData != "" {
  120. secondLastStreamData = lastStreamData
  121. }
  122. lastStreamData = data
  123. streamItems = append(streamItems, data)
  124. }
  125. })
  126. // 对音频模型,从倒数第二个stream data中提取usage信息
  127. if isAudioModel && secondLastStreamData != "" {
  128. var streamResp struct {
  129. Usage *dto.Usage `json:"usage"`
  130. }
  131. err := common.Unmarshal([]byte(secondLastStreamData), &streamResp)
  132. if err == nil && streamResp.Usage != nil && service.ValidUsage(streamResp.Usage) {
  133. usage = streamResp.Usage
  134. containStreamUsage = true
  135. if common.DebugEnabled {
  136. logger.LogDebug(c, fmt.Sprintf("Audio model usage extracted from second last SSE: PromptTokens=%d, CompletionTokens=%d, TotalTokens=%d, InputTokens=%d, OutputTokens=%d",
  137. usage.PromptTokens, usage.CompletionTokens, usage.TotalTokens,
  138. usage.InputTokens, usage.OutputTokens))
  139. }
  140. }
  141. }
  142. // 处理最后的响应
  143. shouldSendLastResp := true
  144. if err := handleLastResponse(lastStreamData, &responseId, &createAt, &systemFingerprint, &model, &usage,
  145. &containStreamUsage, info, &shouldSendLastResp); err != nil {
  146. logger.LogError(c, fmt.Sprintf("error handling last response: %s, lastStreamData: [%s]", err.Error(), lastStreamData))
  147. }
  148. if info.RelayFormat == types.RelayFormatOpenAI {
  149. if shouldSendLastResp {
  150. _ = sendStreamData(c, info, lastStreamData, info.ChannelSetting.ForceFormat, info.ChannelSetting.ThinkingToContent)
  151. }
  152. }
  153. // 处理token计算
  154. if err := processTokens(info.RelayMode, streamItems, &responseTextBuilder, &toolCount); err != nil {
  155. logger.LogError(c, "error processing tokens: "+err.Error())
  156. }
  157. if !containStreamUsage {
  158. usage = service.ResponseText2Usage(c, responseTextBuilder.String(), info.UpstreamModelName, info.GetEstimatePromptTokens())
  159. usage.CompletionTokens += toolCount * 7
  160. }
  161. applyUsagePostProcessing(info, usage, common.StringToByteSlice(lastStreamData))
  162. HandleFinalResponse(c, info, lastStreamData, responseId, createAt, model, systemFingerprint, usage, containStreamUsage)
  163. return usage, nil
  164. }
  165. func OpenaiHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
  166. defer service.CloseResponseBodyGracefully(resp)
  167. var simpleResponse dto.OpenAITextResponse
  168. responseBody, err := io.ReadAll(resp.Body)
  169. if err != nil {
  170. return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
  171. }
  172. if common.DebugEnabled {
  173. println("upstream response body:", string(responseBody))
  174. }
  175. // Unmarshal to simpleResponse
  176. if info.ChannelType == constant.ChannelTypeOpenRouter && info.ChannelOtherSettings.IsOpenRouterEnterprise() {
  177. // 尝试解析为 openrouter enterprise
  178. var enterpriseResponse openrouter.OpenRouterEnterpriseResponse
  179. err = common.Unmarshal(responseBody, &enterpriseResponse)
  180. if err != nil {
  181. return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
  182. }
  183. if enterpriseResponse.Success {
  184. responseBody = enterpriseResponse.Data
  185. } else {
  186. logger.LogError(c, fmt.Sprintf("openrouter enterprise response success=false, data: %s", enterpriseResponse.Data))
  187. return nil, types.NewOpenAIError(fmt.Errorf("openrouter response success=false"), types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
  188. }
  189. }
  190. err = common.Unmarshal(responseBody, &simpleResponse)
  191. if err != nil {
  192. return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
  193. }
  194. if oaiError := simpleResponse.GetOpenAIError(); oaiError != nil && oaiError.Type != "" {
  195. return nil, types.WithOpenAIError(*oaiError, resp.StatusCode)
  196. }
  197. for _, choice := range simpleResponse.Choices {
  198. if choice.FinishReason == constant.FinishReasonContentFilter {
  199. common.SetContextKey(c, constant.ContextKeyAdminRejectReason, "openai_finish_reason=content_filter")
  200. break
  201. }
  202. }
  203. forceFormat := false
  204. if info.ChannelSetting.ForceFormat {
  205. forceFormat = true
  206. }
  207. usageModified := false
  208. if simpleResponse.Usage.PromptTokens == 0 {
  209. completionTokens := simpleResponse.Usage.CompletionTokens
  210. if completionTokens == 0 {
  211. for _, choice := range simpleResponse.Choices {
  212. ctkm := service.CountTextToken(choice.Message.StringContent()+choice.Message.GetReasoningContent(), info.UpstreamModelName)
  213. completionTokens += ctkm
  214. }
  215. }
  216. simpleResponse.Usage = dto.Usage{
  217. PromptTokens: info.GetEstimatePromptTokens(),
  218. CompletionTokens: completionTokens,
  219. TotalTokens: info.GetEstimatePromptTokens() + completionTokens,
  220. }
  221. usageModified = true
  222. }
  223. applyUsagePostProcessing(info, &simpleResponse.Usage, responseBody)
  224. switch info.RelayFormat {
  225. case types.RelayFormatOpenAI:
  226. if usageModified {
  227. var bodyMap map[string]interface{}
  228. err = common.Unmarshal(responseBody, &bodyMap)
  229. if err != nil {
  230. return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
  231. }
  232. bodyMap["usage"] = simpleResponse.Usage
  233. responseBody, _ = common.Marshal(bodyMap)
  234. }
  235. if forceFormat {
  236. responseBody, err = common.Marshal(simpleResponse)
  237. if err != nil {
  238. return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
  239. }
  240. } else {
  241. break
  242. }
  243. case types.RelayFormatClaude:
  244. claudeResp := service.ResponseOpenAI2Claude(&simpleResponse, info)
  245. claudeRespStr, err := common.Marshal(claudeResp)
  246. if err != nil {
  247. return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
  248. }
  249. responseBody = claudeRespStr
  250. case types.RelayFormatGemini:
  251. geminiResp := service.ResponseOpenAI2Gemini(&simpleResponse, info)
  252. geminiRespStr, err := common.Marshal(geminiResp)
  253. if err != nil {
  254. return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
  255. }
  256. responseBody = geminiRespStr
  257. }
  258. service.IOCopyBytesGracefully(c, resp, responseBody)
  259. return &simpleResponse.Usage, nil
  260. }
  261. func streamTTSResponse(c *gin.Context, resp *http.Response) {
  262. c.Writer.WriteHeaderNow()
  263. flusher, ok := c.Writer.(http.Flusher)
  264. if !ok {
  265. logger.LogWarn(c, "streaming not supported")
  266. _, err := io.Copy(c.Writer, resp.Body)
  267. if err != nil {
  268. logger.LogWarn(c, err.Error())
  269. }
  270. return
  271. }
  272. buffer := make([]byte, 4096)
  273. for {
  274. n, err := resp.Body.Read(buffer)
  275. //logger.LogInfo(c, fmt.Sprintf("streamTTSResponse read %d bytes", n))
  276. if n > 0 {
  277. if _, writeErr := c.Writer.Write(buffer[:n]); writeErr != nil {
  278. logger.LogError(c, writeErr.Error())
  279. break
  280. }
  281. flusher.Flush()
  282. }
  283. if err != nil {
  284. if err != io.EOF {
  285. logger.LogError(c, err.Error())
  286. }
  287. break
  288. }
  289. }
  290. }
  291. func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*types.NewAPIError, *dto.RealtimeUsage) {
  292. if info == nil || info.ClientWs == nil || info.TargetWs == nil {
  293. return types.NewError(fmt.Errorf("invalid websocket connection"), types.ErrorCodeBadResponse), nil
  294. }
  295. info.IsStream = true
  296. clientConn := info.ClientWs
  297. targetConn := info.TargetWs
  298. clientClosed := make(chan struct{})
  299. targetClosed := make(chan struct{})
  300. sendChan := make(chan []byte, 100)
  301. receiveChan := make(chan []byte, 100)
  302. errChan := make(chan error, 2)
  303. usage := &dto.RealtimeUsage{}
  304. localUsage := &dto.RealtimeUsage{}
  305. sumUsage := &dto.RealtimeUsage{}
  306. gopool.Go(func() {
  307. defer func() {
  308. if r := recover(); r != nil {
  309. errChan <- fmt.Errorf("panic in client reader: %v", r)
  310. }
  311. }()
  312. for {
  313. select {
  314. case <-c.Done():
  315. return
  316. default:
  317. _, message, err := clientConn.ReadMessage()
  318. if err != nil {
  319. if !websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) {
  320. errChan <- fmt.Errorf("error reading from client: %v", err)
  321. }
  322. close(clientClosed)
  323. return
  324. }
  325. realtimeEvent := &dto.RealtimeEvent{}
  326. err = common.Unmarshal(message, realtimeEvent)
  327. if err != nil {
  328. errChan <- fmt.Errorf("error unmarshalling message: %v", err)
  329. return
  330. }
  331. if realtimeEvent.Type == dto.RealtimeEventTypeSessionUpdate {
  332. if realtimeEvent.Session != nil {
  333. if realtimeEvent.Session.Tools != nil {
  334. info.RealtimeTools = realtimeEvent.Session.Tools
  335. }
  336. }
  337. }
  338. textToken, audioToken, err := service.CountTokenRealtime(info, *realtimeEvent, info.UpstreamModelName)
  339. if err != nil {
  340. errChan <- fmt.Errorf("error counting text token: %v", err)
  341. return
  342. }
  343. logger.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken))
  344. localUsage.TotalTokens += textToken + audioToken
  345. localUsage.InputTokens += textToken + audioToken
  346. localUsage.InputTokenDetails.TextTokens += textToken
  347. localUsage.InputTokenDetails.AudioTokens += audioToken
  348. err = helper.WssString(c, targetConn, string(message))
  349. if err != nil {
  350. errChan <- fmt.Errorf("error writing to target: %v", err)
  351. return
  352. }
  353. select {
  354. case sendChan <- message:
  355. default:
  356. }
  357. }
  358. }
  359. })
  360. gopool.Go(func() {
  361. defer func() {
  362. if r := recover(); r != nil {
  363. errChan <- fmt.Errorf("panic in target reader: %v", r)
  364. }
  365. }()
  366. for {
  367. select {
  368. case <-c.Done():
  369. return
  370. default:
  371. _, message, err := targetConn.ReadMessage()
  372. if err != nil {
  373. if !websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) {
  374. errChan <- fmt.Errorf("error reading from target: %v", err)
  375. }
  376. close(targetClosed)
  377. return
  378. }
  379. info.SetFirstResponseTime()
  380. realtimeEvent := &dto.RealtimeEvent{}
  381. err = common.Unmarshal(message, realtimeEvent)
  382. if err != nil {
  383. errChan <- fmt.Errorf("error unmarshalling message: %v", err)
  384. return
  385. }
  386. if realtimeEvent.Type == dto.RealtimeEventTypeResponseDone {
  387. realtimeUsage := realtimeEvent.Response.Usage
  388. if realtimeUsage != nil {
  389. usage.TotalTokens += realtimeUsage.TotalTokens
  390. usage.InputTokens += realtimeUsage.InputTokens
  391. usage.OutputTokens += realtimeUsage.OutputTokens
  392. usage.InputTokenDetails.AudioTokens += realtimeUsage.InputTokenDetails.AudioTokens
  393. usage.InputTokenDetails.CachedTokens += realtimeUsage.InputTokenDetails.CachedTokens
  394. usage.InputTokenDetails.TextTokens += realtimeUsage.InputTokenDetails.TextTokens
  395. usage.OutputTokenDetails.AudioTokens += realtimeUsage.OutputTokenDetails.AudioTokens
  396. usage.OutputTokenDetails.TextTokens += realtimeUsage.OutputTokenDetails.TextTokens
  397. err := preConsumeUsage(c, info, usage, sumUsage)
  398. if err != nil {
  399. errChan <- fmt.Errorf("error consume usage: %v", err)
  400. return
  401. }
  402. // 本次计费完成,清除
  403. usage = &dto.RealtimeUsage{}
  404. localUsage = &dto.RealtimeUsage{}
  405. } else {
  406. textToken, audioToken, err := service.CountTokenRealtime(info, *realtimeEvent, info.UpstreamModelName)
  407. if err != nil {
  408. errChan <- fmt.Errorf("error counting text token: %v", err)
  409. return
  410. }
  411. logger.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken))
  412. localUsage.TotalTokens += textToken + audioToken
  413. info.IsFirstRequest = false
  414. localUsage.InputTokens += textToken + audioToken
  415. localUsage.InputTokenDetails.TextTokens += textToken
  416. localUsage.InputTokenDetails.AudioTokens += audioToken
  417. err = preConsumeUsage(c, info, localUsage, sumUsage)
  418. if err != nil {
  419. errChan <- fmt.Errorf("error consume usage: %v", err)
  420. return
  421. }
  422. // 本次计费完成,清除
  423. localUsage = &dto.RealtimeUsage{}
  424. // print now usage
  425. }
  426. logger.LogInfo(c, fmt.Sprintf("realtime streaming sumUsage: %v", sumUsage))
  427. logger.LogInfo(c, fmt.Sprintf("realtime streaming localUsage: %v", localUsage))
  428. logger.LogInfo(c, fmt.Sprintf("realtime streaming localUsage: %v", localUsage))
  429. } else if realtimeEvent.Type == dto.RealtimeEventTypeSessionUpdated || realtimeEvent.Type == dto.RealtimeEventTypeSessionCreated {
  430. realtimeSession := realtimeEvent.Session
  431. if realtimeSession != nil {
  432. // update audio format
  433. info.InputAudioFormat = common.GetStringIfEmpty(realtimeSession.InputAudioFormat, info.InputAudioFormat)
  434. info.OutputAudioFormat = common.GetStringIfEmpty(realtimeSession.OutputAudioFormat, info.OutputAudioFormat)
  435. }
  436. } else {
  437. textToken, audioToken, err := service.CountTokenRealtime(info, *realtimeEvent, info.UpstreamModelName)
  438. if err != nil {
  439. errChan <- fmt.Errorf("error counting text token: %v", err)
  440. return
  441. }
  442. logger.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken))
  443. localUsage.TotalTokens += textToken + audioToken
  444. localUsage.OutputTokens += textToken + audioToken
  445. localUsage.OutputTokenDetails.TextTokens += textToken
  446. localUsage.OutputTokenDetails.AudioTokens += audioToken
  447. }
  448. err = helper.WssString(c, clientConn, string(message))
  449. if err != nil {
  450. errChan <- fmt.Errorf("error writing to client: %v", err)
  451. return
  452. }
  453. select {
  454. case receiveChan <- message:
  455. default:
  456. }
  457. }
  458. }
  459. })
  460. select {
  461. case <-clientClosed:
  462. case <-targetClosed:
  463. case err := <-errChan:
  464. //return service.OpenAIErrorWrapper(err, "realtime_error", http.StatusInternalServerError), nil
  465. logger.LogError(c, "realtime error: "+err.Error())
  466. case <-c.Done():
  467. }
  468. if usage.TotalTokens != 0 {
  469. _ = preConsumeUsage(c, info, usage, sumUsage)
  470. }
  471. if localUsage.TotalTokens != 0 {
  472. _ = preConsumeUsage(c, info, localUsage, sumUsage)
  473. }
  474. // check usage total tokens, if 0, use local usage
  475. return nil, sumUsage
  476. }
  477. func preConsumeUsage(ctx *gin.Context, info *relaycommon.RelayInfo, usage *dto.RealtimeUsage, totalUsage *dto.RealtimeUsage) error {
  478. if usage == nil || totalUsage == nil {
  479. return fmt.Errorf("invalid usage pointer")
  480. }
  481. totalUsage.TotalTokens += usage.TotalTokens
  482. totalUsage.InputTokens += usage.InputTokens
  483. totalUsage.OutputTokens += usage.OutputTokens
  484. totalUsage.InputTokenDetails.CachedTokens += usage.InputTokenDetails.CachedTokens
  485. totalUsage.InputTokenDetails.TextTokens += usage.InputTokenDetails.TextTokens
  486. totalUsage.InputTokenDetails.AudioTokens += usage.InputTokenDetails.AudioTokens
  487. totalUsage.OutputTokenDetails.TextTokens += usage.OutputTokenDetails.TextTokens
  488. totalUsage.OutputTokenDetails.AudioTokens += usage.OutputTokenDetails.AudioTokens
  489. // clear usage
  490. err := service.PreWssConsumeQuota(ctx, info, usage)
  491. return err
  492. }
  493. func OpenaiHandlerWithUsage(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
  494. defer service.CloseResponseBodyGracefully(resp)
  495. responseBody, err := io.ReadAll(resp.Body)
  496. if err != nil {
  497. return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
  498. }
  499. var usageResp dto.SimpleResponse
  500. err = common.Unmarshal(responseBody, &usageResp)
  501. if err != nil {
  502. return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
  503. }
  504. // 写入新的 response body
  505. service.IOCopyBytesGracefully(c, resp, responseBody)
  506. // Once we've written to the client, we should not return errors anymore
  507. // because the upstream has already consumed resources and returned content
  508. // We should still perform billing even if parsing fails
  509. // format
  510. if usageResp.InputTokens > 0 {
  511. usageResp.PromptTokens += usageResp.InputTokens
  512. }
  513. if usageResp.OutputTokens > 0 {
  514. usageResp.CompletionTokens += usageResp.OutputTokens
  515. }
  516. if usageResp.InputTokensDetails != nil {
  517. usageResp.PromptTokensDetails.ImageTokens += usageResp.InputTokensDetails.ImageTokens
  518. usageResp.PromptTokensDetails.TextTokens += usageResp.InputTokensDetails.TextTokens
  519. }
  520. applyUsagePostProcessing(info, &usageResp.Usage, responseBody)
  521. return &usageResp.Usage, nil
  522. }
  523. func applyUsagePostProcessing(info *relaycommon.RelayInfo, usage *dto.Usage, responseBody []byte) {
  524. if info == nil || usage == nil {
  525. return
  526. }
  527. switch info.ChannelType {
  528. case constant.ChannelTypeDeepSeek:
  529. if usage.PromptTokensDetails.CachedTokens == 0 && usage.PromptCacheHitTokens != 0 {
  530. usage.PromptTokensDetails.CachedTokens = usage.PromptCacheHitTokens
  531. }
  532. case constant.ChannelTypeZhipu_v4:
  533. // 智普的cached_tokens在标准位置: usage.prompt_tokens_details.cached_tokens
  534. if usage.PromptTokensDetails.CachedTokens == 0 {
  535. if usage.InputTokensDetails != nil && usage.InputTokensDetails.CachedTokens > 0 {
  536. usage.PromptTokensDetails.CachedTokens = usage.InputTokensDetails.CachedTokens
  537. } else if cachedTokens, ok := extractCachedTokensFromBody(responseBody); ok {
  538. usage.PromptTokensDetails.CachedTokens = cachedTokens
  539. } else if usage.PromptCacheHitTokens > 0 {
  540. usage.PromptTokensDetails.CachedTokens = usage.PromptCacheHitTokens
  541. }
  542. }
  543. case constant.ChannelTypeMoonshot:
  544. // Moonshot的cached_tokens在非标准位置: choices[].usage.cached_tokens
  545. if usage.PromptTokensDetails.CachedTokens == 0 {
  546. if usage.InputTokensDetails != nil && usage.InputTokensDetails.CachedTokens > 0 {
  547. usage.PromptTokensDetails.CachedTokens = usage.InputTokensDetails.CachedTokens
  548. } else if cachedTokens, ok := extractMoonshotCachedTokensFromBody(responseBody); ok {
  549. usage.PromptTokensDetails.CachedTokens = cachedTokens
  550. } else if cachedTokens, ok := extractCachedTokensFromBody(responseBody); ok {
  551. usage.PromptTokensDetails.CachedTokens = cachedTokens
  552. } else if usage.PromptCacheHitTokens > 0 {
  553. usage.PromptTokensDetails.CachedTokens = usage.PromptCacheHitTokens
  554. }
  555. }
  556. case constant.ChannelTypeOpenAI:
  557. if usage.PromptTokensDetails.CachedTokens == 0 {
  558. if cachedTokens, ok := extractLlamaCachedTokensFromBody(responseBody); ok {
  559. usage.PromptTokensDetails.CachedTokens = cachedTokens
  560. }
  561. }
  562. }
  563. }
  564. func extractCachedTokensFromBody(body []byte) (int, bool) {
  565. if len(body) == 0 {
  566. return 0, false
  567. }
  568. var payload struct {
  569. Usage struct {
  570. PromptTokensDetails struct {
  571. CachedTokens *int `json:"cached_tokens"`
  572. } `json:"prompt_tokens_details"`
  573. CachedTokens *int `json:"cached_tokens"`
  574. PromptCacheHitTokens *int `json:"prompt_cache_hit_tokens"`
  575. } `json:"usage"`
  576. }
  577. if err := common.Unmarshal(body, &payload); err != nil {
  578. return 0, false
  579. }
  580. if payload.Usage.PromptTokensDetails.CachedTokens != nil {
  581. return *payload.Usage.PromptTokensDetails.CachedTokens, true
  582. }
  583. if payload.Usage.CachedTokens != nil {
  584. return *payload.Usage.CachedTokens, true
  585. }
  586. if payload.Usage.PromptCacheHitTokens != nil {
  587. return *payload.Usage.PromptCacheHitTokens, true
  588. }
  589. return 0, false
  590. }
  591. // extractMoonshotCachedTokensFromBody 从Moonshot的非标准位置提取cached_tokens
  592. // Moonshot的流式响应格式: {"choices":[{"usage":{"cached_tokens":111}}]}
  593. func extractMoonshotCachedTokensFromBody(body []byte) (int, bool) {
  594. if len(body) == 0 {
  595. return 0, false
  596. }
  597. var payload struct {
  598. Choices []struct {
  599. Usage struct {
  600. CachedTokens *int `json:"cached_tokens"`
  601. } `json:"usage"`
  602. } `json:"choices"`
  603. }
  604. if err := common.Unmarshal(body, &payload); err != nil {
  605. return 0, false
  606. }
  607. // 遍历choices查找cached_tokens
  608. for _, choice := range payload.Choices {
  609. if choice.Usage.CachedTokens != nil && *choice.Usage.CachedTokens > 0 {
  610. return *choice.Usage.CachedTokens, true
  611. }
  612. }
  613. return 0, false
  614. }
  615. // extractLlamaCachedTokensFromBody 从llama.cpp的非标准位置提取cache_n
  616. func extractLlamaCachedTokensFromBody(body []byte) (int, bool) {
  617. if len(body) == 0 {
  618. return 0, false
  619. }
  620. var payload struct {
  621. Timings struct {
  622. CachedTokens *int `json:"cache_n"`
  623. } `json:"timings"`
  624. }
  625. if err := common.Unmarshal(body, &payload); err != nil {
  626. return 0, false
  627. }
  628. if payload.Timings.CachedTokens == nil {
  629. return 0, false
  630. }
  631. return *payload.Timings.CachedTokens, true
  632. }