|
@@ -0,0 +1,111 @@
|
|
|
|
|
+package claude
|
|
|
|
|
+
|
|
|
|
|
+import (
|
|
|
|
|
+ "testing"
|
|
|
|
|
+
|
|
|
|
|
+ "github.com/QuantumNous/new-api/dto"
|
|
|
|
|
+ relaycommon "github.com/QuantumNous/new-api/relay/common"
|
|
|
|
|
+ "github.com/QuantumNous/new-api/setting/model_setting"
|
|
|
|
|
+ "github.com/stretchr/testify/assert"
|
|
|
|
|
+ "github.com/stretchr/testify/require"
|
|
|
|
|
+ "github.com/tidwall/gjson"
|
|
|
|
|
+)
|
|
|
|
|
+
|
|
|
|
|
+func TestPatchClaudeMessageDeltaUsageDataPreserveUnknownFields(t *testing.T) {
|
|
|
|
|
+ originalData := `{"type":"message_delta","delta":{"stop_reason":"end_turn","stop_sequence":null},"usage":{"output_tokens":53},"vendor_meta":{"trace_id":"trace_001"}}`
|
|
|
|
|
+ usage := &dto.ClaudeUsage{
|
|
|
|
|
+ InputTokens: 100,
|
|
|
|
|
+ CacheReadInputTokens: 30,
|
|
|
|
|
+ CacheCreationInputTokens: 50,
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ patchedData := patchClaudeMessageDeltaUsageData(originalData, usage)
|
|
|
|
|
+
|
|
|
|
|
+ require.Equal(t, "message_delta", gjson.Get(patchedData, "type").String())
|
|
|
|
|
+ require.Equal(t, "end_turn", gjson.Get(patchedData, "delta.stop_reason").String())
|
|
|
|
|
+ require.Equal(t, "trace_001", gjson.Get(patchedData, "vendor_meta.trace_id").String())
|
|
|
|
|
+ require.EqualValues(t, 53, gjson.Get(patchedData, "usage.output_tokens").Int())
|
|
|
|
|
+ require.EqualValues(t, 100, gjson.Get(patchedData, "usage.input_tokens").Int())
|
|
|
|
|
+ require.EqualValues(t, 30, gjson.Get(patchedData, "usage.cache_read_input_tokens").Int())
|
|
|
|
|
+ require.EqualValues(t, 50, gjson.Get(patchedData, "usage.cache_creation_input_tokens").Int())
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+func TestPatchClaudeMessageDeltaUsageDataZeroValueChecks(t *testing.T) {
|
|
|
|
|
+ originalData := `{"type":"message_delta","usage":{"output_tokens":53,"input_tokens":9,"cache_read_input_tokens":0}}`
|
|
|
|
|
+ usage := &dto.ClaudeUsage{
|
|
|
|
|
+ InputTokens: 100,
|
|
|
|
|
+ CacheReadInputTokens: 30,
|
|
|
|
|
+ CacheCreationInputTokens: 0,
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ patchedData := patchClaudeMessageDeltaUsageData(originalData, usage)
|
|
|
|
|
+
|
|
|
|
|
+ require.EqualValues(t, 9, gjson.Get(patchedData, "usage.input_tokens").Int())
|
|
|
|
|
+ require.EqualValues(t, 30, gjson.Get(patchedData, "usage.cache_read_input_tokens").Int())
|
|
|
|
|
+ assert.False(t, gjson.Get(patchedData, "usage.cache_creation_input_tokens").Exists())
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+func TestShouldSkipClaudeMessageDeltaUsagePatch(t *testing.T) {
|
|
|
|
|
+ originGlobalPassThrough := model_setting.GetGlobalSettings().PassThroughRequestEnabled
|
|
|
|
|
+ t.Cleanup(func() {
|
|
|
|
|
+ model_setting.GetGlobalSettings().PassThroughRequestEnabled = originGlobalPassThrough
|
|
|
|
|
+ })
|
|
|
|
|
+
|
|
|
|
|
+ model_setting.GetGlobalSettings().PassThroughRequestEnabled = true
|
|
|
|
|
+ assert.True(t, shouldSkipClaudeMessageDeltaUsagePatch(&relaycommon.RelayInfo{}))
|
|
|
|
|
+
|
|
|
|
|
+ model_setting.GetGlobalSettings().PassThroughRequestEnabled = false
|
|
|
|
|
+ assert.True(t, shouldSkipClaudeMessageDeltaUsagePatch(&relaycommon.RelayInfo{
|
|
|
|
|
+ ChannelMeta: &relaycommon.ChannelMeta{ChannelSetting: dto.ChannelSettings{PassThroughBodyEnabled: true}},
|
|
|
|
|
+ }))
|
|
|
|
|
+ assert.False(t, shouldSkipClaudeMessageDeltaUsagePatch(&relaycommon.RelayInfo{
|
|
|
|
|
+ ChannelMeta: &relaycommon.ChannelMeta{ChannelSetting: dto.ChannelSettings{PassThroughBodyEnabled: false}},
|
|
|
|
|
+ }))
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+func TestBuildMessageDeltaPatchUsage(t *testing.T) {
|
|
|
|
|
+ t.Run("merge missing fields from claudeInfo", func(t *testing.T) {
|
|
|
|
|
+ claudeResponse := &dto.ClaudeResponse{Usage: &dto.ClaudeUsage{OutputTokens: 53}}
|
|
|
|
|
+ claudeInfo := &ClaudeResponseInfo{
|
|
|
|
|
+ Usage: &dto.Usage{
|
|
|
|
|
+ PromptTokens: 100,
|
|
|
|
|
+ PromptTokensDetails: dto.InputTokenDetails{
|
|
|
|
|
+ CachedTokens: 30,
|
|
|
|
|
+ CachedCreationTokens: 50,
|
|
|
|
|
+ },
|
|
|
|
|
+ ClaudeCacheCreation5mTokens: 10,
|
|
|
|
|
+ ClaudeCacheCreation1hTokens: 20,
|
|
|
|
|
+ },
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ usage := buildMessageDeltaPatchUsage(claudeResponse, claudeInfo)
|
|
|
|
|
+ require.NotNil(t, usage)
|
|
|
|
|
+ require.EqualValues(t, 100, usage.InputTokens)
|
|
|
|
|
+ require.EqualValues(t, 30, usage.CacheReadInputTokens)
|
|
|
|
|
+ require.EqualValues(t, 50, usage.CacheCreationInputTokens)
|
|
|
|
|
+ require.EqualValues(t, 53, usage.OutputTokens)
|
|
|
|
|
+ require.NotNil(t, usage.CacheCreation)
|
|
|
|
|
+ require.EqualValues(t, 10, usage.CacheCreation.Ephemeral5mInputTokens)
|
|
|
|
|
+ require.EqualValues(t, 20, usage.CacheCreation.Ephemeral1hInputTokens)
|
|
|
|
|
+ })
|
|
|
|
|
+
|
|
|
|
|
+ t.Run("keep upstream non-zero values", func(t *testing.T) {
|
|
|
|
|
+ claudeResponse := &dto.ClaudeResponse{Usage: &dto.ClaudeUsage{
|
|
|
|
|
+ InputTokens: 9,
|
|
|
|
|
+ CacheReadInputTokens: 7,
|
|
|
|
|
+ CacheCreationInputTokens: 6,
|
|
|
|
|
+ }}
|
|
|
|
|
+ claudeInfo := &ClaudeResponseInfo{Usage: &dto.Usage{
|
|
|
|
|
+ PromptTokens: 100,
|
|
|
|
|
+ PromptTokensDetails: dto.InputTokenDetails{
|
|
|
|
|
+ CachedTokens: 30,
|
|
|
|
|
+ CachedCreationTokens: 50,
|
|
|
|
|
+ },
|
|
|
|
|
+ }}
|
|
|
|
|
+
|
|
|
|
|
+ usage := buildMessageDeltaPatchUsage(claudeResponse, claudeInfo)
|
|
|
|
|
+ require.EqualValues(t, 9, usage.InputTokens)
|
|
|
|
|
+ require.EqualValues(t, 7, usage.CacheReadInputTokens)
|
|
|
|
|
+ require.EqualValues(t, 6, usage.CacheCreationInputTokens)
|
|
|
|
|
+ })
|
|
|
|
|
+}
|