Просмотр исходного кода

feat: 增加分组速率功能

tbphp 10 месяцев назад
Родитель
Сommit
6c3fb7777e

+ 29 - 6
middleware/model-rate-limit.go

@@ -168,16 +168,39 @@ func ModelRequestRateLimit() func(c *gin.Context) {
 			return
 			return
 		}
 		}
 
 
-		// 计算限流参数
+		// 计算通用限流参数
 		duration := int64(setting.ModelRequestRateLimitDurationMinutes * 60)
 		duration := int64(setting.ModelRequestRateLimitDurationMinutes * 60)
-		totalMaxCount := setting.ModelRequestRateLimitCount
-		successMaxCount := setting.ModelRequestRateLimitSuccessCount
 
 
-		// 根据存储类型选择并执行限流处理器
+		// 获取用户组
+		group := c.GetString("token_group")
+		if group == "" {
+			group = c.GetString("group")
+		}
+		if group == "" {
+			group = "default" // 默认组
+		}
+
+		// 尝试获取用户组特定的限制
+		groupTotalCount, groupSuccessCount, found := setting.GetGroupRateLimit(group)
+
+		// 确定最终的限制值
+		finalTotalCount := setting.ModelRequestRateLimitCount       // 默认使用全局总次数限制
+		finalSuccessCount := setting.ModelRequestRateLimitSuccessCount // 默认使用全局成功次数限制
+
+		if found {
+			// 如果找到用户组特定限制,则使用它们
+			finalTotalCount = groupTotalCount
+			finalSuccessCount = groupSuccessCount
+			common.LogWarn(c.Request.Context(), fmt.Sprintf("Using rate limit for group '%s': total=%d, success=%d", group, finalTotalCount, finalSuccessCount))
+		} else {
+			common.LogInfo(c.Request.Context(), fmt.Sprintf("No specific rate limit found for group '%s', using global limits: total=%d, success=%d", group, finalTotalCount, finalSuccessCount))
+		}
+
+		// 根据存储类型选择并执行限流处理器,传入最终确定的限制值
 		if common.RedisEnabled {
 		if common.RedisEnabled {
-			redisRateLimitHandler(duration, totalMaxCount, successMaxCount)(c)
+			redisRateLimitHandler(duration, finalTotalCount, finalSuccessCount)(c)
 		} else {
 		} else {
-			memoryRateLimitHandler(duration, totalMaxCount, successMaxCount)(c)
+			memoryRateLimitHandler(duration, finalTotalCount, finalSuccessCount)(c)
 		}
 		}
 	}
 	}
 }
 }

+ 44 - 3
model/option.go

