Explorar o código

feat: 添加域名启用ip过滤开关

creamlike1024 hai 5 meses
pai
achega
467e584359

+ 20 - 13
common/ssrf_protection.go

@@ -10,12 +10,13 @@ import (
 
 // SSRFProtection SSRF防护配置
 type SSRFProtection struct {
-	AllowPrivateIp   bool
-	DomainFilterMode bool     // true: 白名单, false: 黑名单
-	DomainList       []string // domain format, e.g. example.com, *.example.com
-	IpFilterMode     bool     // true: 白名单, false: 黑名单
-	IpList           []string // CIDR or single IP
-	AllowedPorts     []int    // 允许的端口范围
+	AllowPrivateIp         bool
+	DomainFilterMode       bool     // true: 白名单, false: 黑名单
+	DomainList             []string // domain format, e.g. example.com, *.example.com
+	IpFilterMode           bool     // true: 白名单, false: 黑名单
+	IpList                 []string // CIDR or single IP
+	AllowedPorts           []int    // 允许的端口范围
+	ApplyIPFilterForDomain bool     // 对域名启用IP过滤
 }
 
 // DefaultSSRFProtection 默认SSRF防护配置
@@ -276,6 +277,11 @@ func (p *SSRFProtection) ValidateURL(urlStr string) error {
 		return fmt.Errorf("domain in blacklist: %s", host)
 	}
 
+	// 若未启用对域名应用IP过滤,则到此通过
+	if !p.ApplyIPFilterForDomain {
+		return nil
+	}
+
 	// 解析域名对应IP并检查
 	ips, err := net.LookupIP(host)
 	if err != nil {
@@ -296,7 +302,7 @@ func (p *SSRFProtection) ValidateURL(urlStr string) error {
 }
 
 // ValidateURLWithFetchSetting 使用FetchSetting配置验证URL
-func ValidateURLWithFetchSetting(urlStr string, enableSSRFProtection, allowPrivateIp bool, domainFilterMode bool, ipFilterMode bool, domainList, ipList, allowedPorts []string) error {
+func ValidateURLWithFetchSetting(urlStr string, enableSSRFProtection, allowPrivateIp bool, domainFilterMode bool, ipFilterMode bool, domainList, ipList, allowedPorts []string, applyIPFilterForDomain bool) error {
 	// 如果SSRF防护被禁用,直接返回成功
 	if !enableSSRFProtection {
 		return nil
@@ -309,12 +315,13 @@ func ValidateURLWithFetchSetting(urlStr string, enableSSRFProtection, allowPriva
 	}
 
 	protection := &SSRFProtection{
-		AllowPrivateIp:   allowPrivateIp,
-		DomainFilterMode: domainFilterMode,
-		DomainList:       domainList,
-		IpFilterMode:     ipFilterMode,
-		IpList:           ipList,
-		AllowedPorts:     allowedPortInts,
+		AllowPrivateIp:         allowPrivateIp,
+		DomainFilterMode:       domainFilterMode,
+		DomainList:             domainList,
+		IpFilterMode:           ipFilterMode,
+		IpList:                 ipList,
+		AllowedPorts:           allowedPortInts,
+		ApplyIPFilterForDomain: applyIPFilterForDomain,
 	}
 	return protection.ValidateURL(urlStr)
 }

+ 2 - 2
service/download.go

@@ -30,7 +30,7 @@ func DoWorkerRequest(req *WorkerRequest) (*http.Response, error) {
 
 	// SSRF防护:验证请求URL
 	fetchSetting := system_setting.GetFetchSetting()
-	if err := common.ValidateURLWithFetchSetting(req.URL, fetchSetting.EnableSSRFProtection, fetchSetting.AllowPrivateIp, fetchSetting.DomainFilterMode, fetchSetting.IpFilterMode, fetchSetting.DomainList, fetchSetting.IpList, fetchSetting.AllowedPorts); err != nil {
+	if err := common.ValidateURLWithFetchSetting(req.URL, fetchSetting.EnableSSRFProtection, fetchSetting.AllowPrivateIp, fetchSetting.DomainFilterMode, fetchSetting.IpFilterMode, fetchSetting.DomainList, fetchSetting.IpList, fetchSetting.AllowedPorts, fetchSetting.ApplyIPFilterForDomain); err != nil {
 		return nil, fmt.Errorf("request reject: %v", err)
 	}
 
@@ -59,7 +59,7 @@ func DoDownloadRequest(originUrl string, reason ...string) (resp *http.Response,
 	} else {
 		// SSRF防护:验证请求URL(非Worker模式)
 		fetchSetting := system_setting.GetFetchSetting()
-		if err := common.ValidateURLWithFetchSetting(originUrl, fetchSetting.EnableSSRFProtection, fetchSetting.AllowPrivateIp, fetchSetting.DomainFilterMode, fetchSetting.IpFilterMode, fetchSetting.DomainList, fetchSetting.IpList, fetchSetting.AllowedPorts); err != nil {
+		if err := common.ValidateURLWithFetchSetting(originUrl, fetchSetting.EnableSSRFProtection, fetchSetting.AllowPrivateIp, fetchSetting.DomainFilterMode, fetchSetting.IpFilterMode, fetchSetting.DomainList, fetchSetting.IpList, fetchSetting.AllowedPorts, fetchSetting.ApplyIPFilterForDomain); err != nil {
 			return nil, fmt.Errorf("request reject: %v", err)
 		}
 

+ 1 - 1
service/user_notify.go

@@ -115,7 +115,7 @@ func sendBarkNotify(barkURL string, data dto.Notify) error {
 	} else {
 		// SSRF防护:验证Bark URL(非Worker模式)
 		fetchSetting := system_setting.GetFetchSetting()
-		if err := common.ValidateURLWithFetchSetting(finalURL, fetchSetting.EnableSSRFProtection, fetchSetting.AllowPrivateIp, fetchSetting.DomainFilterMode, fetchSetting.IpFilterMode, fetchSetting.DomainList, fetchSetting.IpList, fetchSetting.AllowedPorts); err != nil {
+		if err := common.ValidateURLWithFetchSetting(finalURL, fetchSetting.EnableSSRFProtection, fetchSetting.AllowPrivateIp, fetchSetting.DomainFilterMode, fetchSetting.IpFilterMode, fetchSetting.DomainList, fetchSetting.IpList, fetchSetting.AllowedPorts, fetchSetting.ApplyIPFilterForDomain); err != nil {
 			return fmt.Errorf("request reject: %v", err)
 		}
 

+ 1 - 1
service/webhook.go

@@ -89,7 +89,7 @@ func SendWebhookNotify(webhookURL string, secret string, data dto.Notify) error
 	} else {
 		// SSRF防护:验证Webhook URL(非Worker模式)
 		fetchSetting := system_setting.GetFetchSetting()
-		if err := common.ValidateURLWithFetchSetting(webhookURL, fetchSetting.EnableSSRFProtection, fetchSetting.AllowPrivateIp, fetchSetting.DomainFilterMode, fetchSetting.IpFilterMode, fetchSetting.DomainList, fetchSetting.IpList, fetchSetting.AllowedPorts); err != nil {
+		if err := common.ValidateURLWithFetchSetting(webhookURL, fetchSetting.EnableSSRFProtection, fetchSetting.AllowPrivateIp, fetchSetting.DomainFilterMode, fetchSetting.IpFilterMode, fetchSetting.DomainList, fetchSetting.IpList, fetchSetting.AllowedPorts, fetchSetting.ApplyIPFilterForDomain); err != nil {
 			return fmt.Errorf("request reject: %v", err)
 		}
 

+ 16 - 14
setting/system_setting/fetch_setting.go

@@ -3,23 +3,25 @@ package system_setting
 import "one-api/setting/config"
 
 type FetchSetting struct {
-	EnableSSRFProtection bool     `json:"enable_ssrf_protection"` // 是否启用SSRF防护
-	AllowPrivateIp       bool     `json:"allow_private_ip"`
-	DomainFilterMode     bool     `json:"domain_filter_mode"` // 域名过滤模式,true: 白名单模式,false: 黑名单模式
-	IpFilterMode         bool     `json:"ip_filter_mode"`     // IP过滤模式,true: 白名单模式,false: 黑名单模式
-	DomainList           []string `json:"domain_list"`        // domain format, e.g. example.com, *.example.com
-	IpList               []string `json:"ip_list"`            // CIDR format
-	AllowedPorts         []string `json:"allowed_ports"`      // port range format, e.g. 80, 443, 8000-9000
+	EnableSSRFProtection   bool     `json:"enable_ssrf_protection"` // 是否启用SSRF防护
+	AllowPrivateIp         bool     `json:"allow_private_ip"`
+	DomainFilterMode       bool     `json:"domain_filter_mode"`         // 域名过滤模式,true: 白名单模式,false: 黑名单模式
+	IpFilterMode           bool     `json:"ip_filter_mode"`             // IP过滤模式,true: 白名单模式,false: 黑名单模式
+	DomainList             []string `json:"domain_list"`                // domain format, e.g. example.com, *.example.com
+	IpList                 []string `json:"ip_list"`                    // CIDR format
+	AllowedPorts           []string `json:"allowed_ports"`              // port range format, e.g. 80, 443, 8000-9000
+	ApplyIPFilterForDomain bool     `json:"apply_ip_filter_for_domain"` // 对域名启用IP过滤(实验性)
 }
 
 var defaultFetchSetting = FetchSetting{
-	EnableSSRFProtection: true, // 默认开启SSRF防护
-	AllowPrivateIp:       false,
-	DomainFilterMode:     true,
-	IpFilterMode:         true,
-	DomainList:           []string{},
-	IpList:               []string{},
-	AllowedPorts:         []string{"80", "443", "8080", "8443"},
+	EnableSSRFProtection:   true, // 默认开启SSRF防护
+	AllowPrivateIp:         false,
+	DomainFilterMode:       true,
+	IpFilterMode:           true,
+	DomainList:             []string{},
+	IpList:                 []string{},
+	AllowedPorts:           []string{"80", "443", "8080", "8443"},
+	ApplyIPFilterForDomain: false,
 }
 
 func init() {

+ 17 - 2
web/src/components/settings/SystemSetting.jsx

@@ -97,6 +97,7 @@ const SystemSetting = () => {
     'fetch_setting.domain_list': [],
     'fetch_setting.ip_list': [],
     'fetch_setting.allowed_ports': [],
+    'fetch_setting.apply_ip_filter_for_domain': false,
   });
 
   const [originInputs, setOriginInputs] = useState({});
@@ -132,6 +133,7 @@ const SystemSetting = () => {
           case 'fetch_setting.enable_ssrf_protection':
           case 'fetch_setting.domain_filter_mode':
           case 'fetch_setting.ip_filter_mode':
+          case 'fetch_setting.apply_ip_filter_for_domain':
             item.value = toBoolean(item.value);
             break;
           case 'fetch_setting.domain_list':
@@ -724,6 +726,17 @@ const SystemSetting = () => {
                     style={{ marginTop: 16 }}
                   >
                     <Col xs={24} sm={24} md={24} lg={24} xl={24}>
+                      <Banner type='warning' description={t('此功能为实验性选项,域名可能解析到多个 IPv4/IPv6 地址,若开启,请确保 IP 过滤列表覆盖这些地址,否则可能导致访问失败。')} style={{ marginBottom: 8 }} />
+                      <Form.Checkbox
+                        field='fetch_setting.apply_ip_filter_for_domain'
+                        noLabel
+                        onChange={(e) =>
+                          handleCheckboxChange('fetch_setting.apply_ip_filter_for_domain', e)
+                        }
+                        style={{ marginBottom: 8 }}
+                      >
+                        {t('对域名启用 IP 过滤(实验性)')}
+                      </Form.Checkbox>
                       <Text strong>
                         {t(domainFilterMode ? '域名白名单' : '域名黑名单')}
                       </Text>
@@ -734,7 +747,8 @@ const SystemSetting = () => {
                         type='button'
                         value={domainFilterMode ? 'whitelist' : 'blacklist'}
                         onChange={(val) => {
-                          const isWhitelist = val === 'whitelist';
+                          const selected = val && val.target ? val.target.value : val;
+                          const isWhitelist = selected === 'whitelist';
                           setDomainFilterMode(isWhitelist);
                           setInputs(prev => ({
                             ...prev,
@@ -780,7 +794,8 @@ const SystemSetting = () => {
                         type='button'
                         value={ipFilterMode ? 'whitelist' : 'blacklist'}
                         onChange={(val) => {
-                          const isWhitelist = val === 'whitelist';
+                          const selected = val && val.target ? val.target.value : val;
+                          const isWhitelist = selected === 'whitelist';
                           setIpFilterMode(isWhitelist);
                           setInputs(prev => ({
                             ...prev,