Przeglądaj źródła

fix: unify usage mapping and include toolUsePromptTokenCount in input tokens

Seefs 2 tygodni temu
rodzic
commit
4360393dc1

+ 8 - 6
dto/gemini.go

@@ -453,12 +453,14 @@ type GeminiChatResponse struct {
 }
 
 type GeminiUsageMetadata struct {
-	PromptTokenCount        int                         `json:"promptTokenCount"`
-	CandidatesTokenCount    int                         `json:"candidatesTokenCount"`
-	TotalTokenCount         int                         `json:"totalTokenCount"`
-	ThoughtsTokenCount      int                         `json:"thoughtsTokenCount"`
-	CachedContentTokenCount int                         `json:"cachedContentTokenCount"`
-	PromptTokensDetails     []GeminiPromptTokensDetails `json:"promptTokensDetails"`
+	PromptTokenCount           int                         `json:"promptTokenCount"`
+	ToolUsePromptTokenCount    int                         `json:"toolUsePromptTokenCount"`
+	CandidatesTokenCount       int                         `json:"candidatesTokenCount"`
+	TotalTokenCount            int                         `json:"totalTokenCount"`
+	ThoughtsTokenCount         int                         `json:"thoughtsTokenCount"`
+	CachedContentTokenCount    int                         `json:"cachedContentTokenCount"`
+	PromptTokensDetails        []GeminiPromptTokensDetails `json:"promptTokensDetails"`
+	ToolUsePromptTokensDetails []GeminiPromptTokensDetails `json:"toolUsePromptTokensDetails"`
 }
 
 type GeminiPromptTokensDetails struct {

+ 1 - 16
relay/channel/gemini/relay-gemini-native.go

@@ -42,22 +42,7 @@ func GeminiTextGenerationHandler(c *gin.Context, info *relaycommon.RelayInfo, re
 	}
 
 	// 计算使用量(基于 UsageMetadata)
-	usage := dto.Usage{
-		PromptTokens:     geminiResponse.UsageMetadata.PromptTokenCount,
-		CompletionTokens: geminiResponse.UsageMetadata.CandidatesTokenCount + geminiResponse.UsageMetadata.ThoughtsTokenCount,
-		TotalTokens:      geminiResponse.UsageMetadata.TotalTokenCount,
-	}
-
-	usage.CompletionTokenDetails.ReasoningTokens = geminiResponse.UsageMetadata.ThoughtsTokenCount
-	usage.PromptTokensDetails.CachedTokens = geminiResponse.UsageMetadata.CachedContentTokenCount
-
-	for _, detail := range geminiResponse.UsageMetadata.PromptTokensDetails {
-		if detail.Modality == "AUDIO" {
-			usage.PromptTokensDetails.AudioTokens = detail.TokenCount
-		} else if detail.Modality == "TEXT" {
-			usage.PromptTokensDetails.TextTokens = detail.TokenCount
-		}
-	}
+	usage := buildUsageFromGeminiMetadata(geminiResponse.UsageMetadata, info.GetEstimatePromptTokens())
 
 	service.IOCopyBytesGracefully(c, resp, responseBody)
 

+ 44 - 49
relay/channel/gemini/relay-gemini.go

@@ -1032,6 +1032,46 @@ func getResponseToolCall(item *dto.GeminiPart) *dto.ToolCallResponse {
 	}
 }
 
