فهرست منبع

Refactor: Optimize the request rate limiting for ModelRequestRateLimitCount.
Reason: The original steps 1 and 3 in the redisRateLimitHandler method were not atomic, leading to poor precision under high concurrent requests. For example, with a rate limit set to 60, sending 200 concurrent requests would result in none being blocked, whereas theoretically around 140 should be intercepted.
Solution: I chose not to merge steps 1 and 3 into a single Lua script because a single atomic operation involving read, write, and delete operations could suffer from performance issues under high concurrency. Instead, I implemented a token bucket algorithm to optimize this, reducing the atomic operation to just read and write steps while significantly decreasing the memory footprint.

霍雨佳 10 ماه پیش
والد
کامیت
eb75ff232f
3فایلهای تغییر یافته به همراه160 افزوده شده و 14 حذف شده
  1. 94 0
      common/limiter/limiter.go
  2. 44 0
      common/limiter/lua/rate_limit.lua
  3. 22 14
      middleware/model-rate-limit.go

+ 94 - 0
common/limiter/limiter.go

@@ -0,0 +1,94 @@
+package limiter
+
+import (
+	"context"
+	_ "embed"
+	"fmt"
+	"github.com/go-redis/redis/v8"
+	"sync"
+)
+
+//go:embed lua/rate_limit.lua
+var rateLimitScript string
+
+type RedisLimiter struct {
+	client         *redis.Client
+	limitScriptSHA string
+}
+
+var (
+	instance *RedisLimiter
+	once     sync.Once
+)
+
+func New(ctx context.Context, r *redis.Client) *RedisLimiter {
+	once.Do(func() {
+		client := r
+		_, err := client.Ping(ctx).Result()
+		if err != nil {
+			panic(err) // 或者处理连接错误
+		}
+		// 预加载脚本
+		limitSHA, err := client.ScriptLoad(ctx, rateLimitScript).Result()
+		if err != nil {
+			fmt.Println(err)
+		}
+
+		instance = &RedisLimiter{
+			client:         client,
+			limitScriptSHA: limitSHA,
+		}
+	})
+
+	return instance
+}
+
+func (rl *RedisLimiter) Allow(ctx context.Context, key string, opts ...Option) (bool, error) {
+	// 默认配置
+	config := &Config{
+		Capacity:  10,
+		Rate:      1,
+		Requested: 1,
+	}
+
+	// 应用选项模式
+	for _, opt := range opts {
+		opt(config)
+	}
+
+	// 执行限流
+	result, err := rl.client.EvalSha(
+		ctx,
+		rl.limitScriptSHA,
+		[]string{key},
+		config.Requested,
+		config.Rate,
+		config.Capacity,
+	).Int()
+
+	if err != nil {
+		return false, fmt.Errorf("rate limit failed: %w", err)
+	}
+	return result == 1, nil
+}
+
+// Config 配置选项模式
+type Config struct {
+	Capacity  int64
+	Rate      int64
+	Requested int64
+}
+
+type Option func(*Config)
+
+func WithCapacity(c int64) Option {
+	return func(cfg *Config) { cfg.Capacity = c }
+}
+
+func WithRate(r int64) Option {
+	return func(cfg *Config) { cfg.Rate = r }
+}
+
+func WithRequested(n int64) Option {
+	return func(cfg *Config) { cfg.Requested = n }
+}

+ 44 - 0
common/limiter/lua/rate_limit.lua

@@ -0,0 +1,44 @@
+-- 令牌桶限流器
+-- KEYS[1]: 限流器唯一标识
+-- ARGV[1]: 请求令牌数 (通常为1)
+-- ARGV[2]: 令牌生成速率 (每秒)
+-- ARGV[3]: 桶容量
+
+local key = KEYS[1]
+local requested = tonumber(ARGV[1])
+local rate = tonumber(ARGV[2])
+local capacity = tonumber(ARGV[3])
+
+-- 获取当前时间(Redis服务器时间)
+local now = redis.call('TIME')
+local nowInSeconds = tonumber(now[1])
+
+-- 获取桶状态
+local bucket = redis.call('HMGET', key, 'tokens', 'last_time')
+local tokens = tonumber(bucket[1])
+local last_time = tonumber(bucket[2])
+
+-- 初始化桶(首次请求或过期)
+if not tokens or not last_time then
+    tokens = capacity
+    last_time = nowInSeconds
+else
+    -- 计算新增令牌
+    local elapsed = nowInSeconds - last_time
+    local add_tokens = elapsed * rate
+    tokens = math.min(capacity, tokens + add_tokens)
+    last_time = nowInSeconds
+end
+
+-- 判断是否允许请求
+local allowed = false
+if tokens >= requested then
+    tokens = tokens - requested
+    allowed = true
+end
+
+---- 更新桶状态并设置过期时间
+redis.call('HMSET', key, 'tokens', tokens, 'last_time', last_time)
+--redis.call('EXPIRE', key, math.ceil(capacity / rate) + 60) -- 适当延长过期时间
+
+return allowed and 1 or 0

