channel_affinity_usage_cache_test.go 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105
  1. package service
  2. import (
  3. "fmt"
  4. "net/http/httptest"
  5. "testing"
  6. "time"
  7. "github.com/QuantumNous/new-api/dto"
  8. "github.com/QuantumNous/new-api/types"
  9. "github.com/gin-gonic/gin"
  10. "github.com/stretchr/testify/require"
  11. )
  12. func buildChannelAffinityStatsContextForTest(ruleName, usingGroup, keyFP string) *gin.Context {
  13. rec := httptest.NewRecorder()
  14. ctx, _ := gin.CreateTestContext(rec)
  15. setChannelAffinityContext(ctx, channelAffinityMeta{
  16. CacheKey: fmt.Sprintf("test:%s:%s:%s", ruleName, usingGroup, keyFP),
  17. TTLSeconds: 600,
  18. RuleName: ruleName,
  19. UsingGroup: usingGroup,
  20. KeyFingerprint: keyFP,
  21. })
  22. return ctx
  23. }
  24. func TestObserveChannelAffinityUsageCacheByRelayFormat_ClaudeMode(t *testing.T) {
  25. ruleName := fmt.Sprintf("rule_%d", time.Now().UnixNano())
  26. usingGroup := "default"
  27. keyFP := fmt.Sprintf("fp_%d", time.Now().UnixNano())
  28. ctx := buildChannelAffinityStatsContextForTest(ruleName, usingGroup, keyFP)
  29. usage := &dto.Usage{
  30. PromptTokens: 100,
  31. CompletionTokens: 40,
  32. TotalTokens: 140,
  33. PromptTokensDetails: dto.InputTokenDetails{
  34. CachedTokens: 30,
  35. },
  36. }
  37. ObserveChannelAffinityUsageCacheByRelayFormat(ctx, usage, types.RelayFormatClaude)
  38. stats := GetChannelAffinityUsageCacheStats(ruleName, usingGroup, keyFP)
  39. require.EqualValues(t, 1, stats.Total)
  40. require.EqualValues(t, 1, stats.Hit)
  41. require.EqualValues(t, 100, stats.PromptTokens)
  42. require.EqualValues(t, 40, stats.CompletionTokens)
  43. require.EqualValues(t, 140, stats.TotalTokens)
  44. require.EqualValues(t, 30, stats.CachedTokens)
  45. require.Equal(t, cacheTokenRateModeCachedOverPromptPlusCached, stats.CachedTokenRateMode)
  46. }
  47. func TestObserveChannelAffinityUsageCacheByRelayFormat_MixedMode(t *testing.T) {
  48. ruleName := fmt.Sprintf("rule_%d", time.Now().UnixNano())
  49. usingGroup := "default"
  50. keyFP := fmt.Sprintf("fp_%d", time.Now().UnixNano())
  51. ctx := buildChannelAffinityStatsContextForTest(ruleName, usingGroup, keyFP)
  52. openAIUsage := &dto.Usage{
  53. PromptTokens: 100,
  54. PromptTokensDetails: dto.InputTokenDetails{
  55. CachedTokens: 10,
  56. },
  57. }
  58. claudeUsage := &dto.Usage{
  59. PromptTokens: 80,
  60. PromptTokensDetails: dto.InputTokenDetails{
  61. CachedTokens: 20,
  62. },
  63. }
  64. ObserveChannelAffinityUsageCacheByRelayFormat(ctx, openAIUsage, types.RelayFormatOpenAI)
  65. ObserveChannelAffinityUsageCacheByRelayFormat(ctx, claudeUsage, types.RelayFormatClaude)
  66. stats := GetChannelAffinityUsageCacheStats(ruleName, usingGroup, keyFP)
  67. require.EqualValues(t, 2, stats.Total)
  68. require.EqualValues(t, 2, stats.Hit)
  69. require.EqualValues(t, 180, stats.PromptTokens)
  70. require.EqualValues(t, 30, stats.CachedTokens)
  71. require.Equal(t, cacheTokenRateModeMixed, stats.CachedTokenRateMode)
  72. }
  73. func TestObserveChannelAffinityUsageCacheByRelayFormat_UnsupportedModeKeepsEmpty(t *testing.T) {
  74. ruleName := fmt.Sprintf("rule_%d", time.Now().UnixNano())
  75. usingGroup := "default"
  76. keyFP := fmt.Sprintf("fp_%d", time.Now().UnixNano())
  77. ctx := buildChannelAffinityStatsContextForTest(ruleName, usingGroup, keyFP)
  78. usage := &dto.Usage{
  79. PromptTokens: 100,
  80. PromptTokensDetails: dto.InputTokenDetails{
  81. CachedTokens: 25,
  82. },
  83. }
  84. ObserveChannelAffinityUsageCacheByRelayFormat(ctx, usage, types.RelayFormatGemini)
  85. stats := GetChannelAffinityUsageCacheStats(ruleName, usingGroup, keyFP)
  86. require.EqualValues(t, 1, stats.Total)
  87. require.EqualValues(t, 1, stats.Hit)
  88. require.EqualValues(t, 25, stats.CachedTokens)
  89. require.Equal(t, "", stats.CachedTokenRateMode)
  90. }