+func buildUsageFromGeminiMetadata(metadata dto.GeminiUsageMetadata, fallbackPromptTokens int) dto.Usage {
+	promptTokens := metadata.PromptTokenCount + metadata.ToolUsePromptTokenCount
+	if promptTokens <= 0 && fallbackPromptTokens > 0 {
+		promptTokens = fallbackPromptTokens
+	}
+
+	usage := dto.Usage{
+		PromptTokens:     promptTokens,
+		CompletionTokens: metadata.CandidatesTokenCount + metadata.ThoughtsTokenCount,
+		TotalTokens:      metadata.TotalTokenCount,
+	}
+	usage.CompletionTokenDetails.ReasoningTokens = metadata.ThoughtsTokenCount
+	usage.PromptTokensDetails.CachedTokens = metadata.CachedContentTokenCount
+
+	for _, detail := range metadata.PromptTokensDetails {
+		if detail.Modality == "AUDIO" {
+			usage.PromptTokensDetails.AudioTokens += detail.TokenCount
+		} else if detail.Modality == "TEXT" {
+			usage.PromptTokensDetails.TextTokens += detail.TokenCount
+		}
+	}
+	for _, detail := range metadata.ToolUsePromptTokensDetails {
+		if detail.Modality == "AUDIO" {
+			usage.PromptTokensDetails.AudioTokens += detail.TokenCount
+		} else if detail.Modality == "TEXT" {
+			usage.PromptTokensDetails.TextTokens += detail.TokenCount
+		}
+	}
+
+	if usage.TotalTokens > 0 && usage.CompletionTokens <= 0 {
+		usage.CompletionTokens = usage.TotalTokens - usage.PromptTokens
+	}
+
+	if usage.PromptTokens > 0 && usage.PromptTokensDetails.TextTokens == 0 && usage.PromptTokensDetails.AudioTokens == 0 {
+		usage.PromptTokensDetails.TextTokens = usage.PromptTokens
+	}
+
+	return usage
+}
+
 func responseGeminiChat2OpenAI(c *gin.Context, response *dto.GeminiChatResponse) *dto.OpenAITextResponse {
 	fullTextResponse := dto.OpenAITextResponse{
 		Id:      helper.GetResponseID(c),
@@ -1272,18 +1312,8 @@ func geminiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http
 
 		// 更新使用量统计
 		if geminiResponse.UsageMetadata.TotalTokenCount != 0 {
-			usage.PromptTokens = geminiResponse.UsageMetadata.PromptTokenCount
-			usage.CompletionTokens = geminiResponse.UsageMetadata.CandidatesTokenCount + geminiResponse.UsageMetadata.ThoughtsTokenCount
-			usage.CompletionTokenDetails.ReasoningTokens = geminiResponse.UsageMetadata.ThoughtsTokenCount
-			usage.TotalTokens = geminiResponse.UsageMetadata.TotalTokenCount
-			usage.PromptTokensDetails.CachedTokens = geminiResponse.UsageMetadata.CachedContentTokenCount
-			for _, detail := range geminiResponse.UsageMetadata.PromptTokensDetails {
-				if detail.Modality == "AUDIO" {
-					usage.PromptTokensDetails.AudioTokens = detail.TokenCount
-				} else if detail.Modality == "TEXT" {
-					usage.PromptTokensDetails.TextTokens = detail.TokenCount
-				}
-			}
+			mappedUsage := buildUsageFromGeminiMetadata(geminiResponse.UsageMetadata, info.GetEstimatePromptTokens())
+			*usage = mappedUsage
 		}
 
 		return callback(data, &geminiResponse)
@@ -1295,11 +1325,6 @@ func geminiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http
 		}
 	}
 
