Преглед изворни кода

refactor(task): add CAS-guarded updates to prevent concurrent billing conflicts

Replace all bare task.Update() (DB.Save) calls with UpdateWithStatus(),
which adds a WHERE status = ? guard to prevent concurrent processes from
overwriting each other's state transitions.

Key changes:

model/task.go:
- Add taskSnapshot struct with Equal() method for change detection
- Add Snapshot() method to capture pre-update state
- Add UpdateWithStatus(fromStatus) using DB.Where().Save() for CAS
  semantics with full-struct save (no explicit field listing needed)

model/midjourney.go:
- Add UpdateWithStatus(fromStatus string) with same CAS pattern

service/task_polling.go (updateVideoSingleTask):
- Snapshot before processing upstream response; skip DB write if unchanged
- Terminal transitions (SUCCESS/FAILURE) use UpdateWithStatus CAS:
  billing/refund only executes if this process wins the transition
- Non-terminal updates also use UpdateWithStatus to prevent overwriting
  a concurrent terminal transition back to IN_PROGRESS
- Defer settleTaskBillingOnComplete to after CAS check (shouldSettle flag)

relay/relay_task.go (tryRealtimeFetch):
- Add snapshot + change detection; use UpdateWithStatus for CAS safety

controller/midjourney.go (UpdateMidjourneyTaskBulk):
- Capture preStatus before mutations; use UpdateWithStatus CAS
- Gate refund (IncreaseUserQuota) on CAS success (won && shouldReturnQuota)

This prevents the multi-instance race condition where:
1. Instance A reads task (IN_PROGRESS), fetches upstream (still IN_PROGRESS)
2. Instance B reads same task, fetches upstream (now SUCCESS), writes SUCCESS
3. Instance A's bare Save() overwrites SUCCESS back to IN_PROGRESS
CaIon пре 1 недеља
родитељ
комит
5ec4633cb8
5 измењених фајлова са 95 додато и 74 уклоњено
  1. 8 9
      controller/midjourney.go
  2. 11 0
      model/midjourney.go
  3. 41 50
      model/task.go
  4. 6 1
      relay/relay_task.go
  5. 29 14
      service/task_polling.go

+ 8 - 9
controller/midjourney.go

@@ -130,6 +130,7 @@ func UpdateMidjourneyTaskBulk() {
 				if !checkMjTaskNeedUpdate(task, responseItem) {
 					continue
 				}
+				preStatus := task.Status
 				task.Code = 1
 				task.Progress = responseItem.Progress
 				task.PromptEn = responseItem.PromptEn
@@ -172,18 +173,16 @@ func UpdateMidjourneyTaskBulk() {
 						shouldReturnQuota = true
 					}
 				}
-				err = task.Update()
+				won, err := task.UpdateWithStatus(preStatus)
 				if err != nil {
 					logger.LogError(ctx, "UpdateMidjourneyTask task error: "+err.Error())
-				} else {
-					if shouldReturnQuota {
-						err = model.IncreaseUserQuota(task.UserId, task.Quota, false)
-						if err != nil {
-							logger.LogError(ctx, "fail to increase user quota: "+err.Error())
-						}
-						logContent := fmt.Sprintf("构图失败 %s,补偿 %s", task.MjId, logger.LogQuota(task.Quota))
-						model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
+				} else if won && shouldReturnQuota {
+					err = model.IncreaseUserQuota(task.UserId, task.Quota, false)
+					if err != nil {
+						logger.LogError(ctx, "fail to increase user quota: "+err.Error())
 					}
+					logContent := fmt.Sprintf("构图失败 %s,补偿 %s", task.MjId, logger.LogQuota(task.Quota))
+					model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
 				}
 			}
 		}

+ 11 - 0
model/midjourney.go

@@ -157,6 +157,17 @@ func (midjourney *Midjourney) Update() error {
 	return err
 }
 
