|
@@ -0,0 +1,333 @@
|
|
|
|
|
+package gemini
|
|
|
|
|
+
|
|
|
|
|
+import (
|
|
|
|
|
+ "bytes"
|
|
|
|
|
+ "io"
|
|
|
|
|
+ "net/http"
|
|
|
|
|
+ "net/http/httptest"
|
|
|
|
|
+ "testing"
|
|
|
|
|
+
|
|
|
|
|
+ "github.com/QuantumNous/new-api/common"
|
|
|
|
|
+ "github.com/QuantumNous/new-api/constant"
|
|
|
|
|
+ "github.com/QuantumNous/new-api/dto"
|
|
|
|
|
+ relaycommon "github.com/QuantumNous/new-api/relay/common"
|
|
|
|
|
+ "github.com/QuantumNous/new-api/types"
|
|
|
|
|
+ "github.com/gin-gonic/gin"
|
|
|
|
|
+ "github.com/stretchr/testify/require"
|
|
|
|
|
+)
|
|
|
|
|
+
|
|
|
|
|
+func TestGeminiChatHandlerCompletionTokensExcludeToolUsePromptTokens(t *testing.T) {
|
|
|
|
|
+ t.Parallel()
|
|
|
|
|
+
|
|
|
|
|
+ gin.SetMode(gin.TestMode)
|
|
|
|
|
+ c, _ := gin.CreateTestContext(httptest.NewRecorder())
|
|
|
|
|
+ c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
|
|
|
|
+
|
|
|
|
|
+ info := &relaycommon.RelayInfo{
|
|
|
|
|
+ RelayFormat: types.RelayFormatGemini,
|
|
|
|
|
+ OriginModelName: "gemini-3-flash-preview",
|
|
|
|
|
+ ChannelMeta: &relaycommon.ChannelMeta{
|
|
|
|
|
+ UpstreamModelName: "gemini-3-flash-preview",
|
|
|
|
|
+ },
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ payload := dto.GeminiChatResponse{
|
|
|
|
|
+ Candidates: []dto.GeminiChatCandidate{
|
|
|
|
|
+ {
|
|
|
|
|
+ Content: dto.GeminiChatContent{
|
|
|
|
|
+ Role: "model",
|
|
|
|
|
+ Parts: []dto.GeminiPart{
|
|
|
|
|
+ {Text: "ok"},
|
|
|
|
|
+ },
|
|
|
|
|
+ },
|
|
|
|
|
+ },
|
|
|
|
|
+ },
|
|
|
|
|
+ UsageMetadata: dto.GeminiUsageMetadata{
|
|
|
|
|
+ PromptTokenCount: 151,
|
|
|
|
|
+ ToolUsePromptTokenCount: 18329,
|
|
|
|
|
+ CandidatesTokenCount: 1089,
|
|
|
|
|
+ ThoughtsTokenCount: 1120,
|
|
|
|
|
+ TotalTokenCount: 20689,
|
|
|
|
|
+ },
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ body, err := common.Marshal(payload)
|
|
|
|
|
+ require.NoError(t, err)
|
|
|
|
|
+
|
|
|
|
|
+ resp := &http.Response{
|
|
|
|
|
+ Body: io.NopCloser(bytes.NewReader(body)),
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ usage, newAPIError := GeminiChatHandler(c, info, resp)
|
|
|
|
|
+ require.Nil(t, newAPIError)
|
|
|
|
|
+ require.NotNil(t, usage)
|
|
|
|
|
+ require.Equal(t, 18480, usage.PromptTokens)
|
|
|
|
|
+ require.Equal(t, 2209, usage.CompletionTokens)
|
|
|
|
|
+ require.Equal(t, 20689, usage.TotalTokens)
|
|
|
|
|
+ require.Equal(t, 1120, usage.CompletionTokenDetails.ReasoningTokens)
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+func TestGeminiStreamHandlerCompletionTokensExcludeToolUsePromptTokens(t *testing.T) {
|
|
|
|
|
+ gin.SetMode(gin.TestMode)
|
|
|
|
|
+ c, _ := gin.CreateTestContext(httptest.NewRecorder())
|
|
|
|
|
+ c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
|
|
|
|
+
|
|
|
|
|
+ oldStreamingTimeout := constant.StreamingTimeout
|
|
|
|
|
+ constant.StreamingTimeout = 300
|
|
|
|
|
+ t.Cleanup(func() {
|
|
|
|
|
+ constant.StreamingTimeout = oldStreamingTimeout
|
|
|
|
|
+ })
|
|
|
|
|
+
|
|
|
|
|
+ info := &relaycommon.RelayInfo{
|
|
|
|
|
+ OriginModelName: "gemini-3-flash-preview",
|
|
|
|
|
+ ChannelMeta: &relaycommon.ChannelMeta{
|
|
|
|
|
+ UpstreamModelName: "gemini-3-flash-preview",
|
|
|
|
|
+ },
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ chunk := dto.GeminiChatResponse{
|
|
|
|
|
+ Candidates: []dto.GeminiChatCandidate{
|
|
|
|
|
+ {
|
|
|
|
|
+ Content: dto.GeminiChatContent{
|
|
|
|
|
+ Role: "model",
|
|
|
|
|
+ Parts: []dto.GeminiPart{
|
|
|
|
|
+ {Text: "partial"},
|
|
|
|
|
+ },
|
|
|
|
|
+ },
|
|
|
|
|
+ },
|
|
|
|
|
+ },
|
|
|
|
|
+ UsageMetadata: dto.GeminiUsageMetadata{
|
|
|
|
|
+ PromptTokenCount: 151,
|
|
|
|
|
+ ToolUsePromptTokenCount: 18329,
|
|
|
|
|
+ CandidatesTokenCount: 1089,
|
|
|
|
|
+ ThoughtsTokenCount: 1120,
|
|
|
|
|
+ TotalTokenCount: 20689,
|
|
|
|
|
+ },
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ chunkData, err := common.Marshal(chunk)
|
|
|
|
|
+ require.NoError(t, err)
|
|
|
|
|
+
|
|
|
|
|
+ streamBody := []byte("data: " + string(chunkData) + "\n" + "data: [DONE]\n")
|
|
|
|
|
+ resp := &http.Response{
|
|
|
|
|
+ Body: io.NopCloser(bytes.NewReader(streamBody)),
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ usage, newAPIError := geminiStreamHandler(c, info, resp, func(_ string, _ *dto.GeminiChatResponse) bool {
|
|
|
|
|
+ return true
|
|
|
|
|
+ })
|
|
|
|
|
+ require.Nil(t, newAPIError)
|
|
|
|
|
+ require.NotNil(t, usage)
|
|
|
|
|
+ require.Equal(t, 18480, usage.PromptTokens)
|
|
|
|
|
+ require.Equal(t, 2209, usage.CompletionTokens)
|
|
|
|
|
+ require.Equal(t, 20689, usage.TotalTokens)
|
|
|
|
|
+ require.Equal(t, 1120, usage.CompletionTokenDetails.ReasoningTokens)
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+func TestGeminiTextGenerationHandlerPromptTokensIncludeToolUsePromptTokens(t *testing.T) {
|
|
|
|
|
+ t.Parallel()
|
|
|
|
|
+
|
|
|
|
|
+ gin.SetMode(gin.TestMode)
|
|
|
|
|
+ c, _ := gin.CreateTestContext(httptest.NewRecorder())
|
|
|
|
|
+ c.Request = httptest.NewRequest(http.MethodPost, "/v1beta/models/gemini-3-flash-preview:generateContent", nil)
|
|
|
|
|
+
|
|
|
|
|
+ info := &relaycommon.RelayInfo{
|
|
|
|
|
+ OriginModelName: "gemini-3-flash-preview",
|
|
|
|
|
+ ChannelMeta: &relaycommon.ChannelMeta{
|
|
|
|
|
+ UpstreamModelName: "gemini-3-flash-preview",
|
|
|
|
|
+ },
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ payload := dto.GeminiChatResponse{
|
|
|
|
|
+ Candidates: []dto.GeminiChatCandidate{
|
|
|
|
|
+ {
|
|
|
|
|
+ Content: dto.GeminiChatContent{
|
|
|
|
|
+ Role: "model",
|
|
|
|
|
+ Parts: []dto.GeminiPart{
|
|
|
|
|
+ {Text: "ok"},
|
|
|
|
|
+ },
|
|
|
|
|
+ },
|
|
|
|
|
+ },
|
|
|
|
|
+ },
|
|
|
|
|
+ UsageMetadata: dto.GeminiUsageMetadata{
|
|
|
|
|
+ PromptTokenCount: 151,
|
|
|
|
|
+ ToolUsePromptTokenCount: 18329,
|
|
|
|
|
+ CandidatesTokenCount: 1089,
|
|
|
|
|
+ ThoughtsTokenCount: 1120,
|
|
|
|
|
+ TotalTokenCount: 20689,
|
|
|
|
|
+ },
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ body, err := common.Marshal(payload)
|
|
|
|
|
+ require.NoError(t, err)
|
|
|
|
|
+
|
|
|
|
|
+ resp := &http.Response{
|
|
|
|
|
+ Body: io.NopCloser(bytes.NewReader(body)),
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ usage, newAPIError := GeminiTextGenerationHandler(c, info, resp)
|
|
|
|
|
+ require.Nil(t, newAPIError)
|
|
|
|
|
+ require.NotNil(t, usage)
|
|
|
|
|
+ require.Equal(t, 18480, usage.PromptTokens)
|
|
|
|
|
+ require.Equal(t, 2209, usage.CompletionTokens)
|
|
|
|
|
+ require.Equal(t, 20689, usage.TotalTokens)
|
|
|
|
|
+ require.Equal(t, 1120, usage.CompletionTokenDetails.ReasoningTokens)
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+func TestGeminiChatHandlerUsesEstimatedPromptTokensWhenUsagePromptMissing(t *testing.T) {
|
|
|
|
|
+ t.Parallel()
|
|
|
|
|
+
|
|
|
|
|
+ gin.SetMode(gin.TestMode)
|
|
|
|
|
+ c, _ := gin.CreateTestContext(httptest.NewRecorder())
|
|
|
|
|
+ c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
|
|
|
|
+
|
|
|
|
|
+ info := &relaycommon.RelayInfo{
|
|
|
|
|
+ RelayFormat: types.RelayFormatGemini,
|
|
|
|
|
+ OriginModelName: "gemini-3-flash-preview",
|
|
|
|
|
+ ChannelMeta: &relaycommon.ChannelMeta{
|
|
|
|
|
+ UpstreamModelName: "gemini-3-flash-preview",
|
|
|
|
|
+ },
|
|
|
|
|
+ }
|
|
|
|
|
+ info.SetEstimatePromptTokens(20)
|
|
|
|
|
+
|
|
|
|
|
+ payload := dto.GeminiChatResponse{
|
|
|
|
|
+ Candidates: []dto.GeminiChatCandidate{
|
|
|
|
|
+ {
|
|
|
|
|
+ Content: dto.GeminiChatContent{
|
|
|
|
|
+ Role: "model",
|
|
|
|
|
+ Parts: []dto.GeminiPart{
|
|
|
|
|
+ {Text: "ok"},
|
|
|
|
|
+ },
|
|
|
|
|
+ },
|
|
|
|
|
+ },
|
|
|
|
|
+ },
|
|
|
|
|
+ UsageMetadata: dto.GeminiUsageMetadata{
|
|
|
|
|
+ PromptTokenCount: 0,
|
|
|
|
|
+ ToolUsePromptTokenCount: 0,
|
|
|
|
|
+ CandidatesTokenCount: 90,
|
|
|
|
|
+ ThoughtsTokenCount: 10,
|
|
|
|
|
+ TotalTokenCount: 110,
|
|
|
|
|
+ },
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ body, err := common.Marshal(payload)
|
|
|
|
|
+ require.NoError(t, err)
|
|
|
|
|
+
|
|
|
|
|
+ resp := &http.Response{
|
|
|
|
|
+ Body: io.NopCloser(bytes.NewReader(body)),
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ usage, newAPIError := GeminiChatHandler(c, info, resp)
|
|
|
|
|
+ require.Nil(t, newAPIError)
|
|
|
|
|
+ require.NotNil(t, usage)
|
|
|
|
|
+ require.Equal(t, 20, usage.PromptTokens)
|
|
|
|
|
+ require.Equal(t, 100, usage.CompletionTokens)
|
|
|
|
|
+ require.Equal(t, 110, usage.TotalTokens)
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+func TestGeminiStreamHandlerUsesEstimatedPromptTokensWhenUsagePromptMissing(t *testing.T) {
|
|
|
|
|
+ gin.SetMode(gin.TestMode)
|
|
|
|
|
+ c, _ := gin.CreateTestContext(httptest.NewRecorder())
|
|
|
|
|
+ c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
|
|
|
|
+
|
|
|
|
|
+ oldStreamingTimeout := constant.StreamingTimeout
|
|
|
|
|
+ constant.StreamingTimeout = 300
|
|
|
|
|
+ t.Cleanup(func() {
|
|
|
|
|
+ constant.StreamingTimeout = oldStreamingTimeout
|
|
|
|
|
+ })
|
|
|
|
|
+
|
|
|
|
|
+ info := &relaycommon.RelayInfo{
|
|
|
|
|
+ OriginModelName: "gemini-3-flash-preview",
|
|
|
|
|
+ ChannelMeta: &relaycommon.ChannelMeta{
|
|
|
|
|
+ UpstreamModelName: "gemini-3-flash-preview",
|
|
|
|
|
+ },
|
|
|
|
|
+ }
|
|
|
|
|
+ info.SetEstimatePromptTokens(20)
|
|
|
|
|
+
|
|
|
|
|
+ chunk := dto.GeminiChatResponse{
|
|
|
|
|
+ Candidates: []dto.GeminiChatCandidate{
|
|
|
|
|
+ {
|
|
|
|
|
+ Content: dto.GeminiChatContent{
|
|
|
|
|
+ Role: "model",
|
|
|
|
|
+ Parts: []dto.GeminiPart{
|
|
|
|
|
+ {Text: "partial"},
|
|
|
|
|
+ },
|
|
|
|
|
+ },
|
|
|
|
|
+ },
|
|
|
|
|
+ },
|
|
|
|
|
+ UsageMetadata: dto.GeminiUsageMetadata{
|
|
|
|
|
+ PromptTokenCount: 0,
|
|
|
|
|
+ ToolUsePromptTokenCount: 0,
|
|
|
|
|
+ CandidatesTokenCount: 90,
|
|
|
|
|
+ ThoughtsTokenCount: 10,
|
|
|
|
|
+ TotalTokenCount: 110,
|
|
|
|
|
+ },
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ chunkData, err := common.Marshal(chunk)
|
|
|
|
|
+ require.NoError(t, err)
|
|
|
|
|
+
|
|
|
|
|
+ streamBody := []byte("data: " + string(chunkData) + "\n" + "data: [DONE]\n")
|
|
|
|
|
+ resp := &http.Response{
|
|
|
|
|
+ Body: io.NopCloser(bytes.NewReader(streamBody)),
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ usage, newAPIError := geminiStreamHandler(c, info, resp, func(_ string, _ *dto.GeminiChatResponse) bool {
|
|
|
|
|
+ return true
|
|
|
|
|
+ })
|
|
|
|
|
+ require.Nil(t, newAPIError)
|
|
|
|
|
+ require.NotNil(t, usage)
|
|
|
|
|
+ require.Equal(t, 20, usage.PromptTokens)
|
|
|
|
|
+ require.Equal(t, 100, usage.CompletionTokens)
|
|
|
|
|
+ require.Equal(t, 110, usage.TotalTokens)
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+func TestGeminiTextGenerationHandlerUsesEstimatedPromptTokensWhenUsagePromptMissing(t *testing.T) {
|
|
|
|
|
+ t.Parallel()
|
|
|
|
|
+
|
|
|
|
|
+ gin.SetMode(gin.TestMode)
|
|
|
|
|
+ c, _ := gin.CreateTestContext(httptest.NewRecorder())
|
|
|
|
|
+ c.Request = httptest.NewRequest(http.MethodPost, "/v1beta/models/gemini-3-flash-preview:generateContent", nil)
|
|
|
|
|
+
|
|
|
|
|
+ info := &relaycommon.RelayInfo{
|
|
|
|
|
+ OriginModelName: "gemini-3-flash-preview",
|
|
|
|
|
+ ChannelMeta: &relaycommon.ChannelMeta{
|
|
|
|
|
+ UpstreamModelName: "gemini-3-flash-preview",
|
|
|
|
|
+ },
|
|
|
|
|
+ }
|
|
|
|
|
+ info.SetEstimatePromptTokens(20)
|
|
|
|
|
+
|
|
|
|
|
+ payload := dto.GeminiChatResponse{
|
|
|
|
|
+ Candidates: []dto.GeminiChatCandidate{
|
|
|
|
|
+ {
|
|
|
|
|
+ Content: dto.GeminiChatContent{
|
|
|
|
|
+ Role: "model",
|
|
|
|
|
+ Parts: []dto.GeminiPart{
|
|
|
|
|
+ {Text: "ok"},
|
|
|
|
|
+ },
|
|
|
|
|
+ },
|
|
|
|
|
+ },
|
|
|
|
|
+ },
|
|
|
|
|
+ UsageMetadata: dto.GeminiUsageMetadata{
|
|
|
|
|
+ PromptTokenCount: 0,
|
|
|
|
|
+ ToolUsePromptTokenCount: 0,
|
|
|
|
|
+ CandidatesTokenCount: 90,
|
|
|
|
|
+ ThoughtsTokenCount: 10,
|
|
|
|
|
+ TotalTokenCount: 110,
|
|
|
|
|
+ },
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ body, err := common.Marshal(payload)
|
|
|
|
|
+ require.NoError(t, err)
|
|
|
|
|
+
|
|
|
|
|
+ resp := &http.Response{
|
|
|
|
|
+ Body: io.NopCloser(bytes.NewReader(body)),
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ usage, newAPIError := GeminiTextGenerationHandler(c, info, resp)
|
|
|
|
|
+ require.Nil(t, newAPIError)
|
|
|
|
|
+ require.NotNil(t, usage)
|
|
|
|
|
+ require.Equal(t, 20, usage.PromptTokens)
|
|
|
|
|
+ require.Equal(t, 100, usage.CompletionTokens)
|
|
|
|
|
+ require.Equal(t, 110, usage.TotalTokens)
|
|
|
|
|
+}
|