package common import ( "fmt" "net" "net/url" "strconv" "strings" ) // SSRFProtection SSRF防护配置 type SSRFProtection struct { AllowPrivateIp bool DomainFilterMode bool // true: 白名单, false: 黑名单 DomainList []string // domain format, e.g. example.com, *.example.com IpFilterMode bool // true: 白名单, false: 黑名单 IpList []string // CIDR or single IP AllowedPorts []int // 允许的端口范围 ApplyIPFilterForDomain bool // 对域名启用IP过滤 } // DefaultSSRFProtection 默认SSRF防护配置 var DefaultSSRFProtection = &SSRFProtection{ AllowPrivateIp: false, DomainFilterMode: true, DomainList: []string{}, IpFilterMode: true, IpList: []string{}, AllowedPorts: []int{}, } // isPrivateIP 检查IP是否为私有地址 func isPrivateIP(ip net.IP) bool { if ip.IsLoopback() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() { return true } // 检查私有网段 private := []net.IPNet{ {IP: net.IPv4(10, 0, 0, 0), Mask: net.CIDRMask(8, 32)}, // 10.0.0.0/8 {IP: net.IPv4(172, 16, 0, 0), Mask: net.CIDRMask(12, 32)}, // 172.16.0.0/12 {IP: net.IPv4(192, 168, 0, 0), Mask: net.CIDRMask(16, 32)}, // 192.168.0.0/16 {IP: net.IPv4(127, 0, 0, 0), Mask: net.CIDRMask(8, 32)}, // 127.0.0.0/8 {IP: net.IPv4(169, 254, 0, 0), Mask: net.CIDRMask(16, 32)}, // 169.254.0.0/16 (链路本地) {IP: net.IPv4(224, 0, 0, 0), Mask: net.CIDRMask(4, 32)}, // 224.0.0.0/4 (组播) {IP: net.IPv4(240, 0, 0, 0), Mask: net.CIDRMask(4, 32)}, // 240.0.0.0/4 (保留) } for _, privateNet := range private { if privateNet.Contains(ip) { return true } } // 检查IPv6私有地址 if ip.To4() == nil { // IPv6 loopback if ip.Equal(net.IPv6loopback) { return true } // IPv6 link-local if strings.HasPrefix(ip.String(), "fe80:") { return true } // IPv6 unique local if strings.HasPrefix(ip.String(), "fc") || strings.HasPrefix(ip.String(), "fd") { return true } } return false } // parsePortRanges 解析端口范围配置 // 支持格式: "80", "443", "8000-9000" func parsePortRanges(portConfigs []string) ([]int, error) { var ports []int for _, config := range portConfigs { config = strings.TrimSpace(config) if config == "" { continue } if strings.Contains(config, "-") { // 处理端口范围 "8000-9000" parts := strings.Split(config, "-") if len(parts) != 2 { return nil, fmt.Errorf("invalid port range format: %s", config) } startPort, err := strconv.Atoi(strings.TrimSpace(parts[0])) if err != nil { return nil, fmt.Errorf("invalid start port in range %s: %v", config, err) } endPort, err := strconv.Atoi(strings.TrimSpace(parts[1])) if err != nil { return nil, fmt.Errorf("invalid end port in range %s: %v", config, err) } if startPort > endPort { return nil, fmt.Errorf("invalid port range %s: start port cannot be greater than end port", config) } if startPort < 1 || startPort > 65535 || endPort < 1 || endPort > 65535 { return nil, fmt.Errorf("port range %s contains invalid port numbers (must be 1-65535)", config) } // 添加范围内的所有端口 for port := startPort; port <= endPort; port++ { ports = append(ports, port) } } else { // 处理单个端口 "80" port, err := strconv.Atoi(config) if err != nil { return nil, fmt.Errorf("invalid port number: %s", config) } if port < 1 || port > 65535 { return nil, fmt.Errorf("invalid port number %d (must be 1-65535)", port) } ports = append(ports, port) } } return ports, nil } // isAllowedPort 检查端口是否被允许 func (p *SSRFProtection) isAllowedPort(port int) bool { if len(p.AllowedPorts) == 0 { return true // 如果没有配置端口限制,则允许所有端口 } for _, allowedPort := range p.AllowedPorts { if port == allowedPort { return true } } return false } // isDomainWhitelisted 检查域名是否在白名单中 func isDomainListed(domain string, list []string) bool { if len(list) == 0 { return false } domain = strings.ToLower(domain) for _, item := range list { item = strings.ToLower(strings.TrimSpace(item)) if item == "" { continue } // 精确匹配 if domain == item { return true } // 通配符匹配 (*.example.com) if strings.HasPrefix(item, "*.") { suffix := strings.TrimPrefix(item, "*.") if strings.HasSuffix(domain, "."+suffix) || domain == suffix { return true } } } return false } func (p *SSRFProtection) isDomainAllowed(domain string) bool { listed := isDomainListed(domain, p.DomainList) if p.DomainFilterMode { // 白名单 return listed } // 黑名单 return !listed } // isIPWhitelisted 检查IP是否在白名单中 func isIPListed(ip net.IP, list []string) bool { if len(list) == 0 { return false } return IsIpInCIDRList(ip, list) } // IsIPAccessAllowed 检查IP是否允许访问 func (p *SSRFProtection) IsIPAccessAllowed(ip net.IP) bool { // 私有IP限制 if isPrivateIP(ip) && !p.AllowPrivateIp { return false } listed := isIPListed(ip, p.IpList) if p.IpFilterMode { // 白名单 return listed } // 黑名单 return !listed } // ValidateURL 验证URL是否安全 func (p *SSRFProtection) ValidateURL(urlStr string) error { // 解析URL u, err := url.Parse(urlStr) if err != nil { return fmt.Errorf("invalid URL format: %v", err) } // 只允许HTTP/HTTPS协议 if u.Scheme != "http" && u.Scheme != "https" { return fmt.Errorf("unsupported protocol: %s (only http/https allowed)", u.Scheme) } // 解析主机和端口 host, portStr, err := net.SplitHostPort(u.Host) if err != nil { // 没有端口,使用默认端口 host = u.Hostname() if u.Scheme == "https" { portStr = "443" } else { portStr = "80" } } // 验证端口 port, err := strconv.Atoi(portStr) if err != nil { return fmt.Errorf("invalid port: %s", portStr) } if !p.isAllowedPort(port) { return fmt.Errorf("port %d is not allowed", port) } // 如果 host 是 IP,则跳过域名检查 if ip := net.ParseIP(host); ip != nil { if !p.IsIPAccessAllowed(ip) { if isPrivateIP(ip) { return fmt.Errorf("private IP address not allowed: %s", ip.String()) } if p.IpFilterMode { return fmt.Errorf("ip not in whitelist: %s", ip.String()) } return fmt.Errorf("ip in blacklist: %s", ip.String()) } return nil } // 先进行域名过滤 if !p.isDomainAllowed(host) { if p.DomainFilterMode { return fmt.Errorf("domain not in whitelist: %s", host) } return fmt.Errorf("domain in blacklist: %s", host) } // 若未启用对域名应用IP过滤,则到此通过 if !p.ApplyIPFilterForDomain { return nil } // 解析域名对应IP并检查 ips, err := net.LookupIP(host) if err != nil { return fmt.Errorf("DNS resolution failed for %s: %v", host, err) } for _, ip := range ips { if !p.IsIPAccessAllowed(ip) { if isPrivateIP(ip) && !p.AllowPrivateIp { return fmt.Errorf("private IP address not allowed: %s resolves to %s", host, ip.String()) } if p.IpFilterMode { return fmt.Errorf("ip not in whitelist: %s resolves to %s", host, ip.String()) } return fmt.Errorf("ip in blacklist: %s resolves to %s", host, ip.String()) } } return nil } // ValidateURLWithFetchSetting 使用FetchSetting配置验证URL func ValidateURLWithFetchSetting(urlStr string, enableSSRFProtection, allowPrivateIp bool, domainFilterMode bool, ipFilterMode bool, domainList, ipList, allowedPorts []string, applyIPFilterForDomain bool) error { // 如果SSRF防护被禁用,直接返回成功 if !enableSSRFProtection { return nil } // 解析端口范围配置 allowedPortInts, err := parsePortRanges(allowedPorts) if err != nil { return fmt.Errorf("request reject - invalid port configuration: %v", err) } protection := &SSRFProtection{ AllowPrivateIp: allowPrivateIp, DomainFilterMode: domainFilterMode, DomainList: domainList, IpFilterMode: ipFilterMode, IpList: ipList, AllowedPorts: allowedPortInts, ApplyIPFilterForDomain: applyIPFilterForDomain, } return protection.ValidateURL(urlStr) }