package gemini import ( "bytes" "io" "net/http" "net/http/httptest" "testing" "github.com/QuantumNous/new-api/common" "github.com/QuantumNous/new-api/constant" "github.com/QuantumNous/new-api/dto" relaycommon "github.com/QuantumNous/new-api/relay/common" "github.com/QuantumNous/new-api/types" "github.com/gin-gonic/gin" "github.com/stretchr/testify/require" ) func TestGeminiChatHandlerCompletionTokensExcludeToolUsePromptTokens(t *testing.T) { t.Parallel() gin.SetMode(gin.TestMode) c, _ := gin.CreateTestContext(httptest.NewRecorder()) c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil) info := &relaycommon.RelayInfo{ RelayFormat: types.RelayFormatGemini, OriginModelName: "gemini-3-flash-preview", ChannelMeta: &relaycommon.ChannelMeta{ UpstreamModelName: "gemini-3-flash-preview", }, } payload := dto.GeminiChatResponse{ Candidates: []dto.GeminiChatCandidate{ { Content: dto.GeminiChatContent{ Role: "model", Parts: []dto.GeminiPart{ {Text: "ok"}, }, }, }, }, UsageMetadata: dto.GeminiUsageMetadata{ PromptTokenCount: 151, ToolUsePromptTokenCount: 18329, CandidatesTokenCount: 1089, ThoughtsTokenCount: 1120, TotalTokenCount: 20689, }, } body, err := common.Marshal(payload) require.NoError(t, err) resp := &http.Response{ Body: io.NopCloser(bytes.NewReader(body)), } usage, newAPIError := GeminiChatHandler(c, info, resp) require.Nil(t, newAPIError) require.NotNil(t, usage) require.Equal(t, 18480, usage.PromptTokens) require.Equal(t, 2209, usage.CompletionTokens) require.Equal(t, 20689, usage.TotalTokens) require.Equal(t, 1120, usage.CompletionTokenDetails.ReasoningTokens) } func TestGeminiStreamHandlerCompletionTokensExcludeToolUsePromptTokens(t *testing.T) { gin.SetMode(gin.TestMode) c, _ := gin.CreateTestContext(httptest.NewRecorder()) c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil) oldStreamingTimeout := constant.StreamingTimeout constant.StreamingTimeout = 300 t.Cleanup(func() { constant.StreamingTimeout = oldStreamingTimeout }) info := &relaycommon.RelayInfo{ OriginModelName: "gemini-3-flash-preview", ChannelMeta: &relaycommon.ChannelMeta{ UpstreamModelName: "gemini-3-flash-preview", }, } chunk := dto.GeminiChatResponse{ Candidates: []dto.GeminiChatCandidate{ { Content: dto.GeminiChatContent{ Role: "model", Parts: []dto.GeminiPart{ {Text: "partial"}, }, }, }, }, UsageMetadata: dto.GeminiUsageMetadata{ PromptTokenCount: 151, ToolUsePromptTokenCount: 18329, CandidatesTokenCount: 1089, ThoughtsTokenCount: 1120, TotalTokenCount: 20689, }, } chunkData, err := common.Marshal(chunk) require.NoError(t, err) streamBody := []byte("data: " + string(chunkData) + "\n" + "data: [DONE]\n") resp := &http.Response{ Body: io.NopCloser(bytes.NewReader(streamBody)), } usage, newAPIError := geminiStreamHandler(c, info, resp, func(_ string, _ *dto.GeminiChatResponse) bool { return true }) require.Nil(t, newAPIError) require.NotNil(t, usage) require.Equal(t, 18480, usage.PromptTokens) require.Equal(t, 2209, usage.CompletionTokens) require.Equal(t, 20689, usage.TotalTokens) require.Equal(t, 1120, usage.CompletionTokenDetails.ReasoningTokens) } func TestGeminiTextGenerationHandlerPromptTokensIncludeToolUsePromptTokens(t *testing.T) { t.Parallel() gin.SetMode(gin.TestMode) c, _ := gin.CreateTestContext(httptest.NewRecorder()) c.Request = httptest.NewRequest(http.MethodPost, "/v1beta/models/gemini-3-flash-preview:generateContent", nil) info := &relaycommon.RelayInfo{ OriginModelName: "gemini-3-flash-preview", ChannelMeta: &relaycommon.ChannelMeta{ UpstreamModelName: "gemini-3-flash-preview", }, } payload := dto.GeminiChatResponse{ Candidates: []dto.GeminiChatCandidate{ { Content: dto.GeminiChatContent{ Role: "model", Parts: []dto.GeminiPart{ {Text: "ok"}, }, }, }, }, UsageMetadata: dto.GeminiUsageMetadata{ PromptTokenCount: 151, ToolUsePromptTokenCount: 18329, CandidatesTokenCount: 1089, ThoughtsTokenCount: 1120, TotalTokenCount: 20689, }, } body, err := common.Marshal(payload) require.NoError(t, err) resp := &http.Response{ Body: io.NopCloser(bytes.NewReader(body)), } usage, newAPIError := GeminiTextGenerationHandler(c, info, resp) require.Nil(t, newAPIError) require.NotNil(t, usage) require.Equal(t, 18480, usage.PromptTokens) require.Equal(t, 2209, usage.CompletionTokens) require.Equal(t, 20689, usage.TotalTokens) require.Equal(t, 1120, usage.CompletionTokenDetails.ReasoningTokens) } func TestGeminiChatHandlerUsesEstimatedPromptTokensWhenUsagePromptMissing(t *testing.T) { t.Parallel() gin.SetMode(gin.TestMode) c, _ := gin.CreateTestContext(httptest.NewRecorder()) c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil) info := &relaycommon.RelayInfo{ RelayFormat: types.RelayFormatGemini, OriginModelName: "gemini-3-flash-preview", ChannelMeta: &relaycommon.ChannelMeta{ UpstreamModelName: "gemini-3-flash-preview", }, } info.SetEstimatePromptTokens(20) payload := dto.GeminiChatResponse{ Candidates: []dto.GeminiChatCandidate{ { Content: dto.GeminiChatContent{ Role: "model", Parts: []dto.GeminiPart{ {Text: "ok"}, }, }, }, }, UsageMetadata: dto.GeminiUsageMetadata{ PromptTokenCount: 0, ToolUsePromptTokenCount: 0, CandidatesTokenCount: 90, ThoughtsTokenCount: 10, TotalTokenCount: 110, }, } body, err := common.Marshal(payload) require.NoError(t, err) resp := &http.Response{ Body: io.NopCloser(bytes.NewReader(body)), } usage, newAPIError := GeminiChatHandler(c, info, resp) require.Nil(t, newAPIError) require.NotNil(t, usage) require.Equal(t, 20, usage.PromptTokens) require.Equal(t, 100, usage.CompletionTokens) require.Equal(t, 110, usage.TotalTokens) } func TestGeminiStreamHandlerUsesEstimatedPromptTokensWhenUsagePromptMissing(t *testing.T) { gin.SetMode(gin.TestMode) c, _ := gin.CreateTestContext(httptest.NewRecorder()) c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil) oldStreamingTimeout := constant.StreamingTimeout constant.StreamingTimeout = 300 t.Cleanup(func() { constant.StreamingTimeout = oldStreamingTimeout }) info := &relaycommon.RelayInfo{ OriginModelName: "gemini-3-flash-preview", ChannelMeta: &relaycommon.ChannelMeta{ UpstreamModelName: "gemini-3-flash-preview", }, } info.SetEstimatePromptTokens(20) chunk := dto.GeminiChatResponse{ Candidates: []dto.GeminiChatCandidate{ { Content: dto.GeminiChatContent{ Role: "model", Parts: []dto.GeminiPart{ {Text: "partial"}, }, }, }, }, UsageMetadata: dto.GeminiUsageMetadata{ PromptTokenCount: 0, ToolUsePromptTokenCount: 0, CandidatesTokenCount: 90, ThoughtsTokenCount: 10, TotalTokenCount: 110, }, } chunkData, err := common.Marshal(chunk) require.NoError(t, err) streamBody := []byte("data: " + string(chunkData) + "\n" + "data: [DONE]\n") resp := &http.Response{ Body: io.NopCloser(bytes.NewReader(streamBody)), } usage, newAPIError := geminiStreamHandler(c, info, resp, func(_ string, _ *dto.GeminiChatResponse) bool { return true }) require.Nil(t, newAPIError) require.NotNil(t, usage) require.Equal(t, 20, usage.PromptTokens) require.Equal(t, 100, usage.CompletionTokens) require.Equal(t, 110, usage.TotalTokens) } func TestGeminiTextGenerationHandlerUsesEstimatedPromptTokensWhenUsagePromptMissing(t *testing.T) { t.Parallel() gin.SetMode(gin.TestMode) c, _ := gin.CreateTestContext(httptest.NewRecorder()) c.Request = httptest.NewRequest(http.MethodPost, "/v1beta/models/gemini-3-flash-preview:generateContent", nil) info := &relaycommon.RelayInfo{ OriginModelName: "gemini-3-flash-preview", ChannelMeta: &relaycommon.ChannelMeta{ UpstreamModelName: "gemini-3-flash-preview", }, } info.SetEstimatePromptTokens(20) payload := dto.GeminiChatResponse{ Candidates: []dto.GeminiChatCandidate{ { Content: dto.GeminiChatContent{ Role: "model", Parts: []dto.GeminiPart{ {Text: "ok"}, }, }, }, }, UsageMetadata: dto.GeminiUsageMetadata{ PromptTokenCount: 0, ToolUsePromptTokenCount: 0, CandidatesTokenCount: 90, ThoughtsTokenCount: 10, TotalTokenCount: 110, }, } body, err := common.Marshal(payload) require.NoError(t, err) resp := &http.Response{ Body: io.NopCloser(bytes.NewReader(body)), } usage, newAPIError := GeminiTextGenerationHandler(c, info, resp) require.Nil(t, newAPIError) require.NotNil(t, usage) require.Equal(t, 20, usage.PromptTokens) require.Equal(t, 100, usage.CompletionTokens) require.Equal(t, 110, usage.TotalTokens) }