gemini_generation_config_test.go 2.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889
  1. package dto
  2. import (
  3. "testing"
  4. "github.com/QuantumNous/new-api/common"
  5. "github.com/stretchr/testify/assert"
  6. "github.com/stretchr/testify/require"
  7. )
  8. func TestGeminiChatGenerationConfigPreservesExplicitZeroValuesCamelCase(t *testing.T) {
  9. raw := []byte(`{
  10. "contents":[{"role":"user","parts":[{"text":"hello"}]}],
  11. "generationConfig":{
  12. "topP":0,
  13. "topK":0,
  14. "maxOutputTokens":0,
  15. "candidateCount":0,
  16. "seed":0,
  17. "responseLogprobs":false
  18. }
  19. }`)
  20. var req GeminiChatRequest
  21. require.NoError(t, common.Unmarshal(raw, &req))
  22. encoded, err := common.Marshal(req)
  23. require.NoError(t, err)
  24. var out map[string]any
  25. require.NoError(t, common.Unmarshal(encoded, &out))
  26. generationConfig, ok := out["generationConfig"].(map[string]any)
  27. require.True(t, ok)
  28. assert.Contains(t, generationConfig, "topP")
  29. assert.Contains(t, generationConfig, "topK")
  30. assert.Contains(t, generationConfig, "maxOutputTokens")
  31. assert.Contains(t, generationConfig, "candidateCount")
  32. assert.Contains(t, generationConfig, "seed")
  33. assert.Contains(t, generationConfig, "responseLogprobs")
  34. assert.Equal(t, float64(0), generationConfig["topP"])
  35. assert.Equal(t, float64(0), generationConfig["topK"])
  36. assert.Equal(t, float64(0), generationConfig["maxOutputTokens"])
  37. assert.Equal(t, float64(0), generationConfig["candidateCount"])
  38. assert.Equal(t, float64(0), generationConfig["seed"])
  39. assert.Equal(t, false, generationConfig["responseLogprobs"])
  40. }
  41. func TestGeminiChatGenerationConfigPreservesExplicitZeroValuesSnakeCase(t *testing.T) {
  42. raw := []byte(`{
  43. "contents":[{"role":"user","parts":[{"text":"hello"}]}],
  44. "generationConfig":{
  45. "top_p":0,
  46. "top_k":0,
  47. "max_output_tokens":0,
  48. "candidate_count":0,
  49. "seed":0,
  50. "response_logprobs":false
  51. }
  52. }`)
  53. var req GeminiChatRequest
  54. require.NoError(t, common.Unmarshal(raw, &req))
  55. encoded, err := common.Marshal(req)
  56. require.NoError(t, err)
  57. var out map[string]any
  58. require.NoError(t, common.Unmarshal(encoded, &out))
  59. generationConfig, ok := out["generationConfig"].(map[string]any)
  60. require.True(t, ok)
  61. assert.Contains(t, generationConfig, "topP")
  62. assert.Contains(t, generationConfig, "topK")
  63. assert.Contains(t, generationConfig, "maxOutputTokens")
  64. assert.Contains(t, generationConfig, "candidateCount")
  65. assert.Contains(t, generationConfig, "seed")
  66. assert.Contains(t, generationConfig, "responseLogprobs")
  67. assert.Equal(t, float64(0), generationConfig["topP"])
  68. assert.Equal(t, float64(0), generationConfig["topK"])
  69. assert.Equal(t, float64(0), generationConfig["maxOutputTokens"])
  70. assert.Equal(t, float64(0), generationConfig["candidateCount"])
  71. assert.Equal(t, float64(0), generationConfig["seed"])
  72. assert.Equal(t, false, generationConfig["responseLogprobs"])
  73. }