rate_limit.go 2.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576
  1. package setting
  2. import (
  3. "encoding/json"
  4. "fmt"
  5. "one-api/common"
  6. "sync"
  7. )
  8. var ModelRequestRateLimitEnabled = false
  9. var ModelRequestRateLimitDurationMinutes = 1
  10. var ModelRequestRateLimitCount = 0
  11. var ModelRequestRateLimitSuccessCount = 1000
  12. // ModelRequestRateLimitGroupKey 定义了模型请求按组速率限制的配置键
  13. const ModelRequestRateLimitGroupKey = "ModelRequestRateLimitGroup"
  14. // ModelRequestRateLimitGroupConfig 存储按用户组解析后的速率限制配置
  15. // map[groupName][2]int{totalCount, successCount}
  16. var ModelRequestRateLimitGroupConfig map[string][2]int
  17. var ModelRequestRateLimitGroupMutex sync.RWMutex
  18. // UpdateModelRequestRateLimitGroupConfig 解析、校验并更新内存中的用户组速率限制配置
  19. func UpdateModelRequestRateLimitGroupConfig(jsonStr string) error {
  20. ModelRequestRateLimitGroupMutex.Lock()
  21. defer ModelRequestRateLimitGroupMutex.Unlock()
  22. var newConfig map[string][2]int
  23. if jsonStr == "" || jsonStr == "{}" {
  24. // 如果配置为空或空JSON对象,则清空内存配置
  25. ModelRequestRateLimitGroupConfig = make(map[string][2]int)
  26. common.SysLog("Model request rate limit group config cleared")
  27. return nil
  28. }
  29. err := json.Unmarshal([]byte(jsonStr), &newConfig)
  30. if err != nil {
  31. return fmt.Errorf("failed to unmarshal ModelRequestRateLimitGroup config: %w", err)
  32. }
  33. // 校验配置值
  34. for group, limits := range newConfig {
  35. if len(limits) != 2 {
  36. return fmt.Errorf("invalid config for group '%s': limits array length must be 2", group)
  37. }
  38. if limits[1] <= 0 { // successCount must be greater than 0
  39. return fmt.Errorf("invalid config for group '%s': successCount (limits[1]) must be greater than 0", group)
  40. }
  41. if limits[0] < 0 { // totalCount can be 0 (no limit) or positive
  42. return fmt.Errorf("invalid config for group '%s': totalCount (limits[0]) cannot be negative", group)
  43. }
  44. if limits[0] > 0 && limits[0] < limits[1] { // If totalCount is set, it must be >= successCount
  45. return fmt.Errorf("invalid config for group '%s': totalCount (limits[0]) must be greater than or equal to successCount (limits[1]) when totalCount > 0", group)
  46. }
  47. }
  48. ModelRequestRateLimitGroupConfig = newConfig
  49. common.SysLog("Model request rate limit group config updated")
  50. return nil
  51. }
  52. // GetGroupRateLimit 安全地获取指定用户组的速率限制值
  53. func GetGroupRateLimit(group string) (totalCount, successCount int, found bool) {
  54. ModelRequestRateLimitGroupMutex.RLock()
  55. defer ModelRequestRateLimitGroupMutex.RUnlock()
  56. if ModelRequestRateLimitGroupConfig == nil {
  57. return 0, 0, false // 配置尚未初始化
  58. }
  59. limits, found := ModelRequestRateLimitGroupConfig[group]
  60. if !found {
  61. return 0, 0, false
  62. }
  63. return limits[0], limits[1], true
  64. }