| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111 |
- package claude
- import (
- "testing"
- "github.com/QuantumNous/new-api/dto"
- relaycommon "github.com/QuantumNous/new-api/relay/common"
- "github.com/QuantumNous/new-api/setting/model_setting"
- "github.com/stretchr/testify/assert"
- "github.com/stretchr/testify/require"
- "github.com/tidwall/gjson"
- )
- func TestPatchClaudeMessageDeltaUsageDataPreserveUnknownFields(t *testing.T) {
- originalData := `{"type":"message_delta","delta":{"stop_reason":"end_turn","stop_sequence":null},"usage":{"output_tokens":53},"vendor_meta":{"trace_id":"trace_001"}}`
- usage := &dto.ClaudeUsage{
- InputTokens: 100,
- CacheReadInputTokens: 30,
- CacheCreationInputTokens: 50,
- }
- patchedData := patchClaudeMessageDeltaUsageData(originalData, usage)
- require.Equal(t, "message_delta", gjson.Get(patchedData, "type").String())
- require.Equal(t, "end_turn", gjson.Get(patchedData, "delta.stop_reason").String())
- require.Equal(t, "trace_001", gjson.Get(patchedData, "vendor_meta.trace_id").String())
- require.EqualValues(t, 53, gjson.Get(patchedData, "usage.output_tokens").Int())
- require.EqualValues(t, 100, gjson.Get(patchedData, "usage.input_tokens").Int())
- require.EqualValues(t, 30, gjson.Get(patchedData, "usage.cache_read_input_tokens").Int())
- require.EqualValues(t, 50, gjson.Get(patchedData, "usage.cache_creation_input_tokens").Int())
- }
- func TestPatchClaudeMessageDeltaUsageDataZeroValueChecks(t *testing.T) {
- originalData := `{"type":"message_delta","usage":{"output_tokens":53,"input_tokens":9,"cache_read_input_tokens":0}}`
- usage := &dto.ClaudeUsage{
- InputTokens: 100,
- CacheReadInputTokens: 30,
- CacheCreationInputTokens: 0,
- }
- patchedData := patchClaudeMessageDeltaUsageData(originalData, usage)
- require.EqualValues(t, 9, gjson.Get(patchedData, "usage.input_tokens").Int())
- require.EqualValues(t, 30, gjson.Get(patchedData, "usage.cache_read_input_tokens").Int())
- assert.False(t, gjson.Get(patchedData, "usage.cache_creation_input_tokens").Exists())
- }
- func TestShouldSkipClaudeMessageDeltaUsagePatch(t *testing.T) {
- originGlobalPassThrough := model_setting.GetGlobalSettings().PassThroughRequestEnabled
- t.Cleanup(func() {
- model_setting.GetGlobalSettings().PassThroughRequestEnabled = originGlobalPassThrough
- })
- model_setting.GetGlobalSettings().PassThroughRequestEnabled = true
- assert.True(t, shouldSkipClaudeMessageDeltaUsagePatch(&relaycommon.RelayInfo{}))
- model_setting.GetGlobalSettings().PassThroughRequestEnabled = false
- assert.True(t, shouldSkipClaudeMessageDeltaUsagePatch(&relaycommon.RelayInfo{
- ChannelMeta: &relaycommon.ChannelMeta{ChannelSetting: dto.ChannelSettings{PassThroughBodyEnabled: true}},
- }))
- assert.False(t, shouldSkipClaudeMessageDeltaUsagePatch(&relaycommon.RelayInfo{
- ChannelMeta: &relaycommon.ChannelMeta{ChannelSetting: dto.ChannelSettings{PassThroughBodyEnabled: false}},
- }))
- }
- func TestBuildMessageDeltaPatchUsage(t *testing.T) {
- t.Run("merge missing fields from claudeInfo", func(t *testing.T) {
- claudeResponse := &dto.ClaudeResponse{Usage: &dto.ClaudeUsage{OutputTokens: 53}}
- claudeInfo := &ClaudeResponseInfo{
- Usage: &dto.Usage{
- PromptTokens: 100,
- PromptTokensDetails: dto.InputTokenDetails{
- CachedTokens: 30,
- CachedCreationTokens: 50,
- },
- ClaudeCacheCreation5mTokens: 10,
- ClaudeCacheCreation1hTokens: 20,
- },
- }
- usage := buildMessageDeltaPatchUsage(claudeResponse, claudeInfo)
- require.NotNil(t, usage)
- require.EqualValues(t, 100, usage.InputTokens)
- require.EqualValues(t, 30, usage.CacheReadInputTokens)
- require.EqualValues(t, 50, usage.CacheCreationInputTokens)
- require.EqualValues(t, 53, usage.OutputTokens)
- require.NotNil(t, usage.CacheCreation)
- require.EqualValues(t, 10, usage.CacheCreation.Ephemeral5mInputTokens)
- require.EqualValues(t, 20, usage.CacheCreation.Ephemeral1hInputTokens)
- })
- t.Run("keep upstream non-zero values", func(t *testing.T) {
- claudeResponse := &dto.ClaudeResponse{Usage: &dto.ClaudeUsage{
- InputTokens: 9,
- CacheReadInputTokens: 7,
- CacheCreationInputTokens: 6,
- }}
- claudeInfo := &ClaudeResponseInfo{Usage: &dto.Usage{
- PromptTokens: 100,
- PromptTokensDetails: dto.InputTokenDetails{
- CachedTokens: 30,
- CachedCreationTokens: 50,
- },
- }}
- usage := buildMessageDeltaPatchUsage(claudeResponse, claudeInfo)
- require.EqualValues(t, 9, usage.InputTokens)
- require.EqualValues(t, 7, usage.CacheReadInputTokens)
- require.EqualValues(t, 6, usage.CacheCreationInputTokens)
- })
- }
|