package openai import ( "bytes" "fmt" "io" "math" "net/http" "github.com/QuantumNous/new-api/common" "github.com/QuantumNous/new-api/constant" "github.com/QuantumNous/new-api/dto" "github.com/QuantumNous/new-api/logger" relaycommon "github.com/QuantumNous/new-api/relay/common" "github.com/QuantumNous/new-api/relay/helper" "github.com/QuantumNous/new-api/service" "github.com/QuantumNous/new-api/types" "github.com/gin-gonic/gin" ) func OpenaiTTSHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) *dto.Usage { // the status code has been judged before, if there is a body reading failure, // it should be regarded as a non-recoverable error, so it should not return err for external retry. // Analogous to nginx's load balancing, it will only retry if it can't be requested or // if the upstream returns a specific status code, once the upstream has already written the header, // the subsequent failure of the response body should be regarded as a non-recoverable error, // and can be terminated directly. defer service.CloseResponseBodyGracefully(resp) usage := &dto.Usage{} usage.PromptTokens = info.GetEstimatePromptTokens() usage.TotalTokens = info.GetEstimatePromptTokens() for k, v := range resp.Header { c.Writer.Header().Set(k, v[0]) } c.Writer.WriteHeader(resp.StatusCode) if info.IsStream { helper.StreamScannerHandler(c, resp, info, func(data string) bool { if service.SundaySearch(data, "usage") { var simpleResponse dto.SimpleResponse err := common.Unmarshal([]byte(data), &simpleResponse) if err != nil { logger.LogError(c, err.Error()) } if simpleResponse.Usage.TotalTokens != 0 { usage.PromptTokens = simpleResponse.Usage.InputTokens usage.CompletionTokens = simpleResponse.OutputTokens usage.TotalTokens = simpleResponse.TotalTokens } } _ = helper.StringData(c, data) return true }) } else { common.SetContextKey(c, constant.ContextKeyLocalCountTokens, true) // 读取响应体到缓冲区 bodyBytes, err := io.ReadAll(resp.Body) if err != nil { logger.LogError(c, fmt.Sprintf("failed to read TTS response body: %v", err)) c.Writer.WriteHeaderNow() return usage } // 写入响应到客户端 c.Writer.WriteHeaderNow() _, err = c.Writer.Write(bodyBytes) if err != nil { logger.LogError(c, fmt.Sprintf("failed to write TTS response: %v", err)) } // 计算音频时长并更新 usage audioFormat := "mp3" // 默认格式 if audioReq, ok := info.Request.(*dto.AudioRequest); ok && audioReq.ResponseFormat != "" { audioFormat = audioReq.ResponseFormat } var duration float64 var durationErr error if audioFormat == "pcm" { // PCM 格式没有文件头,根据 OpenAI TTS 的 PCM 参数计算时长 // 采样率: 24000 Hz, 位深度: 16-bit (2 bytes), 声道数: 1 const sampleRate = 24000 const bytesPerSample = 2 const channels = 1 duration = float64(len(bodyBytes)) / float64(sampleRate*bytesPerSample*channels) } else { ext := "." + audioFormat reader := bytes.NewReader(bodyBytes) duration, durationErr = common.GetAudioDuration(c.Request.Context(), reader, ext) } usage.PromptTokensDetails.TextTokens = usage.PromptTokens if durationErr != nil { logger.LogWarn(c, fmt.Sprintf("failed to get audio duration: %v", durationErr)) // 如果无法获取时长,则设置保底的 CompletionTokens,根据body大小计算 sizeInKB := float64(len(bodyBytes)) / 1000.0 estimatedTokens := int(math.Ceil(sizeInKB)) // 粗略估算每KB约等于1 token usage.CompletionTokens = estimatedTokens usage.CompletionTokenDetails.AudioTokens = estimatedTokens } else if duration > 0 { // 计算 token: ceil(duration) / 60.0 * 1000,即每分钟 1000 tokens completionTokens := int(math.Round(math.Ceil(duration) / 60.0 * 1000)) usage.CompletionTokens = completionTokens usage.CompletionTokenDetails.AudioTokens = completionTokens } usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens } return usage } func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, responseFormat string) (*types.NewAPIError, *dto.Usage) { defer service.CloseResponseBodyGracefully(resp) responseBody, err := io.ReadAll(resp.Body) if err != nil { return types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), nil } // 写入新的 response body service.IOCopyBytesGracefully(c, resp, responseBody) var responseData struct { Usage *dto.Usage `json:"usage"` } if err := common.Unmarshal(responseBody, &responseData); err == nil && responseData.Usage != nil { if responseData.Usage.TotalTokens > 0 { usage := responseData.Usage if usage.PromptTokens == 0 { usage.PromptTokens = usage.InputTokens } if usage.CompletionTokens == 0 { usage.CompletionTokens = usage.OutputTokens } return nil, usage } } usage := &dto.Usage{} usage.PromptTokens = info.GetEstimatePromptTokens() usage.CompletionTokens = 0 usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens return nil, usage }