Просмотр исходного кода

feat: enhance AddAbilities and BatchInsertChannels to support transaction handling

CaIon 7 месяцев назад
Родитель
Сommit
962c40c1a7
2 измененных файлов с 48 добавлено и 26 удалено
  1. 8 3
      model/ability.go
  2. 40 23
      model/channel.go

+ 8 - 3
model/ability.go

@@ -142,7 +142,7 @@ func GetRandomSatisfiedChannel(group string, model string, retry int) (*Channel,
 	return &channel, err
 }
 
-func (channel *Channel) AddAbilities() error {
+func (channel *Channel) AddAbilities(tx *gorm.DB) error {
 	models_ := strings.Split(channel.Models, ",")
 	groups_ := strings.Split(channel.Group, ",")
 	abilitySet := make(map[string]struct{})
@@ -169,8 +169,13 @@ func (channel *Channel) AddAbilities() error {
 	if len(abilities) == 0 {
 		return nil
 	}
+	// choose DB or provided tx
+	useDB := DB
+	if tx != nil {
+		useDB = tx
+	}
 	for _, chunk := range lo.Chunk(abilities, 50) {
-		err := DB.Clauses(clause.OnConflict{DoNothing: true}).Create(&chunk).Error
+		err := useDB.Clauses(clause.OnConflict{DoNothing: true}).Create(&chunk).Error
 		if err != nil {
 			return err
 		}
@@ -321,7 +326,7 @@ func FixAbility() (int, int, error) {
 		}
 		// Then add new abilities
 		for _, channel := range chunk {
-			err = channel.AddAbilities()
+			err = channel.AddAbilities(nil)
 			if err != nil {
 				common.SysError(fmt.Sprintf("Add abilities for channel %d failed: %s", channel.Id, err.Error()))
 				failCount++

+ 40 - 23
model/channel.go

@@ -13,6 +13,7 @@ import (
 	"strings"
 	"sync"
 
+	"github.com/samber/lo"
 	"gorm.io/gorm"
 )
 
@@ -337,38 +338,54 @@ func GetChannelById(id int, selectAll bool) (*Channel, error) {
 }
 
 func BatchInsertChannels(channels []Channel) error {
-	var err error
-	err = DB.Create(&channels).Error
-	if err != nil {
-		return err
+	if len(channels) == 0 {
+		return nil
 	}
-	for _, channel_ := range channels {
-		err = channel_.AddAbilities()
-		if err != nil {
+	tx := DB.Begin()
+	if tx.Error != nil {
+		return tx.Error
+	}
+	defer func() {
+		if r := recover(); r != nil {
+			tx.Rollback()
+		}
+	}()
+
+	for _, chunk := range lo.Chunk(channels, 50) {
+		if err := tx.Create(&chunk).Error; err != nil {
+			tx.Rollback()
 			return err
 		}
+		for _, channel_ := range chunk {
+			if err := channel_.AddAbilities(tx); err != nil {
+				tx.Rollback()
+				return err
+			}
+		}
 	}
-	return nil
+	return tx.Commit().Error
 }
 
 func BatchDeleteChannels(ids []int) error {
-	//使用事务 删除channel表和channel_ability表
+	if len(ids) == 0 {
+		return nil
+	}
+	// 使用事务 分批删除channel表和abilities表
 	tx := DB.Begin()
-	err := tx.Where("id in (?)", ids).Delete(&Channel{}).Error
-	if err != nil {
-		// 回滚事务
-		tx.Rollback()
-		return err
+	if tx.Error != nil {
+		return tx.Error
 	}
-	err = tx.Where("channel_id in (?)", ids).Delete(&Ability{}).Error
-	if err != nil {
-		// 回滚事务
-		tx.Rollback()
-		return err
+	for _, chunk := range lo.Chunk(ids, 200) {
+		if err := tx.Where("id in (?)", chunk).Delete(&Channel{}).Error; err != nil {
+			tx.Rollback()
+			return err
+		}
+		if err := tx.Where("channel_id in (?)", chunk).Delete(&Ability{}).Error; err != nil {
+			tx.Rollback()
+			return err
+		}
 	}
-	// 提交事务
-	tx.Commit()
-	return err
+	return tx.Commit().Error
 }
 
 func (channel *Channel) GetPriority() int64 {
@@ -412,7 +429,7 @@ func (channel *Channel) Insert() error {
 	if err != nil {
 		return err
 	}
-	err = channel.AddAbilities()
+	err = channel.AddAbilities(nil)
 	return err
 }