-	usage.PromptTokensDetails.TextTokens = usage.PromptTokens
-	if usage.TotalTokens > 0 {
-		usage.CompletionTokens = usage.TotalTokens - usage.PromptTokens
-	}
-
 	if usage.CompletionTokens <= 0 {
 		if info.ReceivedResponseCount > 0 {
 			usage = service.ResponseText2Usage(c, responseText.String(), info.UpstreamModelName, info.GetEstimatePromptTokens())
@@ -1416,21 +1441,7 @@ func GeminiChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.R
 		return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
 	}
 	if len(geminiResponse.Candidates) == 0 {
-		usage := dto.Usage{
-			PromptTokens: geminiResponse.UsageMetadata.PromptTokenCount,
-		}
-		usage.CompletionTokenDetails.ReasoningTokens = geminiResponse.UsageMetadata.ThoughtsTokenCount
-		usage.PromptTokensDetails.CachedTokens = geminiResponse.UsageMetadata.CachedContentTokenCount
-		for _, detail := range geminiResponse.UsageMetadata.PromptTokensDetails {
-			if detail.Modality == "AUDIO" {
-				usage.PromptTokensDetails.AudioTokens = detail.TokenCount
-			} else if detail.Modality == "TEXT" {
-				usage.PromptTokensDetails.TextTokens = detail.TokenCount
-			}
-		}
-		if usage.PromptTokens <= 0 {
-			usage.PromptTokens = info.GetEstimatePromptTokens()
-		}
+		usage := buildUsageFromGeminiMetadata(geminiResponse.UsageMetadata, info.GetEstimatePromptTokens())
 
 		var newAPIError *types.NewAPIError
 		if geminiResponse.PromptFeedback != nil && geminiResponse.PromptFeedback.BlockReason != nil {
@@ -1466,23 +1477,7 @@ func GeminiChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.R
 	}
 	fullTextResponse := responseGeminiChat2OpenAI(c, &geminiResponse)
 	fullTextResponse.Model = info.UpstreamModelName
-	usage := dto.Usage{
-		PromptTokens:     geminiResponse.UsageMetadata.PromptTokenCount,
-		CompletionTokens: geminiResponse.UsageMetadata.CandidatesTokenCount,
-		TotalTokens:      geminiResponse.UsageMetadata.TotalTokenCount,
-	}
-
-	usage.CompletionTokenDetails.ReasoningTokens = geminiResponse.UsageMetadata.ThoughtsTokenCount
-	usage.PromptTokensDetails.CachedTokens = geminiResponse.UsageMetadata.CachedContentTokenCount
-	usage.CompletionTokens = usage.TotalTokens - usage.PromptTokens
-
-	for _, detail := range geminiResponse.UsageMetadata.PromptTokensDetails {
-		if detail.Modality == "AUDIO" {
-			usage.PromptTokensDetails.AudioTokens = detail.TokenCount
-		} else if detail.Modality == "TEXT" {
-			usage.PromptTokensDetails.TextTokens = detail.TokenCount
-		}
-	}
+	usage := buildUsageFromGeminiMetadata(geminiResponse.UsageMetadata, info.GetEstimatePromptTokens())
 
 	fullTextResponse.Usage = usage
 

+ 333 - 0
relay/channel/gemini/relay_gemini_usage_test.go

@@ -0,0 +1,333 @@
+package gemini
+
+import (
+	"bytes"
+	"io"
+	"net/http"
+	"net/http/httptest"
+	"testing"
+
+	"github.com/QuantumNous/new-api/common"
+	"github.com/QuantumNous/new-api/constant"
+	"github.com/QuantumNous/new-api/dto"
+	relaycommon "github.com/QuantumNous/new-api/relay/common"
+	"github.com/QuantumNous/new-api/types"
+	"github.com/gin-gonic/gin"
+	"github.com/stretchr/testify/require"
+)
+
+func TestGeminiChatHandlerCompletionTokensExcludeToolUsePromptTokens(t *testing.T) {
+	t.Parallel()
+
+	gin.SetMode(gin.TestMode)
+	c, _ := gin.CreateTestContext(httptest.NewRecorder())
+	c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
+
+	info := &relaycommon.RelayInfo{
+		RelayFormat:     types.RelayFormatGemini,
+		OriginModelName: "gemini-3-flash-preview",
+		ChannelMeta: &relaycommon.ChannelMeta{
+			UpstreamModelName: "gemini-3-flash-preview",
+		},
+	}
+
+	payload := dto.GeminiChatResponse{
+		Candidates: []dto.GeminiChatCandidate{
+			{
+				Content: dto.GeminiChatContent{
+					Role: "model",
+					Parts: []dto.GeminiPart{
+						{Text: "ok"},
+					},
+				},
+			},
+		},
+		UsageMetadata: dto.GeminiUsageMetadata{
+			PromptTokenCount:        151,
+			ToolUsePromptTokenCount: 18329,
+			CandidatesTokenCount:    1089,
+			ThoughtsTokenCount:      1120,
+			TotalTokenCount:         20689,
+		},
+	}
+
+	body, err := common.Marshal(payload)
+	require.NoError(t, err)
+
+	resp := &http.Response{
+		Body: io.NopCloser(bytes.NewReader(body)),
+	}
+
+	usage, newAPIError := GeminiChatHandler(c, info, resp)
+	require.Nil(t, newAPIError)
+	require.NotNil(t, usage)
+	require.Equal(t, 18480, usage.PromptTokens)
+	require.Equal(t, 2209, usage.CompletionTokens)
+	require.Equal(t, 20689, usage.TotalTokens)
+	require.Equal(t, 1120, usage.CompletionTokenDetails.ReasoningTokens)
+}
+
+func TestGeminiStreamHandlerCompletionTokensExcludeToolUsePromptTokens(t *testing.T) {
+	gin.SetMode(gin.TestMode)
+	c, _ := gin.CreateTestContext(httptest.NewRecorder())
+	c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
+
+	oldStreamingTimeout := constant.StreamingTimeout
+	constant.StreamingTimeout = 300
+	t.Cleanup(func() {
+		constant.StreamingTimeout = oldStreamingTimeout
+	})
+
+	info := &relaycommon.RelayInfo{
+		OriginModelName: "gemini-3-flash-preview",
+		ChannelMeta: &relaycommon.ChannelMeta{
+			UpstreamModelName: "gemini-3-flash-preview",
+		},
+	}
+
+	chunk := dto.GeminiChatResponse{
+		Candidates: []dto.GeminiChatCandidate{
+			{
+				Content: dto.GeminiChatContent{
+					Role: "model",
+					Parts: []dto.GeminiPart{
+						{Text: "partial"},
+					},
+				},
+			},
+		},
+		UsageMetadata: dto.GeminiUsageMetadata{
+			PromptTokenCount:        151,
+			ToolUsePromptTokenCount: 18329,
+			CandidatesTokenCount:    1089,
+			ThoughtsTokenCount:      1120,
+			TotalTokenCount:         20689,
+		},
+	}
+
+	chunkData, err := common.Marshal(chunk)
+	require.NoError(t, err)
+
+	streamBody := []byte("data: " + string(chunkData) + "\n" + "data: [DONE]\n")
+	resp := &http.Response{
+		Body: io.NopCloser(bytes.NewReader(streamBody)),
+	}
+
+	usage, newAPIError := geminiStreamHandler(c, info, resp, func(_ string, _ *dto.GeminiChatResponse) bool {
+		return true
+	})
+	require.Nil(t, newAPIError)
+	require.NotNil(t, usage)
+	require.Equal(t, 18480, usage.PromptTokens)
+	require.Equal(t, 2209, usage.CompletionTokens)
+	require.Equal(t, 20689, usage.TotalTokens)
+	require.Equal(t, 1120, usage.CompletionTokenDetails.ReasoningTokens)
+}
+
+func TestGeminiTextGenerationHandlerPromptTokensIncludeToolUsePromptTokens(t *testing.T) {
+	t.Parallel()
+
+	gin.SetMode(gin.TestMode)
+	c, _ := gin.CreateTestContext(httptest.NewRecorder())
+	c.Request = httptest.NewRequest(http.MethodPost, "/v1beta/models/gemini-3-flash-preview:generateContent", nil)
+
+	info := &relaycommon.RelayInfo{
+		OriginModelName: "gemini-3-flash-preview",
+		ChannelMeta: &relaycommon.ChannelMeta{
+			UpstreamModelName: "gemini-3-flash-preview",
+		},
+	}
+
+	payload := dto.GeminiChatResponse{
+		Candidates: []dto.GeminiChatCandidate{
+			{
+				Content: dto.GeminiChatContent{
+					Role: "model",
+					Parts: []dto.GeminiPart{
+						{Text: "ok"},
+					},
+				},
+			},
+		},
+		UsageMetadata: dto.GeminiUsageMetadata{
+			PromptTokenCount:        151,
+			ToolUsePromptTokenCount: 18329,
+			CandidatesTokenCount:    1089,
+			ThoughtsTokenCount:      1120,
+			TotalTokenCount:         20689,
+		},
+	}
+
+	body, err := common.Marshal(payload)
+	require.NoError(t, err)
+
+	resp := &http.Response{
+		Body: io.NopCloser(bytes.NewReader(body)),
+	}
+
+	usage, newAPIError := GeminiTextGenerationHandler(c, info, resp)
+	require.Nil(t, newAPIError)
+	require.NotNil(t, usage)
+	require.Equal(t, 18480, usage.PromptTokens)
+	require.Equal(t, 2209, usage.CompletionTokens)
+	require.Equal(t, 20689, usage.TotalTokens)
+	require.Equal(t, 1120, usage.CompletionTokenDetails.ReasoningTokens)
+}
+
+func TestGeminiChatHandlerUsesEstimatedPromptTokensWhenUsagePromptMissing(t *testing.T) {
+	t.Parallel()
+
+	gin.SetMode(gin.TestMode)
+	c, _ := gin.CreateTestContext(httptest.NewRecorder())
+	c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
+
+	info := &relaycommon.RelayInfo{
+		RelayFormat:     types.RelayFormatGemini,
+		OriginModelName: "gemini-3-flash-preview",
+		ChannelMeta: &relaycommon.ChannelMeta{
+			UpstreamModelName: "gemini-3-flash-preview",
+		},
+	}
+	info.SetEstimatePromptTokens(20)
+
+	payload := dto.GeminiChatResponse{
+		Candidates: []dto.GeminiChatCandidate{
+			{
+				Content: dto.GeminiChatContent{
+					Role: "model",
+					Parts: []dto.GeminiPart{
+						{Text: "ok"},
+					},
+				},
+			},
+		},
+		UsageMetadata: dto.GeminiUsageMetadata{
+			PromptTokenCount:        0,
+			ToolUsePromptTokenCount: 0,
+			CandidatesTokenCount:    90,
+			ThoughtsTokenCount:      10,
+			TotalTokenCount:         110,
+		},
+	}
+
+	body, err := common.Marshal(payload)
+	require.NoError(t, err)
+
+	resp := &http.Response{
+		Body: io.NopCloser(bytes.NewReader(body)),
+	}
+
+	usage, newAPIError := GeminiChatHandler(c, info, resp)
+	require.Nil(t, newAPIError)
+	require.NotNil(t, usage)
+	require.Equal(t, 20, usage.PromptTokens)
+	require.Equal(t, 100, usage.CompletionTokens)
+	require.Equal(t, 110, usage.TotalTokens)
+}
+
+func TestGeminiStreamHandlerUsesEstimatedPromptTokensWhenUsagePromptMissing(t *testing.T) {
+	gin.SetMode(gin.TestMode)
+	c, _ := gin.CreateTestContext(httptest.NewRecorder())
+	c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
+
+	oldStreamingTimeout := constant.StreamingTimeout
+	constant.StreamingTimeout = 300
+	t.Cleanup(func() {
+		constant.StreamingTimeout = oldStreamingTimeout
+	})
+
+	info := &relaycommon.RelayInfo{
+		OriginModelName: "gemini-3-flash-preview",
+		ChannelMeta: &relaycommon.ChannelMeta{
+			UpstreamModelName: "gemini-3-flash-preview",
+		},
+	}
+	info.SetEstimatePromptTokens(20)
+
+	chunk := dto.GeminiChatResponse{
+		Candidates: []dto.GeminiChatCandidate{
+			{
+				Content: dto.GeminiChatContent{
+					Role: "model",
+					Parts: []dto.GeminiPart{
+						{Text: "partial"},
+					},
+				},
+			},
+		},
+		UsageMetadata: dto.GeminiUsageMetadata{
+			PromptTokenCount:        0,
+			ToolUsePromptTokenCount: 0,
+			CandidatesTokenCount:    90,
+			ThoughtsTokenCount:      10,
+			TotalTokenCount:         110,
+		},
+	}
+
+	chunkData, err := common.Marshal(chunk)
+	require.NoError(t, err)
+
+	streamBody := []byte("data: " + string(chunkData) + "\n" + "data: [DONE]\n")
+	resp := &http.Response{
+		Body: io.NopCloser(bytes.NewReader(streamBody)),
+	}
+
+	usage, newAPIError := geminiStreamHandler(c, info, resp, func(_ string, _ *dto.GeminiChatResponse) bool {
+		return true
+	})
+	require.Nil(t, newAPIError)
+	require.NotNil(t, usage)
+	require.Equal(t, 20, usage.PromptTokens)
+	require.Equal(t, 100, usage.CompletionTokens)
+	require.Equal(t, 110, usage.TotalTokens)
+}
+
+func TestGeminiTextGenerationHandlerUsesEstimatedPromptTokensWhenUsagePromptMissing(t *testing.T) {
+	t.Parallel()
+
+	gin.SetMode(gin.TestMode)
+	c, _ := gin.CreateTestContext(httptest.NewRecorder())
+	c.Request = httptest.NewRequest(http.MethodPost, "/v1beta/models/gemini-3-flash-preview:generateContent", nil)
+
+	info := &relaycommon.RelayInfo{
+		OriginModelName: "gemini-3-flash-preview",
+		ChannelMeta: &relaycommon.ChannelMeta{
+			UpstreamModelName: "gemini-3-flash-preview",
+		},
+	}
+	info.SetEstimatePromptTokens(20)
+
+	payload := dto.GeminiChatResponse{
+		Candidates: []dto.GeminiChatCandidate{
+			{
+				Content: dto.GeminiChatContent{
+					Role: "model",
+					Parts: []dto.GeminiPart{
+						{Text: "ok"},
+					},
+				},
+			},
+		},
+		UsageMetadata: dto.GeminiUsageMetadata{
+			PromptTokenCount:        0,
+			ToolUsePromptTokenCount: 0,
+			CandidatesTokenCount:    90,
+			ThoughtsTokenCount:      10,
+			TotalTokenCount:         110,
+		},
+	}
+
+	body, err := common.Marshal(payload)
+	require.NoError(t, err)
+
+	resp := &http.Response{
+		Body: io.NopCloser(bytes.NewReader(body)),
+	}
+
+	usage, newAPIError := GeminiTextGenerationHandler(c, info, resp)
+	require.Nil(t, newAPIError)
+	require.NotNil(t, usage)
+	require.Equal(t, 20, usage.PromptTokens)
+	require.Equal(t, 100, usage.CompletionTokens)
+	require.Equal(t, 110, usage.TotalTokens)
+}