relay_gemini_usage_test.go 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333
  1. package gemini
  2. import (
  3. "bytes"
  4. "io"
  5. "net/http"
  6. "net/http/httptest"
  7. "testing"
  8. "github.com/QuantumNous/new-api/common"
  9. "github.com/QuantumNous/new-api/constant"
  10. "github.com/QuantumNous/new-api/dto"
  11. relaycommon "github.com/QuantumNous/new-api/relay/common"
  12. "github.com/QuantumNous/new-api/types"
  13. "github.com/gin-gonic/gin"
  14. "github.com/stretchr/testify/require"
  15. )
  16. func TestGeminiChatHandlerCompletionTokensExcludeToolUsePromptTokens(t *testing.T) {
  17. t.Parallel()
  18. gin.SetMode(gin.TestMode)
  19. c, _ := gin.CreateTestContext(httptest.NewRecorder())
  20. c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
  21. info := &relaycommon.RelayInfo{
  22. RelayFormat: types.RelayFormatGemini,
  23. OriginModelName: "gemini-3-flash-preview",
  24. ChannelMeta: &relaycommon.ChannelMeta{
  25. UpstreamModelName: "gemini-3-flash-preview",
  26. },
  27. }
  28. payload := dto.GeminiChatResponse{
  29. Candidates: []dto.GeminiChatCandidate{
  30. {
  31. Content: dto.GeminiChatContent{
  32. Role: "model",
  33. Parts: []dto.GeminiPart{
  34. {Text: "ok"},
  35. },
  36. },
  37. },
  38. },
  39. UsageMetadata: dto.GeminiUsageMetadata{
  40. PromptTokenCount: 151,
  41. ToolUsePromptTokenCount: 18329,
  42. CandidatesTokenCount: 1089,
  43. ThoughtsTokenCount: 1120,
  44. TotalTokenCount: 20689,
  45. },
  46. }
  47. body, err := common.Marshal(payload)
  48. require.NoError(t, err)
  49. resp := &http.Response{
  50. Body: io.NopCloser(bytes.NewReader(body)),
  51. }
  52. usage, newAPIError := GeminiChatHandler(c, info, resp)
  53. require.Nil(t, newAPIError)
  54. require.NotNil(t, usage)
  55. require.Equal(t, 18480, usage.PromptTokens)
  56. require.Equal(t, 2209, usage.CompletionTokens)
  57. require.Equal(t, 20689, usage.TotalTokens)
  58. require.Equal(t, 1120, usage.CompletionTokenDetails.ReasoningTokens)
  59. }
  60. func TestGeminiStreamHandlerCompletionTokensExcludeToolUsePromptTokens(t *testing.T) {
  61. gin.SetMode(gin.TestMode)
  62. c, _ := gin.CreateTestContext(httptest.NewRecorder())
  63. c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
  64. oldStreamingTimeout := constant.StreamingTimeout
  65. constant.StreamingTimeout = 300
  66. t.Cleanup(func() {
  67. constant.StreamingTimeout = oldStreamingTimeout
  68. })
  69. info := &relaycommon.RelayInfo{
  70. OriginModelName: "gemini-3-flash-preview",
  71. ChannelMeta: &relaycommon.ChannelMeta{
  72. UpstreamModelName: "gemini-3-flash-preview",
  73. },
  74. }
  75. chunk := dto.GeminiChatResponse{
  76. Candidates: []dto.GeminiChatCandidate{
  77. {
  78. Content: dto.GeminiChatContent{
  79. Role: "model",
  80. Parts: []dto.GeminiPart{
  81. {Text: "partial"},
  82. },
  83. },
  84. },
  85. },
  86. UsageMetadata: dto.GeminiUsageMetadata{
  87. PromptTokenCount: 151,
  88. ToolUsePromptTokenCount: 18329,
  89. CandidatesTokenCount: 1089,
  90. ThoughtsTokenCount: 1120,
  91. TotalTokenCount: 20689,
  92. },
  93. }
  94. chunkData, err := common.Marshal(chunk)
  95. require.NoError(t, err)
  96. streamBody := []byte("data: " + string(chunkData) + "\n" + "data: [DONE]\n")
  97. resp := &http.Response{
  98. Body: io.NopCloser(bytes.NewReader(streamBody)),
  99. }
  100. usage, newAPIError := geminiStreamHandler(c, info, resp, func(_ string, _ *dto.GeminiChatResponse) bool {
  101. return true
  102. })
  103. require.Nil(t, newAPIError)
  104. require.NotNil(t, usage)
  105. require.Equal(t, 18480, usage.PromptTokens)
  106. require.Equal(t, 2209, usage.CompletionTokens)
  107. require.Equal(t, 20689, usage.TotalTokens)
  108. require.Equal(t, 1120, usage.CompletionTokenDetails.ReasoningTokens)
  109. }
  110. func TestGeminiTextGenerationHandlerPromptTokensIncludeToolUsePromptTokens(t *testing.T) {
  111. t.Parallel()
  112. gin.SetMode(gin.TestMode)
  113. c, _ := gin.CreateTestContext(httptest.NewRecorder())
  114. c.Request = httptest.NewRequest(http.MethodPost, "/v1beta/models/gemini-3-flash-preview:generateContent", nil)
  115. info := &relaycommon.RelayInfo{
  116. OriginModelName: "gemini-3-flash-preview",
  117. ChannelMeta: &relaycommon.ChannelMeta{
  118. UpstreamModelName: "gemini-3-flash-preview",
  119. },
  120. }
  121. payload := dto.GeminiChatResponse{
  122. Candidates: []dto.GeminiChatCandidate{
  123. {
  124. Content: dto.GeminiChatContent{
  125. Role: "model",
  126. Parts: []dto.GeminiPart{
  127. {Text: "ok"},
  128. },
  129. },
  130. },
  131. },
  132. UsageMetadata: dto.GeminiUsageMetadata{
  133. PromptTokenCount: 151,
  134. ToolUsePromptTokenCount: 18329,
  135. CandidatesTokenCount: 1089,
  136. ThoughtsTokenCount: 1120,
  137. TotalTokenCount: 20689,
  138. },
  139. }
  140. body, err := common.Marshal(payload)
  141. require.NoError(t, err)
  142. resp := &http.Response{
  143. Body: io.NopCloser(bytes.NewReader(body)),
  144. }
  145. usage, newAPIError := GeminiTextGenerationHandler(c, info, resp)
  146. require.Nil(t, newAPIError)
  147. require.NotNil(t, usage)
  148. require.Equal(t, 18480, usage.PromptTokens)
  149. require.Equal(t, 2209, usage.CompletionTokens)
  150. require.Equal(t, 20689, usage.TotalTokens)
  151. require.Equal(t, 1120, usage.CompletionTokenDetails.ReasoningTokens)
  152. }
  153. func TestGeminiChatHandlerUsesEstimatedPromptTokensWhenUsagePromptMissing(t *testing.T) {
  154. t.Parallel()
  155. gin.SetMode(gin.TestMode)
  156. c, _ := gin.CreateTestContext(httptest.NewRecorder())
  157. c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
  158. info := &relaycommon.RelayInfo{
  159. RelayFormat: types.RelayFormatGemini,
  160. OriginModelName: "gemini-3-flash-preview",
  161. ChannelMeta: &relaycommon.ChannelMeta{
  162. UpstreamModelName: "gemini-3-flash-preview",
  163. },
  164. }
  165. info.SetEstimatePromptTokens(20)
  166. payload := dto.GeminiChatResponse{
  167. Candidates: []dto.GeminiChatCandidate{
  168. {
  169. Content: dto.GeminiChatContent{
  170. Role: "model",
  171. Parts: []dto.GeminiPart{
  172. {Text: "ok"},
  173. },
  174. },
  175. },
  176. },
  177. UsageMetadata: dto.GeminiUsageMetadata{
  178. PromptTokenCount: 0,
  179. ToolUsePromptTokenCount: 0,
  180. CandidatesTokenCount: 90,
  181. ThoughtsTokenCount: 10,
  182. TotalTokenCount: 110,
  183. },
  184. }
  185. body, err := common.Marshal(payload)
  186. require.NoError(t, err)
  187. resp := &http.Response{
  188. Body: io.NopCloser(bytes.NewReader(body)),
  189. }
  190. usage, newAPIError := GeminiChatHandler(c, info, resp)
  191. require.Nil(t, newAPIError)
  192. require.NotNil(t, usage)
  193. require.Equal(t, 20, usage.PromptTokens)
  194. require.Equal(t, 100, usage.CompletionTokens)
  195. require.Equal(t, 110, usage.TotalTokens)
  196. }
  197. func TestGeminiStreamHandlerUsesEstimatedPromptTokensWhenUsagePromptMissing(t *testing.T) {
  198. gin.SetMode(gin.TestMode)
  199. c, _ := gin.CreateTestContext(httptest.NewRecorder())
  200. c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
  201. oldStreamingTimeout := constant.StreamingTimeout
  202. constant.StreamingTimeout = 300
  203. t.Cleanup(func() {
  204. constant.StreamingTimeout = oldStreamingTimeout
  205. })
  206. info := &relaycommon.RelayInfo{
  207. OriginModelName: "gemini-3-flash-preview",
  208. ChannelMeta: &relaycommon.ChannelMeta{
  209. UpstreamModelName: "gemini-3-flash-preview",
  210. },
  211. }
  212. info.SetEstimatePromptTokens(20)
  213. chunk := dto.GeminiChatResponse{
  214. Candidates: []dto.GeminiChatCandidate{
  215. {
  216. Content: dto.GeminiChatContent{
  217. Role: "model",
  218. Parts: []dto.GeminiPart{
  219. {Text: "partial"},
  220. },
  221. },
  222. },
  223. },
  224. UsageMetadata: dto.GeminiUsageMetadata{
  225. PromptTokenCount: 0,
  226. ToolUsePromptTokenCount: 0,
  227. CandidatesTokenCount: 90,
  228. ThoughtsTokenCount: 10,
  229. TotalTokenCount: 110,
  230. },
  231. }
  232. chunkData, err := common.Marshal(chunk)
  233. require.NoError(t, err)
  234. streamBody := []byte("data: " + string(chunkData) + "\n" + "data: [DONE]\n")
  235. resp := &http.Response{
  236. Body: io.NopCloser(bytes.NewReader(streamBody)),
  237. }
  238. usage, newAPIError := geminiStreamHandler(c, info, resp, func(_ string, _ *dto.GeminiChatResponse) bool {
  239. return true
  240. })
  241. require.Nil(t, newAPIError)
  242. require.NotNil(t, usage)
  243. require.Equal(t, 20, usage.PromptTokens)
  244. require.Equal(t, 100, usage.CompletionTokens)
  245. require.Equal(t, 110, usage.TotalTokens)
  246. }
  247. func TestGeminiTextGenerationHandlerUsesEstimatedPromptTokensWhenUsagePromptMissing(t *testing.T) {
  248. t.Parallel()
  249. gin.SetMode(gin.TestMode)
  250. c, _ := gin.CreateTestContext(httptest.NewRecorder())
  251. c.Request = httptest.NewRequest(http.MethodPost, "/v1beta/models/gemini-3-flash-preview:generateContent", nil)
  252. info := &relaycommon.RelayInfo{
  253. OriginModelName: "gemini-3-flash-preview",
  254. ChannelMeta: &relaycommon.ChannelMeta{
  255. UpstreamModelName: "gemini-3-flash-preview",
  256. },
  257. }
  258. info.SetEstimatePromptTokens(20)
  259. payload := dto.GeminiChatResponse{
  260. Candidates: []dto.GeminiChatCandidate{
  261. {
  262. Content: dto.GeminiChatContent{
  263. Role: "model",
  264. Parts: []dto.GeminiPart{
  265. {Text: "ok"},
  266. },
  267. },
  268. },
  269. },
  270. UsageMetadata: dto.GeminiUsageMetadata{
  271. PromptTokenCount: 0,
  272. ToolUsePromptTokenCount: 0,
  273. CandidatesTokenCount: 90,
  274. ThoughtsTokenCount: 10,
  275. TotalTokenCount: 110,
  276. },
  277. }
  278. body, err := common.Marshal(payload)
  279. require.NoError(t, err)
  280. resp := &http.Response{
  281. Body: io.NopCloser(bytes.NewReader(body)),
  282. }
  283. usage, newAPIError := GeminiTextGenerationHandler(c, info, resp)
  284. require.Nil(t, newAPIError)
  285. require.NotNil(t, usage)
  286. require.Equal(t, 20, usage.PromptTokens)
  287. require.Equal(t, 100, usage.CompletionTokens)
  288. require.Equal(t, 110, usage.TotalTokens)
  289. }