model-rate-limit.go 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181
  1. package middleware
  2. import (
  3. "context"
  4. "fmt"
  5. "net/http"
  6. "one-api/common"
  7. "one-api/common/limiter"
  8. "one-api/setting"
  9. "strconv"
  10. "time"
  11. "github.com/gin-gonic/gin"
  12. "github.com/go-redis/redis/v8"
  13. )
  14. const (
  15. ModelRequestRateLimitCountMark = "MRRL"
  16. ModelRequestRateLimitSuccessCountMark = "MRRLS"
  17. )
  18. func checkRedisRateLimit(ctx context.Context, rdb *redis.Client, key string, maxCount int, duration int64) (bool, error) {
  19. if maxCount == 0 {
  20. return true, nil
  21. }
  22. length, err := rdb.LLen(ctx, key).Result()
  23. if err != nil {
  24. return false, err
  25. }
  26. if length < int64(maxCount) {
  27. return true, nil
  28. }
  29. oldTimeStr, _ := rdb.LIndex(ctx, key, -1).Result()
  30. oldTime, err := time.Parse(timeFormat, oldTimeStr)
  31. if err != nil {
  32. return false, err
  33. }
  34. nowTimeStr := time.Now().Format(timeFormat)
  35. nowTime, err := time.Parse(timeFormat, nowTimeStr)
  36. if err != nil {
  37. return false, err
  38. }
  39. subTime := nowTime.Sub(oldTime).Seconds()
  40. if int64(subTime) < duration {
  41. rdb.Expire(ctx, key, time.Duration(setting.ModelRequestRateLimitDurationMinutes)*time.Minute)
  42. return false, nil
  43. }
  44. return true, nil
  45. }
  46. func recordRedisRequest(ctx context.Context, rdb *redis.Client, key string, maxCount int) {
  47. if maxCount == 0 {
  48. return
  49. }
  50. now := time.Now().Format(timeFormat)
  51. rdb.LPush(ctx, key, now)
  52. rdb.LTrim(ctx, key, 0, int64(maxCount-1))
  53. rdb.Expire(ctx, key, time.Duration(setting.ModelRequestRateLimitDurationMinutes)*time.Minute)
  54. }
  55. func redisRateLimitHandler(duration int64, totalMaxCount, successMaxCount int) gin.HandlerFunc {
  56. return func(c *gin.Context) {
  57. userId := strconv.Itoa(c.GetInt("id"))
  58. ctx := context.Background()
  59. rdb := common.RDB
  60. successKey := fmt.Sprintf("rateLimit:%s:%s", ModelRequestRateLimitSuccessCountMark, userId)
  61. allowed, err := checkRedisRateLimit(ctx, rdb, successKey, successMaxCount, duration)
  62. if err != nil {
  63. fmt.Println("检查成功请求数限制失败:", err.Error())
  64. abortWithOpenAiMessage(c, http.StatusInternalServerError, "rate_limit_check_failed")
  65. return
  66. }
  67. if !allowed {
  68. abortWithOpenAiMessage(c, http.StatusTooManyRequests, fmt.Sprintf("您已达到请求数限制:%d分钟内最多请求%d次", setting.ModelRequestRateLimitDurationMinutes, successMaxCount))
  69. return
  70. }
  71. totalKey := fmt.Sprintf("rateLimit:%s", userId)
  72. tb := limiter.New(ctx, rdb)
  73. allowed, err = tb.Allow(
  74. ctx,
  75. totalKey,
  76. limiter.WithCapacity(int64(totalMaxCount)*duration),
  77. limiter.WithRate(int64(totalMaxCount)),
  78. limiter.WithRequested(duration),
  79. )
  80. if err != nil {
  81. fmt.Println("检查总请求数限制失败:", err.Error())
  82. abortWithOpenAiMessage(c, http.StatusInternalServerError, "rate_limit_check_failed")
  83. return
  84. }
  85. if !allowed {
  86. abortWithOpenAiMessage(c, http.StatusTooManyRequests, fmt.Sprintf("您已达到总请求数限制:%d分钟内最多请求%d次,包括失败次数,请检查您的请求是否正确", setting.ModelRequestRateLimitDurationMinutes, totalMaxCount))
  87. }
  88. c.Next()
  89. if c.Writer.Status() < 400 {
  90. recordRedisRequest(ctx, rdb, successKey, successMaxCount)
  91. }
  92. }
  93. }
  94. func memoryRateLimitHandler(duration int64, totalMaxCount, successMaxCount int) gin.HandlerFunc {
  95. inMemoryRateLimiter.Init(time.Duration(setting.ModelRequestRateLimitDurationMinutes) * time.Minute)
  96. return func(c *gin.Context) {
  97. userId := strconv.Itoa(c.GetInt("id"))
  98. totalKey := ModelRequestRateLimitCountMark + userId
  99. successKey := ModelRequestRateLimitSuccessCountMark + userId
  100. if totalMaxCount > 0 && !inMemoryRateLimiter.Request(totalKey, totalMaxCount, duration) {
  101. c.Status(http.StatusTooManyRequests)
  102. c.Abort()
  103. return
  104. }
  105. checkKey := successKey + "_check"
  106. if !inMemoryRateLimiter.Request(checkKey, successMaxCount, duration) {
  107. c.Status(http.StatusTooManyRequests)
  108. c.Abort()
  109. return
  110. }
  111. c.Next()
  112. if c.Writer.Status() < 400 {
  113. inMemoryRateLimiter.Request(successKey, successMaxCount, duration)
  114. }
  115. }
  116. }
  117. func ModelRequestRateLimit() func(c *gin.Context) {
  118. return func(c *gin.Context) {
  119. if !setting.ModelRequestRateLimitEnabled {
  120. c.Next()
  121. return
  122. }
  123. duration := int64(setting.ModelRequestRateLimitDurationMinutes * 60)
  124. group := c.GetString("token_group")
  125. if group == "" {
  126. group = c.GetString("group")
  127. }
  128. if group == "" {
  129. group = "default"
  130. }
  131. finalTotalCount := setting.ModelRequestRateLimitCount
  132. finalSuccessCount := setting.ModelRequestRateLimitSuccessCount
  133. foundGroupLimit := false
  134. groupTotalCount, groupSuccessCount, found := setting.GetGroupRateLimit(group)
  135. if found {
  136. finalTotalCount = groupTotalCount
  137. finalSuccessCount = groupSuccessCount
  138. foundGroupLimit = true
  139. common.LogWarn(c.Request.Context(), fmt.Sprintf("Using rate limit for group '%s': total=%d, success=%d", group, finalTotalCount, finalSuccessCount))
  140. }
  141. if !foundGroupLimit {
  142. common.LogInfo(c.Request.Context(), fmt.Sprintf("No specific rate limit found for group '%s', using global limits: total=%d, success=%d", group, finalTotalCount, finalSuccessCount))
  143. }
  144. if common.RedisEnabled {
  145. redisRateLimitHandler(duration, finalTotalCount, finalSuccessCount)(c)
  146. } else {
  147. memoryRateLimitHandler(duration, finalTotalCount, finalSuccessCount)(c)
  148. }
  149. }
  150. }