package service import ( "fmt" "regexp" "strconv" "strings" "sync" "time" "github.com/QuantumNous/new-api/common" "github.com/QuantumNous/new-api/pkg/cachex" "github.com/QuantumNous/new-api/setting/operation_setting" "github.com/gin-gonic/gin" "github.com/samber/hot" "github.com/tidwall/gjson" ) const ( ginKeyChannelAffinityCacheKey = "channel_affinity_cache_key" ginKeyChannelAffinityTTLSeconds = "channel_affinity_ttl_seconds" ginKeyChannelAffinityMeta = "channel_affinity_meta" ginKeyChannelAffinityLogInfo = "channel_affinity_log_info" channelAffinityCacheNamespace = "new-api:channel_affinity:v1" ) var ( channelAffinityCacheOnce sync.Once channelAffinityCache *cachex.HybridCache[int] channelAffinityRegexCache sync.Map // map[string]*regexp.Regexp ) type channelAffinityMeta struct { CacheKey string TTLSeconds int RuleName string KeySourceType string KeySourceKey string KeySourcePath string KeyFingerprint string UsingGroup string ModelName string RequestPath string } type ChannelAffinityCacheStats struct { Enabled bool `json:"enabled"` Total int `json:"total"` Unknown int `json:"unknown"` ByRuleName map[string]int `json:"by_rule_name"` CacheCapacity int `json:"cache_capacity"` CacheAlgo string `json:"cache_algo"` } func getChannelAffinityCache() *cachex.HybridCache[int] { channelAffinityCacheOnce.Do(func() { setting := operation_setting.GetChannelAffinitySetting() capacity := setting.MaxEntries if capacity <= 0 { capacity = 100_000 } defaultTTLSeconds := setting.DefaultTTLSeconds if defaultTTLSeconds <= 0 { defaultTTLSeconds = 3600 } channelAffinityCache = cachex.NewHybridCache[int](cachex.HybridCacheConfig[int]{ Namespace: cachex.Namespace(channelAffinityCacheNamespace), Redis: common.RDB, RedisEnabled: func() bool { return common.RedisEnabled && common.RDB != nil }, RedisCodec: cachex.IntCodec{}, Memory: func() *hot.HotCache[string, int] { return hot.NewHotCache[string, int](hot.LRU, capacity). WithTTL(time.Duration(defaultTTLSeconds) * time.Second). WithJanitor(). Build() }, }) }) return channelAffinityCache } func GetChannelAffinityCacheStats() ChannelAffinityCacheStats { setting := operation_setting.GetChannelAffinitySetting() if setting == nil { return ChannelAffinityCacheStats{ Enabled: false, Total: 0, Unknown: 0, ByRuleName: map[string]int{}, } } cache := getChannelAffinityCache() mainCap, _ := cache.Capacity() mainAlgo, _ := cache.Algorithm() rules := setting.Rules ruleByName := make(map[string]operation_setting.ChannelAffinityRule, len(rules)) for _, r := range rules { name := strings.TrimSpace(r.Name) if name == "" { continue } if !r.IncludeRuleName { continue } ruleByName[name] = r } byRuleName := make(map[string]int, len(ruleByName)) for name := range ruleByName { byRuleName[name] = 0 } keys, err := cache.Keys() if err != nil { common.SysError(fmt.Sprintf("channel affinity cache list keys failed: err=%v", err)) keys = nil } total := len(keys) unknown := 0 for _, k := range keys { prefix := channelAffinityCacheNamespace + ":" if !strings.HasPrefix(k, prefix) { unknown++ continue } rest := strings.TrimPrefix(k, prefix) parts := strings.Split(rest, ":") if len(parts) < 2 { unknown++ continue } ruleName := parts[0] rule, ok := ruleByName[ruleName] if !ok { unknown++ continue } if rule.IncludeUsingGroup { if len(parts) < 3 { unknown++ continue } } byRuleName[ruleName]++ } return ChannelAffinityCacheStats{ Enabled: setting.Enabled, Total: total, Unknown: unknown, ByRuleName: byRuleName, CacheCapacity: mainCap, CacheAlgo: mainAlgo, } } func ClearChannelAffinityCacheAll() int { cache := getChannelAffinityCache() keys, err := cache.Keys() if err != nil { common.SysError(fmt.Sprintf("channel affinity cache list keys failed: err=%v", err)) keys = nil } if len(keys) > 0 { if _, err := cache.DeleteMany(keys); err != nil { common.SysError(fmt.Sprintf("channel affinity cache delete many failed: err=%v", err)) } } return len(keys) } func ClearChannelAffinityCacheByRuleName(ruleName string) (int, error) { ruleName = strings.TrimSpace(ruleName) if ruleName == "" { return 0, fmt.Errorf("rule_name 不能为空") } setting := operation_setting.GetChannelAffinitySetting() if setting == nil { return 0, fmt.Errorf("channel_affinity_setting 未初始化") } var matchedRule *operation_setting.ChannelAffinityRule for i := range setting.Rules { r := &setting.Rules[i] if strings.TrimSpace(r.Name) != ruleName { continue } matchedRule = r break } if matchedRule == nil { return 0, fmt.Errorf("未知规则名称") } if !matchedRule.IncludeRuleName { return 0, fmt.Errorf("该规则未启用 include_rule_name,无法按规则清空缓存") } cache := getChannelAffinityCache() deleted, err := cache.DeleteByPrefix(ruleName) if err != nil { return 0, err } return deleted, nil } func matchAnyRegexCached(patterns []string, s string) bool { if len(patterns) == 0 || s == "" { return false } for _, pattern := range patterns { if pattern == "" { continue } re, ok := channelAffinityRegexCache.Load(pattern) if !ok { compiled, err := regexp.Compile(pattern) if err != nil { continue } re = compiled channelAffinityRegexCache.Store(pattern, re) } if re.(*regexp.Regexp).MatchString(s) { return true } } return false } func matchAnyIncludeFold(patterns []string, s string) bool { if len(patterns) == 0 || s == "" { return false } sLower := strings.ToLower(s) for _, p := range patterns { p = strings.TrimSpace(p) if p == "" { continue } if strings.Contains(sLower, strings.ToLower(p)) { return true } } return false } func extractChannelAffinityValue(c *gin.Context, src operation_setting.ChannelAffinityKeySource) string { switch src.Type { case "context_int": if src.Key == "" { return "" } v := c.GetInt(src.Key) if v <= 0 { return "" } return strconv.Itoa(v) case "context_string": if src.Key == "" { return "" } return strings.TrimSpace(c.GetString(src.Key)) case "gjson": if src.Path == "" { return "" } body, err := common.GetRequestBody(c) if err != nil || len(body) == 0 { return "" } res := gjson.GetBytes(body, src.Path) if !res.Exists() { return "" } switch res.Type { case gjson.String, gjson.Number, gjson.True, gjson.False: return strings.TrimSpace(res.String()) default: return strings.TrimSpace(res.Raw) } default: return "" } } func buildChannelAffinityCacheKeySuffix(rule operation_setting.ChannelAffinityRule, usingGroup string, affinityValue string) string { parts := make([]string, 0, 3) if rule.IncludeRuleName && rule.Name != "" { parts = append(parts, rule.Name) } if rule.IncludeUsingGroup && usingGroup != "" { parts = append(parts, usingGroup) } parts = append(parts, affinityValue) return strings.Join(parts, ":") } func setChannelAffinityContext(c *gin.Context, meta channelAffinityMeta) { c.Set(ginKeyChannelAffinityCacheKey, meta.CacheKey) c.Set(ginKeyChannelAffinityTTLSeconds, meta.TTLSeconds) c.Set(ginKeyChannelAffinityMeta, meta) } func getChannelAffinityContext(c *gin.Context) (string, int, bool) { keyAny, ok := c.Get(ginKeyChannelAffinityCacheKey) if !ok { return "", 0, false } key, ok := keyAny.(string) if !ok || key == "" { return "", 0, false } ttlAny, ok := c.Get(ginKeyChannelAffinityTTLSeconds) if !ok { return key, 0, true } ttlSeconds, _ := ttlAny.(int) return key, ttlSeconds, true } func getChannelAffinityMeta(c *gin.Context) (channelAffinityMeta, bool) { anyMeta, ok := c.Get(ginKeyChannelAffinityMeta) if !ok { return channelAffinityMeta{}, false } meta, ok := anyMeta.(channelAffinityMeta) if !ok { return channelAffinityMeta{}, false } return meta, true } func affinityFingerprint(s string) string { if s == "" { return "" } hex := common.Sha1([]byte(s)) if len(hex) >= 8 { return hex[:8] } return hex } func GetPreferredChannelByAffinity(c *gin.Context, modelName string, usingGroup string) (int, bool) { setting := operation_setting.GetChannelAffinitySetting() if setting == nil || !setting.Enabled { return 0, false } path := "" if c != nil && c.Request != nil && c.Request.URL != nil { path = c.Request.URL.Path } userAgent := "" if c != nil && c.Request != nil { userAgent = c.Request.UserAgent() } for _, rule := range setting.Rules { if !matchAnyRegexCached(rule.ModelRegex, modelName) { continue } if len(rule.PathRegex) > 0 && !matchAnyRegexCached(rule.PathRegex, path) { continue } if len(rule.UserAgentInclude) > 0 && !matchAnyIncludeFold(rule.UserAgentInclude, userAgent) { continue } var affinityValue string var usedSource operation_setting.ChannelAffinityKeySource for _, src := range rule.KeySources { affinityValue = extractChannelAffinityValue(c, src) if affinityValue != "" { usedSource = src break } } if affinityValue == "" { continue } if rule.ValueRegex != "" && !matchAnyRegexCached([]string{rule.ValueRegex}, affinityValue) { continue } ttlSeconds := rule.TTLSeconds if ttlSeconds <= 0 { ttlSeconds = setting.DefaultTTLSeconds } cacheKeySuffix := buildChannelAffinityCacheKeySuffix(rule, usingGroup, affinityValue) cacheKeyFull := channelAffinityCacheNamespace + ":" + cacheKeySuffix setChannelAffinityContext(c, channelAffinityMeta{ CacheKey: cacheKeyFull, TTLSeconds: ttlSeconds, RuleName: rule.Name, KeySourceType: strings.TrimSpace(usedSource.Type), KeySourceKey: strings.TrimSpace(usedSource.Key), KeySourcePath: strings.TrimSpace(usedSource.Path), KeyFingerprint: affinityFingerprint(affinityValue), UsingGroup: usingGroup, ModelName: modelName, RequestPath: path, }) cache := getChannelAffinityCache() channelID, found, err := cache.Get(cacheKeySuffix) if err != nil { common.SysError(fmt.Sprintf("channel affinity cache get failed: key=%s, err=%v", cacheKeyFull, err)) return 0, false } if found { return channelID, true } return 0, false } return 0, false } func MarkChannelAffinityUsed(c *gin.Context, selectedGroup string, channelID int) { if c == nil || channelID <= 0 { return } meta, ok := getChannelAffinityMeta(c) if !ok { return } info := map[string]interface{}{ "reason": meta.RuleName, "rule_name": meta.RuleName, "using_group": meta.UsingGroup, "selected_group": selectedGroup, "model": meta.ModelName, "request_path": meta.RequestPath, "channel_id": channelID, "key_source": meta.KeySourceType, "key_key": meta.KeySourceKey, "key_path": meta.KeySourcePath, "key_fp": meta.KeyFingerprint, } c.Set(ginKeyChannelAffinityLogInfo, info) } func AppendChannelAffinityAdminInfo(c *gin.Context, adminInfo map[string]interface{}) { if c == nil || adminInfo == nil { return } anyInfo, ok := c.Get(ginKeyChannelAffinityLogInfo) if !ok || anyInfo == nil { return } adminInfo["channel_affinity"] = anyInfo } func RecordChannelAffinity(c *gin.Context, channelID int) { if channelID <= 0 { return } setting := operation_setting.GetChannelAffinitySetting() if setting == nil || !setting.Enabled { return } if setting.SwitchOnSuccess && c != nil { if successChannelID := c.GetInt("channel_id"); successChannelID > 0 { channelID = successChannelID } } cacheKey, ttlSeconds, ok := getChannelAffinityContext(c) if !ok { return } if ttlSeconds <= 0 { ttlSeconds = setting.DefaultTTLSeconds } if ttlSeconds <= 0 { ttlSeconds = 3600 } cache := getChannelAffinityCache() if err := cache.SetWithTTL(cacheKey, channelID, time.Duration(ttlSeconds)*time.Second); err != nil { common.SysError(fmt.Sprintf("channel affinity cache set failed: key=%s, err=%v", cacheKey, err)) } }