| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145 |
- package openai
- import (
- "bytes"
- "fmt"
- "io"
- "math"
- "net/http"
- "github.com/QuantumNous/new-api/common"
- "github.com/QuantumNous/new-api/constant"
- "github.com/QuantumNous/new-api/dto"
- "github.com/QuantumNous/new-api/logger"
- relaycommon "github.com/QuantumNous/new-api/relay/common"
- "github.com/QuantumNous/new-api/relay/helper"
- "github.com/QuantumNous/new-api/service"
- "github.com/QuantumNous/new-api/types"
- "github.com/gin-gonic/gin"
- )
- func OpenaiTTSHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) *dto.Usage {
- // the status code has been judged before, if there is a body reading failure,
- // it should be regarded as a non-recoverable error, so it should not return err for external retry.
- // Analogous to nginx's load balancing, it will only retry if it can't be requested or
- // if the upstream returns a specific status code, once the upstream has already written the header,
- // the subsequent failure of the response body should be regarded as a non-recoverable error,
- // and can be terminated directly.
- defer service.CloseResponseBodyGracefully(resp)
- usage := &dto.Usage{}
- usage.PromptTokens = info.GetEstimatePromptTokens()
- usage.TotalTokens = info.GetEstimatePromptTokens()
- for k, v := range resp.Header {
- c.Writer.Header().Set(k, v[0])
- }
- c.Writer.WriteHeader(resp.StatusCode)
- if info.IsStream {
- helper.StreamScannerHandler(c, resp, info, func(data string) bool {
- if service.SundaySearch(data, "usage") {
- var simpleResponse dto.SimpleResponse
- err := common.Unmarshal([]byte(data), &simpleResponse)
- if err != nil {
- logger.LogError(c, err.Error())
- }
- if simpleResponse.Usage.TotalTokens != 0 {
- usage.PromptTokens = simpleResponse.Usage.InputTokens
- usage.CompletionTokens = simpleResponse.OutputTokens
- usage.TotalTokens = simpleResponse.TotalTokens
- }
- }
- _ = helper.StringData(c, data)
- return true
- })
- } else {
- common.SetContextKey(c, constant.ContextKeyLocalCountTokens, true)
- // 读取响应体到缓冲区
- bodyBytes, err := io.ReadAll(resp.Body)
- if err != nil {
- logger.LogError(c, fmt.Sprintf("failed to read TTS response body: %v", err))
- c.Writer.WriteHeaderNow()
- return usage
- }
- // 写入响应到客户端
- c.Writer.WriteHeaderNow()
- _, err = c.Writer.Write(bodyBytes)
- if err != nil {
- logger.LogError(c, fmt.Sprintf("failed to write TTS response: %v", err))
- }
- // 计算音频时长并更新 usage
- audioFormat := "mp3" // 默认格式
- if audioReq, ok := info.Request.(*dto.AudioRequest); ok && audioReq.ResponseFormat != "" {
- audioFormat = audioReq.ResponseFormat
- }
- var duration float64
- var durationErr error
- if audioFormat == "pcm" {
- // PCM 格式没有文件头,根据 OpenAI TTS 的 PCM 参数计算时长
- // 采样率: 24000 Hz, 位深度: 16-bit (2 bytes), 声道数: 1
- const sampleRate = 24000
- const bytesPerSample = 2
- const channels = 1
- duration = float64(len(bodyBytes)) / float64(sampleRate*bytesPerSample*channels)
- } else {
- ext := "." + audioFormat
- reader := bytes.NewReader(bodyBytes)
- duration, durationErr = common.GetAudioDuration(c.Request.Context(), reader, ext)
- }
- usage.PromptTokensDetails.TextTokens = usage.PromptTokens
- usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
- if durationErr != nil {
- logger.LogWarn(c, fmt.Sprintf("failed to get audio duration: %v", durationErr))
- // 如果无法获取时长,则设置保底的 CompletionTokens,根据body大小计算
- sizeInKB := float64(len(bodyBytes)) / 1000.0
- estimatedTokens := int(math.Ceil(sizeInKB)) // 粗略估算每KB约等于1 token
- usage.CompletionTokens = estimatedTokens
- usage.CompletionTokenDetails.AudioTokens = estimatedTokens
- } else if duration > 0 {
- // 计算 token: ceil(duration) / 60.0 * 1000,即每分钟 1000 tokens
- completionTokens := int(math.Round(math.Ceil(duration) / 60.0 * 1000))
- usage.CompletionTokens = completionTokens
- usage.CompletionTokenDetails.AudioTokens = completionTokens
- }
- }
- return usage
- }
- func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, responseFormat string) (*types.NewAPIError, *dto.Usage) {
- defer service.CloseResponseBodyGracefully(resp)
- responseBody, err := io.ReadAll(resp.Body)
- if err != nil {
- return types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), nil
- }
- // 写入新的 response body
- service.IOCopyBytesGracefully(c, resp, responseBody)
- var responseData struct {
- Usage *dto.Usage `json:"usage"`
- }
- if err := common.Unmarshal(responseBody, &responseData); err == nil && responseData.Usage != nil {
- if responseData.Usage.TotalTokens > 0 {
- usage := responseData.Usage
- if usage.PromptTokens == 0 {
- usage.PromptTokens = usage.InputTokens
- }
- if usage.CompletionTokens == 0 {
- usage.CompletionTokens = usage.OutputTokens
- }
- return nil, usage
- }
- }
- usage := &dto.Usage{}
- usage.PromptTokens = info.GetEstimatePromptTokens()
- usage.CompletionTokens = 0
- usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
- return nil, usage
- }
|