ssrf_protection.go 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355
  1. package common
  2. import (
  3. "fmt"
  4. "net"
  5. "net/url"
  6. "strconv"
  7. "strings"
  8. )
  9. // SSRFProtection SSRF防护配置
  10. type SSRFProtection struct {
  11. AllowPrivateIp bool
  12. DomainFilterMode bool // true: 白名单, false: 黑名单
  13. DomainList []string // domain format, e.g. example.com, *.example.com
  14. IpFilterMode bool // true: 白名单, false: 黑名单
  15. IpList []string // CIDR or single IP
  16. AllowedPorts []int // 允许的端口范围
  17. ApplyIPFilterForDomain bool // 对域名启用IP过滤
  18. }
  19. // DefaultSSRFProtection 默认SSRF防护配置
  20. var DefaultSSRFProtection = &SSRFProtection{
  21. AllowPrivateIp: false,
  22. DomainFilterMode: true,
  23. DomainList: []string{},
  24. IpFilterMode: true,
  25. IpList: []string{},
  26. AllowedPorts: []int{},
  27. }
  28. // privateIPv4Nets IPv4 私有/保留/特殊用途网段
  29. // 参考 IANA IPv4 Special-Purpose Address Registry
  30. // https://www.iana.org/assignments/iana-ipv4-special-registry/
  31. var privateIPv4Nets = []net.IPNet{
  32. {IP: net.IPv4(0, 0, 0, 0), Mask: net.CIDRMask(8, 32)}, // 0.0.0.0/8 ("This network" / 未指定)
  33. {IP: net.IPv4(10, 0, 0, 0), Mask: net.CIDRMask(8, 32)}, // 10.0.0.0/8 (私有)
  34. {IP: net.IPv4(100, 64, 0, 0), Mask: net.CIDRMask(10, 32)}, // 100.64.0.0/10 (运营商级 NAT / CGNAT)
  35. {IP: net.IPv4(127, 0, 0, 0), Mask: net.CIDRMask(8, 32)}, // 127.0.0.0/8 (回环)
  36. {IP: net.IPv4(169, 254, 0, 0), Mask: net.CIDRMask(16, 32)}, // 169.254.0.0/16 (链路本地)
  37. {IP: net.IPv4(172, 16, 0, 0), Mask: net.CIDRMask(12, 32)}, // 172.16.0.0/12 (私有)
  38. {IP: net.IPv4(192, 0, 0, 0), Mask: net.CIDRMask(24, 32)}, // 192.0.0.0/24 (IETF 协议分配)
  39. {IP: net.IPv4(192, 0, 2, 0), Mask: net.CIDRMask(24, 32)}, // 192.0.2.0/24 (TEST-NET-1)
  40. {IP: net.IPv4(192, 168, 0, 0), Mask: net.CIDRMask(16, 32)}, // 192.168.0.0/16 (私有)
  41. {IP: net.IPv4(198, 18, 0, 0), Mask: net.CIDRMask(15, 32)}, // 198.18.0.0/15 (基准测试)
  42. {IP: net.IPv4(198, 51, 100, 0), Mask: net.CIDRMask(24, 32)}, // 198.51.100.0/24 (TEST-NET-2)
  43. {IP: net.IPv4(203, 0, 113, 0), Mask: net.CIDRMask(24, 32)}, // 203.0.113.0/24 (TEST-NET-3)
  44. {IP: net.IPv4(224, 0, 0, 0), Mask: net.CIDRMask(4, 32)}, // 224.0.0.0/4 (组播)
  45. {IP: net.IPv4(240, 0, 0, 0), Mask: net.CIDRMask(4, 32)}, // 240.0.0.0/4 (保留)
  46. {IP: net.IPv4(255, 255, 255, 255), Mask: net.CIDRMask(32, 32)}, // 255.255.255.255/32 (受限广播)
  47. }
  48. // privateIPv6Nets IPv6 私有/保留/特殊用途网段
  49. // 参考 IANA IPv6 Special-Purpose Address Registry
  50. // https://www.iana.org/assignments/iana-ipv6-special-registry/
  51. var privateIPv6Nets = func() []net.IPNet {
  52. cidrs := []string{
  53. "::/128", // 未指定地址
  54. "::1/128", // 回环
  55. "::ffff:0:0/96", // IPv4-mapped
  56. "64:ff9b::/96", // IPv4/IPv6 translation
  57. "100::/64", // Discard-Only
  58. "2001::/23", // IETF Protocol Assignments
  59. "2001:db8::/32", // 文档
  60. "fc00::/7", // Unique Local Address (ULA)
  61. "fe80::/10", // 链路本地
  62. "ff00::/8", // 组播
  63. }
  64. nets := make([]net.IPNet, 0, len(cidrs))
  65. for _, c := range cidrs {
  66. if _, n, err := net.ParseCIDR(c); err == nil && n != nil {
  67. nets = append(nets, *n)
  68. }
  69. }
  70. return nets
  71. }()
  72. // isPrivateIP 检查IP是否为私有/保留/特殊用途地址
  73. func isPrivateIP(ip net.IP) bool {
  74. if ip == nil {
  75. return true
  76. }
  77. // 未指定地址 (0.0.0.0, ::)
  78. if ip.IsUnspecified() {
  79. return true
  80. }
  81. // 回环、链路本地 (unicast/multicast)
  82. if ip.IsLoopback() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() {
  83. return true
  84. }
  85. // 接口本地组播 (IPv6 ff01::/16 等)
  86. if ip.IsInterfaceLocalMulticast() {
  87. return true
  88. }
  89. if v4 := ip.To4(); v4 != nil {
  90. for _, privateNet := range privateIPv4Nets {
  91. if privateNet.Contains(v4) {
  92. return true
  93. }
  94. }
  95. return false
  96. }
  97. // IPv6 检查
  98. for _, privateNet := range privateIPv6Nets {
  99. if privateNet.Contains(ip) {
  100. return true
  101. }
  102. }
  103. // 兜底: Go 标准库识别的其他私有地址
  104. if ip.IsPrivate() {
  105. return true
  106. }
  107. return false
  108. }
  109. // parsePortRanges 解析端口范围配置
  110. // 支持格式: "80", "443", "8000-9000"
  111. func parsePortRanges(portConfigs []string) ([]int, error) {
  112. var ports []int
  113. for _, config := range portConfigs {
  114. config = strings.TrimSpace(config)
  115. if config == "" {
  116. continue
  117. }
  118. if strings.Contains(config, "-") {
  119. // 处理端口范围 "8000-9000"
  120. parts := strings.Split(config, "-")
  121. if len(parts) != 2 {
  122. return nil, fmt.Errorf("invalid port range format: %s", config)
  123. }
  124. startPort, err := strconv.Atoi(strings.TrimSpace(parts[0]))
  125. if err != nil {
  126. return nil, fmt.Errorf("invalid start port in range %s: %v", config, err)
  127. }
  128. endPort, err := strconv.Atoi(strings.TrimSpace(parts[1]))
  129. if err != nil {
  130. return nil, fmt.Errorf("invalid end port in range %s: %v", config, err)
  131. }
  132. if startPort > endPort {
  133. return nil, fmt.Errorf("invalid port range %s: start port cannot be greater than end port", config)
  134. }
  135. if startPort < 1 || startPort > 65535 || endPort < 1 || endPort > 65535 {
  136. return nil, fmt.Errorf("port range %s contains invalid port numbers (must be 1-65535)", config)
  137. }
  138. // 添加范围内的所有端口
  139. for port := startPort; port <= endPort; port++ {
  140. ports = append(ports, port)
  141. }
  142. } else {
  143. // 处理单个端口 "80"
  144. port, err := strconv.Atoi(config)
  145. if err != nil {
  146. return nil, fmt.Errorf("invalid port number: %s", config)
  147. }
  148. if port < 1 || port > 65535 {
  149. return nil, fmt.Errorf("invalid port number %d (must be 1-65535)", port)
  150. }
  151. ports = append(ports, port)
  152. }
  153. }
  154. return ports, nil
  155. }
  156. // isAllowedPort 检查端口是否被允许
  157. func (p *SSRFProtection) isAllowedPort(port int) bool {
  158. if len(p.AllowedPorts) == 0 {
  159. return true // 如果没有配置端口限制,则允许所有端口
  160. }
  161. for _, allowedPort := range p.AllowedPorts {
  162. if port == allowedPort {
  163. return true
  164. }
  165. }
  166. return false
  167. }
  168. // isDomainWhitelisted 检查域名是否在白名单中
  169. func isDomainListed(domain string, list []string) bool {
  170. if len(list) == 0 {
  171. return false
  172. }
  173. domain = strings.ToLower(domain)
  174. for _, item := range list {
  175. item = strings.ToLower(strings.TrimSpace(item))
  176. if item == "" {
  177. continue
  178. }
  179. // 精确匹配
  180. if domain == item {
  181. return true
  182. }
  183. // 通配符匹配 (*.example.com)
  184. if strings.HasPrefix(item, "*.") {
  185. suffix := strings.TrimPrefix(item, "*.")
  186. if strings.HasSuffix(domain, "."+suffix) || domain == suffix {
  187. return true
  188. }
  189. }
  190. }
  191. return false
  192. }
  193. func (p *SSRFProtection) isDomainAllowed(domain string) bool {
  194. listed := isDomainListed(domain, p.DomainList)
  195. if p.DomainFilterMode { // 白名单
  196. return listed
  197. }
  198. // 黑名单
  199. return !listed
  200. }
  201. // isIPWhitelisted 检查IP是否在白名单中
  202. func isIPListed(ip net.IP, list []string) bool {
  203. if len(list) == 0 {
  204. return false
  205. }
  206. return IsIpInCIDRList(ip, list)
  207. }
  208. // IsIPAccessAllowed 检查IP是否允许访问
  209. func (p *SSRFProtection) IsIPAccessAllowed(ip net.IP) bool {
  210. // 私有IP限制
  211. if isPrivateIP(ip) && !p.AllowPrivateIp {
  212. return false
  213. }
  214. listed := isIPListed(ip, p.IpList)
  215. if p.IpFilterMode { // 白名单
  216. return listed
  217. }
  218. // 黑名单
  219. return !listed
  220. }
  221. // ValidateURL 验证URL是否安全
  222. func (p *SSRFProtection) ValidateURL(urlStr string) error {
  223. // 解析URL
  224. u, err := url.Parse(urlStr)
  225. if err != nil {
  226. return fmt.Errorf("invalid URL format: %v", err)
  227. }
  228. // 只允许HTTP/HTTPS协议
  229. if u.Scheme != "http" && u.Scheme != "https" {
  230. return fmt.Errorf("unsupported protocol: %s (only http/https allowed)", u.Scheme)
  231. }
  232. // 解析主机和端口
  233. host, portStr, err := net.SplitHostPort(u.Host)
  234. if err != nil {
  235. // 没有端口,使用默认端口
  236. host = u.Hostname()
  237. if u.Scheme == "https" {
  238. portStr = "443"
  239. } else {
  240. portStr = "80"
  241. }
  242. }
  243. // 验证端口
  244. port, err := strconv.Atoi(portStr)
  245. if err != nil {
  246. return fmt.Errorf("invalid port: %s", portStr)
  247. }
  248. if !p.isAllowedPort(port) {
  249. return fmt.Errorf("port %d is not allowed", port)
  250. }
  251. // 如果 host 是 IP,则跳过域名检查
  252. if ip := net.ParseIP(host); ip != nil {
  253. if !p.IsIPAccessAllowed(ip) {
  254. if isPrivateIP(ip) {
  255. return fmt.Errorf("private IP address not allowed: %s", ip.String())
  256. }
  257. if p.IpFilterMode {
  258. return fmt.Errorf("ip not in whitelist: %s", ip.String())
  259. }
  260. return fmt.Errorf("ip in blacklist: %s", ip.String())
  261. }
  262. return nil
  263. }
  264. // 先进行域名过滤
  265. if !p.isDomainAllowed(host) {
  266. if p.DomainFilterMode {
  267. return fmt.Errorf("domain not in whitelist: %s", host)
  268. }
  269. return fmt.Errorf("domain in blacklist: %s", host)
  270. }
  271. // 若未启用对域名应用IP过滤,则到此通过
  272. if !p.ApplyIPFilterForDomain {
  273. return nil
  274. }
  275. // 解析域名对应IP并检查
  276. ips, err := net.LookupIP(host)
  277. if err != nil {
  278. return fmt.Errorf("DNS resolution failed for %s: %v", host, err)
  279. }
  280. for _, ip := range ips {
  281. if !p.IsIPAccessAllowed(ip) {
  282. if isPrivateIP(ip) && !p.AllowPrivateIp {
  283. return fmt.Errorf("private IP address not allowed: %s resolves to %s", host, ip.String())
  284. }
  285. if p.IpFilterMode {
  286. return fmt.Errorf("ip not in whitelist: %s resolves to %s", host, ip.String())
  287. }
  288. return fmt.Errorf("ip in blacklist: %s resolves to %s", host, ip.String())
  289. }
  290. }
  291. return nil
  292. }
  293. // ValidateURLWithFetchSetting 使用FetchSetting配置验证URL
  294. func ValidateURLWithFetchSetting(urlStr string, enableSSRFProtection, allowPrivateIp bool, domainFilterMode bool, ipFilterMode bool, domainList, ipList, allowedPorts []string, applyIPFilterForDomain bool) error {
  295. // 如果SSRF防护被禁用,直接返回成功
  296. if !enableSSRFProtection {
  297. return nil
  298. }
  299. // 解析端口范围配置
  300. allowedPortInts, err := parsePortRanges(allowedPorts)
  301. if err != nil {
  302. return fmt.Errorf("request reject - invalid port configuration: %v", err)
  303. }
  304. protection := &SSRFProtection{
  305. AllowPrivateIp: allowPrivateIp,
  306. DomainFilterMode: domainFilterMode,
  307. DomainList: domainList,
  308. IpFilterMode: ipFilterMode,
  309. IpList: ipList,
  310. AllowedPorts: allowedPortInts,
  311. ApplyIPFilterForDomain: applyIPFilterForDomain,
  312. }
  313. return protection.ValidateURL(urlStr)
  314. }