Просмотр исходного кода

feat: add batch update support (close #414)

JustSong 2 лет назад
Родитель
Сommit
c3dc315e75
7 измененных файлов с 136 добавлено и 1 удалено
  1. 4 0
      README.md
  2. 3 0
      common/constants.go
  3. 5 0
      main.go
  4. 8 0
      model/channel.go
  5. 16 0
      model/token.go
  6. 25 1
      model/user.go
  7. 75 0
      model/utils.go

+ 4 - 0
README.md

@@ -306,6 +306,10 @@ graph LR
    + 例子:`CHANNEL_TEST_FREQUENCY=1440`
 9. `POLLING_INTERVAL`:批量更新渠道余额以及测试可用性时的请求间隔,单位为秒,默认无间隔。
    + 例子:`POLLING_INTERVAL=5`
+10. `BATCH_UPDATE_ENABLED`:启用数据库批量更新聚合,会导致用户额度的更新存在一定的延迟可选值为 `true` 和 `false`,未设置则默认为 `false`。
+    + 例子:`BATCH_UPDATE_ENABLED=true`
+11. `BATCH_UPDATE_INTERVAL=5`:批量更新聚合的时间间隔,单位为秒,默认为 `5`。
+    + 例子:`BATCH_UPDATE_INTERVAL=5`
 
 ### 命令行参数
 1. `--port <port_number>`: 指定服务器监听的端口号,默认为 `3000`。

+ 3 - 0
common/constants.go

@@ -94,6 +94,9 @@ var RequestInterval = time.Duration(requestInterval) * time.Second
 
 var SyncFrequency = 10 * 60 // unit is second, will be overwritten by SYNC_FREQUENCY
 
+var BatchUpdateEnabled = false
+var BatchUpdateInterval = GetOrDefault("BATCH_UPDATE_INTERVAL", 5)
+
 const (
 	RoleGuestUser  = 0
 	RoleCommonUser = 1

+ 5 - 0
main.go

@@ -77,6 +77,11 @@ func main() {
 		}
 		go controller.AutomaticallyTestChannels(frequency)
 	}
+	if os.Getenv("BATCH_UPDATE_ENABLED") == "true" {
+		common.BatchUpdateEnabled = true
+		common.SysLog("batch update enabled with interval " + strconv.Itoa(common.BatchUpdateInterval) + "s")
+		model.InitBatchUpdater()
+	}
 	controller.InitTokenEncoders()
 
 	// Initialize HTTP server

+ 8 - 0
model/channel.go

@@ -141,6 +141,14 @@ func UpdateChannelStatusById(id int, status int) {
 }
 
 func UpdateChannelUsedQuota(id int, quota int) {
+	if common.BatchUpdateEnabled {
+		addNewRecord(BatchUpdateTypeChannelUsedQuota, id, quota)
+		return
+	}
+	updateChannelUsedQuota(id, quota)
+}
+
+func updateChannelUsedQuota(id int, quota int) {
 	err := DB.Model(&Channel{}).Where("id = ?", id).Update("used_quota", gorm.Expr("used_quota + ?", quota)).Error
 	if err != nil {
 		common.SysError("failed to update channel used quota: " + err.Error())

+ 16 - 0
model/token.go

@@ -131,6 +131,14 @@ func IncreaseTokenQuota(id int, quota int) (err error) {
 	if quota < 0 {
 		return errors.New("quota 不能为负数!")
 	}
+	if common.BatchUpdateEnabled {
+		addNewRecord(BatchUpdateTypeTokenQuota, id, quota)
+		return nil
+	}
+	return increaseTokenQuota(id, quota)
+}
+
+func increaseTokenQuota(id int, quota int) (err error) {
 	err = DB.Model(&Token{}).Where("id = ?", id).Updates(
 		map[string]interface{}{
 			"remain_quota": gorm.Expr("remain_quota + ?", quota),
@@ -144,6 +152,14 @@ func DecreaseTokenQuota(id int, quota int) (err error) {
 	if quota < 0 {
 		return errors.New("quota 不能为负数!")
 	}
+	if common.BatchUpdateEnabled {
+		addNewRecord(BatchUpdateTypeTokenQuota, id, -quota)
+		return nil
+	}
+	return decreaseTokenQuota(id, quota)
+}
+
+func decreaseTokenQuota(id int, quota int) (err error) {
 	err = DB.Model(&Token{}).Where("id = ?", id).Updates(
 		map[string]interface{}{
 			"remain_quota": gorm.Expr("remain_quota - ?", quota),

+ 25 - 1
model/user.go

@@ -275,6 +275,14 @@ func IncreaseUserQuota(id int, quota int) (err error) {
 	if quota < 0 {
 		return errors.New("quota 不能为负数!")
 	}
+	if common.BatchUpdateEnabled {
+		addNewRecord(BatchUpdateTypeUserQuota, id, quota)
+		return nil
+	}
+	return increaseUserQuota(id, quota)
+}
+
+func increaseUserQuota(id int, quota int) (err error) {
 	err = DB.Model(&User{}).Where("id = ?", id).Update("quota", gorm.Expr("quota + ?", quota)).Error
 	return err
 }
@@ -283,6 +291,14 @@ func DecreaseUserQuota(id int, quota int) (err error) {
 	if quota < 0 {
 		return errors.New("quota 不能为负数!")
 	}
+	if common.BatchUpdateEnabled {
+		addNewRecord(BatchUpdateTypeUserQuota, id, -quota)
+		return nil
+	}
+	return decreaseUserQuota(id, quota)
+}
+
+func decreaseUserQuota(id int, quota int) (err error) {
 	err = DB.Model(&User{}).Where("id = ?", id).Update("quota", gorm.Expr("quota - ?", quota)).Error
 	return err
 }
@@ -293,10 +309,18 @@ func GetRootUserEmail() (email string) {
 }
 
 func UpdateUserUsedQuotaAndRequestCount(id int, quota int) {
+	if common.BatchUpdateEnabled {
+		addNewRecord(BatchUpdateTypeUsedQuotaAndRequestCount, id, quota)
+		return
+	}
+	updateUserUsedQuotaAndRequestCount(id, quota, 1)
+}
+
+func updateUserUsedQuotaAndRequestCount(id int, quota int, count int) {
 	err := DB.Model(&User{}).Where("id = ?", id).Updates(
 		map[string]interface{}{
 			"used_quota":    gorm.Expr("used_quota + ?", quota),
-			"request_count": gorm.Expr("request_count + ?", 1),
+			"request_count": gorm.Expr("request_count + ?", count),
 		},
 	).Error
 	if err != nil {

+ 75 - 0
model/utils.go

@@ -0,0 +1,75 @@
+package model
+
+import (
+	"one-api/common"
+	"sync"
+	"time"
+)
+
+const BatchUpdateTypeCount = 4 // if you add a new type, you need to add a new map and a new lock
+
+const (
+	BatchUpdateTypeUserQuota = iota
+	BatchUpdateTypeTokenQuota
+	BatchUpdateTypeUsedQuotaAndRequestCount
+	BatchUpdateTypeChannelUsedQuota
+)
+
+var batchUpdateStores []map[int]int
+var batchUpdateLocks []sync.Mutex
+
+func init() {
+	for i := 0; i < BatchUpdateTypeCount; i++ {
+		batchUpdateStores = append(batchUpdateStores, make(map[int]int))
+		batchUpdateLocks = append(batchUpdateLocks, sync.Mutex{})
+	}
+}
+
+func InitBatchUpdater() {
+	go func() {
+		for {
+			time.Sleep(time.Duration(common.BatchUpdateInterval) * time.Second)
+			batchUpdate()
+		}
+	}()
+}
+
+func addNewRecord(type_ int, id int, value int) {
+	batchUpdateLocks[type_].Lock()
+	defer batchUpdateLocks[type_].Unlock()
+	if _, ok := batchUpdateStores[type_][id]; !ok {
+		batchUpdateStores[type_][id] = value
+	} else {
+		batchUpdateStores[type_][id] += value
+	}
+}
+
+func batchUpdate() {
+	common.SysLog("batch update started")
+	for i := 0; i < BatchUpdateTypeCount; i++ {
+		batchUpdateLocks[i].Lock()
+		store := batchUpdateStores[i]
+		batchUpdateStores[i] = make(map[int]int)
+		batchUpdateLocks[i].Unlock()
+
+		for key, value := range store {
+			switch i {
+			case BatchUpdateTypeUserQuota:
+				err := increaseUserQuota(key, value)
+				if err != nil {
+					common.SysError("failed to batch update user quota: " + err.Error())
+				}
+			case BatchUpdateTypeTokenQuota:
+				err := increaseTokenQuota(key, value)
+				if err != nil {
+					common.SysError("failed to batch update token quota: " + err.Error())
+				}
+			case BatchUpdateTypeUsedQuotaAndRequestCount:
+				updateUserUsedQuotaAndRequestCount(key, value, 1) // TODO: count is incorrect
+			case BatchUpdateTypeChannelUsedQuota:
+				updateChannelUsedQuota(key, value)
+			}
+		}
+	}
+	common.SysLog("batch update finished")
+}