relay-openai.go 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747
  1. package openai
  2. import (
  3. "bytes"
  4. "encoding/json"
  5. "fmt"
  6. "io"
  7. "math"
  8. "mime/multipart"
  9. "net/http"
  10. "one-api/common"
  11. "one-api/constant"
  12. "one-api/dto"
  13. relaycommon "one-api/relay/common"
  14. "one-api/relay/helper"
  15. "one-api/service"
  16. "os"
  17. "strings"
  18. "github.com/bytedance/gopkg/util/gopool"
  19. "github.com/gin-gonic/gin"
  20. "github.com/gorilla/websocket"
  21. "github.com/pkg/errors"
  22. )
  23. func sendStreamData(c *gin.Context, info *relaycommon.RelayInfo, data string, forceFormat bool, thinkToContent bool) error {
  24. if data == "" {
  25. return nil
  26. }
  27. if !forceFormat && !thinkToContent {
  28. return helper.StringData(c, data)
  29. }
  30. var lastStreamResponse dto.ChatCompletionsStreamResponse
  31. if err := common.DecodeJsonStr(data, &lastStreamResponse); err != nil {
  32. return err
  33. }
  34. if !thinkToContent {
  35. return helper.ObjectData(c, lastStreamResponse)
  36. }
  37. hasThinkingContent := false
  38. hasContent := false
  39. var thinkingContent strings.Builder
  40. for _, choice := range lastStreamResponse.Choices {
  41. if len(choice.Delta.GetReasoningContent()) > 0 {
  42. hasThinkingContent = true
  43. thinkingContent.WriteString(choice.Delta.GetReasoningContent())
  44. }
  45. if len(choice.Delta.GetContentString()) > 0 {
  46. hasContent = true
  47. }
  48. }
  49. // Handle think to content conversion
  50. if info.ThinkingContentInfo.IsFirstThinkingContent {
  51. if hasThinkingContent {
  52. response := lastStreamResponse.Copy()
  53. for i := range response.Choices {
  54. // send `think` tag with thinking content
  55. response.Choices[i].Delta.SetContentString("<think>\n" + thinkingContent.String())
  56. response.Choices[i].Delta.ReasoningContent = nil
  57. response.Choices[i].Delta.Reasoning = nil
  58. }
  59. info.ThinkingContentInfo.IsFirstThinkingContent = false
  60. info.ThinkingContentInfo.HasSentThinkingContent = true
  61. return helper.ObjectData(c, response)
  62. }
  63. }
  64. if lastStreamResponse.Choices == nil || len(lastStreamResponse.Choices) == 0 {
  65. return helper.ObjectData(c, lastStreamResponse)
  66. }
  67. // Process each choice
  68. for i, choice := range lastStreamResponse.Choices {
  69. // Handle transition from thinking to content
  70. // only send `</think>` tag when previous thinking content has been sent
  71. if hasContent && !info.ThinkingContentInfo.SendLastThinkingContent && info.ThinkingContentInfo.HasSentThinkingContent {
  72. response := lastStreamResponse.Copy()
  73. for j := range response.Choices {
  74. response.Choices[j].Delta.SetContentString("\n</think>\n")
  75. response.Choices[j].Delta.ReasoningContent = nil
  76. response.Choices[j].Delta.Reasoning = nil
  77. }
  78. info.ThinkingContentInfo.SendLastThinkingContent = true
  79. helper.ObjectData(c, response)
  80. }
  81. // Convert reasoning content to regular content if any
  82. if len(choice.Delta.GetReasoningContent()) > 0 {
  83. lastStreamResponse.Choices[i].Delta.SetContentString(choice.Delta.GetReasoningContent())
  84. lastStreamResponse.Choices[i].Delta.ReasoningContent = nil
  85. lastStreamResponse.Choices[i].Delta.Reasoning = nil
  86. } else if !hasThinkingContent && !hasContent {
  87. // flush thinking content
  88. lastStreamResponse.Choices[i].Delta.ReasoningContent = nil
  89. lastStreamResponse.Choices[i].Delta.Reasoning = nil
  90. }
  91. }
  92. return helper.ObjectData(c, lastStreamResponse)
  93. }
  94. func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
  95. if resp == nil || resp.Body == nil {
  96. common.LogError(c, "invalid response or response body")
  97. return service.OpenAIErrorWrapper(fmt.Errorf("invalid response"), "invalid_response", http.StatusInternalServerError), nil
  98. }
  99. containStreamUsage := false
  100. var responseId string
  101. var createAt int64 = 0
  102. var systemFingerprint string
  103. model := info.UpstreamModelName
  104. var responseTextBuilder strings.Builder
  105. var toolCount int
  106. var usage = &dto.Usage{}
  107. var streamItems []string // store stream items
  108. var forceFormat bool
  109. var thinkToContent bool
  110. if forceFmt, ok := info.ChannelSetting[constant.ForceFormat].(bool); ok {
  111. forceFormat = forceFmt
  112. }
  113. if think2Content, ok := info.ChannelSetting[constant.ChannelSettingThinkingToContent].(bool); ok {
  114. thinkToContent = think2Content
  115. }
  116. var (
  117. lastStreamData string
  118. )
  119. helper.StreamScannerHandler(c, resp, info, func(data string) bool {
  120. if lastStreamData != "" {
  121. err := handleStreamFormat(c, info, lastStreamData, forceFormat, thinkToContent)
  122. if err != nil {
  123. common.SysError("error handling stream format: " + err.Error())
  124. }
  125. }
  126. lastStreamData = data
  127. streamItems = append(streamItems, data)
  128. return true
  129. })
  130. shouldSendLastResp := true
  131. var lastStreamResponse dto.ChatCompletionsStreamResponse
  132. err := common.DecodeJsonStr(lastStreamData, &lastStreamResponse)
  133. if err == nil {
  134. responseId = lastStreamResponse.Id
  135. createAt = lastStreamResponse.Created
  136. systemFingerprint = lastStreamResponse.GetSystemFingerprint()
  137. model = lastStreamResponse.Model
  138. if service.ValidUsage(lastStreamResponse.Usage) {
  139. containStreamUsage = true
  140. usage = lastStreamResponse.Usage
  141. if !info.ShouldIncludeUsage {
  142. shouldSendLastResp = false
  143. }
  144. }
  145. for _, choice := range lastStreamResponse.Choices {
  146. if choice.FinishReason != nil {
  147. shouldSendLastResp = true
  148. }
  149. }
  150. }
  151. if shouldSendLastResp {
  152. sendStreamData(c, info, lastStreamData, forceFormat, thinkToContent)
  153. //err = handleStreamFormat(c, info, lastStreamData, forceFormat, thinkToContent)
  154. }
  155. // 处理token计算
  156. if err := processTokens(info.RelayMode, streamItems, &responseTextBuilder, &toolCount); err != nil {
  157. common.SysError("error processing tokens: " + err.Error())
  158. }
  159. if !containStreamUsage {
  160. usage, _ = service.ResponseText2Usage(responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens)
  161. usage.CompletionTokens += toolCount * 7
  162. } else {
  163. if info.ChannelType == common.ChannelTypeDeepSeek {
  164. if usage.PromptCacheHitTokens != 0 {
  165. usage.PromptTokensDetails.CachedTokens = usage.PromptCacheHitTokens
  166. }
  167. }
  168. }
  169. handleFinalResponse(c, info, lastStreamData, responseId, createAt, model, systemFingerprint, usage, containStreamUsage)
  170. return nil, usage
  171. }
  172. func OpenaiHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
  173. var simpleResponse dto.OpenAITextResponse
  174. responseBody, err := io.ReadAll(resp.Body)
  175. if err != nil {
  176. return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
  177. }
  178. err = resp.Body.Close()
  179. if err != nil {
  180. return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
  181. }
  182. err = common.DecodeJson(responseBody, &simpleResponse)
  183. if err != nil {
  184. return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
  185. }
  186. if simpleResponse.Error != nil && simpleResponse.Error.Type != "" {
  187. return &dto.OpenAIErrorWithStatusCode{
  188. Error: *simpleResponse.Error,
  189. StatusCode: resp.StatusCode,
  190. }, nil
  191. }
  192. switch info.RelayFormat {
  193. case relaycommon.RelayFormatOpenAI:
  194. break
  195. case relaycommon.RelayFormatClaude:
  196. claudeResp := service.ResponseOpenAI2Claude(&simpleResponse, info)
  197. claudeRespStr, err := json.Marshal(claudeResp)
  198. if err != nil {
  199. return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
  200. }
  201. responseBody = claudeRespStr
  202. }
  203. // Reset response body
  204. resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
  205. // We shouldn't set the header before we parse the response body, because the parse part may fail.
  206. // And then we will have to send an error response, but in this case, the header has already been set.
  207. // So the httpClient will be confused by the response.
  208. // For example, Postman will report error, and we cannot check the response at all.
  209. for k, v := range resp.Header {
  210. c.Writer.Header().Set(k, v[0])
  211. }
  212. c.Writer.WriteHeader(resp.StatusCode)
  213. _, err = io.Copy(c.Writer, resp.Body)
  214. if err != nil {
  215. //return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
  216. common.SysError("error copying response body: " + err.Error())
  217. }
  218. resp.Body.Close()
  219. if simpleResponse.Usage.TotalTokens == 0 || (simpleResponse.Usage.PromptTokens == 0 && simpleResponse.Usage.CompletionTokens == 0) {
  220. completionTokens := 0
  221. for _, choice := range simpleResponse.Choices {
  222. ctkm, _ := service.CountTextToken(choice.Message.StringContent()+choice.Message.ReasoningContent+choice.Message.Reasoning, info.UpstreamModelName)
  223. completionTokens += ctkm
  224. }
  225. simpleResponse.Usage = dto.Usage{
  226. PromptTokens: info.PromptTokens,
  227. CompletionTokens: completionTokens,
  228. TotalTokens: info.PromptTokens + completionTokens,
  229. }
  230. }
  231. return nil, &simpleResponse.Usage
  232. }
  233. func OpenaiTTSHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
  234. responseBody, err := io.ReadAll(resp.Body)
  235. if err != nil {
  236. return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
  237. }
  238. err = resp.Body.Close()
  239. if err != nil {
  240. return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
  241. }
  242. // Reset response body
  243. resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
  244. // We shouldn't set the header before we parse the response body, because the parse part may fail.
  245. // And then we will have to send an error response, but in this case, the header has already been set.
  246. // So the httpClient will be confused by the response.
  247. // For example, Postman will report error, and we cannot check the response at all.
  248. for k, v := range resp.Header {
  249. c.Writer.Header().Set(k, v[0])
  250. }
  251. c.Writer.WriteHeader(resp.StatusCode)
  252. _, err = io.Copy(c.Writer, resp.Body)
  253. if err != nil {
  254. return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
  255. }
  256. err = resp.Body.Close()
  257. if err != nil {
  258. return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
  259. }
  260. usage := &dto.Usage{}
  261. usage.PromptTokens = info.PromptTokens
  262. usage.TotalTokens = info.PromptTokens
  263. return nil, usage
  264. }
  265. func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, responseFormat string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
  266. // count tokens by audio file duration
  267. audioTokens, err := countAudioTokens(c)
  268. if err != nil {
  269. return service.OpenAIErrorWrapper(err, "count_audio_tokens_failed", http.StatusInternalServerError), nil
  270. }
  271. responseBody, err := io.ReadAll(resp.Body)
  272. if err != nil {
  273. return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
  274. }
  275. err = resp.Body.Close()
  276. if err != nil {
  277. return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
  278. }
  279. // Reset response body
  280. resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
  281. // We shouldn't set the header before we parse the response body, because the parse part may fail.
  282. // And then we will have to send an error response, but in this case, the header has already been set.
  283. // So the httpClient will be confused by the response.
  284. // For example, Postman will report error, and we cannot check the response at all.
  285. for k, v := range resp.Header {
  286. c.Writer.Header().Set(k, v[0])
  287. }
  288. c.Writer.WriteHeader(resp.StatusCode)
  289. _, err = io.Copy(c.Writer, resp.Body)
  290. if err != nil {
  291. return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
  292. }
  293. resp.Body.Close()
  294. usage := &dto.Usage{}
  295. usage.PromptTokens = audioTokens
  296. usage.CompletionTokens = 0
  297. usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
  298. return nil, usage
  299. }
  300. func countAudioTokens(c *gin.Context) (int, error) {
  301. body, err := common.GetRequestBody(c)
  302. if err != nil {
  303. return 0, errors.WithStack(err)
  304. }
  305. var reqBody struct {
  306. File *multipart.FileHeader `form:"file" binding:"required"`
  307. }
  308. c.Request.Body = io.NopCloser(bytes.NewReader(body))
  309. if err = c.ShouldBind(&reqBody); err != nil {
  310. return 0, errors.WithStack(err)
  311. }
  312. reqFp, err := reqBody.File.Open()
  313. if err != nil {
  314. return 0, errors.WithStack(err)
  315. }
  316. tmpFp, err := os.CreateTemp("", "audio-*")
  317. if err != nil {
  318. return 0, errors.WithStack(err)
  319. }
  320. defer os.Remove(tmpFp.Name())
  321. _, err = io.Copy(tmpFp, reqFp)
  322. if err != nil {
  323. return 0, errors.WithStack(err)
  324. }
  325. if err = tmpFp.Close(); err != nil {
  326. return 0, errors.WithStack(err)
  327. }
  328. duration, err := common.GetAudioDuration(c.Request.Context(), tmpFp.Name())
  329. if err != nil {
  330. return 0, errors.WithStack(err)
  331. }
  332. return int(math.Round(math.Ceil(duration) / 60.0 * 1000)), nil // 1 minute 相当于 1k tokens
  333. }
  334. func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.RealtimeUsage) {
  335. if info == nil || info.ClientWs == nil || info.TargetWs == nil {
  336. return service.OpenAIErrorWrapper(fmt.Errorf("invalid websocket connection"), "invalid_connection", http.StatusBadRequest), nil
  337. }
  338. info.IsStream = true
  339. clientConn := info.ClientWs
  340. targetConn := info.TargetWs
  341. clientClosed := make(chan struct{})
  342. targetClosed := make(chan struct{})
  343. sendChan := make(chan []byte, 100)
  344. receiveChan := make(chan []byte, 100)
  345. errChan := make(chan error, 2)
  346. usage := &dto.RealtimeUsage{}
  347. localUsage := &dto.RealtimeUsage{}
  348. sumUsage := &dto.RealtimeUsage{}
  349. gopool.Go(func() {
  350. defer func() {
  351. if r := recover(); r != nil {
  352. errChan <- fmt.Errorf("panic in client reader: %v", r)
  353. }
  354. }()
  355. for {
  356. select {
  357. case <-c.Done():
  358. return
  359. default:
  360. _, message, err := clientConn.ReadMessage()
  361. if err != nil {
  362. if !websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) {
  363. errChan <- fmt.Errorf("error reading from client: %v", err)
  364. }
  365. close(clientClosed)
  366. return
  367. }
  368. realtimeEvent := &dto.RealtimeEvent{}
  369. err = json.Unmarshal(message, realtimeEvent)
  370. if err != nil {
  371. errChan <- fmt.Errorf("error unmarshalling message: %v", err)
  372. return
  373. }
  374. if realtimeEvent.Type == dto.RealtimeEventTypeSessionUpdate {
  375. if realtimeEvent.Session != nil {
  376. if realtimeEvent.Session.Tools != nil {
  377. info.RealtimeTools = realtimeEvent.Session.Tools
  378. }
  379. }
  380. }
  381. textToken, audioToken, err := service.CountTokenRealtime(info, *realtimeEvent, info.UpstreamModelName)
  382. if err != nil {
  383. errChan <- fmt.Errorf("error counting text token: %v", err)
  384. return
  385. }
  386. common.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken))
  387. localUsage.TotalTokens += textToken + audioToken
  388. localUsage.InputTokens += textToken + audioToken
  389. localUsage.InputTokenDetails.TextTokens += textToken
  390. localUsage.InputTokenDetails.AudioTokens += audioToken
  391. err = helper.WssString(c, targetConn, string(message))
  392. if err != nil {
  393. errChan <- fmt.Errorf("error writing to target: %v", err)
  394. return
  395. }
  396. select {
  397. case sendChan <- message:
  398. default:
  399. }
  400. }
  401. }
  402. })
  403. gopool.Go(func() {
  404. defer func() {
  405. if r := recover(); r != nil {
  406. errChan <- fmt.Errorf("panic in target reader: %v", r)
  407. }
  408. }()
  409. for {
  410. select {
  411. case <-c.Done():
  412. return
  413. default:
  414. _, message, err := targetConn.ReadMessage()
  415. if err != nil {
  416. if !websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) {
  417. errChan <- fmt.Errorf("error reading from target: %v", err)
  418. }
  419. close(targetClosed)
  420. return
  421. }
  422. info.SetFirstResponseTime()
  423. realtimeEvent := &dto.RealtimeEvent{}
  424. err = json.Unmarshal(message, realtimeEvent)
  425. if err != nil {
  426. errChan <- fmt.Errorf("error unmarshalling message: %v", err)
  427. return
  428. }
  429. if realtimeEvent.Type == dto.RealtimeEventTypeResponseDone {
  430. realtimeUsage := realtimeEvent.Response.Usage
  431. if realtimeUsage != nil {
  432. usage.TotalTokens += realtimeUsage.TotalTokens
  433. usage.InputTokens += realtimeUsage.InputTokens
  434. usage.OutputTokens += realtimeUsage.OutputTokens
  435. usage.InputTokenDetails.AudioTokens += realtimeUsage.InputTokenDetails.AudioTokens
  436. usage.InputTokenDetails.CachedTokens += realtimeUsage.InputTokenDetails.CachedTokens
  437. usage.InputTokenDetails.TextTokens += realtimeUsage.InputTokenDetails.TextTokens
  438. usage.OutputTokenDetails.AudioTokens += realtimeUsage.OutputTokenDetails.AudioTokens
  439. usage.OutputTokenDetails.TextTokens += realtimeUsage.OutputTokenDetails.TextTokens
  440. err := preConsumeUsage(c, info, usage, sumUsage)
  441. if err != nil {
  442. errChan <- fmt.Errorf("error consume usage: %v", err)
  443. return
  444. }
  445. // 本次计费完成,清除
  446. usage = &dto.RealtimeUsage{}
  447. localUsage = &dto.RealtimeUsage{}
  448. } else {
  449. textToken, audioToken, err := service.CountTokenRealtime(info, *realtimeEvent, info.UpstreamModelName)
  450. if err != nil {
  451. errChan <- fmt.Errorf("error counting text token: %v", err)
  452. return
  453. }
  454. common.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken))
  455. localUsage.TotalTokens += textToken + audioToken
  456. info.IsFirstRequest = false
  457. localUsage.InputTokens += textToken + audioToken
  458. localUsage.InputTokenDetails.TextTokens += textToken
  459. localUsage.InputTokenDetails.AudioTokens += audioToken
  460. err = preConsumeUsage(c, info, localUsage, sumUsage)
  461. if err != nil {
  462. errChan <- fmt.Errorf("error consume usage: %v", err)
  463. return
  464. }
  465. // 本次计费完成,清除
  466. localUsage = &dto.RealtimeUsage{}
  467. // print now usage
  468. }
  469. //common.LogInfo(c, fmt.Sprintf("realtime streaming sumUsage: %v", sumUsage))
  470. //common.LogInfo(c, fmt.Sprintf("realtime streaming localUsage: %v", localUsage))
  471. //common.LogInfo(c, fmt.Sprintf("realtime streaming localUsage: %v", localUsage))
  472. } else if realtimeEvent.Type == dto.RealtimeEventTypeSessionUpdated || realtimeEvent.Type == dto.RealtimeEventTypeSessionCreated {
  473. realtimeSession := realtimeEvent.Session
  474. if realtimeSession != nil {
  475. // update audio format
  476. info.InputAudioFormat = common.GetStringIfEmpty(realtimeSession.InputAudioFormat, info.InputAudioFormat)
  477. info.OutputAudioFormat = common.GetStringIfEmpty(realtimeSession.OutputAudioFormat, info.OutputAudioFormat)
  478. }
  479. } else {
  480. textToken, audioToken, err := service.CountTokenRealtime(info, *realtimeEvent, info.UpstreamModelName)
  481. if err != nil {
  482. errChan <- fmt.Errorf("error counting text token: %v", err)
  483. return
  484. }
  485. common.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken))
  486. localUsage.TotalTokens += textToken + audioToken
  487. localUsage.OutputTokens += textToken + audioToken
  488. localUsage.OutputTokenDetails.TextTokens += textToken
  489. localUsage.OutputTokenDetails.AudioTokens += audioToken
  490. }
  491. err = helper.WssString(c, clientConn, string(message))
  492. if err != nil {
  493. errChan <- fmt.Errorf("error writing to client: %v", err)
  494. return
  495. }
  496. select {
  497. case receiveChan <- message:
  498. default:
  499. }
  500. }
  501. }
  502. })
  503. select {
  504. case <-clientClosed:
  505. case <-targetClosed:
  506. case err := <-errChan:
  507. //return service.OpenAIErrorWrapper(err, "realtime_error", http.StatusInternalServerError), nil
  508. common.LogError(c, "realtime error: "+err.Error())
  509. case <-c.Done():
  510. }
  511. if usage.TotalTokens != 0 {
  512. _ = preConsumeUsage(c, info, usage, sumUsage)
  513. }
  514. if localUsage.TotalTokens != 0 {
  515. _ = preConsumeUsage(c, info, localUsage, sumUsage)
  516. }
  517. // check usage total tokens, if 0, use local usage
  518. return nil, sumUsage
  519. }
  520. func preConsumeUsage(ctx *gin.Context, info *relaycommon.RelayInfo, usage *dto.RealtimeUsage, totalUsage *dto.RealtimeUsage) error {
  521. if usage == nil || totalUsage == nil {
  522. return fmt.Errorf("invalid usage pointer")
  523. }
  524. totalUsage.TotalTokens += usage.TotalTokens
  525. totalUsage.InputTokens += usage.InputTokens
  526. totalUsage.OutputTokens += usage.OutputTokens
  527. totalUsage.InputTokenDetails.CachedTokens += usage.InputTokenDetails.CachedTokens
  528. totalUsage.InputTokenDetails.TextTokens += usage.InputTokenDetails.TextTokens
  529. totalUsage.InputTokenDetails.AudioTokens += usage.InputTokenDetails.AudioTokens
  530. totalUsage.OutputTokenDetails.TextTokens += usage.OutputTokenDetails.TextTokens
  531. totalUsage.OutputTokenDetails.AudioTokens += usage.OutputTokenDetails.AudioTokens
  532. // clear usage
  533. err := service.PreWssConsumeQuota(ctx, info, usage)
  534. return err
  535. }
  536. func OpenaiHandlerWithUsage(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
  537. responseBody, err := io.ReadAll(resp.Body)
  538. if err != nil {
  539. return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
  540. }
  541. err = resp.Body.Close()
  542. if err != nil {
  543. return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
  544. }
  545. // Reset response body
  546. resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
  547. // We shouldn't set the header before we parse the response body, because the parse part may fail.
  548. // And then we will have to send an error response, but in this case, the header has already been set.
  549. // So the httpClient will be confused by the response.
  550. // For example, Postman will report error, and we cannot check the response at all.
  551. for k, v := range resp.Header {
  552. c.Writer.Header().Set(k, v[0])
  553. }
  554. // reset content length
  555. c.Writer.Header().Set("Content-Length", fmt.Sprintf("%d", len(responseBody)))
  556. c.Writer.WriteHeader(resp.StatusCode)
  557. _, err = io.Copy(c.Writer, resp.Body)
  558. if err != nil {
  559. return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
  560. }
  561. err = resp.Body.Close()
  562. if err != nil {
  563. return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
  564. }
  565. var usageResp dto.SimpleResponse
  566. err = json.Unmarshal(responseBody, &usageResp)
  567. if err != nil {
  568. return service.OpenAIErrorWrapper(err, "parse_response_body_failed", http.StatusInternalServerError), nil
  569. }
  570. // format
  571. if usageResp.InputTokens > 0 {
  572. usageResp.PromptTokens += usageResp.InputTokens
  573. }
  574. if usageResp.OutputTokens > 0 {
  575. usageResp.CompletionTokens += usageResp.OutputTokens
  576. }
  577. if usageResp.InputTokensDetails != nil {
  578. usageResp.PromptTokensDetails.ImageTokens += usageResp.InputTokensDetails.ImageTokens
  579. usageResp.PromptTokensDetails.TextTokens += usageResp.InputTokensDetails.TextTokens
  580. }
  581. return nil, &usageResp.Usage
  582. }
  583. func OpenaiResponsesHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
  584. // read response body
  585. var responsesResponse dto.OpenAIResponsesResponse
  586. responseBody, err := io.ReadAll(resp.Body)
  587. if err != nil {
  588. return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
  589. }
  590. err = resp.Body.Close()
  591. if err != nil {
  592. return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
  593. }
  594. err = common.DecodeJson(responseBody, &responsesResponse)
  595. if err != nil {
  596. return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
  597. }
  598. if responsesResponse.Error != nil {
  599. return &dto.OpenAIErrorWithStatusCode{
  600. Error: dto.OpenAIError{
  601. Message: responsesResponse.Error.Message,
  602. Type: "openai_error",
  603. Code: responsesResponse.Error.Code,
  604. },
  605. StatusCode: resp.StatusCode,
  606. }, nil
  607. }
  608. // reset response body
  609. resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
  610. // We shouldn't set the header before we parse the response body, because the parse part may fail.
  611. // And then we will have to send an error response, but in this case, the header has already been set.
  612. // So the httpClient will be confused by the response.
  613. // For example, Postman will report error, and we cannot check the response at all.
  614. for k, v := range resp.Header {
  615. c.Writer.Header().Set(k, v[0])
  616. }
  617. c.Writer.WriteHeader(resp.StatusCode)
  618. // copy response body
  619. _, err = io.Copy(c.Writer, resp.Body)
  620. if err != nil {
  621. common.SysError("error copying response body: " + err.Error())
  622. }
  623. resp.Body.Close()
  624. // compute usage
  625. usage := dto.Usage{}
  626. usage.PromptTokens = responsesResponse.Usage.InputTokens
  627. usage.CompletionTokens = responsesResponse.Usage.OutputTokens
  628. usage.TotalTokens = responsesResponse.Usage.TotalTokens
  629. return nil, &usage
  630. }
  631. func OaiResponsesStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
  632. if resp == nil || resp.Body == nil {
  633. common.LogError(c, "invalid response or response body")
  634. return service.OpenAIErrorWrapper(fmt.Errorf("invalid response"), "invalid_response", http.StatusInternalServerError), nil
  635. }
  636. var usage = &dto.Usage{}
  637. var responseTextBuilder strings.Builder
  638. helper.StreamScannerHandler(c, resp, info, func(data string) bool {
  639. // 检查当前数据是否包含 completed 状态和 usage 信息
  640. var streamResponse dto.ResponsesStreamResponse
  641. if err := common.DecodeJsonStr(data, &streamResponse); err == nil {
  642. sendResponsesStreamData(c, streamResponse, data)
  643. switch streamResponse.Type {
  644. case "response.completed":
  645. usage.PromptTokens = streamResponse.Response.Usage.InputTokens
  646. usage.CompletionTokens = streamResponse.Response.Usage.OutputTokens
  647. usage.TotalTokens = streamResponse.Response.Usage.TotalTokens
  648. case "response.output_text.delta":
  649. // 处理输出文本
  650. responseTextBuilder.WriteString(streamResponse.Delta)
  651. }
  652. }
  653. return true
  654. })
  655. helper.Done(c)
  656. if usage.CompletionTokens == 0 {
  657. // 计算输出文本的 token 数量
  658. tempStr := responseTextBuilder.String()
  659. if len(tempStr) > 0 {
  660. // 非正常结束,使用输出文本的 token 数量
  661. completionTokens, _ := service.CountTextToken(tempStr, info.UpstreamModelName)
  662. usage.CompletionTokens = completionTokens
  663. }
  664. }
  665. return nil, usage
  666. }
  667. func sendResponsesStreamData(c *gin.Context, streamResponse dto.ResponsesStreamResponse, data string) {
  668. if data == "" {
  669. return
  670. }
  671. helper.ResponseChunkData(c, streamResponse, data)
  672. }