| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333 |
- package gemini
- import (
- "bytes"
- "io"
- "net/http"
- "net/http/httptest"
- "testing"
- "github.com/QuantumNous/new-api/common"
- "github.com/QuantumNous/new-api/constant"
- "github.com/QuantumNous/new-api/dto"
- relaycommon "github.com/QuantumNous/new-api/relay/common"
- "github.com/QuantumNous/new-api/types"
- "github.com/gin-gonic/gin"
- "github.com/stretchr/testify/require"
- )
- func TestGeminiChatHandlerCompletionTokensExcludeToolUsePromptTokens(t *testing.T) {
- t.Parallel()
- gin.SetMode(gin.TestMode)
- c, _ := gin.CreateTestContext(httptest.NewRecorder())
- c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
- info := &relaycommon.RelayInfo{
- RelayFormat: types.RelayFormatGemini,
- OriginModelName: "gemini-3-flash-preview",
- ChannelMeta: &relaycommon.ChannelMeta{
- UpstreamModelName: "gemini-3-flash-preview",
- },
- }
- payload := dto.GeminiChatResponse{
- Candidates: []dto.GeminiChatCandidate{
- {
- Content: dto.GeminiChatContent{
- Role: "model",
- Parts: []dto.GeminiPart{
- {Text: "ok"},
- },
- },
- },
- },
- UsageMetadata: dto.GeminiUsageMetadata{
- PromptTokenCount: 151,
- ToolUsePromptTokenCount: 18329,
- CandidatesTokenCount: 1089,
- ThoughtsTokenCount: 1120,
- TotalTokenCount: 20689,
- },
- }
- body, err := common.Marshal(payload)
- require.NoError(t, err)
- resp := &http.Response{
- Body: io.NopCloser(bytes.NewReader(body)),
- }
- usage, newAPIError := GeminiChatHandler(c, info, resp)
- require.Nil(t, newAPIError)
- require.NotNil(t, usage)
- require.Equal(t, 18480, usage.PromptTokens)
- require.Equal(t, 2209, usage.CompletionTokens)
- require.Equal(t, 20689, usage.TotalTokens)
- require.Equal(t, 1120, usage.CompletionTokenDetails.ReasoningTokens)
- }
- func TestGeminiStreamHandlerCompletionTokensExcludeToolUsePromptTokens(t *testing.T) {
- gin.SetMode(gin.TestMode)
- c, _ := gin.CreateTestContext(httptest.NewRecorder())
- c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
- oldStreamingTimeout := constant.StreamingTimeout
- constant.StreamingTimeout = 300
- t.Cleanup(func() {
- constant.StreamingTimeout = oldStreamingTimeout
- })
- info := &relaycommon.RelayInfo{
- OriginModelName: "gemini-3-flash-preview",
- ChannelMeta: &relaycommon.ChannelMeta{
- UpstreamModelName: "gemini-3-flash-preview",
- },
- }
- chunk := dto.GeminiChatResponse{
- Candidates: []dto.GeminiChatCandidate{
- {
- Content: dto.GeminiChatContent{
- Role: "model",
- Parts: []dto.GeminiPart{
- {Text: "partial"},
- },
- },
- },
- },
- UsageMetadata: dto.GeminiUsageMetadata{
- PromptTokenCount: 151,
- ToolUsePromptTokenCount: 18329,
- CandidatesTokenCount: 1089,
- ThoughtsTokenCount: 1120,
- TotalTokenCount: 20689,
- },
- }
- chunkData, err := common.Marshal(chunk)
- require.NoError(t, err)
- streamBody := []byte("data: " + string(chunkData) + "\n" + "data: [DONE]\n")
- resp := &http.Response{
- Body: io.NopCloser(bytes.NewReader(streamBody)),
- }
- usage, newAPIError := geminiStreamHandler(c, info, resp, func(_ string, _ *dto.GeminiChatResponse) bool {
- return true
- })
- require.Nil(t, newAPIError)
- require.NotNil(t, usage)
- require.Equal(t, 18480, usage.PromptTokens)
- require.Equal(t, 2209, usage.CompletionTokens)
- require.Equal(t, 20689, usage.TotalTokens)
- require.Equal(t, 1120, usage.CompletionTokenDetails.ReasoningTokens)
- }
- func TestGeminiTextGenerationHandlerPromptTokensIncludeToolUsePromptTokens(t *testing.T) {
- t.Parallel()
- gin.SetMode(gin.TestMode)
- c, _ := gin.CreateTestContext(httptest.NewRecorder())
- c.Request = httptest.NewRequest(http.MethodPost, "/v1beta/models/gemini-3-flash-preview:generateContent", nil)
- info := &relaycommon.RelayInfo{
- OriginModelName: "gemini-3-flash-preview",
- ChannelMeta: &relaycommon.ChannelMeta{
- UpstreamModelName: "gemini-3-flash-preview",
- },
- }
- payload := dto.GeminiChatResponse{
- Candidates: []dto.GeminiChatCandidate{
- {
- Content: dto.GeminiChatContent{
- Role: "model",
- Parts: []dto.GeminiPart{
- {Text: "ok"},
- },
- },
- },
- },
- UsageMetadata: dto.GeminiUsageMetadata{
- PromptTokenCount: 151,
- ToolUsePromptTokenCount: 18329,
- CandidatesTokenCount: 1089,
- ThoughtsTokenCount: 1120,
- TotalTokenCount: 20689,
- },
- }
- body, err := common.Marshal(payload)
- require.NoError(t, err)
- resp := &http.Response{
- Body: io.NopCloser(bytes.NewReader(body)),
- }
- usage, newAPIError := GeminiTextGenerationHandler(c, info, resp)
- require.Nil(t, newAPIError)
- require.NotNil(t, usage)
- require.Equal(t, 18480, usage.PromptTokens)
- require.Equal(t, 2209, usage.CompletionTokens)
- require.Equal(t, 20689, usage.TotalTokens)
- require.Equal(t, 1120, usage.CompletionTokenDetails.ReasoningTokens)
- }
- func TestGeminiChatHandlerUsesEstimatedPromptTokensWhenUsagePromptMissing(t *testing.T) {
- t.Parallel()
- gin.SetMode(gin.TestMode)
- c, _ := gin.CreateTestContext(httptest.NewRecorder())
- c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
- info := &relaycommon.RelayInfo{
- RelayFormat: types.RelayFormatGemini,
- OriginModelName: "gemini-3-flash-preview",
- ChannelMeta: &relaycommon.ChannelMeta{
- UpstreamModelName: "gemini-3-flash-preview",
- },
- }
- info.SetEstimatePromptTokens(20)
- payload := dto.GeminiChatResponse{
- Candidates: []dto.GeminiChatCandidate{
- {
- Content: dto.GeminiChatContent{
- Role: "model",
- Parts: []dto.GeminiPart{
- {Text: "ok"},
- },
- },
- },
- },
- UsageMetadata: dto.GeminiUsageMetadata{
- PromptTokenCount: 0,
- ToolUsePromptTokenCount: 0,
- CandidatesTokenCount: 90,
- ThoughtsTokenCount: 10,
- TotalTokenCount: 110,
- },
- }
- body, err := common.Marshal(payload)
- require.NoError(t, err)
- resp := &http.Response{
- Body: io.NopCloser(bytes.NewReader(body)),
- }
- usage, newAPIError := GeminiChatHandler(c, info, resp)
- require.Nil(t, newAPIError)
- require.NotNil(t, usage)
- require.Equal(t, 20, usage.PromptTokens)
- require.Equal(t, 100, usage.CompletionTokens)
- require.Equal(t, 110, usage.TotalTokens)
- }
- func TestGeminiStreamHandlerUsesEstimatedPromptTokensWhenUsagePromptMissing(t *testing.T) {
- gin.SetMode(gin.TestMode)
- c, _ := gin.CreateTestContext(httptest.NewRecorder())
- c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
- oldStreamingTimeout := constant.StreamingTimeout
- constant.StreamingTimeout = 300
- t.Cleanup(func() {
- constant.StreamingTimeout = oldStreamingTimeout
- })
- info := &relaycommon.RelayInfo{
- OriginModelName: "gemini-3-flash-preview",
- ChannelMeta: &relaycommon.ChannelMeta{
- UpstreamModelName: "gemini-3-flash-preview",
- },
- }
- info.SetEstimatePromptTokens(20)
- chunk := dto.GeminiChatResponse{
- Candidates: []dto.GeminiChatCandidate{
- {
- Content: dto.GeminiChatContent{
- Role: "model",
- Parts: []dto.GeminiPart{
- {Text: "partial"},
- },
- },
- },
- },
- UsageMetadata: dto.GeminiUsageMetadata{
- PromptTokenCount: 0,
- ToolUsePromptTokenCount: 0,
- CandidatesTokenCount: 90,
- ThoughtsTokenCount: 10,
- TotalTokenCount: 110,
- },
- }
- chunkData, err := common.Marshal(chunk)
- require.NoError(t, err)
- streamBody := []byte("data: " + string(chunkData) + "\n" + "data: [DONE]\n")
- resp := &http.Response{
- Body: io.NopCloser(bytes.NewReader(streamBody)),
- }
- usage, newAPIError := geminiStreamHandler(c, info, resp, func(_ string, _ *dto.GeminiChatResponse) bool {
- return true
- })
- require.Nil(t, newAPIError)
- require.NotNil(t, usage)
- require.Equal(t, 20, usage.PromptTokens)
- require.Equal(t, 100, usage.CompletionTokens)
- require.Equal(t, 110, usage.TotalTokens)
- }
- func TestGeminiTextGenerationHandlerUsesEstimatedPromptTokensWhenUsagePromptMissing(t *testing.T) {
- t.Parallel()
- gin.SetMode(gin.TestMode)
- c, _ := gin.CreateTestContext(httptest.NewRecorder())
- c.Request = httptest.NewRequest(http.MethodPost, "/v1beta/models/gemini-3-flash-preview:generateContent", nil)
- info := &relaycommon.RelayInfo{
- OriginModelName: "gemini-3-flash-preview",
- ChannelMeta: &relaycommon.ChannelMeta{
- UpstreamModelName: "gemini-3-flash-preview",
- },
- }
- info.SetEstimatePromptTokens(20)
- payload := dto.GeminiChatResponse{
- Candidates: []dto.GeminiChatCandidate{
- {
- Content: dto.GeminiChatContent{
- Role: "model",
- Parts: []dto.GeminiPart{
- {Text: "ok"},
- },
- },
- },
- },
- UsageMetadata: dto.GeminiUsageMetadata{
- PromptTokenCount: 0,
- ToolUsePromptTokenCount: 0,
- CandidatesTokenCount: 90,
- ThoughtsTokenCount: 10,
- TotalTokenCount: 110,
- },
- }
- body, err := common.Marshal(payload)
- require.NoError(t, err)
- resp := &http.Response{
- Body: io.NopCloser(bytes.NewReader(body)),
- }
- usage, newAPIError := GeminiTextGenerationHandler(c, info, resp)
- require.Nil(t, newAPIError)
- require.NotNil(t, usage)
- require.Equal(t, 20, usage.PromptTokens)
- require.Equal(t, 100, usage.CompletionTokens)
- require.Equal(t, 110, usage.TotalTokens)
- }
|