audio.go 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145
  1. package openai
  2. import (
  3. "bytes"
  4. "fmt"
  5. "io"
  6. "math"
  7. "net/http"
  8. "github.com/QuantumNous/new-api/common"
  9. "github.com/QuantumNous/new-api/constant"
  10. "github.com/QuantumNous/new-api/dto"
  11. "github.com/QuantumNous/new-api/logger"
  12. relaycommon "github.com/QuantumNous/new-api/relay/common"
  13. "github.com/QuantumNous/new-api/relay/helper"
  14. "github.com/QuantumNous/new-api/service"
  15. "github.com/QuantumNous/new-api/types"
  16. "github.com/gin-gonic/gin"
  17. )
  18. func OpenaiTTSHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) *dto.Usage {
  19. // the status code has been judged before, if there is a body reading failure,
  20. // it should be regarded as a non-recoverable error, so it should not return err for external retry.
  21. // Analogous to nginx's load balancing, it will only retry if it can't be requested or
  22. // if the upstream returns a specific status code, once the upstream has already written the header,
  23. // the subsequent failure of the response body should be regarded as a non-recoverable error,
  24. // and can be terminated directly.
  25. defer service.CloseResponseBodyGracefully(resp)
  26. usage := &dto.Usage{}
  27. usage.PromptTokens = info.GetEstimatePromptTokens()
  28. usage.TotalTokens = info.GetEstimatePromptTokens()
  29. for k, v := range resp.Header {
  30. c.Writer.Header().Set(k, v[0])
  31. }
  32. c.Writer.WriteHeader(resp.StatusCode)
  33. if info.IsStream {
  34. helper.StreamScannerHandler(c, resp, info, func(data string) bool {
  35. if service.SundaySearch(data, "usage") {
  36. var simpleResponse dto.SimpleResponse
  37. err := common.Unmarshal([]byte(data), &simpleResponse)
  38. if err != nil {
  39. logger.LogError(c, err.Error())
  40. }
  41. if simpleResponse.Usage.TotalTokens != 0 {
  42. usage.PromptTokens = simpleResponse.Usage.InputTokens
  43. usage.CompletionTokens = simpleResponse.OutputTokens
  44. usage.TotalTokens = simpleResponse.TotalTokens
  45. }
  46. }
  47. _ = helper.StringData(c, data)
  48. return true
  49. })
  50. } else {
  51. common.SetContextKey(c, constant.ContextKeyLocalCountTokens, true)
  52. // 读取响应体到缓冲区
  53. bodyBytes, err := io.ReadAll(resp.Body)
  54. if err != nil {
  55. logger.LogError(c, fmt.Sprintf("failed to read TTS response body: %v", err))
  56. c.Writer.WriteHeaderNow()
  57. return usage
  58. }
  59. // 写入响应到客户端
  60. c.Writer.WriteHeaderNow()
  61. _, err = c.Writer.Write(bodyBytes)
  62. if err != nil {
  63. logger.LogError(c, fmt.Sprintf("failed to write TTS response: %v", err))
  64. }
  65. // 计算音频时长并更新 usage
  66. audioFormat := "mp3" // 默认格式
  67. if audioReq, ok := info.Request.(*dto.AudioRequest); ok && audioReq.ResponseFormat != "" {
  68. audioFormat = audioReq.ResponseFormat
  69. }
  70. var duration float64
  71. var durationErr error
  72. if audioFormat == "pcm" {
  73. // PCM 格式没有文件头,根据 OpenAI TTS 的 PCM 参数计算时长
  74. // 采样率: 24000 Hz, 位深度: 16-bit (2 bytes), 声道数: 1
  75. const sampleRate = 24000
  76. const bytesPerSample = 2
  77. const channels = 1
  78. duration = float64(len(bodyBytes)) / float64(sampleRate*bytesPerSample*channels)
  79. } else {
  80. ext := "." + audioFormat
  81. reader := bytes.NewReader(bodyBytes)
  82. duration, durationErr = common.GetAudioDuration(c.Request.Context(), reader, ext)
  83. }
  84. usage.PromptTokensDetails.TextTokens = usage.PromptTokens
  85. usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
  86. if durationErr != nil {
  87. logger.LogWarn(c, fmt.Sprintf("failed to get audio duration: %v", durationErr))
  88. // 如果无法获取时长,则设置保底的 CompletionTokens,根据body大小计算
  89. sizeInKB := float64(len(bodyBytes)) / 1000.0
  90. estimatedTokens := int(math.Ceil(sizeInKB)) // 粗略估算每KB约等于1 token
  91. usage.CompletionTokens = estimatedTokens
  92. usage.CompletionTokenDetails.AudioTokens = estimatedTokens
  93. } else if duration > 0 {
  94. // 计算 token: ceil(duration) / 60.0 * 1000,即每分钟 1000 tokens
  95. completionTokens := int(math.Round(math.Ceil(duration) / 60.0 * 1000))
  96. usage.CompletionTokens = completionTokens
  97. usage.CompletionTokenDetails.AudioTokens = completionTokens
  98. }
  99. }
  100. return usage
  101. }
  102. func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, responseFormat string) (*types.NewAPIError, *dto.Usage) {
  103. defer service.CloseResponseBodyGracefully(resp)
  104. responseBody, err := io.ReadAll(resp.Body)
  105. if err != nil {
  106. return types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), nil
  107. }
  108. // 写入新的 response body
  109. service.IOCopyBytesGracefully(c, resp, responseBody)
  110. var responseData struct {
  111. Usage *dto.Usage `json:"usage"`
  112. }
  113. if err := common.Unmarshal(responseBody, &responseData); err == nil && responseData.Usage != nil {
  114. if responseData.Usage.TotalTokens > 0 {
  115. usage := responseData.Usage
  116. if usage.PromptTokens == 0 {
  117. usage.PromptTokens = usage.InputTokens
  118. }
  119. if usage.CompletionTokens == 0 {
  120. usage.CompletionTokens = usage.OutputTokens
  121. }
  122. return nil, usage
  123. }
  124. }
  125. usage := &dto.Usage{}
  126. usage.PromptTokens = info.GetEstimatePromptTokens()
  127. usage.CompletionTokens = 0
  128. usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
  129. return nil, usage
  130. }