Explorar el Código

refactor: improve error handling and database transactions in 2FA model methods

Seefs hace 7 meses
padre
commit
398ae7156b
Se han modificado 2 ficheros con 31 adiciones y 26 borrados
  1. 2 2
      controller/twofa.go
  2. 29 24
      model/twofa.go

+ 2 - 2
controller/twofa.go

@@ -1,12 +1,12 @@
 package controller
 
 import (
+	"errors"
 	"fmt"
 	"net/http"
 	"one-api/common"
 	"one-api/model"
 	"strconv"
-	"strings"
 
 	"github.com/gin-contrib/sessions"
 	"github.com/gin-gonic/gin"
@@ -530,7 +530,7 @@ func AdminDisable2FA(c *gin.Context) {
 
 	// 禁用2FA
 	if err := model.DisableTwoFA(userId); err != nil {
-		if strings.Contains(err.Error(), "未启用2FA") {
+		if errors.Is(err, model.ErrTwoFANotEnabled) {
 			c.JSON(http.StatusOK, gin.H{
 				"success": false,
 				"message": "用户未启用2FA",

+ 29 - 24
model/twofa.go

@@ -100,13 +100,16 @@ func (t *TwoFA) Delete() error {
 		return errors.New("2FA记录ID不能为空")
 	}
 
-	// 同时删除相关的备用码记录(硬删除)
-	if err := DB.Unscoped().Where("user_id = ?", t.UserId).Delete(&TwoFABackupCode{}).Error; err != nil {
-		return err
-	}
+	// 使用事务确保原子性
+	return DB.Transaction(func(tx *gorm.DB) error {
+		// 同时删除相关的备用码记录(硬删除)
+		if err := tx.Unscoped().Where("user_id = ?", t.UserId).Delete(&TwoFABackupCode{}).Error; err != nil {
+			return err
+		}
 
-	// 硬删除2FA记录
-	return DB.Unscoped().Delete(t).Error
+		// 硬删除2FA记录
+		return tx.Unscoped().Delete(t).Error
+	})
 }
 
 // ResetFailedAttempts 重置失败尝试次数
@@ -139,30 +142,32 @@ func (t *TwoFA) IsLocked() bool {
 
 // CreateBackupCodes 创建备用码
 func CreateBackupCodes(userId int, codes []string) error {
-	// 先删除现有的备用码
-	if err := DB.Where("user_id = ?", userId).Delete(&TwoFABackupCode{}).Error; err != nil {
-		return err
-	}
-
-	// 创建新的备用码记录
-	for _, code := range codes {
-		hashedCode, err := common.HashBackupCode(code)
-		if err != nil {
+	return DB.Transaction(func(tx *gorm.DB) error {
+		// 先删除现有的备用码
+		if err := tx.Where("user_id = ?", userId).Delete(&TwoFABackupCode{}).Error; err != nil {
 			return err
 		}
 
-		backupCode := TwoFABackupCode{
-			UserId:   userId,
-			CodeHash: hashedCode,
-			IsUsed:   false,
-		}
+		// 创建新的备用码记录
+		for _, code := range codes {
+			hashedCode, err := common.HashBackupCode(code)
+			if err != nil {
+				return err
+			}
 
-		if err := DB.Create(&backupCode).Error; err != nil {
-			return err
+			backupCode := TwoFABackupCode{
+				UserId:   userId,
+				CodeHash: hashedCode,
+				IsUsed:   false,
+			}
+
+			if err := tx.Create(&backupCode).Error; err != nil {
+				return err
+			}
 		}
-	}
 
-	return nil
+		return nil
+	})
 }
 
 // ValidateBackupCode 验证并使用备用码