فهرست منبع

feat: Add trusted redirect domains.

mehunk 1 ماه پیش
والد
کامیت
65fd33e3ef
6فایلهای تغییر یافته به همراه205 افزوده شده و 0 حذف شده
  1. 5 0
      .env.example
  2. 13 0
      common/init.go
  3. 39 0
      common/url_validator.go
  4. 134 0
      common/url_validator_test.go
  5. 4 0
      constant/env.go
  6. 10 0
      controller/topup_stripe.go

+ 5 - 0
.env.example

@@ -85,3 +85,8 @@ LINUX_DO_USER_ENDPOINT=https://connect.linux.do/api/user
 # 节点类型
 # 节点类型
 # 如果是主节点则为master
 # 如果是主节点则为master
 # NODE_TYPE=master
 # NODE_TYPE=master
+
+# 可信任重定向域名列表(逗号分隔,支持子域名匹配)
+# 用于验证支付成功/取消回调URL的域名安全性
+# 示例: example.com,myapp.io 将允许 example.com, sub.example.com, myapp.io 等
+# TRUSTED_REDIRECT_DOMAINS=example.com,myapp.io

+ 13 - 0
common/init.go

@@ -159,4 +159,17 @@ func initConstantEnv() {
 		}
 		}
 		constant.TaskPricePatches = taskPricePatches
 		constant.TaskPricePatches = taskPricePatches
 	}
 	}
+
+	// Initialize trusted redirect domains for URL validation
+	trustedDomainsStr := GetEnvOrDefaultString("TRUSTED_REDIRECT_DOMAINS", "")
+	var trustedDomains []string
+	domains := strings.Split(trustedDomainsStr, ",")
+	for _, domain := range domains {
+		trimmedDomain := strings.TrimSpace(domain)
+		if trimmedDomain != "" {
+			// Normalize domain to lowercase
+			trustedDomains = append(trustedDomains, strings.ToLower(trimmedDomain))
+		}
+	}
+	constant.TrustedRedirectDomains = trustedDomains
 }
 }

+ 39 - 0
common/url_validator.go

@@ -0,0 +1,39 @@
+package common
+
+import (
+	"fmt"
+	"net/url"
+	"strings"
+
+	"github.com/QuantumNous/new-api/constant"
+)
+
+// ValidateRedirectURL validates that a redirect URL is safe to use.
+// It checks that:
+//   - The URL is properly formatted
+//   - The scheme is either http or https
+//   - The domain is in the trusted domains list (exact match or subdomain)
+//
+// Returns nil if the URL is valid and trusted, otherwise returns an error
+// describing why the validation failed.
+func ValidateRedirectURL(rawURL string) error {
+	// Parse the URL
+	parsedURL, err := url.Parse(rawURL)
+	if err != nil {
+		return fmt.Errorf("invalid URL format: %s", err.Error())
+	}
+
+	if parsedURL.Scheme != "http" && parsedURL.Scheme != "https" {
+		return fmt.Errorf("invalid URL scheme: only http and https are allowed")
+	}
+
+	domain := strings.ToLower(parsedURL.Hostname())
+
+	for _, trustedDomain := range constant.TrustedRedirectDomains {
+		if domain == trustedDomain || strings.HasSuffix(domain, "."+trustedDomain) {
+			return nil
+		}
+	}
+
+	return fmt.Errorf("domain %s is not in the trusted domains list", domain)
+}

+ 134 - 0
common/url_validator_test.go

