Pārlūkot izejas kodu

feat(task): introduce task timeout configuration and cleanup unfinished tasks

- Added TaskTimeoutMinutes constant to configure the timeout duration for asynchronous tasks.
- Implemented sweepTimedOutTasks function to identify and handle unfinished tasks that exceed the timeout limit, marking them as failed and processing refunds if applicable.
- Enhanced task polling loop to include the new timeout handling logic, ensuring timely cleanup of stale tasks.
CaIon 1 nedēļu atpakaļ
vecāks
revīzija
13ada6484a
4 mainītis faili ar 75 papildinājumiem un 0 dzēšanām
  1. 2 0
      common/init.go
  2. 1 0
      constant/env.go
  3. 19 0
      model/task.go
  4. 53 0
      service/task_polling.go

+ 2 - 0
common/init.go

@@ -145,6 +145,8 @@ func initConstantEnv() {
 	constant.ErrorLogEnabled = GetEnvOrDefaultBool("ERROR_LOG_ENABLED", false)
 	// 任务轮询时查询的最大数量
 	constant.TaskQueryLimit = GetEnvOrDefault("TASK_QUERY_LIMIT", 1000)
+	// 异步任务超时时间(分钟),超过此时间未完成的任务将被标记为失败并退款。0 表示禁用。
+	constant.TaskTimeoutMinutes = GetEnvOrDefault("TASK_TIMEOUT_MINUTES", 1440)
 
 	soraPatchStr := GetEnvOrDefaultString("TASK_PRICE_PATCH", "")
 	if soraPatchStr != "" {

+ 1 - 0
constant/env.go

@@ -16,6 +16,7 @@ var NotificationLimitDurationMinute int
 var GenerateDefaultToken bool
 var ErrorLogEnabled bool
 var TaskQueryLimit int
+var TaskTimeoutMinutes int
 
 // temporary variable for sora patch, will be removed in future
 var TaskPricePatches []string

+ 19 - 0
model/task.go

@@ -288,6 +288,20 @@ func TaskGetAllTasks(startIdx int, num int, queryParams SyncTaskQueryParams) []*
 	return tasks
 }
 
+func GetTimedOutUnfinishedTasks(cutoffUnix int64, limit int) []*Task {
+	var tasks []*Task
+	err := DB.Where("progress != ?", "100%").
+		Where("status NOT IN ?", []string{TaskStatusFailure, TaskStatusSuccess}).
+		Where("submit_time < ?", cutoffUnix).
+		Order("submit_time").
+		Limit(limit).
+		Find(&tasks).Error
+	if err != nil {
+		return nil
+	}
+	return tasks
+}
+
 func GetAllUnFinishSyncTasks(limit int) []*Task {
 	var tasks []*Task
 	var err error
@@ -401,6 +415,11 @@ func (t *Task) UpdateWithStatus(fromStatus TaskStatus) (bool, error) {
 	return result.RowsAffected > 0, nil
 }
 
+// TaskBulkUpdateByID performs an unconditional bulk UPDATE by primary key IDs.
+// WARNING: This function has NO CAS (Compare-And-Swap) guard — it will overwrite
+// any concurrent status changes. DO NOT use in billing/quota lifecycle flows
+// (e.g., timeout, success, failure transitions that trigger refunds or settlements).
+// For status transitions that involve billing, use Task.UpdateWithStatus() instead.
 func TaskBulkUpdateByID(ids []int64, params map[string]any) error {
 	if len(ids) == 0 {
 		return nil

+ 53 - 0
service/task_polling.go

@@ -35,12 +35,65 @@ type TaskPollingAdaptor interface {
 // 打破 service -> relay -> relay/channel -> service 的循环依赖。
 var GetTaskAdaptorFunc func(platform constant.TaskPlatform) TaskPollingAdaptor
 
+// sweepTimedOutTasks 在主轮询之前独立清理超时任务。
+// 每次最多处理 100 条,剩余的下个周期继续处理。
+// 使用 per-task CAS (UpdateWithStatus) 防止覆盖被正常轮询已推进的任务。
+func sweepTimedOutTasks(ctx context.Context) {
+	if constant.TaskTimeoutMinutes <= 0 {
+		return
+	}
+	cutoff := time.Now().Unix() - int64(constant.TaskTimeoutMinutes)*60
+	tasks := model.GetTimedOutUnfinishedTasks(cutoff, 100)
+	if len(tasks) == 0 {
+		return
+	}
+
+	const legacyTaskCutoff int64 = 1740182400 // 2026-02-22 00:00:00 UTC
+	reason := fmt.Sprintf("任务超时(%d分钟)", constant.TaskTimeoutMinutes)
+	legacyReason := "任务超时(旧系统遗留任务,不进行退款,请联系管理员)"
+	now := time.Now().Unix()
+	timedOutCount := 0
+
+	for _, task := range tasks {
+		isLegacy := task.SubmitTime > 0 && task.SubmitTime < legacyTaskCutoff
+
+		oldStatus := task.Status
+		task.Status = model.TaskStatusFailure
+		task.Progress = "100%"
+		task.FinishTime = now
+		if isLegacy {
+			task.FailReason = legacyReason
+		} else {
+			task.FailReason = reason
+		}
+
+		won, err := task.UpdateWithStatus(oldStatus)
+		if err != nil {
+			logger.LogError(ctx, fmt.Sprintf("sweepTimedOutTasks CAS update error for task %s: %v", task.TaskID, err))
+			continue
+		}
+		if !won {
+			logger.LogInfo(ctx, fmt.Sprintf("sweepTimedOutTasks: task %s already transitioned, skip", task.TaskID))
+			continue
+		}
+		timedOutCount++
+		if !isLegacy && task.Quota != 0 {
+			RefundTaskQuota(ctx, task, reason)
+		}
+	}
+
+	if timedOutCount > 0 {
+		logger.LogInfo(ctx, fmt.Sprintf("sweepTimedOutTasks: timed out %d tasks", timedOutCount))
+	}
+}
+
 // TaskPollingLoop 主轮询循环,每 15 秒检查一次未完成的任务
 func TaskPollingLoop() {
 	for {
 		time.Sleep(time.Duration(15) * time.Second)
 		common.SysLog("任务进度轮询开始")
 		ctx := context.TODO()
+		sweepTimedOutTasks(ctx)
 		allTasks := model.GetAllUnFinishSyncTasks(constant.TaskQueryLimit)
 		platformTask := make(map[constant.TaskPlatform][]*model.Task)
 		for _, t := range allTasks {