Browse Source

feat: use audio token usage if return

feitianbubu 6 months ago
parent
commit
f7ae3621f4
1 changed files with 21 additions and 5 deletions
  1. 21 5
      relay/channel/openai/relay-openai.go

+ 21 - 5
relay/channel/openai/relay-openai.go

@@ -2,6 +2,7 @@ package openai
 
 import (
 	"bytes"
+	"encoding/json"
 	"fmt"
 	"io"
 	"math"
@@ -280,11 +281,6 @@ func OpenaiTTSHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
 func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, responseFormat string) (*types.NewAPIError, *dto.Usage) {
 	defer service.CloseResponseBodyGracefully(resp)
 
-	// count tokens by audio file duration
-	audioTokens, err := countAudioTokens(c)
-	if err != nil {
-		return types.NewError(err, types.ErrorCodeCountTokenFailed), nil
-	}
 	responseBody, err := io.ReadAll(resp.Body)
 	if err != nil {
 		return types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), nil
@@ -292,6 +288,26 @@ func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
 	// 写入新的 response body
 	service.IOCopyBytesGracefully(c, resp, responseBody)
 
+	var responseData struct {
+		Usage *dto.Usage `json:"usage"`
+	}
+	if err := json.Unmarshal(responseBody, &responseData); err == nil && responseData.Usage != nil {
+		if responseData.Usage.TotalTokens > 0 {
+			usage := responseData.Usage
+			if usage.PromptTokens == 0 {
+				usage.PromptTokens = usage.InputTokens
+			}
+			if usage.CompletionTokens == 0 {
+				usage.CompletionTokens = usage.OutputTokens
+			}
+			return nil, usage
+		}
+	}
+
+	audioTokens, err := countAudioTokens(c)
+	if err != nil {
+		return types.NewError(err, types.ErrorCodeCountTokenFailed), nil
+	}
 	usage := &dto.Usage{}
 	usage.PromptTokens = audioTokens
 	usage.CompletionTokens = 0