@@ -0,0 +1,134 @@
+package common
+
+import (
+	"testing"
+
+	"github.com/QuantumNous/new-api/constant"
+)
+
+func TestValidateRedirectURL(t *testing.T) {
+	// Save original trusted domains and restore after test
+	originalDomains := constant.TrustedRedirectDomains
+	defer func() {
+		constant.TrustedRedirectDomains = originalDomains
+	}()
+
+	tests := []struct {
+		name           string
+		url            string
+		trustedDomains []string
+		wantErr        bool
+		errContains    string
+	}{
+		// Valid cases
+		{
+			name:           "exact domain match with https",
+			url:            "https://example.com/success",
+			trustedDomains: []string{"example.com"},
+			wantErr:        false,
+		},
+		{
+			name:           "exact domain match with http",
+			url:            "http://example.com/callback",
+			trustedDomains: []string{"example.com"},
+			wantErr:        false,
+		},
+		{
+			name:           "subdomain match",
+			url:            "https://sub.example.com/success",
+			trustedDomains: []string{"example.com"},
+			wantErr:        false,
+		},
+		{
+			name:           "case insensitive domain",
+			url:            "https://EXAMPLE.COM/success",
+			trustedDomains: []string{"example.com"},
+			wantErr:        false,
+		},
+
+		// Invalid cases - untrusted domain
+		{
+			name:           "untrusted domain",
+			url:            "https://evil.com/phishing",
+			trustedDomains: []string{"example.com"},
+			wantErr:        true,
+			errContains:    "not in the trusted domains list",
+		},
+		{
+			name:           "suffix attack - fakeexample.com",
+			url:            "https://fakeexample.com/success",
+			trustedDomains: []string{"example.com"},
+			wantErr:        true,
+			errContains:    "not in the trusted domains list",
+		},
+		{
+			name:           "empty trusted domains list",
+			url:            "https://example.com/success",
+			trustedDomains: []string{},
+			wantErr:        true,
+			errContains:    "not in the trusted domains list",
+		},
+
+		// Invalid cases - scheme
+		{
+			name:           "javascript scheme",
+			url:            "javascript:alert('xss')",
+			trustedDomains: []string{"example.com"},
+			wantErr:        true,
+			errContains:    "invalid URL scheme",
+		},
+		{
+			name:           "data scheme",
+			url:            "data:text/html,<script>alert('xss')</script>",
+			trustedDomains: []string{"example.com"},
+			wantErr:        true,
+			errContains:    "invalid URL scheme",
+		},
+
+		// Edge cases
+		{
+			name:           "empty URL",
+			url:            "",
+			trustedDomains: []string{"example.com"},
+			wantErr:        true,
+			errContains:    "invalid URL scheme",
+		},
+	}
+
+	for _, tt := range tests {
+		t.Run(tt.name, func(t *testing.T) {
+			// Set up trusted domains for this test case
+			constant.TrustedRedirectDomains = tt.trustedDomains
+
+			err := ValidateRedirectURL(tt.url)
+
+			if tt.wantErr {
+				if err == nil {
+					t.Errorf("ValidateRedirectURL(%q) expected error containing %q, got nil", tt.url, tt.errContains)
+					return
+				}
+				if tt.errContains != "" && !contains(err.Error(), tt.errContains) {
+					t.Errorf("ValidateRedirectURL(%q) error = %q, want error containing %q", tt.url, err.Error(), tt.errContains)
+				}
+			} else {
+				if err != nil {
+					t.Errorf("ValidateRedirectURL(%q) unexpected error: %v", tt.url, err)
+				}
+			}
+		})
+	}
+}
+
+func contains(s, substr string) bool {
+	return len(s) >= len(substr) && (s == substr || len(substr) == 0 ||
+		(len(s) > 0 && len(substr) > 0 && findSubstring(s, substr)))
+}
+
+func findSubstring(s, substr string) bool {
+	for i := 0; i <= len(s)-len(substr); i++ {
+		if s[i:i+len(substr)] == substr {
+			return true
+		}
+	}
+	return false
+}

+ 4 - 0
constant/env.go

@@ -20,3 +20,7 @@ var TaskQueryLimit int
 
 
 // temporary variable for sora patch, will be removed in future
 // temporary variable for sora patch, will be removed in future
 var TaskPricePatches []string
 var TaskPricePatches []string
+
+// TrustedRedirectDomains is a list of trusted domains for redirect URL validation.
+// Domains support subdomain matching (e.g., "example.com" matches "sub.example.com").
+var TrustedRedirectDomains []string

+ 10 - 0
controller/topup_stripe.go

@@ -78,6 +78,16 @@ func (*StripeAdaptor) RequestPay(c *gin.Context, req *StripePayRequest) {
 		return
 		return
 	}
 	}
 
 
+	if req.SuccessURL != "" && common.ValidateRedirectURL(req.SuccessURL) != nil {
+		c.JSON(http.StatusBadRequest, gin.H{"message": "支付成功重定向URL不在可信任域名列表中", "data": ""})
+		return
+	}
+
+	if req.CancelURL != "" && common.ValidateRedirectURL(req.CancelURL) != nil {
+		c.JSON(http.StatusBadRequest, gin.H{"message": "支付取消重定向URL不在可信任域名列表中", "data": ""})
+		return
+	}
+
 	id := c.GetInt("id")
 	id := c.GetInt("id")
 	user, _ := model.GetUserById(id, false)
 	user, _ := model.GetUserById(id, false)
 	chargedMoney := GetChargedAmount(float64(req.Amount), *user)
 	chargedMoney := GetChargedAmount(float64(req.Amount), *user)