Просмотр исходного кода

refactor: 调整代码,符合项目现有规范

tbphp 10 месяцев назад
Родитель
Сommit
1513ed7847
4 измененных файлов с 65 добавлено и 66 удалено
  1. 35 19
      middleware/model-rate-limit.go
  2. 9 25
      model/option.go
  3. 17 20
      setting/rate_limit.go
  4. 4 2
      web/src/components/RateLimitSetting.js

+ 35 - 19
middleware/model-rate-limit.go

@@ -6,6 +6,7 @@ import (
 	"net/http"
 	"one-api/common"
 	"one-api/common/limiter"
+	"one-api/constant"
 	"one-api/setting"
 	"strconv"
 	"time"
@@ -19,20 +20,25 @@ const (
 	ModelRequestRateLimitSuccessCountMark = "MRRLS"
 )
 
+// 检查Redis中的请求限制
 func checkRedisRateLimit(ctx context.Context, rdb *redis.Client, key string, maxCount int, duration int64) (bool, error) {
+	// 如果maxCount为0,表示不限制
 	if maxCount == 0 {
 		return true, nil
 	}
 
+	// 获取当前计数
 	length, err := rdb.LLen(ctx, key).Result()
 	if err != nil {
 		return false, err
 	}
 
+	// 如果未达到限制,允许请求
 	if length < int64(maxCount) {
 		return true, nil
 	}
 
+	// 检查时间窗口
 	oldTimeStr, _ := rdb.LIndex(ctx, key, -1).Result()
 	oldTime, err := time.Parse(timeFormat, oldTimeStr)
 	if err != nil {
@@ -44,6 +50,7 @@ func checkRedisRateLimit(ctx context.Context, rdb *redis.Client, key string, max
 	if err != nil {
 		return false, err
 	}
+	// 如果在时间窗口内已达到限制,拒绝请求
 	subTime := nowTime.Sub(oldTime).Seconds()
 	if int64(subTime) < duration {
 		rdb.Expire(ctx, key, time.Duration(setting.ModelRequestRateLimitDurationMinutes)*time.Minute)
@@ -53,7 +60,9 @@ func checkRedisRateLimit(ctx context.Context, rdb *redis.Client, key string, max
 	return true, nil
 }
 
+// 记录Redis请求
 func recordRedisRequest(ctx context.Context, rdb *redis.Client, key string, maxCount int) {
+	// 如果maxCount为0,不记录请求
 	if maxCount == 0 {
 		return
 	}
@@ -64,12 +73,14 @@ func recordRedisRequest(ctx context.Context, rdb *redis.Client, key string, maxC
 	rdb.Expire(ctx, key, time.Duration(setting.ModelRequestRateLimitDurationMinutes)*time.Minute)
 }
 
+// Redis限流处理器
 func redisRateLimitHandler(duration int64, totalMaxCount, successMaxCount int) gin.HandlerFunc {
 	return func(c *gin.Context) {
 		userId := strconv.Itoa(c.GetInt("id"))
 		ctx := context.Background()
 		rdb := common.RDB
 
+		// 1. 检查成功请求数限制
 		successKey := fmt.Sprintf("rateLimit:%s:%s", ModelRequestRateLimitSuccessCountMark, userId)
 		allowed, err := checkRedisRateLimit(ctx, rdb, successKey, successMaxCount, duration)
 		if err != nil {
@@ -82,7 +93,9 @@ func redisRateLimitHandler(duration int64, totalMaxCount, successMaxCount int) g
 			return
 		}
 
+		//2.检查总请求数限制并记录总请求(当totalMaxCount为0时会自动跳过,使用令牌桶限流器
 		totalKey := fmt.Sprintf("rateLimit:%s", userId)
+		// 初始化
 		tb := limiter.New(ctx, rdb)
 		allowed, err = tb.Allow(
 			ctx,
@@ -102,14 +115,17 @@ func redisRateLimitHandler(duration int64, totalMaxCount, successMaxCount int) g
 			abortWithOpenAiMessage(c, http.StatusTooManyRequests, fmt.Sprintf("您已达到总请求数限制:%d分钟内最多请求%d次,包括失败次数,请检查您的请求是否正确", setting.ModelRequestRateLimitDurationMinutes, totalMaxCount))
 		}
 
+		// 4. 处理请求
 		c.Next()
 
+		// 5. 如果请求成功,记录成功请求
 		if c.Writer.Status() < 400 {
 			recordRedisRequest(ctx, rdb, successKey, successMaxCount)
 		}
 	}
 }
 
+// 内存限流处理器
 func memoryRateLimitHandler(duration int64, totalMaxCount, successMaxCount int) gin.HandlerFunc {
 	inMemoryRateLimiter.Init(time.Duration(setting.ModelRequestRateLimitDurationMinutes) * time.Minute)
 
@@ -118,12 +134,15 @@ func memoryRateLimitHandler(duration int64, totalMaxCount, successMaxCount int)
 		totalKey := ModelRequestRateLimitCountMark + userId
 		successKey := ModelRequestRateLimitSuccessCountMark + userId
 
+		// 1. 检查总请求数限制(当totalMaxCount为0时跳过)
 		if totalMaxCount > 0 && !inMemoryRateLimiter.Request(totalKey, totalMaxCount, duration) {
 			c.Status(http.StatusTooManyRequests)
 			c.Abort()
 			return
 		}
 
+		// 2. 检查成功请求数限制
+		// 使用一个临时key来检查限制,这样可以避免实际记录
 		checkKey := successKey + "_check"
 		if !inMemoryRateLimiter.Request(checkKey, successMaxCount, duration) {
 			c.Status(http.StatusTooManyRequests)
@@ -131,51 +150,48 @@ func memoryRateLimitHandler(duration int64, totalMaxCount, successMaxCount int)
 			return
 		}
 
+		// 3. 处理请求
 		c.Next()
 
+		// 4. 如果请求成功,记录到实际的成功请求计数中
 		if c.Writer.Status() < 400 {
 			inMemoryRateLimiter.Request(successKey, successMaxCount, duration)
 		}
 	}
 }
 
+// ModelRequestRateLimit 模型请求限流中间件
 func ModelRequestRateLimit() func(c *gin.Context) {
 	return func(c *gin.Context) {
+		// 在每个请求时检查是否启用限流
 		if !setting.ModelRequestRateLimitEnabled {
 			c.Next()
 			return
 		}
 
+		// 计算限流参数
 		duration := int64(setting.ModelRequestRateLimitDurationMinutes * 60)
+		totalMaxCount := setting.ModelRequestRateLimitCount
+		successMaxCount := setting.ModelRequestRateLimitSuccessCount
 
+		// 获取分组
 		group := c.GetString("token_group")
 		if group == "" {
-			group = c.GetString("group")
+			group = c.GetString(constant.ContextKeyUserGroup)
 		}
-		if group == "" {
-			group = "default"
-		}
-
-		finalTotalCount := setting.ModelRequestRateLimitCount
-		finalSuccessCount := setting.ModelRequestRateLimitSuccessCount
-		foundGroupLimit := false
 
+		//获取分组的限流配置
 		groupTotalCount, groupSuccessCount, found := setting.GetGroupRateLimit(group)
 		if found {
-			finalTotalCount = groupTotalCount
-			finalSuccessCount = groupSuccessCount
-			foundGroupLimit = true
-			common.LogWarn(c.Request.Context(), fmt.Sprintf("Using rate limit for group '%s': total=%d, success=%d", group, finalTotalCount, finalSuccessCount))
-		}
-
-		if !foundGroupLimit {
-			common.LogInfo(c.Request.Context(), fmt.Sprintf("No specific rate limit found for group '%s', using global limits: total=%d, success=%d", group, finalTotalCount, finalSuccessCount))
+			totalMaxCount = groupTotalCount
+			successMaxCount = groupSuccessCount
 		}
 
+		// 根据存储类型选择并执行限流处理器
 		if common.RedisEnabled {
-			redisRateLimitHandler(duration, finalTotalCount, finalSuccessCount)(c)
+			redisRateLimitHandler(duration, totalMaxCount, successMaxCount)(c)
 		} else {
-			memoryRateLimitHandler(duration, finalTotalCount, finalSuccessCount)(c)
+			memoryRateLimitHandler(duration, totalMaxCount, successMaxCount)(c)
 		}
 	}
-}
+}

+ 9 - 25
model/option.go

@@ -1,8 +1,6 @@
 package model
 
 import (
-	"encoding/json"
-	"fmt"
 	"one-api/common"
 	"one-api/setting"
 	"one-api/setting/config"
@@ -94,8 +92,7 @@ func InitOptionMap() {
 	common.OptionMap["ModelRequestRateLimitCount"] = strconv.Itoa(setting.ModelRequestRateLimitCount)
 	common.OptionMap["ModelRequestRateLimitDurationMinutes"] = strconv.Itoa(setting.ModelRequestRateLimitDurationMinutes)
 	common.OptionMap["ModelRequestRateLimitSuccessCount"] = strconv.Itoa(setting.ModelRequestRateLimitSuccessCount)
-	jsonBytes, _ := json.Marshal(map[string][2]int{})
-	common.OptionMap["ModelRequestRateLimitGroup"] = string(jsonBytes)
+	common.OptionMap["ModelRequestRateLimitGroup"] = setting.ModelRequestRateLimitGroup2JSONString()
 	common.OptionMap["ModelRatio"] = operation_setting.ModelRatio2JSONString()
 	common.OptionMap["ModelPrice"] = operation_setting.ModelPrice2JSONString()
 	common.OptionMap["CacheRatio"] = operation_setting.CacheRatio2JSONString()
@@ -154,31 +151,18 @@ func SyncOptions(frequency int) {
 }
 
 func UpdateOption(key string, value string) error {
-	originalValue := value
-
-	if key == "ModelRequestRateLimitGroup" {
-		var cfg map[string][2]int
-		err := json.Unmarshal([]byte(originalValue), &cfg)
-		if err != nil {
-			return fmt.Errorf("无效的 JSON 格式 for %s: %w", key, err)
-		}
-
-		formattedValueBytes, marshalErr := json.MarshalIndent(cfg, "", "  ")
-		if marshalErr != nil {
-			return fmt.Errorf("failed to marshal validated %s config: %w", key, marshalErr)
-		}
-		value = string(formattedValueBytes)
-	}
-
+	// Save to database first
 	option := Option{
 		Key: key,
 	}
+	// https://gorm.io/docs/update.html#Save-All-Fields
 	DB.FirstOrCreate(&option, Option{Key: key})
 	option.Value = value
-	if err := DB.Save(&option).Error; err != nil {
-		return fmt.Errorf("保存选项 %s 到数据库失败: %w", key, err)
-	}
-
+	// Save is a combination function.
+	// If save value does not contain primary key, it will execute Create,
+	// otherwise it will execute Update (with all fields).
+	DB.Save(&option)
+	// Update OptionMap
 	return updateOptionMap(key, value)
 }
 
@@ -356,7 +340,7 @@ func updateOptionMap(key string, value string) (err error) {
 	case "ModelRequestRateLimitSuccessCount":
 		setting.ModelRequestRateLimitSuccessCount, _ = strconv.Atoi(value)
 	case "ModelRequestRateLimitGroup":
-		err = setting.UpdateModelRequestRateLimitGroup(value)
+		err = setting.UpdateModelRequestRateLimitGroupByJSONString(value)
 	case "RetryTimes":
 		common.RetryTimes, _ = strconv.Atoi(value)
 	case "DataExportInterval":

+ 17 - 20
setting/rate_limit.go

@@ -2,7 +2,6 @@ package setting
 
 import (
 	"encoding/json"
-	"fmt"
 	"one-api/common"
 	"sync"
 )
@@ -11,33 +10,31 @@ var ModelRequestRateLimitEnabled = false
 var ModelRequestRateLimitDurationMinutes = 1
 var ModelRequestRateLimitCount = 0
 var ModelRequestRateLimitSuccessCount = 1000
-var ModelRequestRateLimitGroup map[string][2]int
+var ModelRequestRateLimitGroup = map[string][2]int{}
+var ModelRequestRateLimitMutex sync.RWMutex
 
-var ModelRequestRateLimitGroupMutex sync.RWMutex
+func ModelRequestRateLimitGroup2JSONString() string {
+	ModelRequestRateLimitMutex.RLock()
+	defer ModelRequestRateLimitMutex.RUnlock()
 
-func UpdateModelRequestRateLimitGroup(jsonStr string) error {
-	ModelRequestRateLimitGroupMutex.Lock()
-	defer ModelRequestRateLimitGroupMutex.Unlock()
-
-	var newConfig map[string][2]int
-	if jsonStr == "" || jsonStr == "{}" {
-		ModelRequestRateLimitGroup = make(map[string][2]int)
-		common.SysLog("Model request rate limit group config cleared")
-		return nil
-	}
-
-	err := json.Unmarshal([]byte(jsonStr), &newConfig)
+	jsonBytes, err := json.Marshal(ModelRequestRateLimitGroup)
 	if err != nil {
-		return fmt.Errorf("failed to unmarshal ModelRequestRateLimitGroup config: %w", err)
+		common.SysError("error marshalling model ratio: " + err.Error())
 	}
+	return string(jsonBytes)
+}
+
+func UpdateModelRequestRateLimitGroupByJSONString(jsonStr string) error {
+	ModelRequestRateLimitMutex.RLock()
+	defer ModelRequestRateLimitMutex.RUnlock()
 
-	ModelRequestRateLimitGroup = newConfig
-	return nil
+	ModelRequestRateLimitGroup = make(map[string][2]int)
+	return json.Unmarshal([]byte(jsonStr), &ModelRequestRateLimitGroup)
 }
 
 func GetGroupRateLimit(group string) (totalCount, successCount int, found bool) {
-	ModelRequestRateLimitGroupMutex.RLock()
-	defer ModelRequestRateLimitGroupMutex.RUnlock()
+	ModelRequestRateLimitMutex.RLock()
+	defer ModelRequestRateLimitMutex.RUnlock()
 
 	if ModelRequestRateLimitGroup == nil {
 		return 0, 0, false

+ 4 - 2
web/src/components/RateLimitSetting.js

@@ -24,7 +24,6 @@ const RateLimitSetting = () => {
   	if (success) {
   		let newInputs = {};
   		data.forEach((item) => {
-  			// 检查 key 是否在初始 inputs 中定义
   			if (Object.prototype.hasOwnProperty.call(inputs, item.key)) {
   				if (item.key.endsWith('Enabled')) {
   					newInputs[item.key] = item.value === 'true';
@@ -33,6 +32,7 @@ const RateLimitSetting = () => {
   				}
   			}
   		});
+		
   		setInputs(newInputs);
   	} else {
   		showError(message);
@@ -42,6 +42,7 @@ const RateLimitSetting = () => {
   	try {
   		setLoading(true);
   		await getOptions();
+		// showSuccess('刷新成功');
   	} catch (error) {
   		showError('刷新失败');
   	} finally {
@@ -56,6 +57,7 @@ const RateLimitSetting = () => {
   return (
   	<>
   		<Spin spinning={loading} size='large'>
+			 {/* AI请求速率限制 */}
   			<Card style={{ marginTop: '10px' }}>
   				<RequestRateLimit options={inputs} refresh={onRefresh} />
   			</Card>
@@ -64,4 +66,4 @@ const RateLimitSetting = () => {
   );
  };
  
- export default RateLimitSetting;
+ export default RateLimitSetting;