@@ -1,6 +1,8 @@
 package model
 package model
 
 
 import (
 import (
+	"encoding/json"
+	"fmt"
 	"one-api/common"
 	"one-api/common"
 	"one-api/setting"
 	"one-api/setting"
 	"one-api/setting/config"
 	"one-api/setting/config"
@@ -96,6 +98,7 @@ func InitOptionMap() {
 	common.OptionMap["ModelPrice"] = operation_setting.ModelPrice2JSONString()
 	common.OptionMap["ModelPrice"] = operation_setting.ModelPrice2JSONString()
 	common.OptionMap["CacheRatio"] = operation_setting.CacheRatio2JSONString()
 	common.OptionMap["CacheRatio"] = operation_setting.CacheRatio2JSONString()
 	common.OptionMap["GroupRatio"] = setting.GroupRatio2JSONString()
 	common.OptionMap["GroupRatio"] = setting.GroupRatio2JSONString()
+	common.OptionMap[setting.ModelRequestRateLimitGroupKey] = "{}" // 添加用户组速率限制默认值
 	common.OptionMap["UserUsableGroups"] = setting.UserUsableGroups2JSONString()
 	common.OptionMap["UserUsableGroups"] = setting.UserUsableGroups2JSONString()
 	common.OptionMap["CompletionRatio"] = operation_setting.CompletionRatio2JSONString()
 	common.OptionMap["CompletionRatio"] = operation_setting.CompletionRatio2JSONString()
 	common.OptionMap["TopUpLink"] = common.TopUpLink
 	common.OptionMap["TopUpLink"] = common.TopUpLink
@@ -150,7 +153,32 @@ func SyncOptions(frequency int) {
 }
 }
 
 
 func UpdateOption(key string, value string) error {
 func UpdateOption(key string, value string) error {
-	// Save to database first
+	originalValue := value // 保存原始值以备后用
+
+	// Validate and format specific keys before saving
+	if key == setting.ModelRequestRateLimitGroupKey {
+		var cfg map[string][2]int
+		// Validate the JSON structure first using the original value
+		err := json.Unmarshal([]byte(originalValue), &cfg)
+		if err != nil {
+			// 提供更具体的错误信息
+			return fmt.Errorf("无效的 JSON 格式 for %s: %w", key, err)
+		}
+		// TODO: 可以添加更细致的结构验证,例如检查数组长度是否为2,值是否为非负数等。
+		// if !isValidModelRequestRateLimitGroupConfig(cfg) {
+		//     return fmt.Errorf("无效的配置值 for %s", key)
+		// }
+
+		// If valid, format the JSON before saving
+		formattedValueBytes, marshalErr := json.MarshalIndent(cfg, "", "  ")
+		if marshalErr != nil {
+			// This should ideally not happen if validation passed, but handle defensively
+			return fmt.Errorf("failed to marshal validated %s config: %w", key, marshalErr)
+		}
+		value = string(formattedValueBytes) // Use formatted JSON for saving and memory update
+	}
+
+	// Save to database
 	option := Option{
 	option := Option{
 		Key: key,
 		Key: key,
 	}
 	}
@@ -160,8 +188,12 @@ func UpdateOption(key string, value string) error {
 	// Save is a combination function.
 	// Save is a combination function.
 	// If save value does not contain primary key, it will execute Create,
 	// If save value does not contain primary key, it will execute Create,
 	// otherwise it will execute Update (with all fields).
 	// otherwise it will execute Update (with all fields).
-	DB.Save(&option)
-	// Update OptionMap
+	if err := DB.Save(&option).Error; err != nil {
+		return fmt.Errorf("保存选项 %s 到数据库失败: %w", key, err) // 添加错误上下文
+	}
+
+	// Update OptionMap in memory using the potentially formatted value
+	// updateOptionMap 会处理内存中 setting.ModelRequestRateLimitGroupConfig 的更新
 	return updateOptionMap(key, value)
 	return updateOptionMap(key, value)
 }
 }
 
 
@@ -372,6 +404,15 @@ func updateOptionMap(key string, value string) (err error) {
 		operation_setting.AutomaticDisableKeywordsFromString(value)
 		operation_setting.AutomaticDisableKeywordsFromString(value)
 	case "StreamCacheQueueLength":
 	case "StreamCacheQueueLength":
 		setting.StreamCacheQueueLength, _ = strconv.Atoi(value)
 		setting.StreamCacheQueueLength, _ = strconv.Atoi(value)
+	case setting.ModelRequestRateLimitGroupKey:
+		// Use the (potentially formatted) value passed from UpdateOption
+		// to update the actual configuration in memory.
+		// This is the single point where the memory state for this specific setting is updated.
+		err = setting.UpdateModelRequestRateLimitGroupConfig(value)
+		if err != nil {
+			// 添加错误上下文
+			err = fmt.Errorf("更新内存中的 %s 配置失败: %w", key, err)
+		}
 	}
 	}
 	return err
 	return err
 }
 }

+ 70 - 0
setting/rate_limit.go

@@ -1,6 +1,76 @@
 package setting
 package setting
 
 
+import (
+	"encoding/json"
+	"fmt"
+	"one-api/common"
+	"sync"
+)
+
 var ModelRequestRateLimitEnabled = false
 var ModelRequestRateLimitEnabled = false
 var ModelRequestRateLimitDurationMinutes = 1
 var ModelRequestRateLimitDurationMinutes = 1
 var ModelRequestRateLimitCount = 0
 var ModelRequestRateLimitCount = 0
 var ModelRequestRateLimitSuccessCount = 1000
 var ModelRequestRateLimitSuccessCount = 1000
