message_delta_usage_patch_test.go 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111
  1. package claude
  2. import (
  3. "testing"
  4. "github.com/QuantumNous/new-api/dto"
  5. relaycommon "github.com/QuantumNous/new-api/relay/common"
  6. "github.com/QuantumNous/new-api/setting/model_setting"
  7. "github.com/stretchr/testify/assert"
  8. "github.com/stretchr/testify/require"
  9. "github.com/tidwall/gjson"
  10. )
  11. func TestPatchClaudeMessageDeltaUsageDataPreserveUnknownFields(t *testing.T) {
  12. originalData := `{"type":"message_delta","delta":{"stop_reason":"end_turn","stop_sequence":null},"usage":{"output_tokens":53},"vendor_meta":{"trace_id":"trace_001"}}`
  13. usage := &dto.ClaudeUsage{
  14. InputTokens: 100,
  15. CacheReadInputTokens: 30,
  16. CacheCreationInputTokens: 50,
  17. }
  18. patchedData := patchClaudeMessageDeltaUsageData(originalData, usage)
  19. require.Equal(t, "message_delta", gjson.Get(patchedData, "type").String())
  20. require.Equal(t, "end_turn", gjson.Get(patchedData, "delta.stop_reason").String())
  21. require.Equal(t, "trace_001", gjson.Get(patchedData, "vendor_meta.trace_id").String())
  22. require.EqualValues(t, 53, gjson.Get(patchedData, "usage.output_tokens").Int())
  23. require.EqualValues(t, 100, gjson.Get(patchedData, "usage.input_tokens").Int())
  24. require.EqualValues(t, 30, gjson.Get(patchedData, "usage.cache_read_input_tokens").Int())
  25. require.EqualValues(t, 50, gjson.Get(patchedData, "usage.cache_creation_input_tokens").Int())
  26. }
  27. func TestPatchClaudeMessageDeltaUsageDataZeroValueChecks(t *testing.T) {
  28. originalData := `{"type":"message_delta","usage":{"output_tokens":53,"input_tokens":9,"cache_read_input_tokens":0}}`
  29. usage := &dto.ClaudeUsage{
  30. InputTokens: 100,
  31. CacheReadInputTokens: 30,
  32. CacheCreationInputTokens: 0,
  33. }
  34. patchedData := patchClaudeMessageDeltaUsageData(originalData, usage)
  35. require.EqualValues(t, 9, gjson.Get(patchedData, "usage.input_tokens").Int())
  36. require.EqualValues(t, 30, gjson.Get(patchedData, "usage.cache_read_input_tokens").Int())
  37. assert.False(t, gjson.Get(patchedData, "usage.cache_creation_input_tokens").Exists())
  38. }
  39. func TestShouldSkipClaudeMessageDeltaUsagePatch(t *testing.T) {
  40. originGlobalPassThrough := model_setting.GetGlobalSettings().PassThroughRequestEnabled
  41. t.Cleanup(func() {
  42. model_setting.GetGlobalSettings().PassThroughRequestEnabled = originGlobalPassThrough
  43. })
  44. model_setting.GetGlobalSettings().PassThroughRequestEnabled = true
  45. assert.True(t, shouldSkipClaudeMessageDeltaUsagePatch(&relaycommon.RelayInfo{}))
  46. model_setting.GetGlobalSettings().PassThroughRequestEnabled = false
  47. assert.True(t, shouldSkipClaudeMessageDeltaUsagePatch(&relaycommon.RelayInfo{
  48. ChannelMeta: &relaycommon.ChannelMeta{ChannelSetting: dto.ChannelSettings{PassThroughBodyEnabled: true}},
  49. }))
  50. assert.False(t, shouldSkipClaudeMessageDeltaUsagePatch(&relaycommon.RelayInfo{
  51. ChannelMeta: &relaycommon.ChannelMeta{ChannelSetting: dto.ChannelSettings{PassThroughBodyEnabled: false}},
  52. }))
  53. }
  54. func TestBuildMessageDeltaPatchUsage(t *testing.T) {
  55. t.Run("merge missing fields from claudeInfo", func(t *testing.T) {
  56. claudeResponse := &dto.ClaudeResponse{Usage: &dto.ClaudeUsage{OutputTokens: 53}}
  57. claudeInfo := &ClaudeResponseInfo{
  58. Usage: &dto.Usage{
  59. PromptTokens: 100,
  60. PromptTokensDetails: dto.InputTokenDetails{
  61. CachedTokens: 30,
  62. CachedCreationTokens: 50,
  63. },
  64. ClaudeCacheCreation5mTokens: 10,
  65. ClaudeCacheCreation1hTokens: 20,
  66. },
  67. }
  68. usage := buildMessageDeltaPatchUsage(claudeResponse, claudeInfo)
  69. require.NotNil(t, usage)
  70. require.EqualValues(t, 100, usage.InputTokens)
  71. require.EqualValues(t, 30, usage.CacheReadInputTokens)
  72. require.EqualValues(t, 50, usage.CacheCreationInputTokens)
  73. require.EqualValues(t, 53, usage.OutputTokens)
  74. require.NotNil(t, usage.CacheCreation)
  75. require.EqualValues(t, 10, usage.CacheCreation.Ephemeral5mInputTokens)
  76. require.EqualValues(t, 20, usage.CacheCreation.Ephemeral1hInputTokens)
  77. })
  78. t.Run("keep upstream non-zero values", func(t *testing.T) {
  79. claudeResponse := &dto.ClaudeResponse{Usage: &dto.ClaudeUsage{
  80. InputTokens: 9,
  81. CacheReadInputTokens: 7,
  82. CacheCreationInputTokens: 6,
  83. }}
  84. claudeInfo := &ClaudeResponseInfo{Usage: &dto.Usage{
  85. PromptTokens: 100,
  86. PromptTokensDetails: dto.InputTokenDetails{
  87. CachedTokens: 30,
  88. CachedCreationTokens: 50,
  89. },
  90. }}
  91. usage := buildMessageDeltaPatchUsage(claudeResponse, claudeInfo)
  92. require.EqualValues(t, 9, usage.InputTokens)
  93. require.EqualValues(t, 7, usage.CacheReadInputTokens)
  94. require.EqualValues(t, 6, usage.CacheCreationInputTokens)
  95. })
  96. }