Преглед изворни кода

Refactor: Optimize the token bucket algorithm, specifically the New method in common/imiterlimiter.go.
Solution: Remove Redis ping. When printing exceptions, use SysLog to print and add additional logging information.

霍雨佳 пре 10 месеци
родитељ
комит
e385e347ea
2 измењених фајлова са 6 додато и 11 уклоњено
  1. 4 9
      common/limiter/limiter.go
  2. 2 2
      middleware/model-rate-limit.go

+ 4 - 9
common/limiter/limiter.go

@@ -5,6 +5,7 @@ import (
 	_ "embed"
 	_ "embed"
 	"fmt"
 	"fmt"
 	"github.com/go-redis/redis/v8"
 	"github.com/go-redis/redis/v8"
+	"one-api/common"
 	"sync"
 	"sync"
 )
 )
 
 
@@ -23,19 +24,13 @@ var (
 
 
 func New(ctx context.Context, r *redis.Client) *RedisLimiter {
 func New(ctx context.Context, r *redis.Client) *RedisLimiter {
 	once.Do(func() {
 	once.Do(func() {
-		client := r
-		_, err := client.Ping(ctx).Result()
-		if err != nil {
-			panic(err) // 或者处理连接错误
-		}
 		// 预加载脚本
 		// 预加载脚本
-		limitSHA, err := client.ScriptLoad(ctx, rateLimitScript).Result()
+		limitSHA, err := r.ScriptLoad(ctx, rateLimitScript).Result()
 		if err != nil {
 		if err != nil {
-			fmt.Println(err)
+			common.SysLog(fmt.Sprintf("Failed to load rate limit script: %v", err))
 		}
 		}
-
 		instance = &RedisLimiter{
 		instance = &RedisLimiter{
-			client:         client,
+			client:         r,
 			limitScriptSHA: limitSHA,
 			limitScriptSHA: limitSHA,
 		}
 		}
 	})
 	})

+ 2 - 2
middleware/model-rate-limit.go

@@ -91,9 +91,9 @@ func redisRateLimitHandler(duration int64, totalMaxCount, successMaxCount int) g
 			abortWithOpenAiMessage(c, http.StatusTooManyRequests, fmt.Sprintf("您已达到请求数限制:%d分钟内最多请求%d次", setting.ModelRequestRateLimitDurationMinutes, successMaxCount))
 			abortWithOpenAiMessage(c, http.StatusTooManyRequests, fmt.Sprintf("您已达到请求数限制:%d分钟内最多请求%d次", setting.ModelRequestRateLimitDurationMinutes, successMaxCount))
 			return
 			return
 		}
 		}
-		//检查总请求数限制并记录总请求(当totalMaxCount为0时会自动跳过,使用令牌桶限流器
+
+		//2.检查总请求数限制并记录总请求(当totalMaxCount为0时会自动跳过,使用令牌桶限流器
 		totalKey := fmt.Sprintf("rateLimit:%s", userId)
 		totalKey := fmt.Sprintf("rateLimit:%s", userId)
-		//allowed, err = checkRedisRateLimit(ctx, rdb, totalKey, totalMaxCount, duration)
 		// 初始化
 		// 初始化
 		tb := limiter.New(ctx, rdb)
 		tb := limiter.New(ctx, rdb)
 		allowed, err = tb.Allow(
 		allowed, err = tb.Allow(