Просмотр исходного кода

fix: add PaymentProvider field to prevent cross-gateway callback attacks

EPay allows users to switch payment methods (e.g. wxpay→alipay) during
checkout, causing callback rejection. Replace fragile blocklist guard
with a PaymentProvider field on TopUp and SubscriptionOrder that
identifies which gateway created the order.
CaIon 1 месяц назад
Родитель
Сommit
a7c38ec851

+ 8 - 7
controller/subscription_payment_creem.go

@@ -83,13 +83,14 @@ func SubscriptionRequestCreemPay(c *gin.Context) {
 
 	// create pending order first
 	order := &model.SubscriptionOrder{
-		UserId:        userId,
-		PlanId:        plan.Id,
-		Money:         plan.PriceAmount,
-		TradeNo:       referenceId,
-		PaymentMethod: model.PaymentMethodCreem,
-		CreateTime:    time.Now().Unix(),
-		Status:        common.TopUpStatusPending,
+		UserId:          userId,
+		PlanId:          plan.Id,
+		Money:           plan.PriceAmount,
+		TradeNo:         referenceId,
+		PaymentMethod:   model.PaymentMethodCreem,
+		PaymentProvider: model.PaymentProviderCreem,
+		CreateTime:      time.Now().Unix(),
+		Status:          common.TopUpStatusPending,
 	}
 	if err := order.Insert(); err != nil {
 		c.JSON(http.StatusOK, gin.H{"message": "error", "data": "创建订单失败"})

+ 11 - 10
controller/subscription_payment_epay.go

@@ -82,13 +82,14 @@ func SubscriptionRequestEpay(c *gin.Context) {
 	}
 
 	order := &model.SubscriptionOrder{
-		UserId:        userId,
-		PlanId:        plan.Id,
-		Money:         plan.PriceAmount,
-		TradeNo:       tradeNo,
-		PaymentMethod: req.PaymentMethod,
-		CreateTime:    time.Now().Unix(),
-		Status:        common.TopUpStatusPending,
+		UserId:          userId,
+		PlanId:          plan.Id,
+		Money:           plan.PriceAmount,
+		TradeNo:         tradeNo,
+		PaymentMethod:   req.PaymentMethod,
+		PaymentProvider: model.PaymentProviderEpay,
+		CreateTime:      time.Now().Unix(),
+		Status:          common.TopUpStatusPending,
 	}
 	if err := order.Insert(); err != nil {
 		common.ApiErrorMsg(c, "创建订单失败")
@@ -104,7 +105,7 @@ func SubscriptionRequestEpay(c *gin.Context) {
 		ReturnUrl:      returnUrl,
 	})
 	if err != nil {
-		_ = model.ExpireSubscriptionOrder(tradeNo, req.PaymentMethod)
+		_ = model.ExpireSubscriptionOrder(tradeNo, model.PaymentProviderEpay)
 		common.ApiErrorMsg(c, "拉起支付失败")
 		return
 	}
@@ -156,7 +157,7 @@ func SubscriptionEpayNotify(c *gin.Context) {
 	LockOrder(verifyInfo.ServiceTradeNo)
 	defer UnlockOrder(verifyInfo.ServiceTradeNo)
 
-	if err := model.CompleteSubscriptionOrder(verifyInfo.ServiceTradeNo, common.GetJsonString(verifyInfo), verifyInfo.Type); err != nil {
+	if err := model.CompleteSubscriptionOrder(verifyInfo.ServiceTradeNo, common.GetJsonString(verifyInfo), model.PaymentProviderEpay, verifyInfo.Type); err != nil {
 		_, _ = c.Writer.Write([]byte("fail"))
 		return
 	}
@@ -205,7 +206,7 @@ func SubscriptionEpayReturn(c *gin.Context) {
 	if verifyInfo.TradeStatus == epay.StatusTradeSuccess {
 		LockOrder(verifyInfo.ServiceTradeNo)
 		defer UnlockOrder(verifyInfo.ServiceTradeNo)
-		if err := model.CompleteSubscriptionOrder(verifyInfo.ServiceTradeNo, common.GetJsonString(verifyInfo), verifyInfo.Type); err != nil {
+		if err := model.CompleteSubscriptionOrder(verifyInfo.ServiceTradeNo, common.GetJsonString(verifyInfo), model.PaymentProviderEpay, verifyInfo.Type); err != nil {
 			c.Redirect(http.StatusFound, system_setting.ServerAddress+"/console/topup?pay=fail")
 			return
 		}

+ 8 - 7
controller/subscription_payment_stripe.go

@@ -84,13 +84,14 @@ func SubscriptionRequestStripePay(c *gin.Context) {
 	}
 
 	order := &model.SubscriptionOrder{
-		UserId:        userId,
-		PlanId:        plan.Id,
-		Money:         plan.PriceAmount,
-		TradeNo:       referenceId,
-		PaymentMethod: model.PaymentMethodStripe,
-		CreateTime:    time.Now().Unix(),
-		Status:        common.TopUpStatusPending,
+		UserId:          userId,
+		PlanId:          plan.Id,
+		Money:           plan.PriceAmount,
+		TradeNo:         referenceId,
+		PaymentMethod:   model.PaymentMethodStripe,
+		PaymentProvider: model.PaymentProviderStripe,
+		CreateTime:      time.Now().Unix(),
+		Status:          common.TopUpStatusPending,
 	}
 	if err := order.Insert(); err != nil {
 		c.JSON(http.StatusOK, gin.H{"message": "error", "data": "创建订单失败"})

+ 14 - 24
controller/topup.go

@@ -123,17 +123,6 @@ type AmountRequest struct {
 	Amount int64 `json:"amount"`
 }
 
-var nonEpayPaymentMethodsForCallback = []string{
-	model.PaymentMethodStripe,
-	model.PaymentMethodCreem,
-	model.PaymentMethodWaffo,
-	model.PaymentMethodWaffoPancake,
-}
-
-func isNonEpayPaymentMethodForEpayCallback(paymentMethod string) bool {
-	return lo.Contains(nonEpayPaymentMethodsForCallback, paymentMethod)
-}
-
 func GetEpayClient() *epay.Client {
 	if operation_setting.PayAddress == "" || operation_setting.EpayId == "" || operation_setting.EpayKey == "" {
 		return nil
@@ -248,13 +237,14 @@ func RequestEpay(c *gin.Context) {
 		amount = dAmount.Div(dQuotaPerUnit).IntPart()
 	}
 	topUp := &model.TopUp{
-		UserId:        id,
-		Amount:        amount,
-		Money:         payMoney,
-		TradeNo:       tradeNo,
-		PaymentMethod: req.PaymentMethod,
-		CreateTime:    time.Now().Unix(),
-		Status:        common.TopUpStatusPending,
+		UserId:          id,
+		Amount:          amount,
+		Money:           payMoney,
+		TradeNo:         tradeNo,
+		PaymentMethod:   req.PaymentMethod,
+		PaymentProvider: model.PaymentProviderEpay,
+		CreateTime:      time.Now().Unix(),
+		Status:          common.TopUpStatusPending,
 	}
 	err = topUp.Insert()
 	if err != nil {
@@ -379,15 +369,15 @@ func EpayNotify(c *gin.Context) {
 			logger.LogWarn(c.Request.Context(), fmt.Sprintf("易支付 回调订单不存在 trade_no=%s callback_type=%s client_ip=%s verify_info=%q", verifyInfo.ServiceTradeNo, verifyInfo.Type, c.ClientIP(), common.GetJsonString(verifyInfo)))
 			return
 		}
-		if isNonEpayPaymentMethodForEpayCallback(topUp.PaymentMethod) {
-			logger.LogWarn(c.Request.Context(), fmt.Sprintf("易支付 订单支付方式不匹配 trade_no=%s order_payment_method=%s callback_type=%s client_ip=%s", verifyInfo.ServiceTradeNo, topUp.PaymentMethod, verifyInfo.Type, c.ClientIP()))
-			return
-		}
-		if topUp.PaymentMethod != verifyInfo.Type {
-			logger.LogWarn(c.Request.Context(), fmt.Sprintf("易支付 订单支付方式不匹配 trade_no=%s order_payment_method=%s callback_type=%s client_ip=%s", verifyInfo.ServiceTradeNo, topUp.PaymentMethod, verifyInfo.Type, c.ClientIP()))
+		if topUp.PaymentProvider != model.PaymentProviderEpay {
+			logger.LogWarn(c.Request.Context(), fmt.Sprintf("易支付 订单支付网关不匹配 trade_no=%s order_provider=%s callback_type=%s client_ip=%s", verifyInfo.ServiceTradeNo, topUp.PaymentProvider, verifyInfo.Type, c.ClientIP()))
 			return
 		}
 		if topUp.Status == common.TopUpStatusPending {
+			if topUp.PaymentMethod != verifyInfo.Type {
+				logger.LogInfo(c.Request.Context(), fmt.Sprintf("易支付 实际支付方式与订单不同 trade_no=%s order_payment_method=%s actual_type=%s client_ip=%s", verifyInfo.ServiceTradeNo, topUp.PaymentMethod, verifyInfo.Type, c.ClientIP()))
+				topUp.PaymentMethod = verifyInfo.Type
+			}
 			topUp.Status = common.TopUpStatusSuccess
 			err := topUp.Update()
 			if err != nil {

+ 9 - 8
controller/topup_creem.go

@@ -106,13 +106,14 @@ func (*CreemAdaptor) RequestPay(c *gin.Context, req *CreemPayRequest) {
 
 	// 先创建订单记录,使用产品配置的金额和充值额度
 	topUp := &model.TopUp{
-		UserId:        id,
-		Amount:        selectedProduct.Quota, // 充值额度
-		Money:         selectedProduct.Price, // 支付金额
-		TradeNo:       referenceId,
-		PaymentMethod: model.PaymentMethodCreem,
-		CreateTime:    time.Now().Unix(),
-		Status:        common.TopUpStatusPending,
+		UserId:          id,
+		Amount:          selectedProduct.Quota, // 充值额度
+		Money:           selectedProduct.Price, // 支付金额
+		TradeNo:         referenceId,
+		PaymentMethod:   model.PaymentMethodCreem,
+		PaymentProvider: model.PaymentProviderCreem,
+		CreateTime:      time.Now().Unix(),
+		Status:          common.TopUpStatusPending,
 	}
 	err = topUp.Insert()
 	if err != nil {
@@ -301,7 +302,7 @@ func handleCheckoutCompleted(c *gin.Context, event *CreemWebhookEvent) {
 	// Try complete subscription order first
 	LockOrder(referenceId)
 	defer UnlockOrder(referenceId)
-	if err := model.CompleteSubscriptionOrder(referenceId, common.GetJsonString(event), model.PaymentMethodCreem); err == nil {
+	if err := model.CompleteSubscriptionOrder(referenceId, common.GetJsonString(event), model.PaymentProviderCreem, ""); err == nil {
 		logger.LogInfo(c.Request.Context(), fmt.Sprintf("Creem 订阅订单处理成功 trade_no=%s creem_order_id=%s", referenceId, event.Object.Order.Id))
 		c.Status(http.StatusOK)
 		return

+ 0 - 31
controller/topup_epay_guard_test.go

@@ -1,31 +0,0 @@
-package controller
-
-import (
-	"testing"
-
-	"github.com/QuantumNous/new-api/model"
-)
-
-func TestIsNonEpayPaymentMethodForEpayCallback(t *testing.T) {
-	testCases := []struct {
-		name            string
-		paymentMethod   string
-		expectedBlocked bool
-	}{
-		{name: "stripe", paymentMethod: model.PaymentMethodStripe, expectedBlocked: true},
-		{name: "creem", paymentMethod: model.PaymentMethodCreem, expectedBlocked: true},
-		{name: "waffo", paymentMethod: model.PaymentMethodWaffo, expectedBlocked: true},
-		{name: "waffo pancake", paymentMethod: model.PaymentMethodWaffoPancake, expectedBlocked: true},
-		{name: "alipay", paymentMethod: "alipay", expectedBlocked: false},
-		{name: "wxpay", paymentMethod: "wxpay", expectedBlocked: false},
-		{name: "custom epay type", paymentMethod: "custom1", expectedBlocked: false},
-	}
-
-	for _, tc := range testCases {
-		t.Run(tc.name, func(t *testing.T) {
-			if actual := isNonEpayPaymentMethodForEpayCallback(tc.paymentMethod); actual != tc.expectedBlocked {
-				t.Fatalf("expected blocked=%v, got %v for payment method %q", tc.expectedBlocked, actual, tc.paymentMethod)
-			}
-		})
-	}
-}

+ 13 - 12
controller/topup_stripe.go

@@ -101,13 +101,14 @@ func (*StripeAdaptor) RequestPay(c *gin.Context, req *StripePayRequest) {
 	}
 
 	topUp := &model.TopUp{
-		UserId:        id,
-		Amount:        req.Amount,
-		Money:         chargedMoney,
-		TradeNo:       referenceId,
-		PaymentMethod: model.PaymentMethodStripe,
-		CreateTime:    time.Now().Unix(),
-		Status:        common.TopUpStatusPending,
+		UserId:          id,
+		Amount:          req.Amount,
+		Money:           chargedMoney,
+		TradeNo:         referenceId,
+		PaymentMethod:   model.PaymentMethodStripe,
+		PaymentProvider: model.PaymentProviderStripe,
+		CreateTime:      time.Now().Unix(),
+		Status:          common.TopUpStatusPending,
 	}
 	err = topUp.Insert()
 	if err != nil {
@@ -237,8 +238,8 @@ func sessionAsyncPaymentFailed(ctx context.Context, event stripe.Event, callerIp
 		return
 	}
 
-	if topUp.PaymentMethod != model.PaymentMethodStripe {
-		logger.LogWarn(ctx, fmt.Sprintf("Stripe 异步支付失败但订单支付方式不匹配 trade_no=%s payment_method=%s client_ip=%s", referenceId, topUp.PaymentMethod, callerIp))
+	if topUp.PaymentProvider != model.PaymentProviderStripe {
+		logger.LogWarn(ctx, fmt.Sprintf("Stripe 异步支付失败但订单支付网关不匹配 trade_no=%s payment_provider=%s client_ip=%s", referenceId, topUp.PaymentProvider, callerIp))
 		return
 	}
 
@@ -270,7 +271,7 @@ func fulfillOrder(ctx context.Context, event stripe.Event, referenceId string, c
 		"currency":     strings.ToUpper(event.GetObjectValue("currency")),
 		"event_type":   string(event.Type),
 	}
-	if err := model.CompleteSubscriptionOrder(referenceId, common.GetJsonString(payload), model.PaymentMethodStripe); err == nil {
+	if err := model.CompleteSubscriptionOrder(referenceId, common.GetJsonString(payload), model.PaymentProviderStripe, ""); err == nil {
 		logger.LogInfo(ctx, fmt.Sprintf("Stripe 订阅订单处理成功 trade_no=%s event_type=%s client_ip=%s", referenceId, string(event.Type), callerIp))
 		return
 	} else if err != nil && !errors.Is(err, model.ErrSubscriptionOrderNotFound) {
@@ -305,7 +306,7 @@ func sessionExpired(ctx context.Context, event stripe.Event) {
 	// Subscription order expiration
 	LockOrder(referenceId)
 	defer UnlockOrder(referenceId)
-	if err := model.ExpireSubscriptionOrder(referenceId, model.PaymentMethodStripe); err == nil {
+	if err := model.ExpireSubscriptionOrder(referenceId, model.PaymentProviderStripe); err == nil {
 		logger.LogInfo(ctx, fmt.Sprintf("Stripe 订阅订单已过期 trade_no=%s", referenceId))
 		return
 	} else if err != nil && !errors.Is(err, model.ErrSubscriptionOrderNotFound) {
@@ -313,7 +314,7 @@ func sessionExpired(ctx context.Context, event stripe.Event) {
 		return
 	}
 
-	err := model.UpdatePendingTopUpStatus(referenceId, model.PaymentMethodStripe, common.TopUpStatusExpired)
+	err := model.UpdatePendingTopUpStatus(referenceId, model.PaymentProviderStripe, common.TopUpStatusExpired)
 	if errors.Is(err, model.ErrTopUpNotFound) {
 		logger.LogWarn(ctx, fmt.Sprintf("Stripe 充值订单不存在,无法标记过期 trade_no=%s", referenceId))
 		return

+ 9 - 8
controller/topup_waffo.go

@@ -208,13 +208,14 @@ func RequestWaffoPay(c *gin.Context) {
 
 	// 创建本地订单
 	topUp := &model.TopUp{
-		UserId:        id,
-		Amount:        amount,
-		Money:         payMoney,
-		TradeNo:       merchantOrderId,
-		PaymentMethod: model.PaymentMethodWaffo,
-		CreateTime:    time.Now().Unix(),
-		Status:        common.TopUpStatusPending,
+		UserId:          id,
+		Amount:          amount,
+		Money:           payMoney,
+		TradeNo:         merchantOrderId,
+		PaymentMethod:   model.PaymentMethodWaffo,
+		PaymentProvider: model.PaymentProviderWaffo,
+		CreateTime:      time.Now().Unix(),
+		Status:          common.TopUpStatusPending,
 	}
 	if err := topUp.Insert(); err != nil {
 		logger.LogError(c.Request.Context(), fmt.Sprintf("Waffo 创建充值订单失败 user_id=%d trade_no=%s amount=%d error=%q", id, merchantOrderId, req.Amount, err.Error()))
@@ -379,7 +380,7 @@ func handleWaffoPayment(c *gin.Context, wh *core.WebhookHandler, result *core.Pa
 		logger.LogInfo(c.Request.Context(), fmt.Sprintf("Waffo 订单状态非成功,忽略充值 trade_no=%s order_status=%s client_ip=%s", result.MerchantOrderID, result.OrderStatus, c.ClientIP()))
 		// 终态失败订单标记为 failed,避免永远停在 pending
 		if result.MerchantOrderID != "" {
-			if err := model.UpdatePendingTopUpStatus(result.MerchantOrderID, model.PaymentMethodWaffo, common.TopUpStatusFailed); err != nil &&
+			if err := model.UpdatePendingTopUpStatus(result.MerchantOrderID, model.PaymentProviderWaffo, common.TopUpStatusFailed); err != nil &&
 				!errors.Is(err, model.ErrTopUpNotFound) &&
 				!errors.Is(err, model.ErrTopUpStatusInvalid) {
 				logger.LogError(c.Request.Context(), fmt.Sprintf("Waffo 标记失败订单状态失败 trade_no=%s error=%q", result.MerchantOrderID, err.Error()))

+ 8 - 7
controller/topup_waffo_pancake.go

@@ -159,13 +159,14 @@ func RequestWaffoPancakePay(c *gin.Context) {
 
 	tradeNo := fmt.Sprintf("WAFFO_PANCAKE-%d-%d-%s", id, time.Now().UnixMilli(), randstr.String(6))
 	topUp := &model.TopUp{
-		UserId:        id,
-		Amount:        normalizeWaffoPancakeTopUpAmount(req.Amount),
-		Money:         payMoney,
-		TradeNo:       tradeNo,
-		PaymentMethod: model.PaymentMethodWaffoPancake,
-		CreateTime:    time.Now().Unix(),
-		Status:        common.TopUpStatusPending,
+		UserId:          id,
+		Amount:          normalizeWaffoPancakeTopUpAmount(req.Amount),
+		Money:           payMoney,
+		TradeNo:         tradeNo,
+		PaymentMethod:   model.PaymentMethodWaffoPancake,
+		PaymentProvider: model.PaymentProviderWaffoPancake,
+		CreateTime:      time.Now().Unix(),
+		Status:          common.TopUpStatusPending,
 	}
 	if err := topUp.Insert(); err != nil {
 		logger.LogError(c.Request.Context(), fmt.Sprintf("Waffo Pancake 创建充值订单失败 user_id=%d trade_no=%s amount=%d error=%q", id, tradeNo, req.Amount, err.Error()))

+ 43 - 41
model/payment_method_guard_test.go

@@ -36,30 +36,32 @@ func insertSubscriptionPlanForPaymentGuardTest(t *testing.T, id int) *Subscripti
 	return plan
 }
 
-func insertSubscriptionOrderForPaymentGuardTest(t *testing.T, tradeNo string, userID int, planID int, paymentMethod string) {
+func insertSubscriptionOrderForPaymentGuardTest(t *testing.T, tradeNo string, userID int, planID int, paymentProvider string) {
 	t.Helper()
 	order := &SubscriptionOrder{
-		UserId:        userID,
-		PlanId:        planID,
-		Money:         9.99,
-		TradeNo:       tradeNo,
-		PaymentMethod: paymentMethod,
-		Status:        common.TopUpStatusPending,
-		CreateTime:    time.Now().Unix(),
+		UserId:          userID,
+		PlanId:          planID,
+		Money:           9.99,
+		TradeNo:         tradeNo,
+		PaymentMethod:   paymentProvider,
+		PaymentProvider: paymentProvider,
+		Status:          common.TopUpStatusPending,
+		CreateTime:      time.Now().Unix(),
 	}
 	require.NoError(t, order.Insert())
 }
 
-func insertTopUpForPaymentGuardTest(t *testing.T, tradeNo string, userID int, paymentMethod string) {
+func insertTopUpForPaymentGuardTest(t *testing.T, tradeNo string, userID int, paymentProvider string) {
 	t.Helper()
 	topUp := &TopUp{
-		UserId:        userID,
-		Amount:        2,
-		Money:         9.99,
-		TradeNo:       tradeNo,
-		PaymentMethod: paymentMethod,
-		Status:        common.TopUpStatusPending,
-		CreateTime:    time.Now().Unix(),
+		UserId:          userID,
+		Amount:          2,
+		Money:           9.99,
+		TradeNo:         tradeNo,
+		PaymentMethod:   paymentProvider,
+		PaymentProvider: paymentProvider,
+		Status:          common.TopUpStatusPending,
+		CreateTime:      time.Now().Unix(),
 	}
 	require.NoError(t, topUp.Insert())
 }
@@ -89,7 +91,7 @@ func TestRechargeWaffoPancake_RejectsMismatchedPaymentMethod(t *testing.T) {
 	truncateTables(t)
 
 	insertUserForPaymentGuardTest(t, 101, 0)
-	insertTopUpForPaymentGuardTest(t, "waffo-pancake-guard", 101, PaymentMethodStripe)
+	insertTopUpForPaymentGuardTest(t, "waffo-pancake-guard", 101, PaymentProviderStripe)
 
 	err := RechargeWaffoPancake("waffo-pancake-guard")
 	require.Error(t, err)
@@ -100,27 +102,27 @@ func TestRechargeWaffoPancake_RejectsMismatchedPaymentMethod(t *testing.T) {
 	assert.Equal(t, 0, getUserQuotaForPaymentGuardTest(t, 101))
 }
 
-func TestUpdatePendingTopUpStatus_RejectsMismatchedPaymentMethod(t *testing.T) {
+func TestUpdatePendingTopUpStatus_RejectsMismatchedPaymentProvider(t *testing.T) {
 	testCases := []struct {
-		name                  string
-		tradeNo               string
-		storedPaymentMethod   string
-		expectedPaymentMethod string
-		targetStatus          string
+		name                    string
+		tradeNo                 string
+		storedPaymentProvider   string
+		expectedPaymentProvider string
+		targetStatus            string
 	}{
 		{
-			name:                  "stripe expire",
-			tradeNo:               "stripe-expire-guard",
-			storedPaymentMethod:   PaymentMethodCreem,
-			expectedPaymentMethod: PaymentMethodStripe,
-			targetStatus:          common.TopUpStatusExpired,
+			name:                    "stripe expire",
+			tradeNo:                 "stripe-expire-guard",
+			storedPaymentProvider:   PaymentProviderCreem,
+			expectedPaymentProvider: PaymentProviderStripe,
+			targetStatus:            common.TopUpStatusExpired,
 		},
 		{
-			name:                  "waffo failed",
-			tradeNo:               "waffo-failed-guard",
-			storedPaymentMethod:   PaymentMethodStripe,
-			expectedPaymentMethod: PaymentMethodWaffo,
-			targetStatus:          common.TopUpStatusFailed,
+			name:                    "waffo failed",
+			tradeNo:                 "waffo-failed-guard",
+			storedPaymentProvider:   PaymentProviderStripe,
+			expectedPaymentProvider: PaymentProviderWaffo,
+			targetStatus:            common.TopUpStatusFailed,
 		},
 	}
 
@@ -128,23 +130,23 @@ func TestUpdatePendingTopUpStatus_RejectsMismatchedPaymentMethod(t *testing.T) {
 		t.Run(tc.name, func(t *testing.T) {
 			truncateTables(t)
 			insertUserForPaymentGuardTest(t, 150, 0)
-			insertTopUpForPaymentGuardTest(t, tc.tradeNo, 150, tc.storedPaymentMethod)
+			insertTopUpForPaymentGuardTest(t, tc.tradeNo, 150, tc.storedPaymentProvider)
 
-			err := UpdatePendingTopUpStatus(tc.tradeNo, tc.expectedPaymentMethod, tc.targetStatus)
+			err := UpdatePendingTopUpStatus(tc.tradeNo, tc.expectedPaymentProvider, tc.targetStatus)
 			require.ErrorIs(t, err, ErrPaymentMethodMismatch)
 			assert.Equal(t, common.TopUpStatusPending, getTopUpStatusForPaymentGuardTest(t, tc.tradeNo))
 		})
 	}
 }
 
-func TestCompleteSubscriptionOrder_RejectsMismatchedPaymentMethod(t *testing.T) {
+func TestCompleteSubscriptionOrder_RejectsMismatchedPaymentProvider(t *testing.T) {
 	truncateTables(t)
 
 	insertUserForPaymentGuardTest(t, 202, 0)
 	plan := insertSubscriptionPlanForPaymentGuardTest(t, 301)
-	insertSubscriptionOrderForPaymentGuardTest(t, "sub-guard-order", 202, plan.Id, PaymentMethodStripe)
+	insertSubscriptionOrderForPaymentGuardTest(t, "sub-guard-order", 202, plan.Id, PaymentProviderStripe)
 
-	err := CompleteSubscriptionOrder("sub-guard-order", `{"provider":"epay"}`, "alipay")
+	err := CompleteSubscriptionOrder("sub-guard-order", `{"provider":"epay"}`, PaymentProviderEpay, "alipay")
 	require.ErrorIs(t, err, ErrPaymentMethodMismatch)
 
 	order := GetSubscriptionOrderByTradeNo("sub-guard-order")
@@ -156,14 +158,14 @@ func TestCompleteSubscriptionOrder_RejectsMismatchedPaymentMethod(t *testing.T)
 	assert.Nil(t, topUp)
 }
 
-func TestExpireSubscriptionOrder_RejectsMismatchedPaymentMethod(t *testing.T) {
+func TestExpireSubscriptionOrder_RejectsMismatchedPaymentProvider(t *testing.T) {
 	truncateTables(t)
 
 	insertUserForPaymentGuardTest(t, 303, 0)
 	plan := insertSubscriptionPlanForPaymentGuardTest(t, 401)
-	insertSubscriptionOrderForPaymentGuardTest(t, "sub-expire-guard", 303, plan.Id, PaymentMethodStripe)
+	insertSubscriptionOrderForPaymentGuardTest(t, "sub-expire-guard", 303, plan.Id, PaymentProviderStripe)
 
-	err := ExpireSubscriptionOrder("sub-expire-guard", PaymentMethodCreem)
+	err := ExpireSubscriptionOrder("sub-expire-guard", PaymentProviderCreem)
 	require.ErrorIs(t, err, ErrPaymentMethodMismatch)
 
 	order := GetSubscriptionOrderByTradeNo("sub-expire-guard")

+ 15 - 9
model/subscription.go

@@ -198,11 +198,12 @@ type SubscriptionOrder struct {
 	PlanId int     `json:"plan_id" gorm:"index"`
 	Money  float64 `json:"money"`
 
-	TradeNo       string `json:"trade_no" gorm:"unique;type:varchar(255);index"`
-	PaymentMethod string `json:"payment_method" gorm:"type:varchar(50)"`
-	Status        string `json:"status"`
-	CreateTime    int64  `json:"create_time"`
-	CompleteTime  int64  `json:"complete_time"`
+	TradeNo         string `json:"trade_no" gorm:"unique;type:varchar(255);index"`
+	PaymentMethod   string `json:"payment_method" gorm:"type:varchar(50)"`
+	PaymentProvider string `json:"payment_provider" gorm:"type:varchar(50);default:''"`
+	Status          string `json:"status"`
+	CreateTime      int64  `json:"create_time"`
+	CompleteTime    int64  `json:"complete_time"`
 
 	ProviderPayload string `json:"provider_payload" gorm:"type:text"`
 }
@@ -505,7 +506,9 @@ func CreateUserSubscriptionFromPlanTx(tx *gorm.DB, userId int, plan *Subscriptio
 }
 
 // Complete a subscription order (idempotent). Creates a UserSubscription snapshot from the plan.
-func CompleteSubscriptionOrder(tradeNo string, providerPayload string, expectedPaymentMethod string) error {
+// expectedPaymentProvider guards against cross-gateway callback attacks (empty skips the check).
+// actualPaymentMethod updates the order's PaymentMethod to reflect the real payment type used (empty skips update).
+func CompleteSubscriptionOrder(tradeNo string, providerPayload string, expectedPaymentProvider string, actualPaymentMethod string) error {
 	if tradeNo == "" {
 		return errors.New("tradeNo is empty")
 	}
@@ -523,7 +526,7 @@ func CompleteSubscriptionOrder(tradeNo string, providerPayload string, expectedP
 		if err := tx.Set("gorm:query_option", "FOR UPDATE").Where(refCol+" = ?", tradeNo).First(&order).Error; err != nil {
 			return ErrSubscriptionOrderNotFound
 		}
-		if expectedPaymentMethod != "" && order.PaymentMethod != expectedPaymentMethod {
+		if expectedPaymentProvider != "" && order.PaymentProvider != expectedPaymentProvider {
 			return ErrPaymentMethodMismatch
 		}
 		if order.Status == common.TopUpStatusSuccess {
@@ -552,6 +555,9 @@ func CompleteSubscriptionOrder(tradeNo string, providerPayload string, expectedP
 		if providerPayload != "" {
 			order.ProviderPayload = providerPayload
 		}
+		if actualPaymentMethod != "" && order.PaymentMethod != actualPaymentMethod {
+			order.PaymentMethod = actualPaymentMethod
+		}
 		if err := tx.Save(&order).Error; err != nil {
 			return err
 		}
@@ -610,7 +616,7 @@ func upsertSubscriptionTopUpTx(tx *gorm.DB, order *SubscriptionOrder) error {
 	return tx.Save(&topup).Error
 }
 
-func ExpireSubscriptionOrder(tradeNo string, expectedPaymentMethod string) error {
+func ExpireSubscriptionOrder(tradeNo string, expectedPaymentProvider string) error {
 	if tradeNo == "" {
 		return errors.New("tradeNo is empty")
 	}
@@ -623,7 +629,7 @@ func ExpireSubscriptionOrder(tradeNo string, expectedPaymentMethod string) error
 		if err := tx.Set("gorm:query_option", "FOR UPDATE").Where(refCol+" = ?", tradeNo).First(&order).Error; err != nil {
 			return ErrSubscriptionOrderNotFound
 		}
-		if expectedPaymentMethod != "" && order.PaymentMethod != expectedPaymentMethod {
+		if expectedPaymentProvider != "" && order.PaymentProvider != expectedPaymentProvider {
 			return ErrPaymentMethodMismatch
 		}
 		if order.Status != common.TopUpStatusPending {

+ 25 - 16
model/topup.go

@@ -12,15 +12,16 @@ import (
 )
 
 type TopUp struct {
-	Id            int     `json:"id"`
-	UserId        int     `json:"user_id" gorm:"index"`
-	Amount        int64   `json:"amount"`
-	Money         float64 `json:"money"`
-	TradeNo       string  `json:"trade_no" gorm:"unique;type:varchar(255);index"`
-	PaymentMethod string  `json:"payment_method" gorm:"type:varchar(50)"`
-	CreateTime    int64   `json:"create_time"`
-	CompleteTime  int64   `json:"complete_time"`
-	Status        string  `json:"status"`
+	Id              int     `json:"id"`
+	UserId          int     `json:"user_id" gorm:"index"`
+	Amount          int64   `json:"amount"`
+	Money           float64 `json:"money"`
+	TradeNo         string  `json:"trade_no" gorm:"unique;type:varchar(255);index"`
+	PaymentMethod   string  `json:"payment_method" gorm:"type:varchar(50)"`
+	PaymentProvider string  `json:"payment_provider" gorm:"type:varchar(50);default:''"`
+	CreateTime      int64   `json:"create_time"`
+	CompleteTime    int64   `json:"complete_time"`
+	Status          string  `json:"status"`
 }
 
 const (
@@ -30,6 +31,14 @@ const (
 	PaymentMethodWaffoPancake = "waffo_pancake"
 )
 
+const (
+	PaymentProviderEpay         = "epay"
+	PaymentProviderStripe       = "stripe"
+	PaymentProviderCreem        = "creem"
+	PaymentProviderWaffo        = "waffo"
+	PaymentProviderWaffoPancake = "waffo_pancake"
+)
+
 var (
 	ErrPaymentMethodMismatch = errors.New("payment method mismatch")
 	ErrTopUpNotFound         = errors.New("topup not found")
@@ -68,7 +77,7 @@ func GetTopUpByTradeNo(tradeNo string) *TopUp {
 	return topUp
 }
 
-func UpdatePendingTopUpStatus(tradeNo string, expectedPaymentMethod string, targetStatus string) error {
+func UpdatePendingTopUpStatus(tradeNo string, expectedPaymentProvider string, targetStatus string) error {
 	if tradeNo == "" {
 		return errors.New("未提供支付单号")
 	}
@@ -83,7 +92,7 @@ func UpdatePendingTopUpStatus(tradeNo string, expectedPaymentMethod string, targ
 		if err := tx.Set("gorm:query_option", "FOR UPDATE").Where(refCol+" = ?", tradeNo).First(topUp).Error; err != nil {
 			return ErrTopUpNotFound
 		}
-		if expectedPaymentMethod != "" && topUp.PaymentMethod != expectedPaymentMethod {
+		if expectedPaymentProvider != "" && topUp.PaymentProvider != expectedPaymentProvider {
 			return ErrPaymentMethodMismatch
 		}
 		if topUp.Status != common.TopUpStatusPending {
@@ -114,7 +123,7 @@ func Recharge(referenceId string, customerId string, callerIp string) (err error
 			return errors.New("充值订单不存在")
 		}
 
-		if topUp.PaymentMethod != PaymentMethodStripe {
+		if topUp.PaymentProvider != PaymentProviderStripe {
 			return ErrPaymentMethodMismatch
 		}
 
@@ -340,7 +349,7 @@ func ManualCompleteTopUp(tradeNo string, callerIp string) error {
 		// 计算应充值额度:
 		// - Stripe 订单:Money 代表经分组倍率换算后的美元数量,直接 * QuotaPerUnit
 		// - 其他订单(如易支付):Amount 为美元数量,* QuotaPerUnit
-		if topUp.PaymentMethod == PaymentMethodStripe {
+		if topUp.PaymentProvider == PaymentProviderStripe {
 			dQuotaPerUnit := decimal.NewFromFloat(common.QuotaPerUnit)
 			quotaToAdd = int(decimal.NewFromFloat(topUp.Money).Mul(dQuotaPerUnit).IntPart())
 		} else {
@@ -397,7 +406,7 @@ func RechargeCreem(referenceId string, customerEmail string, customerName string
 			return errors.New("充值订单不存在")
 		}
 
-		if topUp.PaymentMethod != PaymentMethodCreem {
+		if topUp.PaymentProvider != PaymentProviderCreem {
 			return ErrPaymentMethodMismatch
 		}
 
@@ -472,7 +481,7 @@ func RechargeWaffo(tradeNo string, callerIp string) (err error) {
 			return errors.New("充值订单不存在")
 		}
 
-		if topUp.PaymentMethod != PaymentMethodWaffo {
+		if topUp.PaymentProvider != PaymentProviderWaffo {
 			return ErrPaymentMethodMismatch
 		}
 
@@ -535,7 +544,7 @@ func RechargeWaffoPancake(tradeNo string) (err error) {
 			return errors.New("充值订单不存在")
 		}
 
-		if topUp.PaymentMethod != PaymentMethodWaffoPancake {
+		if topUp.PaymentProvider != PaymentProviderWaffoPancake {
 			return ErrPaymentMethodMismatch
 		}