limiter.go 1.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
  1. package limiter
  2. import (
  3. "context"
  4. _ "embed"
  5. "fmt"
  6. "github.com/go-redis/redis/v8"
  7. "sync"
  8. )
  9. //go:embed lua/rate_limit.lua
  10. var rateLimitScript string
  11. type RedisLimiter struct {
  12. client *redis.Client
  13. limitScriptSHA string
  14. }
  15. var (
  16. instance *RedisLimiter
  17. once sync.Once
  18. )
  19. func New(ctx context.Context, r *redis.Client) *RedisLimiter {
  20. once.Do(func() {
  21. client := r
  22. _, err := client.Ping(ctx).Result()
  23. if err != nil {
  24. panic(err) // 或者处理连接错误
  25. }
  26. // 预加载脚本
  27. limitSHA, err := client.ScriptLoad(ctx, rateLimitScript).Result()
  28. if err != nil {
  29. fmt.Println(err)
  30. }
  31. instance = &RedisLimiter{
  32. client: client,
  33. limitScriptSHA: limitSHA,
  34. }
  35. })
  36. return instance
  37. }
  38. func (rl *RedisLimiter) Allow(ctx context.Context, key string, opts ...Option) (bool, error) {
  39. // 默认配置
  40. config := &Config{
  41. Capacity: 10,
  42. Rate: 1,
  43. Requested: 1,
  44. }
  45. // 应用选项模式
  46. for _, opt := range opts {
  47. opt(config)
  48. }
  49. // 执行限流
  50. result, err := rl.client.EvalSha(
  51. ctx,
  52. rl.limitScriptSHA,
  53. []string{key},
  54. config.Requested,
  55. config.Rate,
  56. config.Capacity,
  57. ).Int()
  58. if err != nil {
  59. return false, fmt.Errorf("rate limit failed: %w", err)
  60. }
  61. return result == 1, nil
  62. }
  63. // Config 配置选项模式
  64. type Config struct {
  65. Capacity int64
  66. Rate int64
  67. Requested int64
  68. }
  69. type Option func(*Config)
  70. func WithCapacity(c int64) Option {
  71. return func(cfg *Config) { cfg.Capacity = c }
  72. }
  73. func WithRate(r int64) Option {
  74. return func(cfg *Config) { cfg.Rate = r }
  75. }
  76. func WithRequested(n int64) Option {
  77. return func(cfg *Config) { cfg.Requested = n }
  78. }