+ 22 - 14
middleware/model-rate-limit.go

@@ -5,6 +5,7 @@ import (
 	"fmt"
 	"net/http"
 	"one-api/common"
+	"one-api/common/limiter"
 	"one-api/setting"
 	"strconv"
 	"time"
@@ -78,34 +79,41 @@ func redisRateLimitHandler(duration int64, totalMaxCount, successMaxCount int) g
 		ctx := context.Background()
 		rdb := common.RDB
 
-		// 1. 检查总请求数限制(当totalMaxCount为0时会自动跳过)
-		totalKey := fmt.Sprintf("rateLimit:%s:%s", ModelRequestRateLimitCountMark, userId)
-		allowed, err := checkRedisRateLimit(ctx, rdb, totalKey, totalMaxCount, duration)
+		// 1. 检查成功请求数限制
+		successKey := fmt.Sprintf("rateLimit:%s:%s", ModelRequestRateLimitSuccessCountMark, userId)
+		allowed, err := checkRedisRateLimit(ctx, rdb, successKey, successMaxCount, duration)
 		if err != nil {
-			fmt.Println("检查请求数限制失败:", err.Error())
+			fmt.Println("检查成功请求数限制失败:", err.Error())
 			abortWithOpenAiMessage(c, http.StatusInternalServerError, "rate_limit_check_failed")
 			return
 		}
 		if !allowed {
-			abortWithOpenAiMessage(c, http.StatusTooManyRequests, fmt.Sprintf("您已达到总请求数限制:%d分钟内最多请求%d次,包括失败次数,请检查您的请求是否正确", setting.ModelRequestRateLimitDurationMinutes, totalMaxCount))
+			abortWithOpenAiMessage(c, http.StatusTooManyRequests, fmt.Sprintf("您已达到请求数限制:%d分钟内最多请求%d次", setting.ModelRequestRateLimitDurationMinutes, successMaxCount))
+			return
 		}
+		//检查总请求数限制并记录总请求(当totalMaxCount为0时会自动跳过,使用令牌桶限流器
+		totalKey := fmt.Sprintf("rateLimit:%s", userId)
+		//allowed, err = checkRedisRateLimit(ctx, rdb, totalKey, totalMaxCount, duration)
+		// 初始化
+		tb := limiter.New(ctx, rdb)
+		allowed, err = tb.Allow(
+			ctx,
+			totalKey,
+			limiter.WithCapacity(int64(totalMaxCount)*duration),
+			limiter.WithRate(int64(totalMaxCount)),
+			limiter.WithRequested(duration),
+		)
 
-		// 2. 检查成功请求数限制
-		successKey := fmt.Sprintf("rateLimit:%s:%s", ModelRequestRateLimitSuccessCountMark, userId)
-		allowed, err = checkRedisRateLimit(ctx, rdb, successKey, successMaxCount, duration)
 		if err != nil {
-			fmt.Println("检查成功请求数限制失败:", err.Error())
+			fmt.Println("检查总请求数限制失败:", err.Error())
 			abortWithOpenAiMessage(c, http.StatusInternalServerError, "rate_limit_check_failed")
 			return
 		}
+
 		if !allowed {
-			abortWithOpenAiMessage(c, http.StatusTooManyRequests, fmt.Sprintf("您已达到请求数限制:%d分钟内最多请求%d次", setting.ModelRequestRateLimitDurationMinutes, successMaxCount))
-			return
+			abortWithOpenAiMessage(c, http.StatusTooManyRequests, fmt.Sprintf("您已达到总请求数限制:%d分钟内最多请求%d次,包括失败次数,请检查您的请求是否正确", setting.ModelRequestRateLimitDurationMinutes, totalMaxCount))
 		}
 
-		// 3. 记录总请求(当totalMaxCount为0时会自动跳过)
-		recordRedisRequest(ctx, rdb, totalKey, totalMaxCount)
-
 		// 4. 处理请求
 		c.Next()