|
|
@@ -0,0 +1,88 @@
|
|
|
+package com.tzld.piaoquan.api.aop;
|
|
|
+
|
|
|
+import com.google.common.base.Strings;
|
|
|
+import com.tzld.piaoquan.api.annotation.RateLimit;
|
|
|
+import com.tzld.piaoquan.api.common.enums.ExceptionEnum;
|
|
|
+import com.tzld.piaoquan.api.common.exception.CommonException;
|
|
|
+import com.tzld.piaoquan.growth.common.utils.IpUtil;
|
|
|
+import com.tzld.piaoquan.growth.common.utils.RedisUtils;
|
|
|
+import org.aspectj.lang.ProceedingJoinPoint;
|
|
|
+import org.aspectj.lang.annotation.Around;
|
|
|
+import org.aspectj.lang.annotation.Aspect;
|
|
|
+import org.aspectj.lang.annotation.Pointcut;
|
|
|
+import org.aspectj.lang.reflect.MethodSignature;
|
|
|
+import org.slf4j.Logger;
|
|
|
+import org.slf4j.LoggerFactory;
|
|
|
+import org.springframework.beans.factory.annotation.Autowired;
|
|
|
+import org.springframework.stereotype.Component;
|
|
|
+import org.springframework.web.context.request.RequestContextHolder;
|
|
|
+import org.springframework.web.context.request.ServletRequestAttributes;
|
|
|
+
|
|
|
+import javax.servlet.http.HttpServletRequest;
|
|
|
+import java.lang.reflect.Method;
|
|
|
+import java.util.Objects;
|
|
|
+
|
|
|
+/**
|
|
|
+ * IP 限流 AOP 切面
|
|
|
+ */
|
|
|
+@Aspect
|
|
|
+@Component
|
|
|
+public class RateLimitAop {
|
|
|
+
|
|
|
+ private static final Logger log = LoggerFactory.getLogger(RateLimitAop.class);
|
|
|
+
|
|
|
+ @Autowired
|
|
|
+ private RedisUtils redisUtils;
|
|
|
+
|
|
|
+ @Pointcut("@annotation(com.tzld.piaoquan.api.annotation.RateLimit)")
|
|
|
+ public void rateLimit() {
|
|
|
+ }
|
|
|
+
|
|
|
+ /**
|
|
|
+ * 环绕通知,执行限流逻辑
|
|
|
+ */
|
|
|
+ @Around("rateLimit()")
|
|
|
+ public Object around(ProceedingJoinPoint joinPoint) throws Throwable {
|
|
|
+ MethodSignature signature = (MethodSignature) joinPoint.getSignature();
|
|
|
+ Method method = signature.getMethod();
|
|
|
+
|
|
|
+ // 获取注解信息
|
|
|
+ RateLimit rateLimit = method.getAnnotation(RateLimit.class);
|
|
|
+ long timeWindow = rateLimit.timeWindow();
|
|
|
+ long maxRequests = rateLimit.maxRequests();
|
|
|
+ String message = rateLimit.message();
|
|
|
+
|
|
|
+ // 获取请求对象
|
|
|
+ ServletRequestAttributes attributes = (ServletRequestAttributes) RequestContextHolder.getRequestAttributes();
|
|
|
+ HttpServletRequest request = Objects.requireNonNull(attributes).getRequest();
|
|
|
+
|
|
|
+ // 获取客户端 IP
|
|
|
+ String clientIp = IpUtil.getIpAddr(request);
|
|
|
+
|
|
|
+ if (Strings.isNullOrEmpty(clientIp)) {
|
|
|
+ log.warn("Rate limit - client IP is empty");
|
|
|
+ return joinPoint.proceed();
|
|
|
+ }
|
|
|
+
|
|
|
+ // 构建 Redis key: rate_limit:{method_name}:{ip}
|
|
|
+ String methodName = method.getName();
|
|
|
+ String redisKey = String.format("rate_limit:%s:%s", methodName, clientIp);
|
|
|
+
|
|
|
+ // 获取当前请求次数
|
|
|
+ Long currentRequests = redisUtils.getLong(redisKey);
|
|
|
+
|
|
|
+ if (currentRequests >= maxRequests) {
|
|
|
+ log.warn("Rate limit exceeded - IP: {}, Method: {}, Requests: {}, Max: {}",
|
|
|
+ clientIp, methodName, currentRequests, maxRequests);
|
|
|
+ throw new CommonException(ExceptionEnum.PARAM_ERROR.getCode(), message);
|
|
|
+ }
|
|
|
+
|
|
|
+ // 增加请求计数
|
|
|
+ redisUtils.setIncrementValue(redisKey, 1, timeWindow);
|
|
|
+
|
|
|
+ log.debug("Rate limit - IP: {}, Method: {}, Current Requests: {}, Max: {}",
|
|
|
+ clientIp, methodName, currentRequests + 1, maxRequests);
|
|
|
+
|
|
|
+ return joinPoint.proceed();
|
|
|
+ }
|
|
|
+}
|