package service import ( "fmt" "hash/fnv" "regexp" "strconv" "strings" "sync" "time" "github.com/QuantumNous/new-api/common" "github.com/QuantumNous/new-api/dto" "github.com/QuantumNous/new-api/pkg/cachex" "github.com/QuantumNous/new-api/setting/operation_setting" "github.com/QuantumNous/new-api/types" "github.com/gin-gonic/gin" "github.com/samber/hot" "github.com/tidwall/gjson" ) const ( ginKeyChannelAffinityCacheKey = "channel_affinity_cache_key" ginKeyChannelAffinityTTLSeconds = "channel_affinity_ttl_seconds" ginKeyChannelAffinityMeta = "channel_affinity_meta" ginKeyChannelAffinityLogInfo = "channel_affinity_log_info" ginKeyChannelAffinitySkipRetry = "channel_affinity_skip_retry_on_failure" channelAffinityCacheNamespace = "new-api:channel_affinity:v1" channelAffinityUsageCacheStatsNamespace = "new-api:channel_affinity_usage_cache_stats:v1" ) var ( channelAffinityCacheOnce sync.Once channelAffinityCache *cachex.HybridCache[int] channelAffinityUsageCacheStatsOnce sync.Once channelAffinityUsageCacheStatsCache *cachex.HybridCache[ChannelAffinityUsageCacheCounters] channelAffinityRegexCache sync.Map // map[string]*regexp.Regexp ) type channelAffinityMeta struct { CacheKey string TTLSeconds int RuleName string SkipRetry bool KeySourceType string KeySourceKey string KeySourcePath string KeyHint string KeyFingerprint string UsingGroup string ModelName string RequestPath string } type ChannelAffinityStatsContext struct { RuleName string UsingGroup string KeyFingerprint string TTLSeconds int64 } const ( cacheTokenRateModeCachedOverPrompt = "cached_over_prompt" cacheTokenRateModeCachedOverPromptPlusCached = "cached_over_prompt_plus_cached" cacheTokenRateModeMixed = "mixed" ) type ChannelAffinityCacheStats struct { Enabled bool `json:"enabled"` Total int `json:"total"` Unknown int `json:"unknown"` ByRuleName map[string]int `json:"by_rule_name"` CacheCapacity int `json:"cache_capacity"` CacheAlgo string `json:"cache_algo"` } func getChannelAffinityCache() *cachex.HybridCache[int] { channelAffinityCacheOnce.Do(func() { setting := operation_setting.GetChannelAffinitySetting() capacity := setting.MaxEntries if capacity <= 0 { capacity = 100_000 } defaultTTLSeconds := setting.DefaultTTLSeconds if defaultTTLSeconds <= 0 { defaultTTLSeconds = 3600 } channelAffinityCache = cachex.NewHybridCache[int](cachex.HybridCacheConfig[int]{ Namespace: cachex.Namespace(channelAffinityCacheNamespace), Redis: common.RDB, RedisEnabled: func() bool { return common.RedisEnabled && common.RDB != nil }, RedisCodec: cachex.IntCodec{}, Memory: func() *hot.HotCache[string, int] { return hot.NewHotCache[string, int](hot.LRU, capacity). WithTTL(time.Duration(defaultTTLSeconds) * time.Second). WithJanitor(). Build() }, }) }) return channelAffinityCache } func GetChannelAffinityCacheStats() ChannelAffinityCacheStats { setting := operation_setting.GetChannelAffinitySetting() if setting == nil { return ChannelAffinityCacheStats{ Enabled: false, Total: 0, Unknown: 0, ByRuleName: map[string]int{}, } } cache := getChannelAffinityCache() mainCap, _ := cache.Capacity() mainAlgo, _ := cache.Algorithm() rules := setting.Rules ruleByName := make(map[string]operation_setting.ChannelAffinityRule, len(rules)) for _, r := range rules { name := strings.TrimSpace(r.Name) if name == "" { continue } if !r.IncludeRuleName { continue } ruleByName[name] = r } byRuleName := make(map[string]int, len(ruleByName)) for name := range ruleByName { byRuleName[name] = 0 } keys, err := cache.Keys() if err != nil { common.SysError(fmt.Sprintf("channel affinity cache list keys failed: err=%v", err)) keys = nil } total := len(keys) unknown := 0 for _, k := range keys { prefix := channelAffinityCacheNamespace + ":" if !strings.HasPrefix(k, prefix) { unknown++ continue } rest := strings.TrimPrefix(k, prefix) parts := strings.Split(rest, ":") if len(parts) < 2 { unknown++ continue } ruleName := parts[0] rule, ok := ruleByName[ruleName] if !ok { unknown++ continue } if rule.IncludeUsingGroup { if len(parts) < 3 { unknown++ continue } } byRuleName[ruleName]++ } return ChannelAffinityCacheStats{ Enabled: setting.Enabled, Total: total, Unknown: unknown, ByRuleName: byRuleName, CacheCapacity: mainCap, CacheAlgo: mainAlgo, } } func ClearChannelAffinityCacheAll() int { cache := getChannelAffinityCache() keys, err := cache.Keys() if err != nil { common.SysError(fmt.Sprintf("channel affinity cache list keys failed: err=%v", err)) keys = nil } if len(keys) > 0 { if _, err := cache.DeleteMany(keys); err != nil { common.SysError(fmt.Sprintf("channel affinity cache delete many failed: err=%v", err)) } } return len(keys) } func ClearChannelAffinityCacheByRuleName(ruleName string) (int, error) { ruleName = strings.TrimSpace(ruleName) if ruleName == "" { return 0, fmt.Errorf("rule_name 不能为空") } setting := operation_setting.GetChannelAffinitySetting() if setting == nil { return 0, fmt.Errorf("channel_affinity_setting 未初始化") } var matchedRule *operation_setting.ChannelAffinityRule for i := range setting.Rules { r := &setting.Rules[i] if strings.TrimSpace(r.Name) != ruleName { continue } matchedRule = r break } if matchedRule == nil { return 0, fmt.Errorf("未知规则名称") } if !matchedRule.IncludeRuleName { return 0, fmt.Errorf("该规则未启用 include_rule_name,无法按规则清空缓存") } cache := getChannelAffinityCache() deleted, err := cache.DeleteByPrefix(ruleName) if err != nil { return 0, err } return deleted, nil } func matchAnyRegexCached(patterns []string, s string) bool { if len(patterns) == 0 || s == "" { return false } for _, pattern := range patterns { if pattern == "" { continue } re, ok := channelAffinityRegexCache.Load(pattern) if !ok { compiled, err := regexp.Compile(pattern) if err != nil { continue } re = compiled channelAffinityRegexCache.Store(pattern, re) } if re.(*regexp.Regexp).MatchString(s) { return true } } return false } func matchAnyIncludeFold(patterns []string, s string) bool { if len(patterns) == 0 || s == "" { return false } sLower := strings.ToLower(s) for _, p := range patterns { p = strings.TrimSpace(p) if p == "" { continue } if strings.Contains(sLower, strings.ToLower(p)) { return true } } return false } func extractChannelAffinityValue(c *gin.Context, src operation_setting.ChannelAffinityKeySource) string { switch src.Type { case "context_int": if src.Key == "" { return "" } v := c.GetInt(src.Key) if v <= 0 { return "" } return strconv.Itoa(v) case "context_string": if src.Key == "" { return "" } return strings.TrimSpace(c.GetString(src.Key)) case "gjson": if src.Path == "" { return "" } storage, err := common.GetBodyStorage(c) if err != nil { return "" } body, err := storage.Bytes() if err != nil || len(body) == 0 { return "" } res := gjson.GetBytes(body, src.Path) if !res.Exists() { return "" } switch res.Type { case gjson.String, gjson.Number, gjson.True, gjson.False: return strings.TrimSpace(res.String()) default: return strings.TrimSpace(res.Raw) } default: return "" } } func buildChannelAffinityCacheKeySuffix(rule operation_setting.ChannelAffinityRule, usingGroup string, affinityValue string) string { parts := make([]string, 0, 3) if rule.IncludeRuleName && rule.Name != "" { parts = append(parts, rule.Name) } if rule.IncludeUsingGroup && usingGroup != "" { parts = append(parts, usingGroup) } parts = append(parts, affinityValue) return strings.Join(parts, ":") } func setChannelAffinityContext(c *gin.Context, meta channelAffinityMeta) { c.Set(ginKeyChannelAffinityCacheKey, meta.CacheKey) c.Set(ginKeyChannelAffinityTTLSeconds, meta.TTLSeconds) c.Set(ginKeyChannelAffinityMeta, meta) } func getChannelAffinityContext(c *gin.Context) (string, int, bool) { keyAny, ok := c.Get(ginKeyChannelAffinityCacheKey) if !ok { return "", 0, false } key, ok := keyAny.(string) if !ok || key == "" { return "", 0, false } ttlAny, ok := c.Get(ginKeyChannelAffinityTTLSeconds) if !ok { return key, 0, true } ttlSeconds, _ := ttlAny.(int) return key, ttlSeconds, true } func getChannelAffinityMeta(c *gin.Context) (channelAffinityMeta, bool) { anyMeta, ok := c.Get(ginKeyChannelAffinityMeta) if !ok { return channelAffinityMeta{}, false } meta, ok := anyMeta.(channelAffinityMeta) if !ok { return channelAffinityMeta{}, false } return meta, true } func GetChannelAffinityStatsContext(c *gin.Context) (ChannelAffinityStatsContext, bool) { if c == nil { return ChannelAffinityStatsContext{}, false } meta, ok := getChannelAffinityMeta(c) if !ok { return ChannelAffinityStatsContext{}, false } ruleName := strings.TrimSpace(meta.RuleName) keyFp := strings.TrimSpace(meta.KeyFingerprint) usingGroup := strings.TrimSpace(meta.UsingGroup) if ruleName == "" || keyFp == "" { return ChannelAffinityStatsContext{}, false } ttlSeconds := int64(meta.TTLSeconds) if ttlSeconds <= 0 { return ChannelAffinityStatsContext{}, false } return ChannelAffinityStatsContext{ RuleName: ruleName, UsingGroup: usingGroup, KeyFingerprint: keyFp, TTLSeconds: ttlSeconds, }, true } func affinityFingerprint(s string) string { if s == "" { return "" } hex := common.Sha1([]byte(s)) if len(hex) >= 8 { return hex[:8] } return hex } func buildChannelAffinityKeyHint(s string) string { s = strings.TrimSpace(s) if s == "" { return "" } s = strings.ReplaceAll(s, "\n", " ") s = strings.ReplaceAll(s, "\r", " ") if len(s) <= 12 { return s } return s[:4] + "..." + s[len(s)-4:] } func GetPreferredChannelByAffinity(c *gin.Context, modelName string, usingGroup string) (int, bool) { setting := operation_setting.GetChannelAffinitySetting() if setting == nil || !setting.Enabled { return 0, false } path := "" if c != nil && c.Request != nil && c.Request.URL != nil { path = c.Request.URL.Path } userAgent := "" if c != nil && c.Request != nil { userAgent = c.Request.UserAgent() } for _, rule := range setting.Rules { if !matchAnyRegexCached(rule.ModelRegex, modelName) { continue } if len(rule.PathRegex) > 0 && !matchAnyRegexCached(rule.PathRegex, path) { continue } if len(rule.UserAgentInclude) > 0 && !matchAnyIncludeFold(rule.UserAgentInclude, userAgent) { continue } var affinityValue string var usedSource operation_setting.ChannelAffinityKeySource for _, src := range rule.KeySources { affinityValue = extractChannelAffinityValue(c, src) if affinityValue != "" { usedSource = src break } } if affinityValue == "" { continue } if rule.ValueRegex != "" && !matchAnyRegexCached([]string{rule.ValueRegex}, affinityValue) { continue } ttlSeconds := rule.TTLSeconds if ttlSeconds <= 0 { ttlSeconds = setting.DefaultTTLSeconds } cacheKeySuffix := buildChannelAffinityCacheKeySuffix(rule, usingGroup, affinityValue) cacheKeyFull := channelAffinityCacheNamespace + ":" + cacheKeySuffix setChannelAffinityContext(c, channelAffinityMeta{ CacheKey: cacheKeyFull, TTLSeconds: ttlSeconds, RuleName: rule.Name, SkipRetry: rule.SkipRetryOnFailure, KeySourceType: strings.TrimSpace(usedSource.Type), KeySourceKey: strings.TrimSpace(usedSource.Key), KeySourcePath: strings.TrimSpace(usedSource.Path), KeyHint: buildChannelAffinityKeyHint(affinityValue), KeyFingerprint: affinityFingerprint(affinityValue), UsingGroup: usingGroup, ModelName: modelName, RequestPath: path, }) cache := getChannelAffinityCache() channelID, found, err := cache.Get(cacheKeySuffix) if err != nil { common.SysError(fmt.Sprintf("channel affinity cache get failed: key=%s, err=%v", cacheKeyFull, err)) return 0, false } if found { return channelID, true } return 0, false } return 0, false } func ShouldSkipRetryAfterChannelAffinityFailure(c *gin.Context) bool { if c == nil { return false } v, ok := c.Get(ginKeyChannelAffinitySkipRetry) if !ok { return false } b, ok := v.(bool) if !ok { return false } return b } func MarkChannelAffinityUsed(c *gin.Context, selectedGroup string, channelID int) { if c == nil || channelID <= 0 { return } meta, ok := getChannelAffinityMeta(c) if !ok { return } c.Set(ginKeyChannelAffinitySkipRetry, meta.SkipRetry) info := map[string]interface{}{ "reason": meta.RuleName, "rule_name": meta.RuleName, "using_group": meta.UsingGroup, "selected_group": selectedGroup, "model": meta.ModelName, "request_path": meta.RequestPath, "channel_id": channelID, "key_source": meta.KeySourceType, "key_key": meta.KeySourceKey, "key_path": meta.KeySourcePath, "key_hint": meta.KeyHint, "key_fp": meta.KeyFingerprint, } c.Set(ginKeyChannelAffinityLogInfo, info) } func AppendChannelAffinityAdminInfo(c *gin.Context, adminInfo map[string]interface{}) { if c == nil || adminInfo == nil { return } anyInfo, ok := c.Get(ginKeyChannelAffinityLogInfo) if !ok || anyInfo == nil { return } adminInfo["channel_affinity"] = anyInfo } func RecordChannelAffinity(c *gin.Context, channelID int) { if channelID <= 0 { return } setting := operation_setting.GetChannelAffinitySetting() if setting == nil || !setting.Enabled { return } if setting.SwitchOnSuccess && c != nil { if successChannelID := c.GetInt("channel_id"); successChannelID > 0 { channelID = successChannelID } } cacheKey, ttlSeconds, ok := getChannelAffinityContext(c) if !ok { return } if ttlSeconds <= 0 { ttlSeconds = setting.DefaultTTLSeconds } if ttlSeconds <= 0 { ttlSeconds = 3600 } cache := getChannelAffinityCache() if err := cache.SetWithTTL(cacheKey, channelID, time.Duration(ttlSeconds)*time.Second); err != nil { common.SysError(fmt.Sprintf("channel affinity cache set failed: key=%s, err=%v", cacheKey, err)) } } type ChannelAffinityUsageCacheStats struct { RuleName string `json:"rule_name"` UsingGroup string `json:"using_group"` KeyFingerprint string `json:"key_fp"` CachedTokenRateMode string `json:"cached_token_rate_mode"` Hit int64 `json:"hit"` Total int64 `json:"total"` WindowSeconds int64 `json:"window_seconds"` PromptTokens int64 `json:"prompt_tokens"` CompletionTokens int64 `json:"completion_tokens"` TotalTokens int64 `json:"total_tokens"` CachedTokens int64 `json:"cached_tokens"` PromptCacheHitTokens int64 `json:"prompt_cache_hit_tokens"` LastSeenAt int64 `json:"last_seen_at"` } type ChannelAffinityUsageCacheCounters struct { CachedTokenRateMode string `json:"cached_token_rate_mode"` Hit int64 `json:"hit"` Total int64 `json:"total"` WindowSeconds int64 `json:"window_seconds"` PromptTokens int64 `json:"prompt_tokens"` CompletionTokens int64 `json:"completion_tokens"` TotalTokens int64 `json:"total_tokens"` CachedTokens int64 `json:"cached_tokens"` PromptCacheHitTokens int64 `json:"prompt_cache_hit_tokens"` LastSeenAt int64 `json:"last_seen_at"` } var channelAffinityUsageCacheStatsLocks [64]sync.Mutex // ObserveChannelAffinityUsageCacheByRelayFormat records usage cache stats with a stable rate mode derived from relay format. func ObserveChannelAffinityUsageCacheByRelayFormat(c *gin.Context, usage *dto.Usage, relayFormat types.RelayFormat) { ObserveChannelAffinityUsageCacheFromContext(c, usage, cachedTokenRateModeByRelayFormat(relayFormat)) } func ObserveChannelAffinityUsageCacheFromContext(c *gin.Context, usage *dto.Usage, cachedTokenRateMode string) { statsCtx, ok := GetChannelAffinityStatsContext(c) if !ok { return } observeChannelAffinityUsageCache(statsCtx, usage, cachedTokenRateMode) } func GetChannelAffinityUsageCacheStats(ruleName, usingGroup, keyFp string) ChannelAffinityUsageCacheStats { ruleName = strings.TrimSpace(ruleName) usingGroup = strings.TrimSpace(usingGroup) keyFp = strings.TrimSpace(keyFp) entryKey := channelAffinityUsageCacheEntryKey(ruleName, usingGroup, keyFp) if entryKey == "" { return ChannelAffinityUsageCacheStats{ RuleName: ruleName, UsingGroup: usingGroup, KeyFingerprint: keyFp, } } cache := getChannelAffinityUsageCacheStatsCache() v, found, err := cache.Get(entryKey) if err != nil || !found { return ChannelAffinityUsageCacheStats{ RuleName: ruleName, UsingGroup: usingGroup, KeyFingerprint: keyFp, } } return ChannelAffinityUsageCacheStats{ CachedTokenRateMode: v.CachedTokenRateMode, RuleName: ruleName, UsingGroup: usingGroup, KeyFingerprint: keyFp, Hit: v.Hit, Total: v.Total, WindowSeconds: v.WindowSeconds, PromptTokens: v.PromptTokens, CompletionTokens: v.CompletionTokens, TotalTokens: v.TotalTokens, CachedTokens: v.CachedTokens, PromptCacheHitTokens: v.PromptCacheHitTokens, LastSeenAt: v.LastSeenAt, } } func observeChannelAffinityUsageCache(statsCtx ChannelAffinityStatsContext, usage *dto.Usage, cachedTokenRateMode string) { entryKey := channelAffinityUsageCacheEntryKey(statsCtx.RuleName, statsCtx.UsingGroup, statsCtx.KeyFingerprint) if entryKey == "" { return } windowSeconds := statsCtx.TTLSeconds if windowSeconds <= 0 { return } cache := getChannelAffinityUsageCacheStatsCache() ttl := time.Duration(windowSeconds) * time.Second lock := channelAffinityUsageCacheStatsLock(entryKey) lock.Lock() defer lock.Unlock() prev, found, err := cache.Get(entryKey) if err != nil { return } next := prev if !found { next = ChannelAffinityUsageCacheCounters{} } currentMode := normalizeCachedTokenRateMode(cachedTokenRateMode) if currentMode != "" { if next.CachedTokenRateMode == "" { next.CachedTokenRateMode = currentMode } else if next.CachedTokenRateMode != currentMode && next.CachedTokenRateMode != cacheTokenRateModeMixed { next.CachedTokenRateMode = cacheTokenRateModeMixed } } next.Total++ hit, cachedTokens, promptCacheHitTokens := usageCacheSignals(usage) if hit { next.Hit++ } next.WindowSeconds = windowSeconds next.LastSeenAt = time.Now().Unix() next.CachedTokens += cachedTokens next.PromptCacheHitTokens += promptCacheHitTokens next.PromptTokens += int64(usagePromptTokens(usage)) next.CompletionTokens += int64(usageCompletionTokens(usage)) next.TotalTokens += int64(usageTotalTokens(usage)) _ = cache.SetWithTTL(entryKey, next, ttl) } func normalizeCachedTokenRateMode(mode string) string { switch mode { case cacheTokenRateModeCachedOverPrompt: return cacheTokenRateModeCachedOverPrompt case cacheTokenRateModeCachedOverPromptPlusCached: return cacheTokenRateModeCachedOverPromptPlusCached case cacheTokenRateModeMixed: return cacheTokenRateModeMixed default: return "" } } func cachedTokenRateModeByRelayFormat(relayFormat types.RelayFormat) string { switch relayFormat { case types.RelayFormatOpenAI, types.RelayFormatOpenAIResponses, types.RelayFormatOpenAIResponsesCompaction: return cacheTokenRateModeCachedOverPrompt case types.RelayFormatClaude: return cacheTokenRateModeCachedOverPromptPlusCached default: return "" } } func channelAffinityUsageCacheEntryKey(ruleName, usingGroup, keyFp string) string { ruleName = strings.TrimSpace(ruleName) usingGroup = strings.TrimSpace(usingGroup) keyFp = strings.TrimSpace(keyFp) if ruleName == "" || keyFp == "" { return "" } return ruleName + "\n" + usingGroup + "\n" + keyFp } func usageCacheSignals(usage *dto.Usage) (hit bool, cachedTokens int64, promptCacheHitTokens int64) { if usage == nil { return false, 0, 0 } cached := int64(0) if usage.PromptTokensDetails.CachedTokens > 0 { cached = int64(usage.PromptTokensDetails.CachedTokens) } else if usage.InputTokensDetails != nil && usage.InputTokensDetails.CachedTokens > 0 { cached = int64(usage.InputTokensDetails.CachedTokens) } pcht := int64(0) if usage.PromptCacheHitTokens > 0 { pcht = int64(usage.PromptCacheHitTokens) } return cached > 0 || pcht > 0, cached, pcht } func usagePromptTokens(usage *dto.Usage) int { if usage == nil { return 0 } if usage.PromptTokens > 0 { return usage.PromptTokens } return usage.InputTokens } func usageCompletionTokens(usage *dto.Usage) int { if usage == nil { return 0 } if usage.CompletionTokens > 0 { return usage.CompletionTokens } return usage.OutputTokens } func usageTotalTokens(usage *dto.Usage) int { if usage == nil { return 0 } if usage.TotalTokens > 0 { return usage.TotalTokens } pt := usagePromptTokens(usage) ct := usageCompletionTokens(usage) if pt > 0 || ct > 0 { return pt + ct } return 0 } func getChannelAffinityUsageCacheStatsCache() *cachex.HybridCache[ChannelAffinityUsageCacheCounters] { channelAffinityUsageCacheStatsOnce.Do(func() { setting := operation_setting.GetChannelAffinitySetting() capacity := 100_000 defaultTTLSeconds := 3600 if setting != nil { if setting.MaxEntries > 0 { capacity = setting.MaxEntries } if setting.DefaultTTLSeconds > 0 { defaultTTLSeconds = setting.DefaultTTLSeconds } } channelAffinityUsageCacheStatsCache = cachex.NewHybridCache[ChannelAffinityUsageCacheCounters](cachex.HybridCacheConfig[ChannelAffinityUsageCacheCounters]{ Namespace: cachex.Namespace(channelAffinityUsageCacheStatsNamespace), Redis: common.RDB, RedisEnabled: func() bool { return common.RedisEnabled && common.RDB != nil }, RedisCodec: cachex.JSONCodec[ChannelAffinityUsageCacheCounters]{}, Memory: func() *hot.HotCache[string, ChannelAffinityUsageCacheCounters] { return hot.NewHotCache[string, ChannelAffinityUsageCacheCounters](hot.LRU, capacity). WithTTL(time.Duration(defaultTTLSeconds) * time.Second). WithJanitor(). Build() }, }) }) return channelAffinityUsageCacheStatsCache } func channelAffinityUsageCacheStatsLock(key string) *sync.Mutex { h := fnv.New32a() _, _ = h.Write([]byte(key)) idx := h.Sum32() % uint32(len(channelAffinityUsageCacheStatsLocks)) return &channelAffinityUsageCacheStatsLocks[idx] }