+
+// ModelRequestRateLimitGroupKey 定义了模型请求按组速率限制的配置键
+const ModelRequestRateLimitGroupKey = "ModelRequestRateLimitGroup"
+
+// ModelRequestRateLimitGroupConfig 存储按用户组解析后的速率限制配置
+// map[groupName][2]int{totalCount, successCount}
+var ModelRequestRateLimitGroupConfig map[string][2]int
+var ModelRequestRateLimitGroupMutex sync.RWMutex
+
+// UpdateModelRequestRateLimitGroupConfig 解析、校验并更新内存中的用户组速率限制配置
+func UpdateModelRequestRateLimitGroupConfig(jsonStr string) error {
+	ModelRequestRateLimitGroupMutex.Lock()
+	defer ModelRequestRateLimitGroupMutex.Unlock()
+
+	var newConfig map[string][2]int
+	if jsonStr == "" || jsonStr == "{}" {
+		// 如果配置为空或空JSON对象,则清空内存配置
+		ModelRequestRateLimitGroupConfig = make(map[string][2]int)
+		common.SysLog("Model request rate limit group config cleared")
+		return nil
+	}
+
+	err := json.Unmarshal([]byte(jsonStr), &newConfig)
+	if err != nil {
+		return fmt.Errorf("failed to unmarshal ModelRequestRateLimitGroup config: %w", err)
+	}
+
+	// 校验配置值
+	for group, limits := range newConfig {
+		if len(limits) != 2 {
+			return fmt.Errorf("invalid config for group '%s': limits array length must be 2", group)
+		}
+		if limits[1] <= 0 { // successCount must be greater than 0
+			return fmt.Errorf("invalid config for group '%s': successCount (limits[1]) must be greater than 0", group)
+		}
+		if limits[0] < 0 { // totalCount can be 0 (no limit) or positive
+			return fmt.Errorf("invalid config for group '%s': totalCount (limits[0]) cannot be negative", group)
+		}
+		if limits[0] > 0 && limits[0] < limits[1] { // If totalCount is set, it must be >= successCount
+			return fmt.Errorf("invalid config for group '%s': totalCount (limits[0]) must be greater than or equal to successCount (limits[1]) when totalCount > 0", group)
+		}
+	}
+
+	ModelRequestRateLimitGroupConfig = newConfig
+	common.SysLog("Model request rate limit group config updated")
+	return nil
+}
+
+// GetGroupRateLimit 安全地获取指定用户组的速率限制值
+func GetGroupRateLimit(group string) (totalCount, successCount int, found bool) {
+	ModelRequestRateLimitGroupMutex.RLock()
+	defer ModelRequestRateLimitGroupMutex.RUnlock()
+
+	if ModelRequestRateLimitGroupConfig == nil {
+		return 0, 0, false // 配置尚未初始化
+	}
+
+	limits, found := ModelRequestRateLimitGroupConfig[group]
+	if !found {
+		return 0, 0, false
+	}
+	return limits[0], limits[1], true
+}

+ 1 - 0
web/src/components/RateLimitSetting.js

@@ -13,6 +13,7 @@ const RateLimitSetting = () => {
     ModelRequestRateLimitCount: 0,
     ModelRequestRateLimitCount: 0,
     ModelRequestRateLimitSuccessCount: 1000,
     ModelRequestRateLimitSuccessCount: 1000,
     ModelRequestRateLimitDurationMinutes: 1,
     ModelRequestRateLimitDurationMinutes: 1,
+    ModelRequestRateLimitGroup: {},
   });
   });
 
 
   let [loading, setLoading] = useState(false);
   let [loading, setLoading] = useState(false);

+ 69 - 7
web/src/pages/Setting/RateLimit/SettingsRequestRateLimit.js

@@ -18,6 +18,7 @@ export default function RequestRateLimit(props) {
     ModelRequestRateLimitCount: -1,
     ModelRequestRateLimitCount: -1,
     ModelRequestRateLimitSuccessCount: 1000,
     ModelRequestRateLimitSuccessCount: 1000,
     ModelRequestRateLimitDurationMinutes: 1,
     ModelRequestRateLimitDurationMinutes: 1,
+    ModelRequestRateLimitGroup: '{}', // 添加新字段并设置默认值
   });
   });
   const refForm = useRef();
   const refForm = useRef();
   const [inputsRow, setInputsRow] = useState(inputs);
   const [inputsRow, setInputsRow] = useState(inputs);
@@ -32,25 +33,49 @@ export default function RequestRateLimit(props) {
       } else {
       } else {
         value = inputs[item.key];
         value = inputs[item.key];
       }
       }
+      // 校验 ModelRequestRateLimitGroup 是否为有效的 JSON 对象字符串
+      if (item.key === 'ModelRequestRateLimitGroup') {
+        try {
+          JSON.parse(value);
+        } catch (e) {
+          showError(t('用户组速率限制配置不是有效的 JSON 格式!'));
+          // 阻止请求发送
+          return Promise.reject('Invalid JSON format');
+        }
+      }
       return API.put('/api/option/', {
       return API.put('/api/option/', {
         key: item.key,
         key: item.key,
         value,
         value,
       });
       });
     });
     });
