notify-limit.go 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. package service
  2. import (
  3. "fmt"
  4. "one-api/common"
  5. "one-api/constant"
  6. "strconv"
  7. "sync"
  8. "time"
  9. )
  10. // notifyLimitStore is used for in-memory rate limiting when Redis is disabled
  11. var (
  12. notifyLimitStore sync.Map
  13. cleanupOnce sync.Once
  14. )
  15. type limitCount struct {
  16. Count int
  17. Timestamp time.Time
  18. }
  19. func getDuration() time.Duration {
  20. minute := constant.NotificationLimitDurationMinute
  21. return time.Duration(minute) * time.Minute
  22. }
  23. // startCleanupTask starts a background task to clean up expired entries
  24. func startCleanupTask() {
  25. go func() {
  26. for {
  27. time.Sleep(time.Hour)
  28. now := time.Now()
  29. notifyLimitStore.Range(func(key, value interface{}) bool {
  30. if limit, ok := value.(limitCount); ok {
  31. if now.Sub(limit.Timestamp) >= getDuration() {
  32. notifyLimitStore.Delete(key)
  33. }
  34. }
  35. return true
  36. })
  37. }
  38. }()
  39. }
  40. // CheckNotificationLimit checks if the user has exceeded their notification limit
  41. // Returns true if the user can send notification, false if limit exceeded
  42. func CheckNotificationLimit(userId int, notifyType string) (bool, error) {
  43. if common.RedisEnabled {
  44. return checkRedisLimit(userId, notifyType)
  45. }
  46. return checkMemoryLimit(userId, notifyType)
  47. }
  48. func checkRedisLimit(userId int, notifyType string) (bool, error) {
  49. key := fmt.Sprintf("notify_limit:%d:%s:%s", userId, notifyType, time.Now().Format("2006010215"))
  50. // Get current count
  51. count, err := common.RedisGet(key)
  52. if err != nil && err.Error() != "redis: nil" {
  53. return false, fmt.Errorf("failed to get notification count: %w", err)
  54. }
  55. // If key doesn't exist, initialize it
  56. if count == "" {
  57. err = common.RedisSet(key, "1", getDuration())
  58. return true, err
  59. }
  60. currentCount, _ := strconv.Atoi(count)
  61. limit := constant.DefaultNotifyHourlyLimit
  62. // Check if limit is already reached
  63. if currentCount >= limit {
  64. return false, nil
  65. }
  66. // Only increment if under limit
  67. err = common.RedisIncr(key, 1)
  68. if err != nil {
  69. return false, fmt.Errorf("failed to increment notification count: %w", err)
  70. }
  71. return true, nil
  72. }
  73. func checkMemoryLimit(userId int, notifyType string) (bool, error) {
  74. // Ensure cleanup task is started
  75. cleanupOnce.Do(startCleanupTask)
  76. key := fmt.Sprintf("%d:%s:%s", userId, notifyType, time.Now().Format("2006010215"))
  77. now := time.Now()
  78. // Get current limit count or initialize new one
  79. var currentLimit limitCount
  80. if value, ok := notifyLimitStore.Load(key); ok {
  81. currentLimit = value.(limitCount)
  82. // Check if the entry has expired
  83. if now.Sub(currentLimit.Timestamp) >= getDuration() {
  84. currentLimit = limitCount{Count: 0, Timestamp: now}
  85. }
  86. } else {
  87. currentLimit = limitCount{Count: 0, Timestamp: now}
  88. }
  89. // Increment count
  90. currentLimit.Count++
  91. // Check against limits
  92. limit := constant.DefaultNotifyHourlyLimit
  93. // Store updated count
  94. notifyLimitStore.Store(key, currentLimit)
  95. return currentLimit.Count <= limit, nil
  96. }