Quellcode durchsuchen

feat: Implement notification rate limiting mechanism

- Add in-memory and Redis-based notification rate limiting
- Create configurable hourly notification limits
- Implement notification limit checking for user notifications
- Add environment variables for customizing notification limits
1808837298@qq.com vor 1 Jahr
Ursprung
Commit
56f6b2ab56
4 geänderte Dateien mit 138 neuen und 1 gelöschten Zeilen
  1. 7 0
      common/model-ratio.go
  2. 3 0
      constant/env.go
  3. 116 0
      service/notify-limit.go
  4. 12 1
      service/user_notify.go

+ 7 - 0
common/model-ratio.go

@@ -356,6 +356,13 @@ func CompletionRatio2JSONString() string {
 	return string(jsonBytes)
 }
 
+func UpdateCompletionRatioByJSONString(jsonStr string) error {
+	CompletionRatioMutex.Lock()
+	defer CompletionRatioMutex.Unlock()
+	CompletionRatio = make(map[string]float64)
+	return json.Unmarshal([]byte(jsonStr), &CompletionRatio)
+}
+
 func GetCompletionRatio(name string) float64 {
 	GetCompletionRatioMap()
 

+ 3 - 0
constant/env.go

@@ -29,6 +29,9 @@ var GeminiModelMap = map[string]string{
 
 var GeminiVisionMaxImageNum = common.GetEnvOrDefault("GEMINI_VISION_MAX_IMAGE_NUM", 16)
 
+var DefaultNotifyHourlyLimit = common.GetEnvOrDefault("NOTIFY_HOURLY_LIMIT", 2)
+var NotificationLimitDurationMinute = common.GetEnvOrDefault("NOTIFICATION_LIMIT_DURATION_MINUTE", 10)
+
 func InitEnv() {
 	modelVersionMapStr := strings.TrimSpace(os.Getenv("GEMINI_MODEL_MAP"))
 	if modelVersionMapStr == "" {

+ 116 - 0
service/notify-limit.go

@@ -0,0 +1,116 @@
+package service
+
+import (
+	"fmt"
+	"one-api/common"
+	"one-api/constant"
+	"strconv"
+	"sync"
+	"time"
+)
+
+// notifyLimitStore is used for in-memory rate limiting when Redis is disabled
+var (
+	notifyLimitStore sync.Map
+	cleanupOnce      sync.Once
+)
+
+type limitCount struct {
+	Count     int
+	Timestamp time.Time
+}
+
+func getDuration() time.Duration {
+	minute := constant.NotificationLimitDurationMinute
+	return time.Duration(minute) * time.Minute
+}
+
+// startCleanupTask starts a background task to clean up expired entries
+func startCleanupTask() {
+	go func() {
+		for {
+			time.Sleep(time.Hour)
+			now := time.Now()
+			notifyLimitStore.Range(func(key, value interface{}) bool {
+				if limit, ok := value.(limitCount); ok {
+					if now.Sub(limit.Timestamp) >= getDuration() {
+						notifyLimitStore.Delete(key)
+					}
+				}
+				return true
+			})
+		}
+	}()
+}
+
+// CheckNotificationLimit checks if the user has exceeded their notification limit
+// Returns true if the user can send notification, false if limit exceeded
+func CheckNotificationLimit(userId int, notifyType string) (bool, error) {
+	if common.RedisEnabled {
+		return checkRedisLimit(userId, notifyType)
+	}
+	return checkMemoryLimit(userId, notifyType)
+}
+
+func checkRedisLimit(userId int, notifyType string) (bool, error) {
+	key := fmt.Sprintf("notify_limit:%d:%s:%s", userId, notifyType, time.Now().Format("2006010215"))
+
+	// Get current count
+	count, err := common.RedisGet(key)
+	if err != nil && err.Error() != "redis: nil" {
+		return false, fmt.Errorf("failed to get notification count: %w", err)
+	}
+
+	// If key doesn't exist, initialize it
+	if count == "" {
+		err = common.RedisSet(key, "1", getDuration())
+		return true, err
+	}
+
+	currentCount, _ := strconv.Atoi(count)
+	limit := constant.DefaultNotifyHourlyLimit
+
+	// Check if limit is already reached
+	if currentCount >= limit {
+		return false, nil
+	}
+
+	// Only increment if under limit
+	err = common.RedisIncr(key, 1)
+	if err != nil {
+		return false, fmt.Errorf("failed to increment notification count: %w", err)
+	}
+
+	return true, nil
+}
+
+func checkMemoryLimit(userId int, notifyType string) (bool, error) {
+	// Ensure cleanup task is started
+	cleanupOnce.Do(startCleanupTask)
+
+	key := fmt.Sprintf("%d:%s:%s", userId, notifyType, time.Now().Format("2006010215"))
+	now := time.Now()
+
+	// Get current limit count or initialize new one
+	var currentLimit limitCount
+	if value, ok := notifyLimitStore.Load(key); ok {
+		currentLimit = value.(limitCount)
+		// Check if the entry has expired
+		if now.Sub(currentLimit.Timestamp) >= getDuration() {
+			currentLimit = limitCount{Count: 0, Timestamp: now}
+		}
+	} else {
+		currentLimit = limitCount{Count: 0, Timestamp: now}
+	}
+
+	// Increment count
+	currentLimit.Count++
+
+	// Check against limits
+	limit := constant.DefaultNotifyHourlyLimit
+
+	// Store updated count
+	notifyLimitStore.Store(key, currentLimit)
+
+	return currentLimit.Count <= limit, nil
+}

+ 12 - 1
service/user_notify.go

@@ -25,6 +25,17 @@ func NotifyUser(user *model.UserCache, data dto.Notify) error {
 	if !ok {
 		notifyType = constant.NotifyTypeEmail
 	}
+
+	// Check notification limit
+	canSend, err := CheckNotificationLimit(user.Id, data.Type)
+	if err != nil {
+		common.SysError(fmt.Sprintf("failed to check notification limit: %s", err.Error()))
+		return err
+	}
+	if !canSend {
+		return fmt.Errorf("notification limit exceeded for user %d with type %s", user.Id, notifyType)
+	}
+
 	switch notifyType {
 	case constant.NotifyTypeEmail:
 		userEmail := user.Email
@@ -46,7 +57,7 @@ func NotifyUser(user *model.UserCache, data dto.Notify) error {
 		// TODO: 实现webhook通知
 		_ = webhookURL // 临时处理未使用警告,等待webhook实现
 	}
-	return nil // 添加缺失的return
+	return nil
 }
 
 func sendEmailNotify(userEmail string, data dto.Notify) error {