model-rate-limit.go 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172
  1. package middleware
  2. import (
  3. "context"
  4. "fmt"
  5. "net/http"
  6. "one-api/common"
  7. "one-api/setting"
  8. "strconv"
  9. "time"
  10. "github.com/gin-gonic/gin"
  11. "github.com/go-redis/redis/v8"
  12. )
  13. const (
  14. ModelRequestRateLimitCountMark = "MRRL"
  15. ModelRequestRateLimitSuccessCountMark = "MRRLS"
  16. )
  17. // 检查Redis中的请求限制
  18. func checkRedisRateLimit(ctx context.Context, rdb *redis.Client, key string, maxCount int, duration int64) (bool, error) {
  19. // 如果maxCount为0,表示不限制
  20. if maxCount == 0 {
  21. return true, nil
  22. }
  23. // 获取当前计数
  24. length, err := rdb.LLen(ctx, key).Result()
  25. if err != nil {
  26. return false, err
  27. }
  28. // 如果未达到限制,允许请求
  29. if length < int64(maxCount) {
  30. return true, nil
  31. }
  32. // 检查时间窗口
  33. oldTimeStr, _ := rdb.LIndex(ctx, key, -1).Result()
  34. oldTime, err := time.Parse(timeFormat, oldTimeStr)
  35. if err != nil {
  36. return false, err
  37. }
  38. nowTimeStr := time.Now().Format(timeFormat)
  39. nowTime, err := time.Parse(timeFormat, nowTimeStr)
  40. if err != nil {
  41. return false, err
  42. }
  43. // 如果在时间窗口内已达到限制,拒绝请求
  44. subTime := nowTime.Sub(oldTime).Seconds()
  45. if int64(subTime) < duration {
  46. rdb.Expire(ctx, key, common.RateLimitKeyExpirationDuration)
  47. return false, nil
  48. }
  49. return true, nil
  50. }
  51. // 记录Redis请求
  52. func recordRedisRequest(ctx context.Context, rdb *redis.Client, key string, maxCount int) {
  53. // 如果maxCount为0,不记录请求
  54. if maxCount == 0 {
  55. return
  56. }
  57. now := time.Now().Format(timeFormat)
  58. rdb.LPush(ctx, key, now)
  59. rdb.LTrim(ctx, key, 0, int64(maxCount-1))
  60. rdb.Expire(ctx, key, common.RateLimitKeyExpirationDuration)
  61. }
  62. // Redis限流处理器
  63. func redisRateLimitHandler(duration int64, totalMaxCount, successMaxCount int) gin.HandlerFunc {
  64. return func(c *gin.Context) {
  65. userId := strconv.Itoa(c.GetInt("id"))
  66. ctx := context.Background()
  67. rdb := common.RDB
  68. // 1. 检查总请求数限制(当totalMaxCount为0时会自动跳过)
  69. totalKey := fmt.Sprintf("rateLimit:%s:%s", ModelRequestRateLimitCountMark, userId)
  70. allowed, err := checkRedisRateLimit(ctx, rdb, totalKey, totalMaxCount, duration)
  71. if err != nil {
  72. fmt.Println("检查总请求数限制失败:", err.Error())
  73. abortWithOpenAiMessage(c, http.StatusInternalServerError, "rate_limit_check_failed")
  74. return
  75. }
  76. if !allowed {
  77. abortWithOpenAiMessage(c, http.StatusTooManyRequests, fmt.Sprintf("您已达到总请求数限制:%d分钟内最多请求%d次,包括失败次数,请检查您的请求是否正确", setting.ModelRequestRateLimitDurationMinutes, totalMaxCount))
  78. }
  79. // 2. 检查成功请求数限制
  80. successKey := fmt.Sprintf("rateLimit:%s:%s", ModelRequestRateLimitSuccessCountMark, userId)
  81. allowed, err = checkRedisRateLimit(ctx, rdb, successKey, successMaxCount, duration)
  82. if err != nil {
  83. fmt.Println("检查成功请求数限制失败:", err.Error())
  84. abortWithOpenAiMessage(c, http.StatusInternalServerError, "rate_limit_check_failed")
  85. return
  86. }
  87. if !allowed {
  88. abortWithOpenAiMessage(c, http.StatusTooManyRequests, fmt.Sprintf("您已达到请求数限制:%d分钟内最多请求%d次", setting.ModelRequestRateLimitDurationMinutes, successMaxCount))
  89. return
  90. }
  91. // 3. 记录总请求(当totalMaxCount为0时会自动跳过)
  92. recordRedisRequest(ctx, rdb, totalKey, totalMaxCount)
  93. // 4. 处理请求
  94. c.Next()
  95. // 5. 如果请求成功,记录成功请求
  96. if c.Writer.Status() < 400 {
  97. recordRedisRequest(ctx, rdb, successKey, successMaxCount)
  98. }
  99. }
  100. }
  101. // 内存限流处理器
  102. func memoryRateLimitHandler(duration int64, totalMaxCount, successMaxCount int) gin.HandlerFunc {
  103. inMemoryRateLimiter.Init(common.RateLimitKeyExpirationDuration)
  104. return func(c *gin.Context) {
  105. userId := strconv.Itoa(c.GetInt("id"))
  106. totalKey := ModelRequestRateLimitCountMark + userId
  107. successKey := ModelRequestRateLimitSuccessCountMark + userId
  108. // 1. 检查总请求数限制(当totalMaxCount为0时跳过)
  109. if totalMaxCount > 0 && !inMemoryRateLimiter.Request(totalKey, totalMaxCount, duration) {
  110. c.Status(http.StatusTooManyRequests)
  111. c.Abort()
  112. return
  113. }
  114. // 2. 检查成功请求数限制
  115. // 使用一个临时key来检查限制,这样可以避免实际记录
  116. checkKey := successKey + "_check"
  117. if !inMemoryRateLimiter.Request(checkKey, successMaxCount, duration) {
  118. c.Status(http.StatusTooManyRequests)
  119. c.Abort()
  120. return
  121. }
  122. // 3. 处理请求
  123. c.Next()
  124. // 4. 如果请求成功,记录到实际的成功请求计数中
  125. if c.Writer.Status() < 400 {
  126. inMemoryRateLimiter.Request(successKey, successMaxCount, duration)
  127. }
  128. }
  129. }
  130. // ModelRequestRateLimit 模型请求限流中间件
  131. func ModelRequestRateLimit() func(c *gin.Context) {
  132. // 如果未启用限流,直接放行
  133. if !setting.ModelRequestRateLimitEnabled {
  134. return defNext
  135. }
  136. // 计算限流参数
  137. duration := int64(setting.ModelRequestRateLimitDurationMinutes * 60)
  138. totalMaxCount := setting.ModelRequestRateLimitCount
  139. successMaxCount := setting.ModelRequestRateLimitSuccessCount
  140. // 根据存储类型选择限流处理器
  141. if common.RedisEnabled {
  142. return redisRateLimitHandler(duration, totalMaxCount, successMaxCount)
  143. } else {
  144. return memoryRateLimitHandler(duration, totalMaxCount, successMaxCount)
  145. }
  146. }