+// UpdateWithStatus performs a conditional UPDATE guarded by fromStatus (CAS).
+// Returns (true, nil) if this caller won the update, (false, nil) if
+// another process already moved the task out of fromStatus.
+func (midjourney *Midjourney) UpdateWithStatus(fromStatus string) (bool, error) {
+	result := DB.Where("status = ?", fromStatus).Save(midjourney)
+	if result.Error != nil {
+		return false, result.Error
+	}
+	return result.RowsAffected > 0, nil
+}
+
 func MjBulkUpdate(mjIds []string, params map[string]any) error {
 	return DB.Model(&Midjourney{}).
 		Where("mj_id in (?)", mjIds).

+ 41 - 50
model/task.go

@@ -1,6 +1,7 @@
 package model
 
 import (
+	"bytes"
 	"database/sql/driver"
 	"encoding/json"
 	"time"
@@ -340,38 +341,59 @@ func GetByTaskIds(userId int, taskIds []any) ([]*Task, error) {
 	return task, nil
 }
 
-func TaskUpdateProgress(id int64, progress string) error {
-	return DB.Model(&Task{}).Where("id = ?", id).Update("progress", progress).Error
-}
-
 func (Task *Task) Insert() error {
 	var err error
 	err = DB.Create(Task).Error
 	return err
 }
 
+type taskSnapshot struct {
+	Status     TaskStatus
+	Progress   string
+	StartTime  int64
+	FinishTime int64
+	FailReason string
+	ResultURL  string
+	Data       json.RawMessage
+}
+
+func (s taskSnapshot) Equal(other taskSnapshot) bool {
+	return s.Status == other.Status &&
+		s.Progress == other.Progress &&
+		s.StartTime == other.StartTime &&
+		s.FinishTime == other.FinishTime &&
+		s.FailReason == other.FailReason &&
+		s.ResultURL == other.ResultURL &&
+		bytes.Equal(s.Data, other.Data)
+}
+
+func (t *Task) Snapshot() taskSnapshot {
+	return taskSnapshot{
+		Status:     t.Status,
+		Progress:   t.Progress,
+		StartTime:  t.StartTime,
+		FinishTime: t.FinishTime,
+		FailReason: t.FailReason,
+		ResultURL:  t.PrivateData.ResultURL,
+		Data:       t.Data,
+	}
+}
+
 func (Task *Task) Update() error {
 	var err error
 	err = DB.Save(Task).Error
 	return err
 }
 
-func TaskBulkUpdate(TaskIds []string, params map[string]any) error {
-	if len(TaskIds) == 0 {
-		return nil
+// UpdateWithStatus performs a conditional UPDATE guarded by fromStatus (CAS).
+// Returns (true, nil) if this caller won the update, (false, nil) if
+// another process already moved the task out of fromStatus.
+func (t *Task) UpdateWithStatus(fromStatus TaskStatus) (bool, error) {
+	result := DB.Where("status = ?", fromStatus).Save(t)
+	if result.Error != nil {
+		return false, result.Error
 	}
-	return DB.Model(&Task{}).
-		Where("task_id in (?)", TaskIds).
-		Updates(params).Error
-}
-
-func TaskBulkUpdateByTaskIds(taskIDs []int64, params map[string]any) error {
-	if len(taskIDs) == 0 {
-		return nil
-	}
-	return DB.Model(&Task{}).
-		Where("id in (?)", taskIDs).
-		Updates(params).Error
+	return result.RowsAffected > 0, nil
 }
 
 func TaskBulkUpdateByID(ids []int64, params map[string]any) error {
@@ -388,37 +410,6 @@ type TaskQuotaUsage struct {
 	Count float64 `json:"count"`
 }
 
-func SumUsedTaskQuota(queryParams SyncTaskQueryParams) (stat []TaskQuotaUsage, err error) {
-	query := DB.Model(Task{})
-	// 添加过滤条件
-	if queryParams.ChannelID != "" {
-		query = query.Where("channel_id = ?", queryParams.ChannelID)
-	}
-	if queryParams.UserID != "" {
-		query = query.Where("user_id = ?", queryParams.UserID)
-	}
-	if len(queryParams.UserIDs) != 0 {
-		query = query.Where("user_id in (?)", queryParams.UserIDs)
-	}
-	if queryParams.TaskID != "" {
-		query = query.Where("task_id = ?", queryParams.TaskID)
-	}
-	if queryParams.Action != "" {
-		query = query.Where("action = ?", queryParams.Action)
-	}
-	if queryParams.Status != "" {
-		query = query.Where("status = ?", queryParams.Status)
-	}
-	if queryParams.StartTimestamp != 0 {
-		query = query.Where("submit_time >= ?", queryParams.StartTimestamp)
-	}
-	if queryParams.EndTimestamp != 0 {
-		query = query.Where("submit_time <= ?", queryParams.EndTimestamp)
-	}
-	err = query.Select("mode, sum(quota) as count").Group("mode").Find(&stat).Error
-	return stat, err
-}
-
 // TaskCountAllTasks returns total tasks that match the given query params (admin usage)
 func TaskCountAllTasks(queryParams SyncTaskQueryParams) int64 {
 	var total int64

+ 6 - 1
relay/relay_task.go

@@ -444,6 +444,8 @@ func tryRealtimeFetch(task *model.Task, isOpenAIVideoAPI bool) []byte {
 		return nil
 	}
 
+	snap := task.Snapshot()
+
 	// 将上游最新状态更新到 task
 	if ti.Status != "" {
 		task.Status = model.TaskStatus(ti.Status)
@@ -459,7 +461,10 @@ func tryRealtimeFetch(task *model.Task, isOpenAIVideoAPI bool) []byte {
 		// No URL from adaptor — construct proxy URL using public task ID
 		task.PrivateData.ResultURL = taskcommon.BuildProxyURL(task.TaskID)
 	}
-	_ = task.Update()
+
+	if !snap.Equal(task.Snapshot()) {
+		_, _ = task.UpdateWithStatus(snap.Status)
+	}
 
 	// OpenAI Video API 由调用者的 ConvertToOpenAIVideo 分支处理
 	if isOpenAIVideoAPI {

+ 29 - 14
service/task_polling.go

@@ -319,6 +319,8 @@ func updateVideoSingleTask(ctx context.Context, adaptor TaskPollingAdaptor, ch *
 
 	logger.LogDebug(ctx, fmt.Sprintf("updateVideoSingleTask response: %s", string(responseBody)))
 
+	snap := task.Snapshot()
+
 	taskResult := &relaycommon.TaskInfo{}
 	// try parse as New API response format
 	var responseItems dto.TaskResponse[model.Task]
@@ -344,10 +346,9 @@ func updateVideoSingleTask(ctx context.Context, adaptor TaskPollingAdaptor, ch *
 		taskResult = relaycommon.FailTaskInfo("upstream returned empty status")
 	}
 
-	// 记录原本的状态,防止重复退款
 	shouldRefund := false
+	shouldSettle := false
 	quota := task.Quota
-	preStatus := task.Status
 
 	task.Status = model.TaskStatus(taskResult.Status)
 	switch taskResult.Status {
@@ -374,9 +375,7 @@ func updateVideoSingleTask(ctx context.Context, adaptor TaskPollingAdaptor, ch *
 			// No URL from adaptor — construct proxy URL using public task ID
 			task.PrivateData.ResultURL = taskcommon.BuildProxyURL(task.TaskID)
 		}
-
-		// 完成时计费调整:优先由 adaptor 计算,回退到 token 重算
-		settleTaskBillingOnComplete(ctx, adaptor, task, taskResult)
+		shouldSettle = true
 	case model.TaskStatusFailure:
 		logger.LogJson(ctx, fmt.Sprintf("Task %s failed", taskId), task)
 		task.Status = model.TaskStatusFailure
@@ -388,23 +387,39 @@ func updateVideoSingleTask(ctx context.Context, adaptor TaskPollingAdaptor, ch *
 		logger.LogInfo(ctx, fmt.Sprintf("Task %s failed: %s", task.TaskID, task.FailReason))
 		taskResult.Progress = taskcommon.ProgressComplete
 		if quota != 0 {
-			if preStatus != model.TaskStatusFailure {
-				shouldRefund = true
-			} else {
-				logger.LogWarn(ctx, fmt.Sprintf("Task %s already in failure status, skip refund", task.TaskID))
-			}
+			shouldRefund = true
 		}
 	default:
-		return fmt.Errorf("unknown task status %s for task %s", taskResult.Status, taskId)
+		return fmt.Errorf("unknown task status %s for task %s", taskResult.Status, task.TaskID)
 	}
 	if taskResult.Progress != "" {
 		task.Progress = taskResult.Progress
 	}
-	if err := task.Update(); err != nil {
-		common.SysLog("UpdateVideoTask task error: " + err.Error())
-		shouldRefund = false
+
+	isDone := task.Status == model.TaskStatusSuccess || task.Status == model.TaskStatusFailure
+	if isDone && snap.Status != task.Status {
+		won, err := task.UpdateWithStatus(snap.Status)
+		if err != nil {
+			logger.LogError(ctx, fmt.Sprintf("UpdateWithStatus failed for task %s: %s", task.TaskID, err.Error()))
+			shouldRefund = false
+			shouldSettle = false
+		} else if !won {
+			logger.LogWarn(ctx, fmt.Sprintf("Task %s already transitioned by another process, skip billing", task.TaskID))
+			shouldRefund = false
+			shouldSettle = false
+		}
+	} else if !snap.Equal(task.Snapshot()) {
+		if _, err := task.UpdateWithStatus(snap.Status); err != nil {
+			logger.LogError(ctx, fmt.Sprintf("Failed to update task %s: %s", task.TaskID, err.Error()))
+		}
+	} else {
+		// No changes, skip update
+		logger.LogDebug(ctx, fmt.Sprintf("No update needed for task %s", task.TaskID))
 	}
 
+	if shouldSettle {
+		settleTaskBillingOnComplete(ctx, adaptor, task, taskResult)
+	}
 	if shouldRefund {
 		RefundTaskQuota(ctx, task, task.FailReason)
 	}