+
+    // 过滤掉无效的请求(例如,无效的 JSON)
+    const validRequests = requestQueue.filter(req => req !== null && req !== undefined && typeof req.then === 'function');
+
+    if (validRequests.length === 0 && requestQueue.length > 0) {
+      // 如果所有请求都被过滤掉了(因为 JSON 无效),则不继续执行
+      return;
+    }
+
     setLoading(true);
     setLoading(true);
-    Promise.all(requestQueue)
+    Promise.all(validRequests)
       .then((res) => {
       .then((res) => {
-        if (requestQueue.length === 1) {
+        if (validRequests.length === 1) {
           if (res.includes(undefined)) return;
           if (res.includes(undefined)) return;
-        } else if (requestQueue.length > 1) {
+        } else if (validRequests.length > 1) {
           if (res.includes(undefined))
           if (res.includes(undefined))
             return showError(t('部分保存失败,请重试'));
             return showError(t('部分保存失败,请重试'));
         }
         }
         showSuccess(t('保存成功'));
         showSuccess(t('保存成功'));
         props.refresh();
         props.refresh();
+        // 更新 inputsRow 以反映保存后的状态
+        setInputsRow(structuredClone(inputs));
       })
       })
-      .catch(() => {
-        showError(t('保存失败,请重试'));
+      .catch((error) => {
+        // 检查是否是由于无效 JSON 导致的错误
+        if (error !== 'Invalid JSON format') {
+          showError(t('保存失败,请重试'));
+        }
       })
       })
       .finally(() => {
       .finally(() => {
         setLoading(false);
         setLoading(false);
@@ -66,8 +91,11 @@ export default function RequestRateLimit(props) {
     }
     }
     setInputs(currentInputs);
     setInputs(currentInputs);
     setInputsRow(structuredClone(currentInputs));
     setInputsRow(structuredClone(currentInputs));
-    refForm.current.setValues(currentInputs);
-  }, [props.options]);
+    // 检查 refForm.current 是否存在
+    if (refForm.current) {
+      refForm.current.setValues(currentInputs);
+    }
+  }, [props.options]); // 依赖项保持不变,因为 inputs 状态的结构已固定
 
 
   return (
   return (
     <>
     <>
@@ -147,7 +175,41 @@ export default function RequestRateLimit(props) {
                 />
                 />
               </Col>
               </Col>
             </Row>
             </Row>
+            {/* 用户组速率限制配置项 */}
             <Row>
             <Row>
+              <Col span={24}>
+                <Form.TextArea
+                  label={t('用户组速率限制 (JSON)')}
+                  field={'ModelRequestRateLimitGroup'}
+                  placeholder={t( // 更新 placeholder
+                    '请输入 JSON 格式的用户组限制,例如:\n{\n  "default": [200, 100],\n  "vip": [1000, 500]\n}',
+                  )}
+                  extraText={ // 更新 extraText
+                    <div>
+                      <p>{t('说明:')}</p>
+                      <ul>
+                        <li>{t('使用 JSON 对象格式,键为用户组名 (字符串),值为包含两个整数的数组 [总次数限制, 成功次数限制]。')}</li>
+                        <li>{t('总次数限制: 周期内允许的总请求次数 (含失败),0 代表不限制。')}</li>
+                        <li>{t('成功次数限制: 周期内允许的成功请求次数 (HTTP < 400),必须大于 0。')}</li>
+                        <li>{t('此配置将优先于上方的全局限制设置。')}</li>
+                        <li>{t('未在此处配置的用户组将使用全局限制。')}</li>
+                        <li>{t('限制周期统一使用上方配置的“限制周期”值。')}</li>
+                        <li>{t('输入无效的 JSON 将无法保存。')}</li>
+                      </ul>
+                    </div>
+                  }
+                  autosize={{ minRows: 5, maxRows: 15 }}
+                  style={{ fontFamily: 'monospace' }}
+                  onChange={(value) => {
+                    setInputs({
+                      ...inputs,
+                      ModelRequestRateLimitGroup: value, // 直接更新字符串值
+                    });
+                  }}
+                />
+              </Col>
+            </Row>
+            <Row style={{ marginTop: 15 }}>
               <Button size='default' onClick={onSubmit}>
               <Button size='default' onClick={onSubmit}>
                 {t('保存模型速率限制')}
                 {t('保存模型速率限制')}
               </Button>
               </Button>