Pārlūkot izejas kodu

refactor(task): extract billing and polling logic from controller to service layer

Restructure the task relay system for better separation of concerns:
- Extract task billing into service/task_billing.go with unified settlement flow
- Move task polling loop from controller to service/task_polling.go (supports Suno + video platforms)
- Split RelayTask into fetch/submit paths with dedicated retry logic (taskSubmitWithRetry)
- Add TaskDto, TaskResponse generics, and FetchReq to dto/task.go
- Add taskcommon/helpers.go for shared task adaptor utilities
- Remove controller/task_video.go (logic consolidated into service layer)
- Update all task adaptors (ali, doubao, gemini, hailuo, jimeng, kling, sora, suno, vertex, vidu)
- Simplify frontend task logs to use new TaskDto response format
CaIon 3 nedēļas atpakaļ
vecāks
revīzija
9e3954428d

+ 98 - 22
controller/relay.go

@@ -451,17 +451,102 @@ func RelayNotFound(c *gin.Context) {
 }
 
 func RelayTask(c *gin.Context) {
-	retryTimes := common.RetryTimes
 	channelId := c.GetInt("channel_id")
 	c.Set("use_channel", []string{fmt.Sprintf("%d", channelId)})
 	relayInfo, err := relaycommon.GenRelayInfo(c, types.RelayFormatTask, nil, nil)
 	if err != nil {
+		c.JSON(http.StatusInternalServerError, &dto.TaskError{
+			Code:       "gen_relay_info_failed",
+			Message:    err.Error(),
+			StatusCode: http.StatusInternalServerError,
+		})
+		return
+	}
+
+	// Fetch 操作是纯 DB 查询(或 task 自带 channelId 的上游查询),不依赖上下文 channel,无需重试
+	// TODO: 在video-route层面优化,避免无谓的 channel 选择和上下文设置,也没必要吧代码放到这里来写这么多屎山
+	switch relayInfo.RelayMode {
+	case relayconstant.RelayModeSunoFetch, relayconstant.RelayModeSunoFetchByID, relayconstant.RelayModeVideoFetchByID:
+		if taskErr := relay.RelayTaskFetch(c, relayInfo.RelayMode); taskErr != nil {
+			respondTaskError(c, taskErr)
+		}
+		return
+	}
+
+	// ── Submit 路径 ─────────────────────────────────────────────────
+
+	// 1. 解析原始任务(remix / continuation),一次性,可能锁定渠道并禁止重试
+	if taskErr := relay.ResolveOriginTask(c, relayInfo); taskErr != nil {
+		respondTaskError(c, taskErr)
 		return
 	}
-	taskErr := taskRelayHandler(c, relayInfo)
+
+	// 2. defer Refund(全部失败时回滚预扣费)
+	var result *relay.TaskSubmitResult
+	var taskErr *dto.TaskError
+	defer func() {
+		if taskErr != nil && relayInfo.Billing != nil {
+			relayInfo.Billing.Refund(c)
+		}
+	}()
+
+	// 3. 执行 + 重试(RelayTaskSubmit 内部在首次调用时自动预扣费)
+	taskErr = taskSubmitWithRetry(c, relayInfo, channelId, common.RetryTimes, func() *dto.TaskError {
+		var te *dto.TaskError
+		result, te = relay.RelayTaskSubmit(c, relayInfo)
+		return te
+	})
+
+	// 4. 成功:结算 + 日志 + 插入任务
+	if taskErr == nil {
+		if settleErr := service.SettleBilling(c, relayInfo, result.Quota); settleErr != nil {
+			common.SysError("settle task billing error: " + settleErr.Error())
+		}
+		service.LogTaskConsumption(c, relayInfo, result.ModelName)
+
+		task := model.InitTask(result.Platform, relayInfo)
+		task.PrivateData.UpstreamTaskID = result.UpstreamTaskID
+		task.PrivateData.BillingSource = relayInfo.BillingSource
+		task.PrivateData.SubscriptionId = relayInfo.SubscriptionId
+		task.PrivateData.TokenId = relayInfo.TokenId
+		task.Quota = result.Quota
+		task.Data = result.TaskData
+		task.Action = relayInfo.Action
+		if insertErr := task.Insert(); insertErr != nil {
+			//taskErr = service.TaskErrorWrapper(insertErr, "insert_task_failed", http.StatusInternalServerError)
+			common.SysError("insert task error: " + insertErr.Error())
+		}
+	}
+
+	if taskErr != nil {
+		respondTaskError(c, taskErr)
+	}
+}
+
+// respondTaskError 统一输出 Task 错误响应(含 429 限流提示改写)
+func respondTaskError(c *gin.Context, taskErr *dto.TaskError) {
+	if taskErr.StatusCode == http.StatusTooManyRequests {
+		taskErr.Message = "当前分组上游负载已饱和,请稍后再试"
+	}
+	c.JSON(taskErr.StatusCode, taskErr)
+}
+
+// taskSubmitWithRetry 执行首次尝试并在失败时切换渠道重试,返回最终的 taskErr。
+// attempt 闭包负责实际的上游请求,不涉及计费。
+func taskSubmitWithRetry(c *gin.Context, relayInfo *relaycommon.RelayInfo,
+	channelId int, retryTimes int, attempt func() *dto.TaskError) *dto.TaskError {
+
+	taskErr := attempt()
 	if taskErr == nil {
-		retryTimes = 0
+		return nil
 	}
+	if !taskErr.LocalError {
+		processChannelError(c,
+			*types.NewChannelError(channelId, c.GetInt("channel_type"), c.GetString("channel_name"), common.GetContextKeyBool(c, constant.ContextKeyChannelIsMultiKey),
+				common.GetContextKeyString(c, constant.ContextKeyChannelKey), common.GetContextKeyBool(c, constant.ContextKeyChannelAutoBan)),
+			types.NewOpenAIError(taskErr.Error, types.ErrorCodeBadResponseStatusCode, taskErr.StatusCode))
+	}
+
 	retryParam := &service.RetryParam{
 		Ctx:        c,
 		TokenGroup: relayInfo.TokenGroup,
@@ -480,7 +565,7 @@ func RelayTask(c *gin.Context) {
 		useChannel = append(useChannel, fmt.Sprintf("%d", channelId))
 		c.Set("use_channel", useChannel)
 		logger.LogInfo(c, fmt.Sprintf("using channel #%d to retry (remain times %d)", channel.Id, retryParam.GetRetry()))
-		//middleware.SetupContextForSelectedChannel(c, channel, originalModel)
+		middleware.SetupContextForSelectedChannel(c, channel, c.GetString("original_model"))
 
 		bodyStorage, err := common.GetBodyStorage(c)
 		if err != nil {
@@ -492,30 +577,21 @@ func RelayTask(c *gin.Context) {
 			break
 		}
 		c.Request.Body = io.NopCloser(bodyStorage)
-		taskErr = taskRelayHandler(c, relayInfo)
+		taskErr = attempt()
+		if taskErr != nil && !taskErr.LocalError {
+			processChannelError(c,
+				*types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey,
+					common.GetContextKeyString(c, constant.ContextKeyChannelKey), channel.GetAutoBan()),
+				types.NewOpenAIError(taskErr.Error, types.ErrorCodeBadResponseStatusCode, taskErr.StatusCode))
+		}
 	}
+
 	useChannel := c.GetStringSlice("use_channel")
 	if len(useChannel) > 1 {
 		retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]"))
 		logger.LogInfo(c, retryLogStr)
 	}
-	if taskErr != nil {
-		if taskErr.StatusCode == http.StatusTooManyRequests {
-			taskErr.Message = "当前分组上游负载已饱和,请稍后再试"
-		}
-		c.JSON(taskErr.StatusCode, taskErr)
-	}
-}
-
-func taskRelayHandler(c *gin.Context, relayInfo *relaycommon.RelayInfo) *dto.TaskError {
-	var err *dto.TaskError
-	switch relayInfo.RelayMode {
-	case relayconstant.RelayModeSunoFetch, relayconstant.RelayModeSunoFetchByID, relayconstant.RelayModeVideoFetchByID:
-		err = relay.RelayTaskFetch(c, relayInfo.RelayMode)
-	default:
-		err = relay.RelayTaskSubmit(c, relayInfo)
-	}
-	return err
+	return taskErr
 }
 
 func shouldRetryTaskRelay(c *gin.Context, channelId int, taskErr *dto.TaskError, retryTimes int) bool {

+ 13 - 215
controller/task.go

@@ -1,231 +1,21 @@
 package controller
 
 import (
-	"context"
-	"encoding/json"
-	"errors"
-	"fmt"
-	"io"
-	"net/http"
-	"sort"
 	"strconv"
-	"time"
 
 	"github.com/QuantumNous/new-api/common"
 	"github.com/QuantumNous/new-api/constant"
 	"github.com/QuantumNous/new-api/dto"
-	"github.com/QuantumNous/new-api/logger"
 	"github.com/QuantumNous/new-api/model"
 	"github.com/QuantumNous/new-api/relay"
+	"github.com/QuantumNous/new-api/service"
 
 	"github.com/gin-gonic/gin"
-	"github.com/samber/lo"
 )
 
+// UpdateTaskBulk 薄入口,实际轮询逻辑在 service 层
 func UpdateTaskBulk() {
-	//revocer
-	//imageModel := "midjourney"
-	for {
-		time.Sleep(time.Duration(15) * time.Second)
-		common.SysLog("任务进度轮询开始")
-		ctx := context.TODO()
-		allTasks := model.GetAllUnFinishSyncTasks(constant.TaskQueryLimit)
-		platformTask := make(map[constant.TaskPlatform][]*model.Task)
-		for _, t := range allTasks {
-			platformTask[t.Platform] = append(platformTask[t.Platform], t)
-		}
-		for platform, tasks := range platformTask {
-			if len(tasks) == 0 {
-				continue
-			}
-			taskChannelM := make(map[int][]string)
-			taskM := make(map[string]*model.Task)
-			nullTaskIds := make([]int64, 0)
-			for _, task := range tasks {
-				if task.TaskID == "" {
-					// 统计失败的未完成任务
-					nullTaskIds = append(nullTaskIds, task.ID)
-					continue
-				}
-				taskM[task.TaskID] = task
-				taskChannelM[task.ChannelId] = append(taskChannelM[task.ChannelId], task.TaskID)
-			}
-			if len(nullTaskIds) > 0 {
-				err := model.TaskBulkUpdateByID(nullTaskIds, map[string]any{
-					"status":   "FAILURE",
-					"progress": "100%",
-				})
-				if err != nil {
-					logger.LogError(ctx, fmt.Sprintf("Fix null task_id task error: %v", err))
-				} else {
-					logger.LogInfo(ctx, fmt.Sprintf("Fix null task_id task success: %v", nullTaskIds))
-				}
-			}
-			if len(taskChannelM) == 0 {
-				continue
-			}
-
-			UpdateTaskByPlatform(platform, taskChannelM, taskM)
-		}
-		common.SysLog("任务进度轮询完成")
-	}
-}
-
-func UpdateTaskByPlatform(platform constant.TaskPlatform, taskChannelM map[int][]string, taskM map[string]*model.Task) {
-	switch platform {
-	case constant.TaskPlatformMidjourney:
-		//_ = UpdateMidjourneyTaskAll(context.Background(), tasks)
-	case constant.TaskPlatformSuno:
-		_ = UpdateSunoTaskAll(context.Background(), taskChannelM, taskM)
-	default:
-		if err := UpdateVideoTaskAll(context.Background(), platform, taskChannelM, taskM); err != nil {
-			common.SysLog(fmt.Sprintf("UpdateVideoTaskAll fail: %s", err))
-		}
-	}
-}
-
-func UpdateSunoTaskAll(ctx context.Context, taskChannelM map[int][]string, taskM map[string]*model.Task) error {
-	for channelId, taskIds := range taskChannelM {
-		err := updateSunoTaskAll(ctx, channelId, taskIds, taskM)
-		if err != nil {
-			logger.LogError(ctx, fmt.Sprintf("渠道 #%d 更新异步任务失败: %s", channelId, err.Error()))
-		}
-	}
-	return nil
-}
-
-func updateSunoTaskAll(ctx context.Context, channelId int, taskIds []string, taskM map[string]*model.Task) error {
-	logger.LogInfo(ctx, fmt.Sprintf("渠道 #%d 未完成的任务有: %d", channelId, len(taskIds)))
-	if len(taskIds) == 0 {
-		return nil
-	}
-	channel, err := model.CacheGetChannel(channelId)
-	if err != nil {
-		common.SysLog(fmt.Sprintf("CacheGetChannel: %v", err))
-		err = model.TaskBulkUpdate(taskIds, map[string]any{
-			"fail_reason": fmt.Sprintf("获取渠道信息失败,请联系管理员,渠道ID:%d", channelId),
-			"status":      "FAILURE",
-			"progress":    "100%",
-		})
-		if err != nil {
-			common.SysLog(fmt.Sprintf("UpdateMidjourneyTask error2: %v", err))
-		}
-		return err
-	}
-	adaptor := relay.GetTaskAdaptor(constant.TaskPlatformSuno)
-	if adaptor == nil {
-		return errors.New("adaptor not found")
-	}
-	proxy := channel.GetSetting().Proxy
-	resp, err := adaptor.FetchTask(*channel.BaseURL, channel.Key, map[string]any{
-		"ids": taskIds,
-	}, proxy)
-	if err != nil {
-		common.SysLog(fmt.Sprintf("Get Task Do req error: %v", err))
-		return err
-	}
-	if resp.StatusCode != http.StatusOK {
-		logger.LogError(ctx, fmt.Sprintf("Get Task status code: %d", resp.StatusCode))
-		return errors.New(fmt.Sprintf("Get Task status code: %d", resp.StatusCode))
-	}
-	defer resp.Body.Close()
-	responseBody, err := io.ReadAll(resp.Body)
-	if err != nil {
-		common.SysLog(fmt.Sprintf("Get Task parse body error: %v", err))
-		return err
-	}
-	var responseItems dto.TaskResponse[[]dto.SunoDataResponse]
-	err = json.Unmarshal(responseBody, &responseItems)
-	if err != nil {
-		logger.LogError(ctx, fmt.Sprintf("Get Task parse body error2: %v, body: %s", err, string(responseBody)))
-		return err
-	}
-	if !responseItems.IsSuccess() {
-		common.SysLog(fmt.Sprintf("渠道 #%d 未完成的任务有: %d, 成功获取到任务数: %s", channelId, len(taskIds), string(responseBody)))
-		return err
-	}
-
-	for _, responseItem := range responseItems.Data {
-		task := taskM[responseItem.TaskID]
-		if !checkTaskNeedUpdate(task, responseItem) {
-			continue
-		}
-
-		task.Status = lo.If(model.TaskStatus(responseItem.Status) != "", model.TaskStatus(responseItem.Status)).Else(task.Status)
-		task.FailReason = lo.If(responseItem.FailReason != "", responseItem.FailReason).Else(task.FailReason)
-		task.SubmitTime = lo.If(responseItem.SubmitTime != 0, responseItem.SubmitTime).Else(task.SubmitTime)
-		task.StartTime = lo.If(responseItem.StartTime != 0, responseItem.StartTime).Else(task.StartTime)
-		task.FinishTime = lo.If(responseItem.FinishTime != 0, responseItem.FinishTime).Else(task.FinishTime)
-		if responseItem.FailReason != "" || task.Status == model.TaskStatusFailure {
-			logger.LogInfo(ctx, task.TaskID+" 构建失败,"+task.FailReason)
-			task.Progress = "100%"
-			//err = model.CacheUpdateUserQuota(task.UserId) ?
-			if err != nil {
-				logger.LogError(ctx, "error update user quota cache: "+err.Error())
-			} else {
-				quota := task.Quota
-				if quota != 0 {
-					err = model.IncreaseUserQuota(task.UserId, quota, false)
-					if err != nil {
-						logger.LogError(ctx, "fail to increase user quota: "+err.Error())
-					}
-					logContent := fmt.Sprintf("异步任务执行失败 %s,补偿 %s", task.TaskID, logger.LogQuota(quota))
-					model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
-				}
-			}
-		}
-		if responseItem.Status == model.TaskStatusSuccess {
-			task.Progress = "100%"
-		}
-		task.Data = responseItem.Data
-
-		err = task.Update()
-		if err != nil {
-			common.SysLog("UpdateMidjourneyTask task error: " + err.Error())
-		}
-	}
-	return nil
-}
-
-func checkTaskNeedUpdate(oldTask *model.Task, newTask dto.SunoDataResponse) bool {
-
-	if oldTask.SubmitTime != newTask.SubmitTime {
-		return true
-	}
-	if oldTask.StartTime != newTask.StartTime {
-		return true
-	}
-	if oldTask.FinishTime != newTask.FinishTime {
-		return true
-	}
-	if string(oldTask.Status) != newTask.Status {
-		return true
-	}
-	if oldTask.FailReason != newTask.FailReason {
-		return true
-	}
-	if oldTask.FinishTime != newTask.FinishTime {
-		return true
-	}
-
-	if (oldTask.Status == model.TaskStatusFailure || oldTask.Status == model.TaskStatusSuccess) && oldTask.Progress != "100%" {
-		return true
-	}
-
-	oldData, _ := json.Marshal(oldTask.Data)
-	newData, _ := json.Marshal(newTask.Data)
-
-	sort.Slice(oldData, func(i, j int) bool {
-		return oldData[i] < oldData[j]
-	})
-	sort.Slice(newData, func(i, j int) bool {
-		return newData[i] < newData[j]
-	})
-
-	if string(oldData) != string(newData) {
-		return true
-	}
-	return false
+	service.TaskPollingLoop()
 }
 
 func GetAllTask(c *gin.Context) {
@@ -247,7 +37,7 @@ func GetAllTask(c *gin.Context) {
 	items := model.TaskGetAllTasks(pageInfo.GetStartIdx(), pageInfo.GetPageSize(), queryParams)
 	total := model.TaskCountAllTasks(queryParams)
 	pageInfo.SetTotal(int(total))
-	pageInfo.SetItems(items)
+	pageInfo.SetItems(tasksToDto(items))
 	common.ApiSuccess(c, pageInfo)
 }
 
@@ -271,6 +61,14 @@ func GetUserTask(c *gin.Context) {
 	items := model.TaskGetAllUserTask(userId, pageInfo.GetStartIdx(), pageInfo.GetPageSize(), queryParams)
 	total := model.TaskCountAllUserTask(userId, queryParams)
 	pageInfo.SetTotal(int(total))
-	pageInfo.SetItems(items)
+	pageInfo.SetItems(tasksToDto(items))
 	common.ApiSuccess(c, pageInfo)
 }
+
+func tasksToDto(tasks []*model.Task) []*dto.TaskDto {
+	result := make([]*dto.TaskDto, len(tasks))
+	for i, task := range tasks {
+		result[i] = relay.TaskModel2Dto(task)
+	}
+	return result
+}

+ 0 - 313
controller/task_video.go

@@ -1,313 +0,0 @@
-package controller
-
-import (
-	"context"
-	"encoding/json"
-	"fmt"
-	"io"
-	"time"
-
-	"github.com/QuantumNous/new-api/common"
-	"github.com/QuantumNous/new-api/constant"
-	"github.com/QuantumNous/new-api/dto"
-	"github.com/QuantumNous/new-api/logger"
-	"github.com/QuantumNous/new-api/model"
-	"github.com/QuantumNous/new-api/relay"
-	"github.com/QuantumNous/new-api/relay/channel"
-	relaycommon "github.com/QuantumNous/new-api/relay/common"
-	"github.com/QuantumNous/new-api/setting/ratio_setting"
-)
-
-func UpdateVideoTaskAll(ctx context.Context, platform constant.TaskPlatform, taskChannelM map[int][]string, taskM map[string]*model.Task) error {
-	for channelId, taskIds := range taskChannelM {
-		if err := updateVideoTaskAll(ctx, platform, channelId, taskIds, taskM); err != nil {
-			logger.LogError(ctx, fmt.Sprintf("Channel #%d failed to update video async tasks: %s", channelId, err.Error()))
-		}
-	}
-	return nil
-}
-
-func updateVideoTaskAll(ctx context.Context, platform constant.TaskPlatform, channelId int, taskIds []string, taskM map[string]*model.Task) error {
-	logger.LogInfo(ctx, fmt.Sprintf("Channel #%d pending video tasks: %d", channelId, len(taskIds)))
-	if len(taskIds) == 0 {
-		return nil
-	}
-	cacheGetChannel, err := model.CacheGetChannel(channelId)
-	if err != nil {
-		errUpdate := model.TaskBulkUpdate(taskIds, map[string]any{
-			"fail_reason": fmt.Sprintf("Failed to get channel info, channel ID: %d", channelId),
-			"status":      "FAILURE",
-			"progress":    "100%",
-		})
-		if errUpdate != nil {
-			common.SysLog(fmt.Sprintf("UpdateVideoTask error: %v", errUpdate))
-		}
-		return fmt.Errorf("CacheGetChannel failed: %w", err)
-	}
-	adaptor := relay.GetTaskAdaptor(platform)
-	if adaptor == nil {
-		return fmt.Errorf("video adaptor not found")
-	}
-	info := &relaycommon.RelayInfo{}
-	info.ChannelMeta = &relaycommon.ChannelMeta{
-		ChannelBaseUrl: cacheGetChannel.GetBaseURL(),
-	}
-	info.ApiKey = cacheGetChannel.Key
-	adaptor.Init(info)
-	for _, taskId := range taskIds {
-		if err := updateVideoSingleTask(ctx, adaptor, cacheGetChannel, taskId, taskM); err != nil {
-			logger.LogError(ctx, fmt.Sprintf("Failed to update video task %s: %s", taskId, err.Error()))
-		}
-	}
-	return nil
-}
-
-func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, channel *model.Channel, taskId string, taskM map[string]*model.Task) error {
-	baseURL := constant.ChannelBaseURLs[channel.Type]
-	if channel.GetBaseURL() != "" {
-		baseURL = channel.GetBaseURL()
-	}
-	proxy := channel.GetSetting().Proxy
-
-	task := taskM[taskId]
-	if task == nil {
-		logger.LogError(ctx, fmt.Sprintf("Task %s not found in taskM", taskId))
-		return fmt.Errorf("task %s not found", taskId)
-	}
-	key := channel.Key
-
-	privateData := task.PrivateData
-	if privateData.Key != "" {
-		key = privateData.Key
-	}
-	resp, err := adaptor.FetchTask(baseURL, key, map[string]any{
-		"task_id": taskId,
-		"action":  task.Action,
-	}, proxy)
-	if err != nil {
-		return fmt.Errorf("fetchTask failed for task %s: %w", taskId, err)
-	}
-	//if resp.StatusCode != http.StatusOK {
-	//return fmt.Errorf("get Video Task status code: %d", resp.StatusCode)
-	//}
-	defer resp.Body.Close()
-	responseBody, err := io.ReadAll(resp.Body)
-	if err != nil {
-		return fmt.Errorf("readAll failed for task %s: %w", taskId, err)
-	}
-
-	logger.LogDebug(ctx, fmt.Sprintf("UpdateVideoSingleTask response: %s", string(responseBody)))
-
-	taskResult := &relaycommon.TaskInfo{}
-	// try parse as New API response format
-	var responseItems dto.TaskResponse[model.Task]
-	if err = common.Unmarshal(responseBody, &responseItems); err == nil && responseItems.IsSuccess() {
-		logger.LogDebug(ctx, fmt.Sprintf("UpdateVideoSingleTask parsed as new api response format: %+v", responseItems))
-		t := responseItems.Data
-		taskResult.TaskID = t.TaskID
-		taskResult.Status = string(t.Status)
-		taskResult.Url = t.FailReason
-		taskResult.Progress = t.Progress
-		taskResult.Reason = t.FailReason
-		task.Data = t.Data
-	} else if taskResult, err = adaptor.ParseTaskResult(responseBody); err != nil {
-		return fmt.Errorf("parseTaskResult failed for task %s: %w", taskId, err)
-	} else {
-		task.Data = redactVideoResponseBody(responseBody)
-	}
-
-	logger.LogDebug(ctx, fmt.Sprintf("UpdateVideoSingleTask taskResult: %+v", taskResult))
-
-	now := time.Now().Unix()
-	if taskResult.Status == "" {
-		//return fmt.Errorf("task %s status is empty", taskId)
-		taskResult = relaycommon.FailTaskInfo("upstream returned empty status")
-	}
-
-	// 记录原本的状态,防止重复退款
-	shouldRefund := false
-	quota := task.Quota
-	preStatus := task.Status
-
-	task.Status = model.TaskStatus(taskResult.Status)
-	switch taskResult.Status {
-	case model.TaskStatusSubmitted:
-		task.Progress = "10%"
-	case model.TaskStatusQueued:
-		task.Progress = "20%"
-	case model.TaskStatusInProgress:
-		task.Progress = "30%"
-		if task.StartTime == 0 {
-			task.StartTime = now
-		}
-	case model.TaskStatusSuccess:
-		task.Progress = "100%"
-		if task.FinishTime == 0 {
-			task.FinishTime = now
-		}
-		if !(len(taskResult.Url) > 5 && taskResult.Url[:5] == "data:") {
-			task.FailReason = taskResult.Url
-		}
-
-		// 如果返回了 total_tokens 并且配置了模型倍率(非固定价格),则重新计费
-		if taskResult.TotalTokens > 0 {
-			// 获取模型名称
-			var taskData map[string]interface{}
-			if err := json.Unmarshal(task.Data, &taskData); err == nil {
-				if modelName, ok := taskData["model"].(string); ok && modelName != "" {
-					// 获取模型价格和倍率
-					modelRatio, hasRatioSetting, _ := ratio_setting.GetModelRatio(modelName)
-					// 只有配置了倍率(非固定价格)时才按 token 重新计费
-					if hasRatioSetting && modelRatio > 0 {
-						// 获取用户和组的倍率信息
-						group := task.Group
-						if group == "" {
-							user, err := model.GetUserById(task.UserId, false)
-							if err == nil {
-								group = user.Group
-							}
-						}
-						if group != "" {
-							groupRatio := ratio_setting.GetGroupRatio(group)
-							userGroupRatio, hasUserGroupRatio := ratio_setting.GetGroupGroupRatio(group, group)
-
-							var finalGroupRatio float64
-							if hasUserGroupRatio {
-								finalGroupRatio = userGroupRatio
-							} else {
-								finalGroupRatio = groupRatio
-							}
-
-							// 计算实际应扣费额度: totalTokens * modelRatio * groupRatio
-							actualQuota := int(float64(taskResult.TotalTokens) * modelRatio * finalGroupRatio)
-
-							// 计算差额
-							preConsumedQuota := task.Quota
-							quotaDelta := actualQuota - preConsumedQuota
-
-							if quotaDelta > 0 {
-								// 需要补扣费
-								logger.LogInfo(ctx, fmt.Sprintf("视频任务 %s 预扣费后补扣费:%s(实际消耗:%s,预扣费:%s,tokens:%d)",
-									task.TaskID,
-									logger.LogQuota(quotaDelta),
-									logger.LogQuota(actualQuota),
-									logger.LogQuota(preConsumedQuota),
-									taskResult.TotalTokens,
-								))
-								if err := model.DecreaseUserQuota(task.UserId, quotaDelta); err != nil {
-									logger.LogError(ctx, fmt.Sprintf("补扣费失败: %s", err.Error()))
-								} else {
-									model.UpdateUserUsedQuotaAndRequestCount(task.UserId, quotaDelta)
-									model.UpdateChannelUsedQuota(task.ChannelId, quotaDelta)
-									task.Quota = actualQuota // 更新任务记录的实际扣费额度
-
-									// 记录消费日志
-									logContent := fmt.Sprintf("视频任务成功补扣费,模型倍率 %.2f,分组倍率 %.2f,tokens %d,预扣费 %s,实际扣费 %s,补扣费 %s",
-										modelRatio, finalGroupRatio, taskResult.TotalTokens,
-										logger.LogQuota(preConsumedQuota), logger.LogQuota(actualQuota), logger.LogQuota(quotaDelta))
-									model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
-								}
-							} else if quotaDelta < 0 {
-								// 需要退还多扣的费用
-								refundQuota := -quotaDelta
-								logger.LogInfo(ctx, fmt.Sprintf("视频任务 %s 预扣费后返还:%s(实际消耗:%s,预扣费:%s,tokens:%d)",
-									task.TaskID,
-									logger.LogQuota(refundQuota),
-									logger.LogQuota(actualQuota),
-									logger.LogQuota(preConsumedQuota),
-									taskResult.TotalTokens,
-								))
-								if err := model.IncreaseUserQuota(task.UserId, refundQuota, false); err != nil {
-									logger.LogError(ctx, fmt.Sprintf("退还预扣费失败: %s", err.Error()))
-								} else {
-									task.Quota = actualQuota // 更新任务记录的实际扣费额度
-
-									// 记录退款日志
-									logContent := fmt.Sprintf("视频任务成功退还多扣费用,模型倍率 %.2f,分组倍率 %.2f,tokens %d,预扣费 %s,实际扣费 %s,退还 %s",
-										modelRatio, finalGroupRatio, taskResult.TotalTokens,
-										logger.LogQuota(preConsumedQuota), logger.LogQuota(actualQuota), logger.LogQuota(refundQuota))
-									model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
-								}
-							} else {
-								// quotaDelta == 0, 预扣费刚好准确
-								logger.LogInfo(ctx, fmt.Sprintf("视频任务 %s 预扣费准确(%s,tokens:%d)",
-									task.TaskID, logger.LogQuota(actualQuota), taskResult.TotalTokens))
-							}
-						}
-					}
-				}
-			}
-		}
-	case model.TaskStatusFailure:
-		logger.LogJson(ctx, fmt.Sprintf("Task %s failed", taskId), task)
-		task.Status = model.TaskStatusFailure
-		task.Progress = "100%"
-		if task.FinishTime == 0 {
-			task.FinishTime = now
-		}
-		task.FailReason = taskResult.Reason
-		logger.LogInfo(ctx, fmt.Sprintf("Task %s failed: %s", task.TaskID, task.FailReason))
-		taskResult.Progress = "100%"
-		if quota != 0 {
-			if preStatus != model.TaskStatusFailure {
-				shouldRefund = true
-			} else {
-				logger.LogWarn(ctx, fmt.Sprintf("Task %s already in failure status, skip refund", task.TaskID))
-			}
-		}
-	default:
-		return fmt.Errorf("unknown task status %s for task %s", taskResult.Status, taskId)
-	}
-	if taskResult.Progress != "" {
-		task.Progress = taskResult.Progress
-	}
-	if err := task.Update(); err != nil {
-		common.SysLog("UpdateVideoTask task error: " + err.Error())
-		shouldRefund = false
-	}
-
-	if shouldRefund {
-		// 任务失败且之前状态不是失败才退还额度,防止重复退还
-		if err := model.IncreaseUserQuota(task.UserId, quota, false); err != nil {
-			logger.LogWarn(ctx, "Failed to increase user quota: "+err.Error())
-		}
-		logContent := fmt.Sprintf("Video async task failed %s, refund %s", task.TaskID, logger.LogQuota(quota))
-		model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
-	}
-
-	return nil
-}
-
-func redactVideoResponseBody(body []byte) []byte {
-	var m map[string]any
-	if err := json.Unmarshal(body, &m); err != nil {
-		return body
-	}
-	resp, _ := m["response"].(map[string]any)
-	if resp != nil {
-		delete(resp, "bytesBase64Encoded")
-		if v, ok := resp["video"].(string); ok {
-			resp["video"] = truncateBase64(v)
-		}
-		if vs, ok := resp["videos"].([]any); ok {
-			for i := range vs {
-				if vm, ok := vs[i].(map[string]any); ok {
-					delete(vm, "bytesBase64Encoded")
-				}
-			}
-		}
-	}
-	b, err := json.Marshal(m)
-	if err != nil {
-		return body
-	}
-	return b
-}
-
-func truncateBase64(s string) string {
-	const maxKeep = 256
-	if len(s) <= maxKeep {
-		return s
-	}
-	return s[:maxKeep] + "..."
-}

+ 30 - 81
controller/video_proxy.go

@@ -16,59 +16,44 @@ import (
 	"github.com/gin-gonic/gin"
 )
 
+// videoProxyError returns a standardized OpenAI-style error response.
+func videoProxyError(c *gin.Context, status int, errType, message string) {
+	c.JSON(status, gin.H{
+		"error": gin.H{
+			"message": message,
+			"type":    errType,
+		},
+	})
+}
+
 func VideoProxy(c *gin.Context) {
 	taskID := c.Param("task_id")
 	if taskID == "" {
-		c.JSON(http.StatusBadRequest, gin.H{
-			"error": gin.H{
-				"message": "task_id is required",
-				"type":    "invalid_request_error",
-			},
-		})
+		videoProxyError(c, http.StatusBadRequest, "invalid_request_error", "task_id is required")
 		return
 	}
 
 	task, exists, err := model.GetByOnlyTaskId(taskID)
 	if err != nil {
 		logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to query task %s: %s", taskID, err.Error()))
-		c.JSON(http.StatusInternalServerError, gin.H{
-			"error": gin.H{
-				"message": "Failed to query task",
-				"type":    "server_error",
-			},
-		})
+		videoProxyError(c, http.StatusInternalServerError, "server_error", "Failed to query task")
 		return
 	}
 	if !exists || task == nil {
-		logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to get task %s: %v", taskID, err))
-		c.JSON(http.StatusNotFound, gin.H{
-			"error": gin.H{
-				"message": "Task not found",
-				"type":    "invalid_request_error",
-			},
-		})
+		videoProxyError(c, http.StatusNotFound, "invalid_request_error", "Task not found")
 		return
 	}
 
 	if task.Status != model.TaskStatusSuccess {
-		c.JSON(http.StatusBadRequest, gin.H{
-			"error": gin.H{
-				"message": fmt.Sprintf("Task is not completed yet, current status: %s", task.Status),
-				"type":    "invalid_request_error",
-			},
-		})
+		videoProxyError(c, http.StatusBadRequest, "invalid_request_error",
+			fmt.Sprintf("Task is not completed yet, current status: %s", task.Status))
 		return
 	}
 
 	channel, err := model.CacheGetChannel(task.ChannelId)
 	if err != nil {
-		logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to get task %s: not found", taskID))
-		c.JSON(http.StatusInternalServerError, gin.H{
-			"error": gin.H{
-				"message": "Failed to retrieve channel information",
-				"type":    "server_error",
-			},
-		})
+		logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to get channel for task %s: %s", taskID, err.Error()))
+		videoProxyError(c, http.StatusInternalServerError, "server_error", "Failed to retrieve channel information")
 		return
 	}
 	baseURL := channel.GetBaseURL()
@@ -81,12 +66,7 @@ func VideoProxy(c *gin.Context) {
 	client, err := service.GetHttpClientWithProxy(proxy)
 	if err != nil {
 		logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to create proxy client for task %s: %s", taskID, err.Error()))
-		c.JSON(http.StatusInternalServerError, gin.H{
-			"error": gin.H{
-				"message": "Failed to create proxy client",
-				"type":    "server_error",
-			},
-		})
+		videoProxyError(c, http.StatusInternalServerError, "server_error", "Failed to create proxy client")
 		return
 	}
 
@@ -95,12 +75,7 @@ func VideoProxy(c *gin.Context) {
 	req, err := http.NewRequestWithContext(ctx, http.MethodGet, "", nil)
 	if err != nil {
 		logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to create request: %s", err.Error()))
-		c.JSON(http.StatusInternalServerError, gin.H{
-			"error": gin.H{
-				"message": "Failed to create proxy request",
-				"type":    "server_error",
-			},
-		})
+		videoProxyError(c, http.StatusInternalServerError, "server_error", "Failed to create proxy request")
 		return
 	}
 
@@ -109,68 +84,43 @@ func VideoProxy(c *gin.Context) {
 		apiKey := task.PrivateData.Key
 		if apiKey == "" {
 			logger.LogError(c.Request.Context(), fmt.Sprintf("Missing stored API key for Gemini task %s", taskID))
-			c.JSON(http.StatusInternalServerError, gin.H{
-				"error": gin.H{
-					"message": "API key not stored for task",
-					"type":    "server_error",
-				},
-			})
+			videoProxyError(c, http.StatusInternalServerError, "server_error", "API key not stored for task")
 			return
 		}
-
 		videoURL, err = getGeminiVideoURL(channel, task, apiKey)
 		if err != nil {
 			logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to resolve Gemini video URL for task %s: %s", taskID, err.Error()))
-			c.JSON(http.StatusBadGateway, gin.H{
-				"error": gin.H{
-					"message": "Failed to resolve Gemini video URL",
-					"type":    "server_error",
-				},
-			})
+			videoProxyError(c, http.StatusBadGateway, "server_error", "Failed to resolve Gemini video URL")
 			return
 		}
 		req.Header.Set("x-goog-api-key", apiKey)
 	case constant.ChannelTypeOpenAI, constant.ChannelTypeSora:
-		videoURL = fmt.Sprintf("%s/v1/videos/%s/content", baseURL, task.TaskID)
+		videoURL = fmt.Sprintf("%s/v1/videos/%s/content", baseURL, task.GetUpstreamTaskID())
 		req.Header.Set("Authorization", "Bearer "+channel.Key)
 	default:
-		// Video URL is directly in task.FailReason
-		videoURL = task.FailReason
+		// Video URL is stored in PrivateData.ResultURL (fallback to FailReason for old data)
+		videoURL = task.GetResultURL()
 	}
 
 	req.URL, err = url.Parse(videoURL)
 	if err != nil {
 		logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to parse URL %s: %s", videoURL, err.Error()))
-		c.JSON(http.StatusInternalServerError, gin.H{
-			"error": gin.H{
-				"message": "Failed to create proxy request",
-				"type":    "server_error",
-			},
-		})
+		videoProxyError(c, http.StatusInternalServerError, "server_error", "Failed to create proxy request")
 		return
 	}
 
 	resp, err := client.Do(req)
 	if err != nil {
 		logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to fetch video from %s: %s", videoURL, err.Error()))
-		c.JSON(http.StatusBadGateway, gin.H{
-			"error": gin.H{
-				"message": "Failed to fetch video content",
-				"type":    "server_error",
-			},
-		})
+		videoProxyError(c, http.StatusBadGateway, "server_error", "Failed to fetch video content")
 		return
 	}
 	defer resp.Body.Close()
 
 	if resp.StatusCode != http.StatusOK {
 		logger.LogError(c.Request.Context(), fmt.Sprintf("Upstream returned status %d for %s", resp.StatusCode, videoURL))
-		c.JSON(http.StatusBadGateway, gin.H{
-			"error": gin.H{
-				"message": fmt.Sprintf("Upstream service returned status %d", resp.StatusCode),
-				"type":    "server_error",
-			},
-		})
+		videoProxyError(c, http.StatusBadGateway, "server_error",
+			fmt.Sprintf("Upstream service returned status %d", resp.StatusCode))
 		return
 	}
 
@@ -180,10 +130,9 @@ func VideoProxy(c *gin.Context) {
 		}
 	}
 
-	c.Writer.Header().Set("Cache-Control", "public, max-age=86400") // Cache for 24 hours
+	c.Writer.Header().Set("Cache-Control", "public, max-age=86400")
 	c.Writer.WriteHeader(resp.StatusCode)
-	_, err = io.Copy(c.Writer, resp.Body)
-	if err != nil {
+	if _, err = io.Copy(c.Writer, resp.Body); err != nil {
 		logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to stream video content: %s", err.Error()))
 	}
 }

+ 4 - 4
controller/video_proxy_gemini.go

@@ -1,12 +1,12 @@
 package controller
 
 import (
-	"encoding/json"
 	"fmt"
 	"io"
 	"strconv"
 	"strings"
 
+	"github.com/QuantumNous/new-api/common"
 	"github.com/QuantumNous/new-api/constant"
 	"github.com/QuantumNous/new-api/model"
 	"github.com/QuantumNous/new-api/relay"
@@ -37,7 +37,7 @@ func getGeminiVideoURL(channel *model.Channel, task *model.Task, apiKey string)
 
 	proxy := channel.GetSetting().Proxy
 	resp, err := adaptor.FetchTask(baseURL, apiKey, map[string]any{
-		"task_id": task.TaskID,
+		"task_id": task.GetUpstreamTaskID(),
 		"action":  task.Action,
 	}, proxy)
 	if err != nil {
@@ -71,7 +71,7 @@ func extractGeminiVideoURLFromTaskData(task *model.Task) string {
 		return ""
 	}
 	var payload map[string]any
-	if err := json.Unmarshal(task.Data, &payload); err != nil {
+	if err := common.Unmarshal(task.Data, &payload); err != nil {
 		return ""
 	}
 	return extractGeminiVideoURLFromMap(payload)
@@ -79,7 +79,7 @@ func extractGeminiVideoURLFromTaskData(task *model.Task) string {
 
 func extractGeminiVideoURLFromPayload(body []byte) string {
 	var payload map[string]any
-	if err := json.Unmarshal(body, &payload); err != nil {
+	if err := common.Unmarshal(body, &payload); err != nil {
 		return ""
 	}
 	return extractGeminiVideoURLFromMap(payload)

+ 0 - 32
dto/suno.go

@@ -4,10 +4,6 @@ import (
 	"encoding/json"
 )
 
-type TaskData interface {
-	SunoDataResponse | []SunoDataResponse | string | any
-}
-
 type SunoSubmitReq struct {
 	GptDescriptionPrompt string  `json:"gpt_description_prompt,omitempty"`
 	Prompt               string  `json:"prompt,omitempty"`
@@ -20,10 +16,6 @@ type SunoSubmitReq struct {
 	MakeInstrumental     bool    `json:"make_instrumental"`
 }
 
-type FetchReq struct {
-	IDs []string `json:"ids"`
-}
-
 type SunoDataResponse struct {
 	TaskID     string          `json:"task_id" gorm:"type:varchar(50);index"`
 	Action     string          `json:"action" gorm:"type:varchar(40);index"` // 任务类型, song, lyrics, description-mode
@@ -66,30 +58,6 @@ type SunoLyrics struct {
 	Text   string `json:"text"`
 }
 
-const TaskSuccessCode = "success"
-
-type TaskResponse[T TaskData] struct {
-	Code    string `json:"code"`
-	Message string `json:"message"`
-	Data    T      `json:"data"`
-}
-
-func (t *TaskResponse[T]) IsSuccess() bool {
-	return t.Code == TaskSuccessCode
-}
-
-type TaskDto struct {
-	TaskID     string          `json:"task_id"` // 第三方id,不一定有/ song id\ Task id
-	Action     string          `json:"action"`  // 任务类型, song, lyrics, description-mode
-	Status     string          `json:"status"`  // 任务状态, submitted, queueing, processing, success, failed
-	FailReason string          `json:"fail_reason"`
-	SubmitTime int64           `json:"submit_time"`
-	StartTime  int64           `json:"start_time"`
-	FinishTime int64           `json:"finish_time"`
-	Progress   string          `json:"progress"`
-	Data       json.RawMessage `json:"data"`
-}
-
 type SunoGoAPISubmitReq struct {
 	CustomMode bool `json:"custom_mode"`
 

+ 47 - 0
dto/task.go

@@ -1,5 +1,9 @@
 package dto
 
+import (
+	"encoding/json"
+)
+
 type TaskError struct {
 	Code       string `json:"code"`
 	Message    string `json:"message"`
@@ -8,3 +12,46 @@ type TaskError struct {
 	LocalError bool   `json:"-"`
 	Error      error  `json:"-"`
 }
+
+type TaskData interface {
+	SunoDataResponse | []SunoDataResponse | string | any
+}
+
+const TaskSuccessCode = "success"
+
+type TaskResponse[T TaskData] struct {
+	Code    string `json:"code"`
+	Message string `json:"message"`
+	Data    T      `json:"data"`
+}
+
+func (t *TaskResponse[T]) IsSuccess() bool {
+	return t.Code == TaskSuccessCode
+}
+
+type TaskDto struct {
+	ID         int64           `json:"id"`
+	CreatedAt  int64           `json:"created_at"`
+	UpdatedAt  int64           `json:"updated_at"`
+	TaskID     string          `json:"task_id"`
+	Platform   string          `json:"platform"`
+	UserId     int             `json:"user_id"`
+	Group      string          `json:"group"`
+	ChannelId  int             `json:"channel_id"`
+	Quota      int             `json:"quota"`
+	Action     string          `json:"action"`
+	Status     string          `json:"status"`
+	FailReason string          `json:"fail_reason"`
+	ResultURL  string          `json:"result_url,omitempty"` // 任务结果 URL(视频地址等)
+	SubmitTime int64           `json:"submit_time"`
+	StartTime  int64           `json:"start_time"`
+	FinishTime int64           `json:"finish_time"`
+	Progress   string          `json:"progress"`
+	Properties any             `json:"properties"`
+	Username   string          `json:"username,omitempty"`
+	Data       json.RawMessage `json:"data"`
+}
+
+type FetchReq struct {
+	IDs []string `json:"ids"`
+}

+ 10 - 0
main.go

@@ -19,6 +19,7 @@ import (
 	"github.com/QuantumNous/new-api/middleware"
 	"github.com/QuantumNous/new-api/model"
 	"github.com/QuantumNous/new-api/oauth"
+	"github.com/QuantumNous/new-api/relay"
 	"github.com/QuantumNous/new-api/router"
 	"github.com/QuantumNous/new-api/service"
 	_ "github.com/QuantumNous/new-api/setting/performance_setting"
@@ -111,6 +112,15 @@ func main() {
 	// Subscription quota reset task (daily/weekly/monthly/custom)
 	service.StartSubscriptionQuotaResetTask()
 
+	// Wire task polling adaptor factory (breaks service -> relay import cycle)
+	service.GetTaskAdaptorFunc = func(platform constant.TaskPlatform) service.TaskPollingAdaptor {
+		a := relay.GetTaskAdaptor(platform)
+		if a == nil {
+			return nil
+		}
+		return a
+	}
+
 	if common.IsMasterNode && constant.UpdateTask {
 		gopool.Go(func() {
 			controller.UpdateMidjourneyTaskBulk()

+ 18 - 0
middleware/auth.go

@@ -170,6 +170,24 @@ func WssAuth(c *gin.Context) {
 
 }
 
+// TokenOrUserAuth allows either session-based user auth or API token auth.
+// Used for endpoints that need to be accessible from both the dashboard and API clients.
+func TokenOrUserAuth() func(c *gin.Context) {
+	return func(c *gin.Context) {
+		// Try session auth first (dashboard users)
+		session := sessions.Default(c)
+		if id := session.Get("id"); id != nil {
+			if status, ok := session.Get("status").(int); ok && status == common.UserStatusEnabled {
+				c.Set("id", id)
+				c.Next()
+				return
+			}
+		}
+		// Fall back to token auth (API clients)
+		TokenAuth()(c)
+	}
+}
+
 // TokenAuthReadOnly 宽松版本的令牌认证中间件,用于只读查询接口。
 // 只验证令牌 key 是否存在,不检查令牌状态、过期时间和额度。
 // 即使令牌已过期、已耗尽或已禁用,也允许访问。

+ 48 - 9
model/task.go

@@ -5,6 +5,7 @@ import (
 	"encoding/json"
 	"time"
 
+	"github.com/QuantumNous/new-api/common"
 	"github.com/QuantumNous/new-api/constant"
 	"github.com/QuantumNous/new-api/dto"
 	commonRelay "github.com/QuantumNous/new-api/relay/common"
@@ -64,13 +65,12 @@ type Task struct {
 }
 
 func (t *Task) SetData(data any) {
-	b, _ := json.Marshal(data)
+	b, _ := common.Marshal(data)
 	t.Data = json.RawMessage(b)
 }
 
 func (t *Task) GetData(v any) error {
-	err := json.Unmarshal(t.Data, &v)
-	return err
+	return common.Unmarshal(t.Data, &v)
 }
 
 type Properties struct {
@@ -85,18 +85,48 @@ func (m *Properties) Scan(val interface{}) error {
 		*m = Properties{}
 		return nil
 	}
-	return json.Unmarshal(bytesValue, m)
+	return common.Unmarshal(bytesValue, m)
 }
 
 func (m Properties) Value() (driver.Value, error) {
 	if m == (Properties{}) {
 		return nil, nil
 	}
-	return json.Marshal(m)
+	return common.Marshal(m)
 }
 
 type TaskPrivateData struct {
-	Key string `json:"key,omitempty"`
+	Key            string `json:"key,omitempty"`
+	UpstreamTaskID string `json:"upstream_task_id,omitempty"` // 上游真实 task ID
+	ResultURL      string `json:"result_url,omitempty"`       // 任务成功后的结果 URL(视频地址等)
+	// 计费上下文:用于异步退款/差额结算(轮询阶段读取)
+	BillingSource  string `json:"billing_source,omitempty"`  // "wallet" 或 "subscription"
+	SubscriptionId int    `json:"subscription_id,omitempty"` // 订阅 ID,用于订阅退款
+	TokenId        int    `json:"token_id,omitempty"`        // 令牌 ID,用于令牌额度退款
+}
+
+// GetUpstreamTaskID 获取上游真实 task ID(用于与 provider 通信)
+// 旧数据没有 UpstreamTaskID 时,TaskID 本身就是上游 ID
+func (t *Task) GetUpstreamTaskID() string {
+	if t.PrivateData.UpstreamTaskID != "" {
+		return t.PrivateData.UpstreamTaskID
+	}
+	return t.TaskID
+}
+
+// GetResultURL 获取任务结果 URL(视频地址等)
+// 新数据存在 PrivateData.ResultURL 中;旧数据回退到 FailReason(历史兼容)
+func (t *Task) GetResultURL() string {
+	if t.PrivateData.ResultURL != "" {
+		return t.PrivateData.ResultURL
+	}
+	return t.FailReason
+}
+
+// GenerateTaskID 生成对外暴露的 task_xxxx 格式 ID
+func GenerateTaskID() string {
+	key, _ := common.GenerateRandomCharsKey(32)
+	return "task_" + key
 }
 
 func (p *TaskPrivateData) Scan(val interface{}) error {
@@ -104,14 +134,14 @@ func (p *TaskPrivateData) Scan(val interface{}) error {
 	if len(bytesValue) == 0 {
 		return nil
 	}
-	return json.Unmarshal(bytesValue, p)
+	return common.Unmarshal(bytesValue, p)
 }
 
 func (p TaskPrivateData) Value() (driver.Value, error) {
 	if (p == TaskPrivateData{}) {
 		return nil, nil
 	}
-	return json.Marshal(p)
+	return common.Marshal(p)
 }
 
 // SyncTaskQueryParams 用于包含所有搜索条件的结构体,可以根据需求添加更多字段
@@ -142,7 +172,16 @@ func InitTask(platform constant.TaskPlatform, relayInfo *commonRelay.RelayInfo)
 		}
 	}
 
+	// 使用预生成的公开 ID(如果有),否则新生成
+	taskID := ""
+	if relayInfo.TaskRelayInfo != nil && relayInfo.TaskRelayInfo.PublicTaskID != "" {
+		taskID = relayInfo.TaskRelayInfo.PublicTaskID
+	} else {
+		taskID = GenerateTaskID()
+	}
+
 	t := &Task{
+		TaskID:      taskID,
 		UserId:      relayInfo.UserId,
 		Group:       relayInfo.UsingGroup,
 		SubmitTime:  time.Now().Unix(),
@@ -438,6 +477,6 @@ func (t *Task) ToOpenAIVideo() *dto.OpenAIVideo {
 	openAIVideo.SetProgressStr(t.Progress)
 	openAIVideo.CreatedAt = t.CreatedAt
 	openAIVideo.CompletedAt = t.UpdatedAt
-	openAIVideo.SetMetadata("url", t.FailReason)
+	openAIVideo.SetMetadata("url", t.GetResultURL())
 	return openAIVideo
 }

+ 3 - 3
model/token.go

@@ -360,7 +360,7 @@ func DeleteTokenById(id int, userId int) (err error) {
 	return token.Delete()
 }
 
-func IncreaseTokenQuota(id int, key string, quota int) (err error) {
+func IncreaseTokenQuota(tokenId int, key string, quota int) (err error) {
 	if quota < 0 {
 		return errors.New("quota 不能为负数!")
 	}
@@ -373,10 +373,10 @@ func IncreaseTokenQuota(id int, key string, quota int) (err error) {
 		})
 	}
 	if common.BatchUpdateEnabled {
-		addNewRecord(BatchUpdateTypeTokenQuota, id, quota)
+		addNewRecord(BatchUpdateTypeTokenQuota, tokenId, quota)
 		return nil
 	}
-	return increaseTokenQuota(id, quota)
+	return increaseTokenQuota(tokenId, quota)
 }
 
 func increaseTokenQuota(id int, quota int) (err error) {

+ 2 - 1
relay/channel/task/ali/adaptor.go

@@ -384,7 +384,8 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
 
 	// 转换为 OpenAI 格式响应
 	openAIResp := dto.NewOpenAIVideo()
-	openAIResp.ID = aliResp.Output.TaskID
+	openAIResp.ID = info.PublicTaskID
+	openAIResp.TaskID = info.PublicTaskID
 	openAIResp.Model = c.GetString("model")
 	if openAIResp.Model == "" && info != nil {
 		openAIResp.Model = info.OriginModelName

+ 9 - 15
relay/channel/task/doubao/adaptor.go

@@ -2,7 +2,6 @@ package doubao
 
 import (
 	"bytes"
-	"encoding/json"
 	"fmt"
 	"io"
 	"net/http"
@@ -14,6 +13,7 @@ import (
 	"github.com/QuantumNous/new-api/dto"
 	"github.com/QuantumNous/new-api/model"
 	"github.com/QuantumNous/new-api/relay/channel"
+	taskcommon "github.com/QuantumNous/new-api/relay/channel/task/taskcommon"
 	relaycommon "github.com/QuantumNous/new-api/relay/common"
 	"github.com/QuantumNous/new-api/service"
 
@@ -131,7 +131,7 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn
 		return nil, errors.Wrap(err, "convert request payload failed")
 	}
 	info.UpstreamModelName = body.Model
-	data, err := json.Marshal(body)
+	data, err := common.Marshal(body)
 	if err != nil {
 		return nil, err
 	}
@@ -154,7 +154,7 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
 
 	// Parse Doubao response
 	var dResp responsePayload
-	if err := json.Unmarshal(responseBody, &dResp); err != nil {
+	if err := common.Unmarshal(responseBody, &dResp); err != nil {
 		taskErr = service.TaskErrorWrapper(errors.Wrapf(err, "body: %s", responseBody), "unmarshal_response_body_failed", http.StatusInternalServerError)
 		return
 	}
@@ -165,8 +165,8 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
 	}
 
 	ov := dto.NewOpenAIVideo()
-	ov.ID = dResp.ID
-	ov.TaskID = dResp.ID
+	ov.ID = info.PublicTaskID
+	ov.TaskID = info.PublicTaskID
 	ov.CreatedAt = time.Now().Unix()
 	ov.Model = info.OriginModelName
 
@@ -234,12 +234,7 @@ func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*
 	}
 
 	metadata := req.Metadata
-	medaBytes, err := json.Marshal(metadata)
-	if err != nil {
-		return nil, errors.Wrap(err, "metadata marshal metadata failed")
-	}
-	err = json.Unmarshal(medaBytes, &r)
-	if err != nil {
+	if err := taskcommon.UnmarshalMetadata(metadata, &r); err != nil {
 		return nil, errors.Wrap(err, "unmarshal metadata failed")
 	}
 
@@ -248,7 +243,7 @@ func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*
 
 func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) {
 	resTask := responseTask{}
-	if err := json.Unmarshal(respBody, &resTask); err != nil {
+	if err := common.Unmarshal(respBody, &resTask); err != nil {
 		return nil, errors.Wrap(err, "unmarshal task result failed")
 	}
 
@@ -286,7 +281,7 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e
 
 func (a *TaskAdaptor) ConvertToOpenAIVideo(originTask *model.Task) ([]byte, error) {
 	var dResp responseTask
-	if err := json.Unmarshal(originTask.Data, &dResp); err != nil {
+	if err := common.Unmarshal(originTask.Data, &dResp); err != nil {
 		return nil, errors.Wrap(err, "unmarshal doubao task data failed")
 	}
 
@@ -307,6 +302,5 @@ func (a *TaskAdaptor) ConvertToOpenAIVideo(originTask *model.Task) ([]byte, erro
 		}
 	}
 
-	jsonData, _ := common.Marshal(openAIVideo)
-	return jsonData, nil
+	return common.Marshal(openAIVideo)
 }

+ 15 - 32
relay/channel/task/gemini/adaptor.go

@@ -2,8 +2,6 @@ package gemini
 
 import (
 	"bytes"
-	"encoding/base64"
-	"encoding/json"
 	"fmt"
 	"io"
 	"net/http"
@@ -16,10 +14,10 @@ import (
 	"github.com/QuantumNous/new-api/dto"
 	"github.com/QuantumNous/new-api/model"
 	"github.com/QuantumNous/new-api/relay/channel"
+	taskcommon "github.com/QuantumNous/new-api/relay/channel/task/taskcommon"
 	relaycommon "github.com/QuantumNous/new-api/relay/common"
 	"github.com/QuantumNous/new-api/service"
 	"github.com/QuantumNous/new-api/setting/model_setting"
-	"github.com/QuantumNous/new-api/setting/system_setting"
 	"github.com/gin-gonic/gin"
 	"github.com/pkg/errors"
 )
@@ -145,16 +143,11 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn
 	}
 
 	metadata := req.Metadata
-	medaBytes, err := json.Marshal(metadata)
-	if err != nil {
-		return nil, errors.Wrap(err, "metadata marshal metadata failed")
-	}
-	err = json.Unmarshal(medaBytes, &body.Parameters)
-	if err != nil {
+	if err := taskcommon.UnmarshalMetadata(metadata, &body.Parameters); err != nil {
 		return nil, errors.Wrap(err, "unmarshal metadata failed")
 	}
 
-	data, err := json.Marshal(body)
+	data, err := common.Marshal(body)
 	if err != nil {
 		return nil, err
 	}
@@ -175,16 +168,16 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
 	_ = resp.Body.Close()
 
 	var s submitResponse
-	if err := json.Unmarshal(responseBody, &s); err != nil {
+	if err := common.Unmarshal(responseBody, &s); err != nil {
 		return "", nil, service.TaskErrorWrapper(err, "unmarshal_response_failed", http.StatusInternalServerError)
 	}
 	if strings.TrimSpace(s.Name) == "" {
 		return "", nil, service.TaskErrorWrapper(fmt.Errorf("missing operation name"), "invalid_response", http.StatusInternalServerError)
 	}
-	taskID = encodeLocalTaskID(s.Name)
+	taskID = taskcommon.EncodeLocalTaskID(s.Name)
 	ov := dto.NewOpenAIVideo()
-	ov.ID = taskID
-	ov.TaskID = taskID
+	ov.ID = info.PublicTaskID
+	ov.TaskID = info.PublicTaskID
 	ov.CreatedAt = time.Now().Unix()
 	ov.Model = info.OriginModelName
 	c.JSON(http.StatusOK, ov)
@@ -206,7 +199,7 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy
 		return nil, fmt.Errorf("invalid task_id")
 	}
 
-	upstreamName, err := decodeLocalTaskID(taskID)
+	upstreamName, err := taskcommon.DecodeLocalTaskID(taskID)
 	if err != nil {
 		return nil, fmt.Errorf("decode task_id failed: %w", err)
 	}
@@ -232,7 +225,7 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy
 
 func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) {
 	var op operationResponse
-	if err := json.Unmarshal(respBody, &op); err != nil {
+	if err := common.Unmarshal(respBody, &op); err != nil {
 		return nil, fmt.Errorf("unmarshal operation response failed: %w", err)
 	}
 
@@ -254,9 +247,8 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e
 	ti.Status = model.TaskStatusSuccess
 	ti.Progress = "100%"
 
-	taskID := encodeLocalTaskID(op.Name)
-	ti.TaskID = taskID
-	ti.Url = fmt.Sprintf("%s/v1/videos/%s/content", system_setting.ServerAddress, taskID)
+	ti.TaskID = taskcommon.EncodeLocalTaskID(op.Name)
+	// Url intentionally left empty — the caller constructs the proxy URL using the public task ID
 
 	// Extract URL from generateVideoResponse if available
 	if len(op.Response.GenerateVideoResponse.GeneratedSamples) > 0 {
@@ -269,7 +261,10 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e
 }
 
 func (a *TaskAdaptor) ConvertToOpenAIVideo(task *model.Task) ([]byte, error) {
-	upstreamName, err := decodeLocalTaskID(task.TaskID)
+	// Use GetUpstreamTaskID() to get the real upstream operation name for model extraction.
+	// task.TaskID is now a public task_xxxx ID, no longer a base64-encoded upstream name.
+	upstreamTaskID := task.GetUpstreamTaskID()
+	upstreamName, err := taskcommon.DecodeLocalTaskID(upstreamTaskID)
 	if err != nil {
 		upstreamName = ""
 	}
@@ -297,18 +292,6 @@ func (a *TaskAdaptor) ConvertToOpenAIVideo(task *model.Task) ([]byte, error) {
 // helpers
 // ============================
 
-func encodeLocalTaskID(name string) string {
-	return base64.RawURLEncoding.EncodeToString([]byte(name))
-}
-
-func decodeLocalTaskID(local string) (string, error) {
-	b, err := base64.RawURLEncoding.DecodeString(local)
-	if err != nil {
-		return "", err
-	}
-	return string(b), nil
-}
-
 var modelRe = regexp.MustCompile(`models/([^/]+)/operations/`)
 
 func extractModelFromOperationName(name string) string {

+ 7 - 8
relay/channel/task/hailuo/adaptor.go

@@ -2,7 +2,6 @@ package hailuo
 
 import (
 	"bytes"
-	"encoding/json"
 	"fmt"
 	"io"
 	"net/http"
@@ -65,7 +64,7 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn
 		return nil, errors.Wrap(err, "convert request payload failed")
 	}
 
-	data, err := json.Marshal(body)
+	data, err := common.Marshal(body)
 	if err != nil {
 		return nil, err
 	}
@@ -86,7 +85,7 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
 	_ = resp.Body.Close()
 
 	var hResp VideoResponse
-	if err := json.Unmarshal(responseBody, &hResp); err != nil {
+	if err := common.Unmarshal(responseBody, &hResp); err != nil {
 		taskErr = service.TaskErrorWrapper(errors.Wrapf(err, "body: %s", responseBody), "unmarshal_response_body_failed", http.StatusInternalServerError)
 		return
 	}
@@ -101,8 +100,8 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
 	}
 
 	ov := dto.NewOpenAIVideo()
-	ov.ID = hResp.TaskID
-	ov.TaskID = hResp.TaskID
+	ov.ID = info.PublicTaskID
+	ov.TaskID = info.PublicTaskID
 	ov.CreatedAt = time.Now().Unix()
 	ov.Model = info.OriginModelName
 
@@ -182,7 +181,7 @@ func (a *TaskAdaptor) parseResolutionFromSize(size string, modelConfig ModelConf
 
 func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) {
 	resTask := QueryTaskResponse{}
-	if err := json.Unmarshal(respBody, &resTask); err != nil {
+	if err := common.Unmarshal(respBody, &resTask); err != nil {
 		return nil, errors.Wrap(err, "unmarshal task result failed")
 	}
 
@@ -224,7 +223,7 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e
 
 func (a *TaskAdaptor) ConvertToOpenAIVideo(originTask *model.Task) ([]byte, error) {
 	var hailuoResp QueryTaskResponse
-	if err := json.Unmarshal(originTask.Data, &hailuoResp); err != nil {
+	if err := common.Unmarshal(originTask.Data, &hailuoResp); err != nil {
 		return nil, errors.Wrap(err, "unmarshal hailuo task data failed")
 	}
 
@@ -271,7 +270,7 @@ func (a *TaskAdaptor) buildVideoURL(_, fileID string) string {
 	}
 
 	var retrieveResp RetrieveFileResponse
-	if err := json.Unmarshal(responseBody, &retrieveResp); err != nil {
+	if err := common.Unmarshal(responseBody, &retrieveResp); err != nil {
 		return ""
 	}
 

+ 10 - 17
relay/channel/task/jimeng/adaptor.go

@@ -6,7 +6,6 @@ import (
 	"crypto/sha256"
 	"encoding/base64"
 	"encoding/hex"
-	"encoding/json"
 	"fmt"
 	"io"
 	"net/http"
@@ -25,6 +24,7 @@ import (
 	"github.com/QuantumNous/new-api/constant"
 	"github.com/QuantumNous/new-api/dto"
 	"github.com/QuantumNous/new-api/relay/channel"
+	taskcommon "github.com/QuantumNous/new-api/relay/channel/task/taskcommon"
 	relaycommon "github.com/QuantumNous/new-api/relay/common"
 	"github.com/QuantumNous/new-api/service"
 )
@@ -168,7 +168,7 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn
 	if err != nil {
 		return nil, errors.Wrap(err, "convert request payload failed")
 	}
-	data, err := json.Marshal(body)
+	data, err := common.Marshal(body)
 	if err != nil {
 		return nil, err
 	}
@@ -191,7 +191,7 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
 
 	// Parse Jimeng response
 	var jResp responsePayload
-	if err := json.Unmarshal(responseBody, &jResp); err != nil {
+	if err := common.Unmarshal(responseBody, &jResp); err != nil {
 		taskErr = service.TaskErrorWrapper(errors.Wrapf(err, "body: %s", responseBody), "unmarshal_response_body_failed", http.StatusInternalServerError)
 		return
 	}
@@ -202,8 +202,8 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
 	}
 
 	ov := dto.NewOpenAIVideo()
-	ov.ID = jResp.Data.TaskID
-	ov.TaskID = jResp.Data.TaskID
+	ov.ID = info.PublicTaskID
+	ov.TaskID = info.PublicTaskID
 	ov.CreatedAt = time.Now().Unix()
 	ov.Model = info.OriginModelName
 	c.JSON(http.StatusOK, ov)
@@ -225,7 +225,7 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy
 		"req_key": "jimeng_vgfm_t2v_l20", // This is fixed value from doc: https://www.volcengine.com/docs/85621/1544774
 		"task_id": taskID,
 	}
-	payloadBytes, err := json.Marshal(payload)
+	payloadBytes, err := common.Marshal(payload)
 	if err != nil {
 		return nil, errors.Wrap(err, "marshal fetch task payload failed")
 	}
@@ -398,13 +398,7 @@ func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*
 			r.BinaryDataBase64 = req.Images
 		}
 	}
-	metadata := req.Metadata
-	medaBytes, err := json.Marshal(metadata)
-	if err != nil {
-		return nil, errors.Wrap(err, "metadata marshal metadata failed")
-	}
-	err = json.Unmarshal(medaBytes, &r)
-	if err != nil {
+	if err := taskcommon.UnmarshalMetadata(req.Metadata, &r); err != nil {
 		return nil, errors.Wrap(err, "unmarshal metadata failed")
 	}
 
@@ -432,7 +426,7 @@ func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*
 
 func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) {
 	resTask := responseTask{}
-	if err := json.Unmarshal(respBody, &resTask); err != nil {
+	if err := common.Unmarshal(respBody, &resTask); err != nil {
 		return nil, errors.Wrap(err, "unmarshal task result failed")
 	}
 	taskResult := relaycommon.TaskInfo{}
@@ -458,7 +452,7 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e
 
 func (a *TaskAdaptor) ConvertToOpenAIVideo(originTask *model.Task) ([]byte, error) {
 	var jimengResp responseTask
-	if err := json.Unmarshal(originTask.Data, &jimengResp); err != nil {
+	if err := common.Unmarshal(originTask.Data, &jimengResp); err != nil {
 		return nil, errors.Wrap(err, "unmarshal jimeng task data failed")
 	}
 
@@ -477,8 +471,7 @@ func (a *TaskAdaptor) ConvertToOpenAIVideo(originTask *model.Task) ([]byte, erro
 		}
 	}
 
-	jsonData, _ := common.Marshal(openAIVideo)
-	return jsonData, nil
+	return common.Marshal(openAIVideo)
 }
 
 func isNewAPIRelay(apiKey string) bool {

+ 11 - 32
relay/channel/task/kling/adaptor.go

@@ -2,7 +2,6 @@ package kling
 
 import (
 	"bytes"
-	"encoding/json"
 	"fmt"
 	"io"
 	"net/http"
@@ -21,6 +20,7 @@ import (
 	"github.com/QuantumNous/new-api/constant"
 	"github.com/QuantumNous/new-api/dto"
 	"github.com/QuantumNous/new-api/relay/channel"
+	taskcommon "github.com/QuantumNous/new-api/relay/channel/task/taskcommon"
 	relaycommon "github.com/QuantumNous/new-api/relay/common"
 	"github.com/QuantumNous/new-api/service"
 )
@@ -156,7 +156,7 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn
 	if body.Image == "" && body.ImageTail == "" {
 		c.Set("action", constant.TaskActionTextGenerate)
 	}
-	data, err := json.Marshal(body)
+	data, err := common.Marshal(body)
 	if err != nil {
 		return nil, err
 	}
@@ -180,7 +180,7 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
 	}
 
 	var kResp responsePayload
-	err = json.Unmarshal(responseBody, &kResp)
+	err = common.Unmarshal(responseBody, &kResp)
 	if err != nil {
 		taskErr = service.TaskErrorWrapper(err, "unmarshal_response_failed", http.StatusInternalServerError)
 		return
@@ -190,8 +190,8 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
 		return
 	}
 	ov := dto.NewOpenAIVideo()
-	ov.ID = kResp.Data.TaskId
-	ov.TaskID = kResp.Data.TaskId
+	ov.ID = info.PublicTaskID
+	ov.TaskID = info.PublicTaskID
 	ov.CreatedAt = time.Now().Unix()
 	ov.Model = info.OriginModelName
 	c.JSON(http.StatusOK, ov)
@@ -251,8 +251,8 @@ func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*
 	r := requestPayload{
 		Prompt:         req.Prompt,
 		Image:          req.Image,
-		Mode:           defaultString(req.Mode, "std"),
-		Duration:       fmt.Sprintf("%d", defaultInt(req.Duration, 5)),
+		Mode:           taskcommon.DefaultString(req.Mode, "std"),
+		Duration:       fmt.Sprintf("%d", taskcommon.DefaultInt(req.Duration, 5)),
 		AspectRatio:    a.getAspectRatio(req.Size),
 		ModelName:      req.Model,
 		Model:          req.Model, // Keep consistent with model_name, double writing improves compatibility
@@ -266,13 +266,7 @@ func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*
 	if r.ModelName == "" {
 		r.ModelName = "kling-v1"
 	}
-	metadata := req.Metadata
-	medaBytes, err := json.Marshal(metadata)
-	if err != nil {
-		return nil, errors.Wrap(err, "metadata marshal metadata failed")
-	}
-	err = json.Unmarshal(medaBytes, &r)
-	if err != nil {
+	if err := taskcommon.UnmarshalMetadata(req.Metadata, &r); err != nil {
 		return nil, errors.Wrap(err, "unmarshal metadata failed")
 	}
 	return &r, nil
@@ -291,20 +285,6 @@ func (a *TaskAdaptor) getAspectRatio(size string) string {
 	}
 }
 
-func defaultString(s, def string) string {
-	if strings.TrimSpace(s) == "" {
-		return def
-	}
-	return s
-}
-
-func defaultInt(v int, def int) int {
-	if v == 0 {
-		return def
-	}
-	return v
-}
-
 // ============================
 // JWT helpers
 // ============================
@@ -340,7 +320,7 @@ func (a *TaskAdaptor) createJWTTokenWithKey(apiKey string) (string, error) {
 func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) {
 	taskInfo := &relaycommon.TaskInfo{}
 	resPayload := responsePayload{}
-	err := json.Unmarshal(respBody, &resPayload)
+	err := common.Unmarshal(respBody, &resPayload)
 	if err != nil {
 		return nil, errors.Wrap(err, "failed to unmarshal response body")
 	}
@@ -374,7 +354,7 @@ func isNewAPIRelay(apiKey string) bool {
 
 func (a *TaskAdaptor) ConvertToOpenAIVideo(originTask *model.Task) ([]byte, error) {
 	var klingResp responsePayload
-	if err := json.Unmarshal(originTask.Data, &klingResp); err != nil {
+	if err := common.Unmarshal(originTask.Data, &klingResp); err != nil {
 		return nil, errors.Wrap(err, "unmarshal kling task data failed")
 	}
 
@@ -401,6 +381,5 @@ func (a *TaskAdaptor) ConvertToOpenAIVideo(originTask *model.Task) ([]byte, erro
 			Code:    fmt.Sprintf("%d", klingResp.Code),
 		}
 	}
-	jsonData, _ := common.Marshal(openAIVideo)
-	return jsonData, nil
+	return common.Marshal(openAIVideo)
 }

+ 13 - 11
relay/channel/task/sora/adaptor.go

@@ -13,7 +13,6 @@ import (
 	"github.com/QuantumNous/new-api/relay/channel"
 	relaycommon "github.com/QuantumNous/new-api/relay/common"
 	"github.com/QuantumNous/new-api/service"
-	"github.com/QuantumNous/new-api/setting/system_setting"
 
 	"github.com/gin-gonic/gin"
 	"github.com/pkg/errors"
@@ -116,7 +115,7 @@ func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, req
 }
 
 // DoResponse handles upstream response, returns taskID etc.
-func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, _ *relaycommon.RelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) {
+func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) {
 	responseBody, err := io.ReadAll(resp.Body)
 	if err != nil {
 		taskErr = service.TaskErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
@@ -131,17 +130,20 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, _ *relayco
 		return
 	}
 
-	if dResp.ID == "" {
-		if dResp.TaskID == "" {
-			taskErr = service.TaskErrorWrapper(fmt.Errorf("task_id is empty"), "invalid_response", http.StatusInternalServerError)
-			return
-		}
-		dResp.ID = dResp.TaskID
-		dResp.TaskID = ""
+	upstreamID := dResp.ID
+	if upstreamID == "" {
+		upstreamID = dResp.TaskID
+	}
+	if upstreamID == "" {
+		taskErr = service.TaskErrorWrapper(fmt.Errorf("task_id is empty"), "invalid_response", http.StatusInternalServerError)
+		return
 	}
 
+	// 使用公开 task_xxxx ID 返回给客户端
+	dResp.ID = info.PublicTaskID
+	dResp.TaskID = info.PublicTaskID
 	c.JSON(http.StatusOK, dResp)
-	return dResp.ID, responseBody, nil
+	return upstreamID, responseBody, nil
 }
 
 // FetchTask fetch task status
@@ -192,7 +194,7 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e
 		taskResult.Status = model.TaskStatusInProgress
 	case "completed":
 		taskResult.Status = model.TaskStatusSuccess
-		taskResult.Url = fmt.Sprintf("%s/v1/videos/%s/content", system_setting.ServerAddress, resTask.ID)
+		// Url intentionally left empty — the caller constructs the proxy URL using the public task ID
 	case "failed", "cancelled":
 		taskResult.Status = model.TaskStatusFailure
 		if resTask.Error != nil {

+ 14 - 15
relay/channel/task/suno/adaptor.go

@@ -3,7 +3,6 @@ package suno
 import (
 	"bytes"
 	"context"
-	"encoding/json"
 	"fmt"
 	"io"
 	"net/http"
@@ -24,8 +23,12 @@ type TaskAdaptor struct {
 	ChannelType int
 }
 
+// ParseTaskResult is not used for Suno tasks.
+// Suno polling uses a dedicated batch-fetch path (service.UpdateSunoTasks) that
+// receives dto.TaskResponse[[]dto.SunoDataResponse] from the upstream /fetch API.
+// This differs from the per-task polling used by video adaptors.
 func (a *TaskAdaptor) ParseTaskResult([]byte) (*relaycommon.TaskInfo, error) {
-	return nil, fmt.Errorf("not implement") // todo implement this method if needed
+	return nil, fmt.Errorf("suno uses batch polling via UpdateSunoTasks, ParseTaskResult is not applicable")
 }
 
 func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) {
@@ -81,7 +84,7 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn
 			return nil, err
 		}
 	}
-	data, err := json.Marshal(sunoRequest)
+	data, err := common.Marshal(sunoRequest)
 	if err != nil {
 		return nil, err
 	}
@@ -99,7 +102,7 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
 		return
 	}
 	var sunoResponse dto.TaskResponse[string]
-	err = json.Unmarshal(responseBody, &sunoResponse)
+	err = common.Unmarshal(responseBody, &sunoResponse)
 	if err != nil {
 		taskErr = service.TaskErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
 		return
@@ -109,17 +112,13 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
 		return
 	}
 
-	for k, v := range resp.Header {
-		c.Writer.Header().Set(k, v[0])
-	}
-	c.Writer.Header().Set("Content-Type", "application/json")
-	c.Writer.WriteHeader(resp.StatusCode)
-
-	_, err = io.Copy(c.Writer, bytes.NewBuffer(responseBody))
-	if err != nil {
-		taskErr = service.TaskErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError)
-		return
+	// 使用公开 task_xxxx ID 替换上游 ID 返回给客户端
+	publicResponse := dto.TaskResponse[string]{
+		Code:    sunoResponse.Code,
+		Message: sunoResponse.Message,
+		Data:    info.PublicTaskID,
 	}
+	c.JSON(http.StatusOK, publicResponse)
 
 	return sunoResponse.Data, nil, nil
 }
@@ -134,7 +133,7 @@ func (a *TaskAdaptor) GetChannelName() string {
 
 func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy string) (*http.Response, error) {
 	requestUrl := fmt.Sprintf("%s/suno/fetch", baseUrl)
-	byteBody, err := json.Marshal(body)
+	byteBody, err := common.Marshal(body)
 	if err != nil {
 		return nil, err
 	}

+ 70 - 0
relay/channel/task/taskcommon/helpers.go

@@ -0,0 +1,70 @@
+package taskcommon
+
+import (
+	"encoding/base64"
+	"fmt"
+
+	"github.com/QuantumNous/new-api/common"
+	"github.com/QuantumNous/new-api/setting/system_setting"
+)
+
+// UnmarshalMetadata converts a map[string]any metadata to a typed struct via JSON round-trip.
+// This replaces the repeated pattern: json.Marshal(metadata) → json.Unmarshal(bytes, &target).
+func UnmarshalMetadata(metadata map[string]any, target any) error {
+	if metadata == nil {
+		return nil
+	}
+	metaBytes, err := common.Marshal(metadata)
+	if err != nil {
+		return fmt.Errorf("marshal metadata failed: %w", err)
+	}
+	if err := common.Unmarshal(metaBytes, target); err != nil {
+		return fmt.Errorf("unmarshal metadata failed: %w", err)
+	}
+	return nil
+}
+
+// DefaultString returns val if non-empty, otherwise fallback.
+func DefaultString(val, fallback string) string {
+	if val == "" {
+		return fallback
+	}
+	return val
+}
+
+// DefaultInt returns val if non-zero, otherwise fallback.
+func DefaultInt(val, fallback int) int {
+	if val == 0 {
+		return fallback
+	}
+	return val
+}
+
+// EncodeLocalTaskID encodes an upstream operation name to a URL-safe base64 string.
+// Used by Gemini/Vertex to store upstream names as task IDs.
+func EncodeLocalTaskID(name string) string {
+	return base64.RawURLEncoding.EncodeToString([]byte(name))
+}
+
+// DecodeLocalTaskID decodes a base64-encoded upstream operation name.
+func DecodeLocalTaskID(id string) (string, error) {
+	b, err := base64.RawURLEncoding.DecodeString(id)
+	if err != nil {
+		return "", err
+	}
+	return string(b), nil
+}
+
+// BuildProxyURL constructs the video proxy URL using the public task ID.
+// e.g., "https://your-server.com/v1/videos/task_xxxx/content"
+func BuildProxyURL(taskID string) string {
+	return fmt.Sprintf("%s/v1/videos/%s/content", system_setting.ServerAddress, taskID)
+}
+
+// Status-to-progress mapping constants for polling updates.
+const (
+	ProgressSubmitted  = "10%"
+	ProgressQueued     = "20%"
+	ProgressInProgress = "30%"
+	ProgressComplete   = "100%"
+)

+ 23 - 27
relay/channel/task/vertex/adaptor.go

@@ -2,13 +2,12 @@ package vertex
 
 import (
 	"bytes"
-	"encoding/base64"
-	"encoding/json"
 	"fmt"
 	"io"
 	"net/http"
 	"regexp"
 	"strings"
+	"time"
 
 	"github.com/QuantumNous/new-api/common"
 	"github.com/QuantumNous/new-api/model"
@@ -17,6 +16,7 @@ import (
 	"github.com/QuantumNous/new-api/constant"
 	"github.com/QuantumNous/new-api/dto"
 	"github.com/QuantumNous/new-api/relay/channel"
+	taskcommon "github.com/QuantumNous/new-api/relay/channel/task/taskcommon"
 	vertexcore "github.com/QuantumNous/new-api/relay/channel/vertex"
 	relaycommon "github.com/QuantumNous/new-api/relay/common"
 	"github.com/QuantumNous/new-api/service"
@@ -82,7 +82,7 @@ func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycom
 // BuildRequestURL constructs the upstream URL.
 func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, error) {
 	adc := &vertexcore.Credentials{}
-	if err := json.Unmarshal([]byte(a.apiKey), adc); err != nil {
+	if err := common.Unmarshal([]byte(a.apiKey), adc); err != nil {
 		return "", fmt.Errorf("failed to decode credentials: %w", err)
 	}
 	modelName := info.OriginModelName
@@ -116,7 +116,7 @@ func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info
 	req.Header.Set("Accept", "application/json")
 
 	adc := &vertexcore.Credentials{}
-	if err := json.Unmarshal([]byte(a.apiKey), adc); err != nil {
+	if err := common.Unmarshal([]byte(a.apiKey), adc); err != nil {
 		return fmt.Errorf("failed to decode credentials: %w", err)
 	}
 
@@ -184,7 +184,7 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn
 	// 	info.PriceData.OtherRatios["durationSeconds"] = float64(v.(int))
 	// }
 
-	data, err := json.Marshal(body)
+	data, err := common.Marshal(body)
 	if err != nil {
 		return nil, err
 	}
@@ -205,14 +205,19 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
 	_ = resp.Body.Close()
 
 	var s submitResponse
-	if err := json.Unmarshal(responseBody, &s); err != nil {
+	if err := common.Unmarshal(responseBody, &s); err != nil {
 		return "", nil, service.TaskErrorWrapper(err, "unmarshal_response_failed", http.StatusInternalServerError)
 	}
 	if strings.TrimSpace(s.Name) == "" {
 		return "", nil, service.TaskErrorWrapper(fmt.Errorf("missing operation name"), "invalid_response", http.StatusInternalServerError)
 	}
-	localID := encodeLocalTaskID(s.Name)
-	c.JSON(http.StatusOK, gin.H{"task_id": localID})
+	localID := taskcommon.EncodeLocalTaskID(s.Name)
+	ov := dto.NewOpenAIVideo()
+	ov.ID = info.PublicTaskID
+	ov.TaskID = info.PublicTaskID
+	ov.CreatedAt = time.Now().Unix()
+	ov.Model = info.OriginModelName
+	c.JSON(http.StatusOK, ov)
 	return localID, responseBody, nil
 }
 
@@ -225,7 +230,7 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy
 	if !ok {
 		return nil, fmt.Errorf("invalid task_id")
 	}
-	upstreamName, err := decodeLocalTaskID(taskID)
+	upstreamName, err := taskcommon.DecodeLocalTaskID(taskID)
 	if err != nil {
 		return nil, fmt.Errorf("decode task_id failed: %w", err)
 	}
@@ -245,12 +250,12 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy
 		url = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:fetchPredictOperation", region, project, region, modelName)
 	}
 	payload := map[string]string{"operationName": upstreamName}
-	data, err := json.Marshal(payload)
+	data, err := common.Marshal(payload)
 	if err != nil {
 		return nil, err
 	}
 	adc := &vertexcore.Credentials{}
-	if err := json.Unmarshal([]byte(key), adc); err != nil {
+	if err := common.Unmarshal([]byte(key), adc); err != nil {
 		return nil, fmt.Errorf("failed to decode credentials: %w", err)
 	}
 	token, err := vertexcore.AcquireAccessToken(*adc, proxy)
@@ -274,7 +279,7 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy
 
 func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) {
 	var op operationResponse
-	if err := json.Unmarshal(respBody, &op); err != nil {
+	if err := common.Unmarshal(respBody, &op); err != nil {
 		return nil, fmt.Errorf("unmarshal operation response failed: %w", err)
 	}
 	ti := &relaycommon.TaskInfo{}
@@ -338,7 +343,10 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e
 }
 
 func (a *TaskAdaptor) ConvertToOpenAIVideo(task *model.Task) ([]byte, error) {
-	upstreamName, err := decodeLocalTaskID(task.TaskID)
+	// Use GetUpstreamTaskID() to get the real upstream operation name for model extraction.
+	// task.TaskID is now a public task_xxxx ID, no longer a base64-encoded upstream name.
+	upstreamTaskID := task.GetUpstreamTaskID()
+	upstreamName, err := taskcommon.DecodeLocalTaskID(upstreamTaskID)
 	if err != nil {
 		upstreamName = ""
 	}
@@ -353,8 +361,8 @@ func (a *TaskAdaptor) ConvertToOpenAIVideo(task *model.Task) ([]byte, error) {
 	v.SetProgressStr(task.Progress)
 	v.CreatedAt = task.CreatedAt
 	v.CompletedAt = task.UpdatedAt
-	if strings.HasPrefix(task.FailReason, "data:") && len(task.FailReason) > 0 {
-		v.SetMetadata("url", task.FailReason)
+	if resultURL := task.GetResultURL(); strings.HasPrefix(resultURL, "data:") && len(resultURL) > 0 {
+		v.SetMetadata("url", resultURL)
 	}
 
 	return common.Marshal(v)
@@ -364,18 +372,6 @@ func (a *TaskAdaptor) ConvertToOpenAIVideo(task *model.Task) ([]byte, error) {
 // helpers
 // ============================
 
-func encodeLocalTaskID(name string) string {
-	return base64.RawURLEncoding.EncodeToString([]byte(name))
-}
-
-func decodeLocalTaskID(local string) (string, error) {
-	b, err := base64.RawURLEncoding.DecodeString(local)
-	if err != nil {
-		return "", err
-	}
-	return string(b), nil
-}
-
 var regionRe = regexp.MustCompile(`locations/([a-z0-9-]+)/`)
 
 func extractRegionFromOperationName(name string) string {

+ 12 - 33
relay/channel/task/vidu/adaptor.go

@@ -2,7 +2,6 @@ package vidu
 
 import (
 	"bytes"
-	"encoding/json"
 	"fmt"
 	"io"
 	"net/http"
@@ -16,6 +15,7 @@ import (
 	"github.com/QuantumNous/new-api/dto"
 	"github.com/QuantumNous/new-api/model"
 	"github.com/QuantumNous/new-api/relay/channel"
+	taskcommon "github.com/QuantumNous/new-api/relay/channel/task/taskcommon"
 	relaycommon "github.com/QuantumNous/new-api/relay/common"
 	"github.com/QuantumNous/new-api/service"
 
@@ -127,7 +127,7 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn
 		}
 	}
 
-	data, err := json.Marshal(body)
+	data, err := common.Marshal(body)
 	if err != nil {
 		return nil, err
 	}
@@ -168,7 +168,7 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
 	}
 
 	var vResp responsePayload
-	err = json.Unmarshal(responseBody, &vResp)
+	err = common.Unmarshal(responseBody, &vResp)
 	if err != nil {
 		taskErr = service.TaskErrorWrapper(errors.Wrap(err, fmt.Sprintf("%s", responseBody)), "unmarshal_response_failed", http.StatusInternalServerError)
 		return
@@ -180,8 +180,8 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
 	}
 
 	ov := dto.NewOpenAIVideo()
-	ov.ID = vResp.TaskId
-	ov.TaskID = vResp.TaskId
+	ov.ID = info.PublicTaskID
+	ov.TaskID = info.PublicTaskID
 	ov.CreatedAt = time.Now().Unix()
 	ov.Model = info.OriginModelName
 	c.JSON(http.StatusOK, ov)
@@ -225,45 +225,25 @@ func (a *TaskAdaptor) GetChannelName() string {
 
 func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*requestPayload, error) {
 	r := requestPayload{
-		Model:             defaultString(req.Model, "viduq1"),
+		Model:             taskcommon.DefaultString(req.Model, "viduq1"),
 		Images:            req.Images,
 		Prompt:            req.Prompt,
-		Duration:          defaultInt(req.Duration, 5),
-		Resolution:        defaultString(req.Size, "1080p"),
+		Duration:          taskcommon.DefaultInt(req.Duration, 5),
+		Resolution:        taskcommon.DefaultString(req.Size, "1080p"),
 		MovementAmplitude: "auto",
 		Bgm:               false,
 	}
-	metadata := req.Metadata
-	medaBytes, err := json.Marshal(metadata)
-	if err != nil {
-		return nil, errors.Wrap(err, "metadata marshal metadata failed")
-	}
-	err = json.Unmarshal(medaBytes, &r)
-	if err != nil {
+	if err := taskcommon.UnmarshalMetadata(req.Metadata, &r); err != nil {
 		return nil, errors.Wrap(err, "unmarshal metadata failed")
 	}
 	return &r, nil
 }
 
-func defaultString(value, defaultValue string) string {
-	if value == "" {
-		return defaultValue
-	}
-	return value
-}
-
-func defaultInt(value, defaultValue int) int {
-	if value == 0 {
-		return defaultValue
-	}
-	return value
-}
-
 func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) {
 	taskInfo := &relaycommon.TaskInfo{}
 
 	var taskResp taskResultResponse
-	err := json.Unmarshal(respBody, &taskResp)
+	err := common.Unmarshal(respBody, &taskResp)
 	if err != nil {
 		return nil, errors.Wrap(err, "failed to unmarshal response body")
 	}
@@ -293,7 +273,7 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e
 
 func (a *TaskAdaptor) ConvertToOpenAIVideo(originTask *model.Task) ([]byte, error) {
 	var viduResp taskResultResponse
-	if err := json.Unmarshal(originTask.Data, &viduResp); err != nil {
+	if err := common.Unmarshal(originTask.Data, &viduResp); err != nil {
 		return nil, errors.Wrap(err, "unmarshal vidu task data failed")
 	}
 
@@ -315,6 +295,5 @@ func (a *TaskAdaptor) ConvertToOpenAIVideo(originTask *model.Task) ([]byte, erro
 		}
 	}
 
-	jsonData, _ := common.Marshal(openAIVideo)
-	return jsonData, nil
+	return common.Marshal(openAIVideo)
 }

+ 12 - 3
relay/common/relay_info.go

@@ -118,8 +118,12 @@ type RelayInfo struct {
 	SendResponseCount      int
 	ReceivedResponseCount  int
 	FinalPreConsumedQuota  int // 最终预消耗的配额
+	// ForcePreConsume 为 true 时禁用 BillingSession 的信任额度旁路,
+	// 强制预扣全额。用于异步任务(视频/音乐生成等),因为请求返回后任务仍在运行,
+	// 必须在提交前锁定全额。
+	ForcePreConsume bool
 	// Billing 是计费会话,封装了预扣费/结算/退款的统一生命周期。
-	// 免费模型和按次计费(MJ/Task)时为 nil。
+	// 免费模型时为 nil。
 	Billing BillingSettler
 	// BillingSource indicates whether this request is billed from wallet quota or subscription.
 	// "" or "wallet" => wallet; "subscription" => subscription
@@ -525,8 +529,10 @@ func GenRelayInfo(c *gin.Context, relayFormat types.RelayFormat, request dto.Req
 		return nil, errors.New("request is not a OpenAIResponsesCompactionRequest")
 	case types.RelayFormatTask:
 		info = genBaseRelayInfo(c, nil)
+		info.TaskRelayInfo = &TaskRelayInfo{}
 	case types.RelayFormatMjProxy:
 		info = genBaseRelayInfo(c, nil)
+		info.TaskRelayInfo = &TaskRelayInfo{}
 	default:
 		err = errors.New("invalid relay format")
 	}
@@ -608,6 +614,9 @@ func (info *RelayInfo) HasSendResponse() bool {
 type TaskRelayInfo struct {
 	Action       string
 	OriginTaskID string
+	// PublicTaskID 是提交时预生成的 task_xxxx 格式公开 ID,
+	// 供 DoResponse 在返回给客户端时使用(避免暴露上游真实 ID)。
+	PublicTaskID string
 
 	ConsumeQuota bool
 }
@@ -667,11 +676,11 @@ func (t *TaskSubmitReq) UnmarshalJSON(data []byte) error {
 func (t *TaskSubmitReq) UnmarshalMetadata(v any) error {
 	metadata := t.Metadata
 	if metadata != nil {
-		metadataBytes, err := json.Marshal(metadata)
+		metadataBytes, err := common.Marshal(metadata)
 		if err != nil {
 			return fmt.Errorf("marshal metadata failed: %w", err)
 		}
-		err = json.Unmarshal(metadataBytes, v)
+		err = common.Unmarshal(metadataBytes, v)
 		if err != nil {
 			return fmt.Errorf("unmarshal metadata to target failed: %w", err)
 		}

+ 13 - 2
relay/helper/price.go

@@ -140,7 +140,7 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens
 }
 
 // ModelPriceHelperPerCall 按次计费的 PriceHelper (MJ、Task)
-func ModelPriceHelperPerCall(c *gin.Context, info *relaycommon.RelayInfo) types.PerCallPriceData {
+func ModelPriceHelperPerCall(c *gin.Context, info *relaycommon.RelayInfo) types.PriceData {
 	groupRatioInfo := HandleGroupRatio(c, info)
 
 	modelPrice, success := ratio_setting.GetModelPrice(info.OriginModelName, true)
@@ -154,7 +154,18 @@ func ModelPriceHelperPerCall(c *gin.Context, info *relaycommon.RelayInfo) types.
 		}
 	}
 	quota := int(modelPrice * common.QuotaPerUnit * groupRatioInfo.GroupRatio)
-	priceData := types.PerCallPriceData{
+
+	// 免费模型检测(与 ModelPriceHelper 对齐)
+	freeModel := false
+	if !operation_setting.GetQuotaSetting().EnableFreeModelPreConsume {
+		if groupRatioInfo.GroupRatio == 0 || modelPrice == 0 {
+			quota = 0
+			freeModel = true
+		}
+	}
+
+	priceData := types.PriceData{
+		FreeModel:      freeModel,
 		ModelPrice:     modelPrice,
 		Quota:          quota,
 		GroupRatioInfo: groupRatioInfo,

+ 274 - 298
relay/relay_task.go

@@ -2,7 +2,6 @@ package relay
 
 import (
 	"bytes"
-	"encoding/json"
 	"errors"
 	"fmt"
 	"io"
@@ -15,29 +14,33 @@ import (
 	"github.com/QuantumNous/new-api/dto"
 	"github.com/QuantumNous/new-api/model"
 	"github.com/QuantumNous/new-api/relay/channel"
+	"github.com/QuantumNous/new-api/relay/channel/task/taskcommon"
 	relaycommon "github.com/QuantumNous/new-api/relay/common"
 	relayconstant "github.com/QuantumNous/new-api/relay/constant"
+	"github.com/QuantumNous/new-api/relay/helper"
 	"github.com/QuantumNous/new-api/service"
-	"github.com/QuantumNous/new-api/setting/ratio_setting"
-
 	"github.com/gin-gonic/gin"
 )
 
-/*
-Task 任务通过平台、Action 区分任务
-*/
-func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) {
-	info.InitChannelMeta(c)
-	// ensure TaskRelayInfo is initialized to avoid nil dereference when accessing embedded fields
-	if info.TaskRelayInfo == nil {
-		info.TaskRelayInfo = &relaycommon.TaskRelayInfo{}
-	}
+type TaskSubmitResult struct {
+	UpstreamTaskID string
+	TaskData       []byte
+	Platform       constant.TaskPlatform
+	ModelName      string
+	Quota          int
+	//PerCallPrice   types.PriceData
+}
+
+// ResolveOriginTask 处理基于已有任务的提交(remix / continuation):
+// 查找原始任务、从中提取模型名称、将渠道锁定到原始任务的渠道(并通过
+// specific_channel_id 禁止重试),以及提取 OtherRatios(时长、分辨率)。
+// 该函数在控制器的重试循环之前调用一次,其结果通过 info 字段和上下文持久化。
+func ResolveOriginTask(c *gin.Context, info *relaycommon.RelayInfo) *dto.TaskError {
+	// 检测 remix action
 	path := c.Request.URL.Path
 	if strings.Contains(path, "/v1/videos/") && strings.HasSuffix(path, "/remix") {
 		info.Action = constant.TaskActionRemix
 	}
-
-	// 提取 remix 任务的 video_id
 	if info.Action == constant.TaskActionRemix {
 		videoID := c.Param("video_id")
 		if strings.TrimSpace(videoID) == "" {
@@ -46,241 +49,164 @@ func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.
 		info.OriginTaskID = videoID
 	}
 
-	platform := constant.TaskPlatform(c.GetString("platform"))
+	if info.OriginTaskID == "" {
+		return nil
+	}
 
-	// 获取原始任务信息
-	if info.OriginTaskID != "" {
-		originTask, exist, err := model.GetByTaskId(info.UserId, info.OriginTaskID)
-		if err != nil {
-			taskErr = service.TaskErrorWrapper(err, "get_origin_task_failed", http.StatusInternalServerError)
-			return
+	// 查找原始任务
+	originTask, exist, err := model.GetByTaskId(info.UserId, info.OriginTaskID)
+	if err != nil {
+		return service.TaskErrorWrapper(err, "get_origin_task_failed", http.StatusInternalServerError)
+	}
+	if !exist {
+		return service.TaskErrorWrapperLocal(errors.New("task_origin_not_exist"), "task_not_exist", http.StatusBadRequest)
+	}
+
+	// 从原始任务推导模型名称
+	if info.OriginModelName == "" {
+		if originTask.Properties.OriginModelName != "" {
+			info.OriginModelName = originTask.Properties.OriginModelName
+		} else if originTask.Properties.UpstreamModelName != "" {
+			info.OriginModelName = originTask.Properties.UpstreamModelName
+		} else {
+			var taskData map[string]interface{}
+			_ = common.Unmarshal(originTask.Data, &taskData)
+			if m, ok := taskData["model"].(string); ok && m != "" {
+				info.OriginModelName = m
+			}
 		}
-		if !exist {
-			taskErr = service.TaskErrorWrapperLocal(errors.New("task_origin_not_exist"), "task_not_exist", http.StatusBadRequest)
-			return
+	}
+
+	// 锁定到原始任务的渠道(如果与当前选中的不同)
+	if originTask.ChannelId != info.ChannelId {
+		ch, err := model.GetChannelById(originTask.ChannelId, true)
+		if err != nil {
+			return service.TaskErrorWrapperLocal(err, "channel_not_found", http.StatusBadRequest)
 		}
-		if info.OriginModelName == "" {
-			if originTask.Properties.OriginModelName != "" {
-				info.OriginModelName = originTask.Properties.OriginModelName
-			} else if originTask.Properties.UpstreamModelName != "" {
-				info.OriginModelName = originTask.Properties.UpstreamModelName
-			} else {
-				var taskData map[string]interface{}
-				_ = json.Unmarshal(originTask.Data, &taskData)
-				if m, ok := taskData["model"].(string); ok && m != "" {
-					info.OriginModelName = m
-					platform = originTask.Platform
-				}
-			}
+		if ch.Status != common.ChannelStatusEnabled {
+			return service.TaskErrorWrapperLocal(errors.New("the channel of the origin task is disabled"), "task_channel_disable", http.StatusBadRequest)
 		}
-		if originTask.ChannelId != info.ChannelId {
-			channel, err := model.GetChannelById(originTask.ChannelId, true)
-			if err != nil {
-				taskErr = service.TaskErrorWrapperLocal(err, "channel_not_found", http.StatusBadRequest)
-				return
-			}
-			if channel.Status != common.ChannelStatusEnabled {
-				taskErr = service.TaskErrorWrapperLocal(errors.New("the channel of the origin task is disabled"), "task_channel_disable", http.StatusBadRequest)
-				return
-			}
-			key, _, newAPIError := channel.GetNextEnabledKey()
-			if newAPIError != nil {
-				taskErr = service.TaskErrorWrapper(newAPIError, "channel_no_available_key", newAPIError.StatusCode)
-				return
-			}
-			common.SetContextKey(c, constant.ContextKeyChannelKey, key)
-			common.SetContextKey(c, constant.ContextKeyChannelType, channel.Type)
-			common.SetContextKey(c, constant.ContextKeyChannelBaseUrl, channel.GetBaseURL())
-			common.SetContextKey(c, constant.ContextKeyChannelId, originTask.ChannelId)
-
-			info.ChannelBaseUrl = channel.GetBaseURL()
-			info.ChannelId = originTask.ChannelId
-			info.ChannelType = channel.Type
-			info.ApiKey = key
-			platform = originTask.Platform
+		key, _, newAPIError := ch.GetNextEnabledKey()
+		if newAPIError != nil {
+			return service.TaskErrorWrapper(newAPIError, "channel_no_available_key", newAPIError.StatusCode)
 		}
+		common.SetContextKey(c, constant.ContextKeyChannelKey, key)
+		common.SetContextKey(c, constant.ContextKeyChannelType, ch.Type)
+		common.SetContextKey(c, constant.ContextKeyChannelBaseUrl, ch.GetBaseURL())
+		common.SetContextKey(c, constant.ContextKeyChannelId, originTask.ChannelId)
 
-		// 使用原始任务的参数
-		if info.Action == constant.TaskActionRemix {
-			var taskData map[string]interface{}
-			_ = json.Unmarshal(originTask.Data, &taskData)
-			secondsStr, _ := taskData["seconds"].(string)
-			seconds, _ := strconv.Atoi(secondsStr)
-			if seconds <= 0 {
-				seconds = 4
-			}
-			sizeStr, _ := taskData["size"].(string)
-			if info.PriceData.OtherRatios == nil {
-				info.PriceData.OtherRatios = map[string]float64{}
-			}
-			info.PriceData.OtherRatios["seconds"] = float64(seconds)
-			info.PriceData.OtherRatios["size"] = 1
-			if sizeStr == "1792x1024" || sizeStr == "1024x1792" {
-				info.PriceData.OtherRatios["size"] = 1.666667
-			}
+		info.ChannelBaseUrl = ch.GetBaseURL()
+		info.ChannelId = originTask.ChannelId
+		info.ChannelType = ch.Type
+		info.ApiKey = key
+	}
+
+	// 渠道已锁定到原始任务 → 禁止重试切换到其他渠道
+	c.Set("specific_channel_id", fmt.Sprintf("%d", originTask.ChannelId))
+
+	// 提取 remix 参数(时长、分辨率 → OtherRatios)
+	if info.Action == constant.TaskActionRemix {
+		var taskData map[string]interface{}
+		_ = common.Unmarshal(originTask.Data, &taskData)
+		secondsStr, _ := taskData["seconds"].(string)
+		seconds, _ := strconv.Atoi(secondsStr)
+		if seconds <= 0 {
+			seconds = 4
+		}
+		sizeStr, _ := taskData["size"].(string)
+		if info.PriceData.OtherRatios == nil {
+			info.PriceData.OtherRatios = map[string]float64{}
+		}
+		info.PriceData.OtherRatios["seconds"] = float64(seconds)
+		info.PriceData.OtherRatios["size"] = 1
+		if sizeStr == "1792x1024" || sizeStr == "1024x1792" {
+			info.PriceData.OtherRatios["size"] = 1.666667
 		}
 	}
+
+	return nil
+}
+
+// RelayTaskSubmit 完成 task 提交的全部流程(每次尝试调用一次):
+// 刷新渠道元数据 → 确定 platform/adaptor → 验证请求 → 计算价格 →
+// 预扣费(仅首次,通过 info.Billing==nil 守卫)→ 构建/发送/解析上游请求。
+// 控制器负责 defer Refund 和成功后 Settle。
+func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (*TaskSubmitResult, *dto.TaskError) {
+	info.InitChannelMeta(c)
+
+	// 1. 确定 platform → 创建适配器 → 验证请求
+	platform := constant.TaskPlatform(c.GetString("platform"))
 	if platform == "" {
 		platform = GetTaskPlatform(c)
 	}
-
-	info.InitChannelMeta(c)
 	adaptor := GetTaskAdaptor(platform)
 	if adaptor == nil {
-		return service.TaskErrorWrapperLocal(fmt.Errorf("invalid api platform: %s", platform), "invalid_api_platform", http.StatusBadRequest)
+		return nil, service.TaskErrorWrapperLocal(fmt.Errorf("invalid api platform: %s", platform), "invalid_api_platform", http.StatusBadRequest)
 	}
 	adaptor.Init(info)
-	// get & validate taskRequest 获取并验证文本请求
-	taskErr = adaptor.ValidateRequestAndSetAction(c, info)
-	if taskErr != nil {
-		return
+	if taskErr := adaptor.ValidateRequestAndSetAction(c, info); taskErr != nil {
+		return nil, taskErr
 	}
 
+	// 2. 确定模型名称
 	modelName := info.OriginModelName
 	if modelName == "" {
 		modelName = service.CoverTaskActionToModelName(platform, info.Action)
 	}
-	modelPrice, success := ratio_setting.GetModelPrice(modelName, true)
-	if !success {
-		defaultPrice, ok := ratio_setting.GetDefaultModelPriceMap()[modelName]
-		if !ok {
-			modelPrice = float64(common.PreConsumedQuota) / common.QuotaPerUnit
-		} else {
-			modelPrice = defaultPrice
-		}
-	}
 
-	// 处理 auto 分组:从 context 获取实际选中的分组
-	// 当使用 auto 分组时,Distribute 中间件会将实际选中的分组存储在 ContextKeyAutoGroup 中
-	if autoGroup, exists := common.GetContextKey(c, constant.ContextKeyAutoGroup); exists {
-		if groupStr, ok := autoGroup.(string); ok && groupStr != "" {
-			info.UsingGroup = groupStr
-		}
+	// 3. 预生成公开 task ID(仅首次)
+	if info.PublicTaskID == "" {
+		info.PublicTaskID = model.GenerateTaskID()
 	}
 
-	// 预扣
-	groupRatio := ratio_setting.GetGroupRatio(info.UsingGroup)
-	var ratio float64
-	userGroupRatio, hasUserGroupRatio := ratio_setting.GetGroupGroupRatio(info.UserGroup, info.UsingGroup)
-	if hasUserGroupRatio {
-		ratio = modelPrice * userGroupRatio
-	} else {
-		ratio = modelPrice * groupRatio
-	}
-	// FIXME: 临时修补,支持任务仅按次计费
+	// 4. 价格计算
+	info.OriginModelName = modelName
+	info.PriceData = helper.ModelPriceHelperPerCall(c, info)
+
 	if !common.StringsContains(constant.TaskPricePatches, modelName) {
-		if len(info.PriceData.OtherRatios) > 0 {
-			for _, ra := range info.PriceData.OtherRatios {
-				if 1.0 != ra {
-					ratio *= ra
-				}
+		for _, ra := range info.PriceData.OtherRatios {
+			if ra != 1.0 {
+				info.PriceData.Quota = int(float64(info.PriceData.Quota) * ra)
 			}
 		}
 	}
-	println(fmt.Sprintf("model: %s, model_price: %.4f, group: %s, group_ratio: %.4f, final_ratio: %.4f", modelName, modelPrice, info.UsingGroup, groupRatio, ratio))
-	userQuota, err := model.GetUserQuota(info.UserId, false)
-	if err != nil {
-		taskErr = service.TaskErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
-		return
-	}
-	quota := int(ratio * common.QuotaPerUnit)
-	if userQuota-quota < 0 {
-		taskErr = service.TaskErrorWrapperLocal(errors.New("user quota is not enough"), "quota_not_enough", http.StatusForbidden)
-		return
+
+	// 5. 预扣费(仅首次 — 重试时 info.Billing 已存在,跳过)
+	if info.Billing == nil && !info.PriceData.FreeModel {
+		info.ForcePreConsume = true
+		if apiErr := service.PreConsumeBilling(c, info.PriceData.Quota, info); apiErr != nil {
+			return nil, service.TaskErrorFromAPIError(apiErr)
+		}
 	}
 
-	// build body
+	// 6. 构建请求体
 	requestBody, err := adaptor.BuildRequestBody(c, info)
 	if err != nil {
-		taskErr = service.TaskErrorWrapper(err, "build_request_failed", http.StatusInternalServerError)
-		return
+		return nil, service.TaskErrorWrapper(err, "build_request_failed", http.StatusInternalServerError)
 	}
-	// do request
+
+	// 7. 发送请求
 	resp, err := adaptor.DoRequest(c, info, requestBody)
 	if err != nil {
-		taskErr = service.TaskErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
-		return
+		return nil, service.TaskErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
 	}
-	// handle response
 	if resp != nil && resp.StatusCode != http.StatusOK {
 		responseBody, _ := io.ReadAll(resp.Body)
-		taskErr = service.TaskErrorWrapper(fmt.Errorf("%s", string(responseBody)), "fail_to_fetch_task", resp.StatusCode)
-		return
+		return nil, service.TaskErrorWrapper(fmt.Errorf("%s", string(responseBody)), "fail_to_fetch_task", resp.StatusCode)
 	}
 
-	defer func() {
-		// release quota
-		if info.ConsumeQuota && taskErr == nil {
-
-			err := service.PostConsumeQuota(info, quota, 0, true)
-			if err != nil {
-				common.SysLog("error consuming token remain quota: " + err.Error())
-			}
-			if quota != 0 {
-				tokenName := c.GetString("token_name")
-				//gRatio := groupRatio
-				//if hasUserGroupRatio {
-				//	gRatio = userGroupRatio
-				//}
-				logContent := fmt.Sprintf("操作 %s", info.Action)
-				// FIXME: 临时修补,支持任务仅按次计费
-				if common.StringsContains(constant.TaskPricePatches, modelName) {
-					logContent = fmt.Sprintf("%s,按次计费", logContent)
-				} else {
-					if len(info.PriceData.OtherRatios) > 0 {
-						var contents []string
-						for key, ra := range info.PriceData.OtherRatios {
-							if 1.0 != ra {
-								contents = append(contents, fmt.Sprintf("%s: %.2f", key, ra))
-							}
-						}
-						if len(contents) > 0 {
-							logContent = fmt.Sprintf("%s, 计算参数:%s", logContent, strings.Join(contents, ", "))
-						}
-					}
-				}
-				other := make(map[string]interface{})
-				if c != nil && c.Request != nil && c.Request.URL != nil {
-					other["request_path"] = c.Request.URL.Path
-				}
-				other["model_price"] = modelPrice
-				other["group_ratio"] = groupRatio
-				if hasUserGroupRatio {
-					other["user_group_ratio"] = userGroupRatio
-				}
-				model.RecordConsumeLog(c, info.UserId, model.RecordConsumeLogParams{
-					ChannelId: info.ChannelId,
-					ModelName: modelName,
-					TokenName: tokenName,
-					Quota:     quota,
-					Content:   logContent,
-					TokenId:   info.TokenId,
-					Group:     info.UsingGroup,
-					Other:     other,
-				})
-				model.UpdateUserUsedQuotaAndRequestCount(info.UserId, quota)
-				model.UpdateChannelUsedQuota(info.ChannelId, quota)
-			}
-		}
-	}()
-
-	taskID, taskData, taskErr := adaptor.DoResponse(c, resp, info)
+	// 8. 解析响应
+	upstreamTaskID, taskData, taskErr := adaptor.DoResponse(c, resp, info)
 	if taskErr != nil {
-		return
+		return nil, taskErr
 	}
-	info.ConsumeQuota = true
-	// insert task
-	task := model.InitTask(platform, info)
-	task.TaskID = taskID
-	task.Quota = quota
-	task.Data = taskData
-	task.Action = info.Action
-	err = task.Insert()
-	if err != nil {
-		taskErr = service.TaskErrorWrapper(err, "insert_task_failed", http.StatusInternalServerError)
-		return
-	}
-	return nil
+
+	return &TaskSubmitResult{
+		UpstreamTaskID: upstreamTaskID,
+		TaskData:       taskData,
+		Platform:       platform,
+		ModelName:      modelName,
+	}, nil
 }
 
 var fetchRespBuilders = map[int]func(c *gin.Context) (respBody []byte, taskResp *dto.TaskError){
@@ -336,7 +262,7 @@ func sunoFetchRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *dto.Ta
 	} else {
 		tasks = make([]any, 0)
 	}
-	respBody, err = json.Marshal(dto.TaskResponse[[]any]{
+	respBody, err = common.Marshal(dto.TaskResponse[[]any]{
 		Code: "success",
 		Data: tasks,
 	})
@@ -357,7 +283,7 @@ func sunoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *dt
 		return
 	}
 
-	respBody, err = json.Marshal(dto.TaskResponse[any]{
+	respBody, err = common.Marshal(dto.TaskResponse[any]{
 		Code: "success",
 		Data: TaskModel2Dto(originTask),
 	})
@@ -381,97 +307,16 @@ func videoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *d
 		return
 	}
 
-	func() {
-		channelModel, err2 := model.GetChannelById(originTask.ChannelId, true)
-		if err2 != nil {
-			return
-		}
-		if channelModel.Type != constant.ChannelTypeVertexAi && channelModel.Type != constant.ChannelTypeGemini {
-			return
-		}
-		baseURL := constant.ChannelBaseURLs[channelModel.Type]
-		if channelModel.GetBaseURL() != "" {
-			baseURL = channelModel.GetBaseURL()
-		}
-		proxy := channelModel.GetSetting().Proxy
-		adaptor := GetTaskAdaptor(constant.TaskPlatform(strconv.Itoa(channelModel.Type)))
-		if adaptor == nil {
-			return
-		}
-		resp, err2 := adaptor.FetchTask(baseURL, channelModel.Key, map[string]any{
-			"task_id": originTask.TaskID,
-			"action":  originTask.Action,
-		}, proxy)
-		if err2 != nil || resp == nil {
-			return
-		}
-		defer resp.Body.Close()
-		body, err2 := io.ReadAll(resp.Body)
-		if err2 != nil {
-			return
-		}
-		ti, err2 := adaptor.ParseTaskResult(body)
-		if err2 == nil && ti != nil {
-			if ti.Status != "" {
-				originTask.Status = model.TaskStatus(ti.Status)
-			}
-			if ti.Progress != "" {
-				originTask.Progress = ti.Progress
-			}
-			if ti.Url != "" {
-				if strings.HasPrefix(ti.Url, "data:") {
-				} else {
-					originTask.FailReason = ti.Url
-				}
-			}
-			_ = originTask.Update()
-			var raw map[string]any
-			_ = json.Unmarshal(body, &raw)
-			format := "mp4"
-			if respObj, ok := raw["response"].(map[string]any); ok {
-				if vids, ok := respObj["videos"].([]any); ok && len(vids) > 0 {
-					if v0, ok := vids[0].(map[string]any); ok {
-						if mt, ok := v0["mimeType"].(string); ok && mt != "" {
-							if strings.Contains(mt, "mp4") {
-								format = "mp4"
-							} else {
-								format = mt
-							}
-						}
-					}
-				}
-			}
-			status := "processing"
-			switch originTask.Status {
-			case model.TaskStatusSuccess:
-				status = "succeeded"
-			case model.TaskStatusFailure:
-				status = "failed"
-			case model.TaskStatusQueued, model.TaskStatusSubmitted:
-				status = "queued"
-			}
-			if !strings.HasPrefix(c.Request.RequestURI, "/v1/videos/") {
-				out := map[string]any{
-					"error":    nil,
-					"format":   format,
-					"metadata": nil,
-					"status":   status,
-					"task_id":  originTask.TaskID,
-					"url":      originTask.FailReason,
-				}
-				respBody, _ = json.Marshal(dto.TaskResponse[any]{
-					Code: "success",
-					Data: out,
-				})
-			}
-		}
-	}()
+	isOpenAIVideoAPI := strings.HasPrefix(c.Request.RequestURI, "/v1/videos/")
 
-	if len(respBody) != 0 {
+	// Gemini/Vertex 支持实时查询:用户 fetch 时直接从上游拉取最新状态
+	if realtimeResp := tryRealtimeFetch(originTask, isOpenAIVideoAPI); len(realtimeResp) > 0 {
+		respBody = realtimeResp
 		return
 	}
 
-	if strings.HasPrefix(c.Request.RequestURI, "/v1/videos/") {
+	// OpenAI Video API 格式: 走各 adaptor 的 ConvertToOpenAIVideo
+	if isOpenAIVideoAPI {
 		adaptor := GetTaskAdaptor(originTask.Platform)
 		if adaptor == nil {
 			taskResp = service.TaskErrorWrapperLocal(fmt.Errorf("invalid channel id: %d", originTask.ChannelId), "invalid_channel_id", http.StatusBadRequest)
@@ -486,10 +331,12 @@ func videoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *d
 			respBody = openAIVideoData
 			return
 		}
-		taskResp = service.TaskErrorWrapperLocal(errors.New(fmt.Sprintf("not_implemented:%s", originTask.Platform)), "not_implemented", http.StatusNotImplemented)
+		taskResp = service.TaskErrorWrapperLocal(fmt.Errorf("not_implemented:%s", originTask.Platform), "not_implemented", http.StatusNotImplemented)
 		return
 	}
-	respBody, err = json.Marshal(dto.TaskResponse[any]{
+
+	// 通用 TaskDto 格式
+	respBody, err = common.Marshal(dto.TaskResponse[any]{
 		Code: "success",
 		Data: TaskModel2Dto(originTask),
 	})
@@ -499,16 +346,145 @@ func videoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *d
 	return
 }
 
+// tryRealtimeFetch 尝试从上游实时拉取 Gemini/Vertex 任务状态。
+// 仅当渠道类型为 Gemini 或 Vertex 时触发;其他渠道或出错时返回 nil。
+// 当非 OpenAI Video API 时,还会构建自定义格式的响应体。
+func tryRealtimeFetch(task *model.Task, isOpenAIVideoAPI bool) []byte {
+	channelModel, err := model.GetChannelById(task.ChannelId, true)
+	if err != nil {
+		return nil
+	}
+	if channelModel.Type != constant.ChannelTypeVertexAi && channelModel.Type != constant.ChannelTypeGemini {
+		return nil
+	}
+
+	baseURL := constant.ChannelBaseURLs[channelModel.Type]
+	if channelModel.GetBaseURL() != "" {
+		baseURL = channelModel.GetBaseURL()
+	}
+	proxy := channelModel.GetSetting().Proxy
+	adaptor := GetTaskAdaptor(constant.TaskPlatform(strconv.Itoa(channelModel.Type)))
+	if adaptor == nil {
+		return nil
+	}
+
+	resp, err := adaptor.FetchTask(baseURL, channelModel.Key, map[string]any{
+		"task_id": task.GetUpstreamTaskID(),
+		"action":  task.Action,
+	}, proxy)
+	if err != nil || resp == nil {
+		return nil
+	}
+	defer resp.Body.Close()
+	body, err := io.ReadAll(resp.Body)
+	if err != nil {
+		return nil
+	}
+
+	ti, err := adaptor.ParseTaskResult(body)
+	if err != nil || ti == nil {
+		return nil
+	}
+
+	// 将上游最新状态更新到 task
+	if ti.Status != "" {
+		task.Status = model.TaskStatus(ti.Status)
+	}
+	if ti.Progress != "" {
+		task.Progress = ti.Progress
+	}
+	if strings.HasPrefix(ti.Url, "data:") {
+		// data: URI — kept in Data, not ResultURL
+	} else if ti.Url != "" {
+		task.PrivateData.ResultURL = ti.Url
+	} else if task.Status == model.TaskStatusSuccess {
+		// No URL from adaptor — construct proxy URL using public task ID
+		task.PrivateData.ResultURL = taskcommon.BuildProxyURL(task.TaskID)
+	}
+	_ = task.Update()
+
+	// OpenAI Video API 由调用者的 ConvertToOpenAIVideo 分支处理
+	if isOpenAIVideoAPI {
+		return nil
+	}
+
+	// 非 OpenAI Video API: 构建自定义格式响应
+	format := detectVideoFormat(body)
+	out := map[string]any{
+		"error":    nil,
+		"format":   format,
+		"metadata": nil,
+		"status":   mapTaskStatusToSimple(task.Status),
+		"task_id":  task.TaskID,
+		"url":      task.GetResultURL(),
+	}
+	respBody, _ := common.Marshal(dto.TaskResponse[any]{
+		Code: "success",
+		Data: out,
+	})
+	return respBody
+}
+
+// detectVideoFormat 从 Gemini/Vertex 原始响应中探测视频格式
+func detectVideoFormat(rawBody []byte) string {
+	var raw map[string]any
+	if err := common.Unmarshal(rawBody, &raw); err != nil {
+		return "mp4"
+	}
+	respObj, ok := raw["response"].(map[string]any)
+	if !ok {
+		return "mp4"
+	}
+	vids, ok := respObj["videos"].([]any)
+	if !ok || len(vids) == 0 {
+		return "mp4"
+	}
+	v0, ok := vids[0].(map[string]any)
+	if !ok {
+		return "mp4"
+	}
+	mt, ok := v0["mimeType"].(string)
+	if !ok || mt == "" || strings.Contains(mt, "mp4") {
+		return "mp4"
+	}
+	return mt
+}
+
+// mapTaskStatusToSimple 将内部 TaskStatus 映射为简化状态字符串
+func mapTaskStatusToSimple(status model.TaskStatus) string {
+	switch status {
+	case model.TaskStatusSuccess:
+		return "succeeded"
+	case model.TaskStatusFailure:
+		return "failed"
+	case model.TaskStatusQueued, model.TaskStatusSubmitted:
+		return "queued"
+	default:
+		return "processing"
+	}
+}
+
 func TaskModel2Dto(task *model.Task) *dto.TaskDto {
 	return &dto.TaskDto{
+		ID:         task.ID,
+		CreatedAt:  task.CreatedAt,
+		UpdatedAt:  task.UpdatedAt,
 		TaskID:     task.TaskID,
+		Platform:   string(task.Platform),
+		UserId:     task.UserId,
+		Group:      task.Group,
+		ChannelId:  task.ChannelId,
+		Quota:      task.Quota,
 		Action:     task.Action,
 		Status:     string(task.Status),
 		FailReason: task.FailReason,
+		ResultURL:  task.GetResultURL(),
 		SubmitTime: task.SubmitTime,
 		StartTime:  task.StartTime,
 		FinishTime: task.FinishTime,
 		Progress:   task.Progress,
+		Properties: task.Properties,
+		Username:   task.Username,
 		Data:       task.Data,
 	}
 }

+ 7 - 1
router/video-router.go

@@ -8,10 +8,16 @@ import (
 )
 
 func SetVideoRouter(router *gin.Engine) {
+	// Video proxy: accepts either session auth (dashboard) or token auth (API clients)
+	videoProxyRouter := router.Group("/v1")
+	videoProxyRouter.Use(middleware.TokenOrUserAuth())
+	{
+		videoProxyRouter.GET("/videos/:task_id/content", controller.VideoProxy)
+	}
+
 	videoV1Router := router.Group("/v1")
 	videoV1Router.Use(middleware.TokenAuth(), middleware.Distribute())
 	{
-		videoV1Router.GET("/videos/:task_id/content", controller.VideoProxy)
 		videoV1Router.POST("/video/generations", controller.RelayTask)
 		videoV1Router.GET("/video/generations/:task_id", controller.RelayTask)
 		videoV1Router.POST("/videos/:video_id/remix", controller.RelayTask)

+ 5 - 0
service/billing_session.go

@@ -193,6 +193,11 @@ func (s *BillingSession) preConsume(c *gin.Context, quota int) *types.NewAPIErro
 
 // shouldTrust 统一信任额度检查,适用于钱包和订阅。
 func (s *BillingSession) shouldTrust(c *gin.Context) bool {
+	// 异步任务(ForcePreConsume=true)必须预扣全额,不允许信任旁路
+	if s.relayInfo.ForcePreConsume {
+		return false
+	}
+
 	trustQuota := common.GetTrustQuota()
 	if trustQuota <= 0 {
 		return false

+ 13 - 0
service/error.go

@@ -206,3 +206,16 @@ func TaskErrorWrapper(err error, code string, statusCode int) *dto.TaskError {
 
 	return taskError
 }
+
+// TaskErrorFromAPIError 将 PreConsumeBilling 返回的 NewAPIError 转换为 TaskError。
+func TaskErrorFromAPIError(apiErr *types.NewAPIError) *dto.TaskError {
+	if apiErr == nil {
+		return nil
+	}
+	return &dto.TaskError{
+		Code:       string(apiErr.GetErrorCode()),
+		Message:    apiErr.Err.Error(),
+		StatusCode: apiErr.StatusCode,
+		Error:      apiErr.Err,
+	}
+}

+ 1 - 1
service/log_info_generate.go

@@ -204,7 +204,7 @@ func GenerateClaudeOtherInfo(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
 	return info
 }
 
-func GenerateMjOtherInfo(relayInfo *relaycommon.RelayInfo, priceData types.PerCallPriceData) map[string]interface{} {
+func GenerateMjOtherInfo(relayInfo *relaycommon.RelayInfo, priceData types.PriceData) map[string]interface{} {
 	other := make(map[string]interface{})
 	other["model_price"] = priceData.ModelPrice
 	other["group_ratio"] = priceData.GroupRatioInfo.GroupRatio

+ 227 - 0
service/task_billing.go

@@ -0,0 +1,227 @@
+package service
+
+import (
+	"context"
+	"fmt"
+	"strings"
+
+	"github.com/QuantumNous/new-api/common"
+	"github.com/QuantumNous/new-api/constant"
+	"github.com/QuantumNous/new-api/logger"
+	"github.com/QuantumNous/new-api/model"
+	relaycommon "github.com/QuantumNous/new-api/relay/common"
+	"github.com/QuantumNous/new-api/setting/ratio_setting"
+	"github.com/gin-gonic/gin"
+)
+
+// LogTaskConsumption 记录任务消费日志和统计信息(仅记录,不涉及实际扣费)。
+// 实际扣费已由 BillingSession(PreConsumeBilling + SettleBilling)完成。
+func LogTaskConsumption(c *gin.Context, info *relaycommon.RelayInfo, modelName string) {
+	tokenName := c.GetString("token_name")
+	logContent := fmt.Sprintf("操作 %s", info.Action)
+	// 支持任务仅按次计费
+	if common.StringsContains(constant.TaskPricePatches, modelName) {
+		logContent = fmt.Sprintf("%s,按次计费", logContent)
+	} else {
+		if len(info.PriceData.OtherRatios) > 0 {
+			var contents []string
+			for key, ra := range info.PriceData.OtherRatios {
+				if 1.0 != ra {
+					contents = append(contents, fmt.Sprintf("%s: %.2f", key, ra))
+				}
+			}
+			if len(contents) > 0 {
+				logContent = fmt.Sprintf("%s, 计算参数:%s", logContent, strings.Join(contents, ", "))
+			}
+		}
+	}
+	other := make(map[string]interface{})
+	other["request_path"] = c.Request.URL.Path
+	other["model_price"] = info.PriceData.ModelPrice
+	other["group_ratio"] = info.PriceData.GroupRatioInfo.GroupRatio
+	if info.PriceData.GroupRatioInfo.HasSpecialRatio {
+		other["user_group_ratio"] = info.PriceData.GroupRatioInfo.GroupSpecialRatio
+	}
+	model.RecordConsumeLog(c, info.UserId, model.RecordConsumeLogParams{
+		ChannelId: info.ChannelId,
+		ModelName: modelName,
+		TokenName: tokenName,
+		Quota:     info.PriceData.Quota,
+		Content:   logContent,
+		TokenId:   info.TokenId,
+		Group:     info.UsingGroup,
+		Other:     other,
+	})
+	model.UpdateUserUsedQuotaAndRequestCount(info.UserId, info.PriceData.Quota)
+	model.UpdateChannelUsedQuota(info.ChannelId, info.PriceData.Quota)
+}
+
+// ---------------------------------------------------------------------------
+// 异步任务计费辅助函数
+// ---------------------------------------------------------------------------
+
+// resolveTokenKey 通过 TokenId 运行时获取令牌 Key(用于 Redis 缓存操作)。
+// 如果令牌已被删除或查询失败,返回空字符串。
+func resolveTokenKey(ctx context.Context, tokenId int, taskID string) string {
+	token, err := model.GetTokenById(tokenId)
+	if err != nil {
+		logger.LogWarn(ctx, fmt.Sprintf("获取令牌 key 失败 (tokenId=%d, task=%s): %s", tokenId, taskID, err.Error()))
+		return ""
+	}
+	return token.Key
+}
+
+// taskIsSubscription 判断任务是否通过订阅计费。
+func taskIsSubscription(task *model.Task) bool {
+	return task.PrivateData.BillingSource == BillingSourceSubscription && task.PrivateData.SubscriptionId > 0
+}
+
+// taskAdjustFunding 调整任务的资金来源(钱包或订阅),delta > 0 表示扣费,delta < 0 表示退还。
+func taskAdjustFunding(task *model.Task, delta int) error {
+	if taskIsSubscription(task) {
+		return model.PostConsumeUserSubscriptionDelta(task.PrivateData.SubscriptionId, int64(delta))
+	}
+	if delta > 0 {
+		return model.DecreaseUserQuota(task.UserId, delta)
+	}
+	return model.IncreaseUserQuota(task.UserId, -delta, false)
+}
+
+// taskAdjustTokenQuota 调整任务的令牌额度,delta > 0 表示扣费,delta < 0 表示退还。
+// 需要通过 resolveTokenKey 运行时获取 key(不从 PrivateData 中读取)。
+func taskAdjustTokenQuota(ctx context.Context, task *model.Task, delta int) {
+	if task.PrivateData.TokenId <= 0 || delta == 0 {
+		return
+	}
+	tokenKey := resolveTokenKey(ctx, task.PrivateData.TokenId, task.TaskID)
+	if tokenKey == "" {
+		return
+	}
+	var err error
+	if delta > 0 {
+		err = model.DecreaseTokenQuota(task.PrivateData.TokenId, tokenKey, delta)
+	} else {
+		err = model.IncreaseTokenQuota(task.PrivateData.TokenId, tokenKey, -delta)
+	}
+	if err != nil {
+		logger.LogWarn(ctx, fmt.Sprintf("调整令牌额度失败 (delta=%d, task=%s): %s", delta, task.TaskID, err.Error()))
+	}
+}
+
+// RefundTaskQuota 统一的任务失败退款逻辑。
+// 当异步任务失败时,将预扣的 quota 退还给用户(支持钱包和订阅),并退还令牌额度。
+func RefundTaskQuota(ctx context.Context, task *model.Task, reason string) {
+	quota := task.Quota
+	if quota == 0 {
+		return
+	}
+
+	// 1. 退还资金来源(钱包或订阅)
+	if err := taskAdjustFunding(task, -quota); err != nil {
+		logger.LogWarn(ctx, fmt.Sprintf("退还资金来源失败 task %s: %s", task.TaskID, err.Error()))
+		return
+	}
+
+	// 2. 退还令牌额度
+	taskAdjustTokenQuota(ctx, task, -quota)
+
+	// 3. 记录日志
+	logContent := fmt.Sprintf("异步任务执行失败 %s,补偿 %s,原因:%s", task.TaskID, logger.LogQuota(quota), reason)
+	model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
+}
+
+// RecalculateTaskQuotaByTokens 根据实际 token 消耗重新计费(异步差额结算)。
+// 当任务成功且返回了 totalTokens 时,根据模型倍率和分组倍率重新计算实际扣费额度,
+// 与预扣费的差额进行补扣或退还。支持钱包和订阅计费来源。
+func RecalculateTaskQuotaByTokens(ctx context.Context, task *model.Task, totalTokens int) {
+	if totalTokens <= 0 {
+		return
+	}
+
+	// 获取模型名称
+	var taskData map[string]interface{}
+	if err := common.Unmarshal(task.Data, &taskData); err != nil {
+		return
+	}
+	modelName, ok := taskData["model"].(string)
+	if !ok || modelName == "" {
+		return
+	}
+
+	// 获取模型价格和倍率
+	modelRatio, hasRatioSetting, _ := ratio_setting.GetModelRatio(modelName)
+	// 只有配置了倍率(非固定价格)时才按 token 重新计费
+	if !hasRatioSetting || modelRatio <= 0 {
+		return
+	}
+
+	// 获取用户和组的倍率信息
+	group := task.Group
+	if group == "" {
+		user, err := model.GetUserById(task.UserId, false)
+		if err == nil {
+			group = user.Group
+		}
+	}
+	if group == "" {
+		return
+	}
+
+	groupRatio := ratio_setting.GetGroupRatio(group)
+	userGroupRatio, hasUserGroupRatio := ratio_setting.GetGroupGroupRatio(group, group)
+
+	var finalGroupRatio float64
+	if hasUserGroupRatio {
+		finalGroupRatio = userGroupRatio
+	} else {
+		finalGroupRatio = groupRatio
+	}
+
+	// 计算实际应扣费额度: totalTokens * modelRatio * groupRatio
+	actualQuota := int(float64(totalTokens) * modelRatio * finalGroupRatio)
+
+	// 计算差额(正数=需要补扣,负数=需要退还)
+	preConsumedQuota := task.Quota
+	quotaDelta := actualQuota - preConsumedQuota
+
+	if quotaDelta == 0 {
+		logger.LogInfo(ctx, fmt.Sprintf("视频任务 %s 预扣费准确(%s,tokens:%d)",
+			task.TaskID, logger.LogQuota(actualQuota), totalTokens))
+		return
+	}
+
+	logger.LogInfo(ctx, fmt.Sprintf("视频任务 %s 差额结算:delta=%s(实际:%s,预扣:%s,tokens:%d)",
+		task.TaskID,
+		logger.LogQuota(quotaDelta),
+		logger.LogQuota(actualQuota),
+		logger.LogQuota(preConsumedQuota),
+		totalTokens,
+	))
+
+	// 调整资金来源
+	if err := taskAdjustFunding(task, quotaDelta); err != nil {
+		logger.LogError(ctx, fmt.Sprintf("差额结算资金调整失败 task %s: %s", task.TaskID, err.Error()))
+		return
+	}
+
+	// 调整令牌额度
+	taskAdjustTokenQuota(ctx, task, quotaDelta)
+
+	// 更新统计(仅补扣时更新,退还不影响已用统计)
+	if quotaDelta > 0 {
+		model.UpdateUserUsedQuotaAndRequestCount(task.UserId, quotaDelta)
+		model.UpdateChannelUsedQuota(task.ChannelId, quotaDelta)
+	}
+	task.Quota = actualQuota
+
+	var action string
+	if quotaDelta > 0 {
+		action = "补扣费"
+	} else {
+		action = "退还"
+	}
+	logContent := fmt.Sprintf("视频任务成功%s,模型倍率 %.2f,分组倍率 %.2f,tokens %d,预扣费 %s,实际扣费 %s",
+		action, modelRatio, finalGroupRatio, totalTokens,
+		logger.LogQuota(preConsumedQuota), logger.LogQuota(actualQuota))
+	model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
+}

+ 446 - 0
service/task_polling.go

@@ -0,0 +1,446 @@
+package service
+
+import (
+	"context"
+	"errors"
+	"fmt"
+	"io"
+	"net/http"
+	"sort"
+	"strings"
+	"time"
+
+	"github.com/QuantumNous/new-api/common"
+	"github.com/QuantumNous/new-api/constant"
+	"github.com/QuantumNous/new-api/dto"
+	"github.com/QuantumNous/new-api/logger"
+	"github.com/QuantumNous/new-api/model"
+	"github.com/QuantumNous/new-api/relay/channel/task/taskcommon"
+	relaycommon "github.com/QuantumNous/new-api/relay/common"
+
+	"github.com/samber/lo"
+)
+
+// TaskPollingAdaptor 定义轮询所需的最小适配器接口,避免 service -> relay 的循环依赖
+type TaskPollingAdaptor interface {
+	Init(info *relaycommon.RelayInfo)
+	FetchTask(baseURL string, key string, body map[string]any, proxy string) (*http.Response, error)
+	ParseTaskResult(body []byte) (*relaycommon.TaskInfo, error)
+}
+
+// GetTaskAdaptorFunc 由 main 包注入,用于获取指定平台的任务适配器。
+// 打破 service -> relay -> relay/channel -> service 的循环依赖。
+var GetTaskAdaptorFunc func(platform constant.TaskPlatform) TaskPollingAdaptor
+
+// TaskPollingLoop 主轮询循环,每 15 秒检查一次未完成的任务
+func TaskPollingLoop() {
+	for {
+		time.Sleep(time.Duration(15) * time.Second)
+		common.SysLog("任务进度轮询开始")
+		ctx := context.TODO()
+		allTasks := model.GetAllUnFinishSyncTasks(constant.TaskQueryLimit)
+		platformTask := make(map[constant.TaskPlatform][]*model.Task)
+		for _, t := range allTasks {
+			platformTask[t.Platform] = append(platformTask[t.Platform], t)
+		}
+		for platform, tasks := range platformTask {
+			if len(tasks) == 0 {
+				continue
+			}
+			taskChannelM := make(map[int][]string)
+			taskM := make(map[string]*model.Task)
+			nullTaskIds := make([]int64, 0)
+			for _, task := range tasks {
+				upstreamID := task.GetUpstreamTaskID()
+				if upstreamID == "" {
+					// 统计失败的未完成任务
+					nullTaskIds = append(nullTaskIds, task.ID)
+					continue
+				}
+				taskM[upstreamID] = task
+				taskChannelM[task.ChannelId] = append(taskChannelM[task.ChannelId], upstreamID)
+			}
+			if len(nullTaskIds) > 0 {
+				err := model.TaskBulkUpdateByID(nullTaskIds, map[string]any{
+					"status":   "FAILURE",
+					"progress": "100%",
+				})
+				if err != nil {
+					logger.LogError(ctx, fmt.Sprintf("Fix null task_id task error: %v", err))
+				} else {
+					logger.LogInfo(ctx, fmt.Sprintf("Fix null task_id task success: %v", nullTaskIds))
+				}
+			}
+			if len(taskChannelM) == 0 {
+				continue
+			}
+
+			DispatchPlatformUpdate(platform, taskChannelM, taskM)
+		}
+		common.SysLog("任务进度轮询完成")
+	}
+}
+
+// DispatchPlatformUpdate 按平台分发轮询更新
+func DispatchPlatformUpdate(platform constant.TaskPlatform, taskChannelM map[int][]string, taskM map[string]*model.Task) {
+	switch platform {
+	case constant.TaskPlatformMidjourney:
+		// MJ 轮询由其自身处理,这里预留入口
+	case constant.TaskPlatformSuno:
+		_ = UpdateSunoTasks(context.Background(), taskChannelM, taskM)
+	default:
+		if err := UpdateVideoTasks(context.Background(), platform, taskChannelM, taskM); err != nil {
+			common.SysLog(fmt.Sprintf("UpdateVideoTasks fail: %s", err))
+		}
+	}
+}
+
+// UpdateSunoTasks 按渠道更新所有 Suno 任务
+func UpdateSunoTasks(ctx context.Context, taskChannelM map[int][]string, taskM map[string]*model.Task) error {
+	for channelId, taskIds := range taskChannelM {
+		err := updateSunoTasks(ctx, channelId, taskIds, taskM)
+		if err != nil {
+			logger.LogError(ctx, fmt.Sprintf("渠道 #%d 更新异步任务失败: %s", channelId, err.Error()))
+		}
+	}
+	return nil
+}
+
+func updateSunoTasks(ctx context.Context, channelId int, taskIds []string, taskM map[string]*model.Task) error {
+	logger.LogInfo(ctx, fmt.Sprintf("渠道 #%d 未完成的任务有: %d", channelId, len(taskIds)))
+	if len(taskIds) == 0 {
+		return nil
+	}
+	ch, err := model.CacheGetChannel(channelId)
+	if err != nil {
+		common.SysLog(fmt.Sprintf("CacheGetChannel: %v", err))
+		// Collect DB primary key IDs for bulk update (taskIds are upstream IDs, not task_id column values)
+		var failedIDs []int64
+		for _, upstreamID := range taskIds {
+			if t, ok := taskM[upstreamID]; ok {
+				failedIDs = append(failedIDs, t.ID)
+			}
+		}
+		err = model.TaskBulkUpdateByID(failedIDs, map[string]any{
+			"fail_reason": fmt.Sprintf("获取渠道信息失败,请联系管理员,渠道ID:%d", channelId),
+			"status":      "FAILURE",
+			"progress":    "100%",
+		})
+		if err != nil {
+			common.SysLog(fmt.Sprintf("UpdateSunoTask error: %v", err))
+		}
+		return err
+	}
+	adaptor := GetTaskAdaptorFunc(constant.TaskPlatformSuno)
+	if adaptor == nil {
+		return errors.New("adaptor not found")
+	}
+	proxy := ch.GetSetting().Proxy
+	resp, err := adaptor.FetchTask(*ch.BaseURL, ch.Key, map[string]any{
+		"ids": taskIds,
+	}, proxy)
+	if err != nil {
+		common.SysLog(fmt.Sprintf("Get Task Do req error: %v", err))
+		return err
+	}
+	if resp.StatusCode != http.StatusOK {
+		logger.LogError(ctx, fmt.Sprintf("Get Task status code: %d", resp.StatusCode))
+		return fmt.Errorf("Get Task status code: %d", resp.StatusCode)
+	}
+	defer resp.Body.Close()
+	responseBody, err := io.ReadAll(resp.Body)
+	if err != nil {
+		common.SysLog(fmt.Sprintf("Get Task parse body error: %v", err))
+		return err
+	}
+	var responseItems dto.TaskResponse[[]dto.SunoDataResponse]
+	err = common.Unmarshal(responseBody, &responseItems)
+	if err != nil {
+		logger.LogError(ctx, fmt.Sprintf("Get Task parse body error2: %v, body: %s", err, string(responseBody)))
+		return err
+	}
+	if !responseItems.IsSuccess() {
+		common.SysLog(fmt.Sprintf("渠道 #%d 未完成的任务有: %d, 成功获取到任务数: %s", channelId, len(taskIds), string(responseBody)))
+		return err
+	}
+
+	for _, responseItem := range responseItems.Data {
+		task := taskM[responseItem.TaskID]
+		if !taskNeedsUpdate(task, responseItem) {
+			continue
+		}
+
+		task.Status = lo.If(model.TaskStatus(responseItem.Status) != "", model.TaskStatus(responseItem.Status)).Else(task.Status)
+		task.FailReason = lo.If(responseItem.FailReason != "", responseItem.FailReason).Else(task.FailReason)
+		task.SubmitTime = lo.If(responseItem.SubmitTime != 0, responseItem.SubmitTime).Else(task.SubmitTime)
+		task.StartTime = lo.If(responseItem.StartTime != 0, responseItem.StartTime).Else(task.StartTime)
+		task.FinishTime = lo.If(responseItem.FinishTime != 0, responseItem.FinishTime).Else(task.FinishTime)
+		if responseItem.FailReason != "" || task.Status == model.TaskStatusFailure {
+			logger.LogInfo(ctx, task.TaskID+" 构建失败,"+task.FailReason)
+			task.Progress = "100%"
+			RefundTaskQuota(ctx, task, task.FailReason)
+		}
+		if responseItem.Status == model.TaskStatusSuccess {
+			task.Progress = "100%"
+		}
+		task.Data = responseItem.Data
+
+		err = task.Update()
+		if err != nil {
+			common.SysLog("UpdateSunoTask task error: " + err.Error())
+		}
+	}
+	return nil
+}
+
+// taskNeedsUpdate 检查 Suno 任务是否需要更新
+func taskNeedsUpdate(oldTask *model.Task, newTask dto.SunoDataResponse) bool {
+	if oldTask.SubmitTime != newTask.SubmitTime {
+		return true
+	}
+	if oldTask.StartTime != newTask.StartTime {
+		return true
+	}
+	if oldTask.FinishTime != newTask.FinishTime {
+		return true
+	}
+	if string(oldTask.Status) != newTask.Status {
+		return true
+	}
+	if oldTask.FailReason != newTask.FailReason {
+		return true
+	}
+
+	if (oldTask.Status == model.TaskStatusFailure || oldTask.Status == model.TaskStatusSuccess) && oldTask.Progress != "100%" {
+		return true
+	}
+
+	oldData, _ := common.Marshal(oldTask.Data)
+	newData, _ := common.Marshal(newTask.Data)
+
+	sort.Slice(oldData, func(i, j int) bool {
+		return oldData[i] < oldData[j]
+	})
+	sort.Slice(newData, func(i, j int) bool {
+		return newData[i] < newData[j]
+	})
+
+	if string(oldData) != string(newData) {
+		return true
+	}
+	return false
+}
+
+// UpdateVideoTasks 按渠道更新所有视频任务
+func UpdateVideoTasks(ctx context.Context, platform constant.TaskPlatform, taskChannelM map[int][]string, taskM map[string]*model.Task) error {
+	for channelId, taskIds := range taskChannelM {
+		if err := updateVideoTasks(ctx, platform, channelId, taskIds, taskM); err != nil {
+			logger.LogError(ctx, fmt.Sprintf("Channel #%d failed to update video async tasks: %s", channelId, err.Error()))
+		}
+	}
+	return nil
+}
+
+func updateVideoTasks(ctx context.Context, platform constant.TaskPlatform, channelId int, taskIds []string, taskM map[string]*model.Task) error {
+	logger.LogInfo(ctx, fmt.Sprintf("Channel #%d pending video tasks: %d", channelId, len(taskIds)))
+	if len(taskIds) == 0 {
+		return nil
+	}
+	cacheGetChannel, err := model.CacheGetChannel(channelId)
+	if err != nil {
+		// Collect DB primary key IDs for bulk update (taskIds are upstream IDs, not task_id column values)
+		var failedIDs []int64
+		for _, upstreamID := range taskIds {
+			if t, ok := taskM[upstreamID]; ok {
+				failedIDs = append(failedIDs, t.ID)
+			}
+		}
+		errUpdate := model.TaskBulkUpdateByID(failedIDs, map[string]any{
+			"fail_reason": fmt.Sprintf("Failed to get channel info, channel ID: %d", channelId),
+			"status":      "FAILURE",
+			"progress":    "100%",
+		})
+		if errUpdate != nil {
+			common.SysLog(fmt.Sprintf("UpdateVideoTask error: %v", errUpdate))
+		}
+		return fmt.Errorf("CacheGetChannel failed: %w", err)
+	}
+	adaptor := GetTaskAdaptorFunc(platform)
+	if adaptor == nil {
+		return fmt.Errorf("video adaptor not found")
+	}
+	info := &relaycommon.RelayInfo{}
+	info.ChannelMeta = &relaycommon.ChannelMeta{
+		ChannelBaseUrl: cacheGetChannel.GetBaseURL(),
+	}
+	info.ApiKey = cacheGetChannel.Key
+	adaptor.Init(info)
+	for _, taskId := range taskIds {
+		if err := updateVideoSingleTask(ctx, adaptor, cacheGetChannel, taskId, taskM); err != nil {
+			logger.LogError(ctx, fmt.Sprintf("Failed to update video task %s: %s", taskId, err.Error()))
+		}
+	}
+	return nil
+}
+
+func updateVideoSingleTask(ctx context.Context, adaptor TaskPollingAdaptor, ch *model.Channel, taskId string, taskM map[string]*model.Task) error {
+	baseURL := constant.ChannelBaseURLs[ch.Type]
+	if ch.GetBaseURL() != "" {
+		baseURL = ch.GetBaseURL()
+	}
+	proxy := ch.GetSetting().Proxy
+
+	task := taskM[taskId]
+	if task == nil {
+		logger.LogError(ctx, fmt.Sprintf("Task %s not found in taskM", taskId))
+		return fmt.Errorf("task %s not found", taskId)
+	}
+	key := ch.Key
+
+	privateData := task.PrivateData
+	if privateData.Key != "" {
+		key = privateData.Key
+	}
+	resp, err := adaptor.FetchTask(baseURL, key, map[string]any{
+		"task_id": task.GetUpstreamTaskID(),
+		"action":  task.Action,
+	}, proxy)
+	if err != nil {
+		return fmt.Errorf("fetchTask failed for task %s: %w", taskId, err)
+	}
+	defer resp.Body.Close()
+	responseBody, err := io.ReadAll(resp.Body)
+	if err != nil {
+		return fmt.Errorf("readAll failed for task %s: %w", taskId, err)
+	}
+
+	logger.LogDebug(ctx, fmt.Sprintf("updateVideoSingleTask response: %s", string(responseBody)))
+
+	taskResult := &relaycommon.TaskInfo{}
+	// try parse as New API response format
+	var responseItems dto.TaskResponse[model.Task]
+	if err = common.Unmarshal(responseBody, &responseItems); err == nil && responseItems.IsSuccess() {
+		logger.LogDebug(ctx, fmt.Sprintf("updateVideoSingleTask parsed as new api response format: %+v", responseItems))
+		t := responseItems.Data
+		taskResult.TaskID = t.TaskID
+		taskResult.Status = string(t.Status)
+		taskResult.Url = t.GetResultURL()
+		taskResult.Progress = t.Progress
+		taskResult.Reason = t.FailReason
+		task.Data = t.Data
+	} else if taskResult, err = adaptor.ParseTaskResult(responseBody); err != nil {
+		return fmt.Errorf("parseTaskResult failed for task %s: %w", taskId, err)
+	} else {
+		task.Data = redactVideoResponseBody(responseBody)
+	}
+
+	logger.LogDebug(ctx, fmt.Sprintf("updateVideoSingleTask taskResult: %+v", taskResult))
+
+	now := time.Now().Unix()
+	if taskResult.Status == "" {
+		taskResult = relaycommon.FailTaskInfo("upstream returned empty status")
+	}
+
+	// 记录原本的状态,防止重复退款
+	shouldRefund := false
+	quota := task.Quota
+	preStatus := task.Status
+
+	task.Status = model.TaskStatus(taskResult.Status)
+	switch taskResult.Status {
+	case model.TaskStatusSubmitted:
+		task.Progress = taskcommon.ProgressSubmitted
+	case model.TaskStatusQueued:
+		task.Progress = taskcommon.ProgressQueued
+	case model.TaskStatusInProgress:
+		task.Progress = taskcommon.ProgressInProgress
+		if task.StartTime == 0 {
+			task.StartTime = now
+		}
+	case model.TaskStatusSuccess:
+		task.Progress = taskcommon.ProgressComplete
+		if task.FinishTime == 0 {
+			task.FinishTime = now
+		}
+		if strings.HasPrefix(taskResult.Url, "data:") {
+			// data: URI (e.g. Vertex base64 encoded video) — keep in Data, not in ResultURL
+		} else if taskResult.Url != "" {
+			// Direct upstream URL (e.g. Kling, Ali, Doubao, etc.)
+			task.PrivateData.ResultURL = taskResult.Url
+		} else {
+			// No URL from adaptor — construct proxy URL using public task ID
+			task.PrivateData.ResultURL = taskcommon.BuildProxyURL(task.TaskID)
+		}
+
+		// 如果返回了 total_tokens,根据模型倍率重新计费
+		if taskResult.TotalTokens > 0 {
+			RecalculateTaskQuotaByTokens(ctx, task, taskResult.TotalTokens)
+		}
+	case model.TaskStatusFailure:
+		logger.LogJson(ctx, fmt.Sprintf("Task %s failed", taskId), task)
+		task.Status = model.TaskStatusFailure
+		task.Progress = taskcommon.ProgressComplete
+		if task.FinishTime == 0 {
+			task.FinishTime = now
+		}
+		task.FailReason = taskResult.Reason
+		logger.LogInfo(ctx, fmt.Sprintf("Task %s failed: %s", task.TaskID, task.FailReason))
+		taskResult.Progress = taskcommon.ProgressComplete
+		if quota != 0 {
+			if preStatus != model.TaskStatusFailure {
+				shouldRefund = true
+			} else {
+				logger.LogWarn(ctx, fmt.Sprintf("Task %s already in failure status, skip refund", task.TaskID))
+			}
+		}
+	default:
+		return fmt.Errorf("unknown task status %s for task %s", taskResult.Status, taskId)
+	}
+	if taskResult.Progress != "" {
+		task.Progress = taskResult.Progress
+	}
+	if err := task.Update(); err != nil {
+		common.SysLog("UpdateVideoTask task error: " + err.Error())
+		shouldRefund = false
+	}
+
+	if shouldRefund {
+		RefundTaskQuota(ctx, task, task.FailReason)
+	}
+
+	return nil
+}
+
+func redactVideoResponseBody(body []byte) []byte {
+	var m map[string]any
+	if err := common.Unmarshal(body, &m); err != nil {
+		return body
+	}
+	resp, _ := m["response"].(map[string]any)
+	if resp != nil {
+		delete(resp, "bytesBase64Encoded")
+		if v, ok := resp["video"].(string); ok {
+			resp["video"] = truncateBase64(v)
+		}
+		if vs, ok := resp["videos"].([]any); ok {
+			for i := range vs {
+				if vm, ok := vs[i].(map[string]any); ok {
+					delete(vm, "bytesBase64Encoded")
+				}
+			}
+		}
+	}
+	b, err := common.Marshal(m)
+	if err != nil {
+		return body
+	}
+	return b
+}
+
+func truncateBase64(s string) string {
+	const maxKeep = 256
+	if len(s) <= maxKeep {
+		return s
+	}
+	return s[:maxKeep] + "..."
+}

+ 2 - 7
types/price_data.go

@@ -22,7 +22,8 @@ type PriceData struct {
 	AudioCompletionRatio float64
 	OtherRatios          map[string]float64
 	UsePrice             bool
-	QuotaToPreConsume    int // 预消耗额度
+	Quota                int // 按次计费的最终额度(MJ / Task)
+	QuotaToPreConsume    int // 按量计费的预消耗额度
 	GroupRatioInfo       GroupRatioInfo
 }
 
@@ -36,12 +37,6 @@ func (p *PriceData) AddOtherRatio(key string, ratio float64) {
 	p.OtherRatios[key] = ratio
 }
 
-type PerCallPriceData struct {
-	ModelPrice     float64
-	Quota          int
-	GroupRatioInfo GroupRatioInfo
-}
-
 func (p *PriceData) ToSetting() string {
 	return fmt.Sprintf("ModelPrice: %f, ModelRatio: %f, CompletionRatio: %f, CacheRatio: %f, GroupRatio: %f, UsePrice: %t, CacheCreationRatio: %f, CacheCreation5mRatio: %f, CacheCreation1hRatio: %f, QuotaToPreConsume: %d, ImageRatio: %f, AudioRatio: %f, AudioCompletionRatio: %f", p.ModelPrice, p.ModelRatio, p.CompletionRatio, p.CacheRatio, p.GroupRatioInfo.GroupRatio, p.UsePrice, p.CacheCreationRatio, p.CacheCreation5mRatio, p.CacheCreation1hRatio, p.QuotaToPreConsume, p.ImageRatio, p.AudioRatio, p.AudioCompletionRatio)
 }

+ 5 - 4
web/src/components/table/task-logs/TaskLogsColumnDefs.jsx

@@ -396,7 +396,7 @@ export const getTaskLogsColumns = ({
       dataIndex: 'fail_reason',
       fixed: 'right',
       render: (text, record, index) => {
-        // 仅当为视频生成任务且成功,且 fail_reason 是 URL 时显示可点击链接
+        // 视频预览:优先使用 result_url,兼容旧数据 fail_reason 中的 URL
         const isVideoTask =
           record.action === TASK_ACTION_GENERATE ||
           record.action === TASK_ACTION_TEXT_GENERATE ||
@@ -404,14 +404,15 @@ export const getTaskLogsColumns = ({
           record.action === TASK_ACTION_REFERENCE_GENERATE ||
           record.action === TASK_ACTION_REMIX_GENERATE;
         const isSuccess = record.status === 'SUCCESS';
-        const isUrl = typeof text === 'string' && /^https?:\/\//.test(text);
-        if (isSuccess && isVideoTask && isUrl) {
+        const resultUrl = record.result_url;
+        const hasResultUrl = typeof resultUrl === 'string' && /^https?:\/\//.test(resultUrl);
+        if (isSuccess && isVideoTask && hasResultUrl) {
           return (
             <a
               href='#'
               onClick={(e) => {
                 e.preventDefault();
-                openVideoModal(text);
+                openVideoModal(resultUrl);
               }}
             >
               {t('点击预览视频')}

+ 0 - 2
web/src/components/table/task-logs/modals/ContentModal.jsx

@@ -144,8 +144,6 @@ const ContentModal = ({
             maxHeight: '100%',
             objectFit: 'contain',
           }}
-          autoPlay
-          crossOrigin='anonymous'
           onError={handleVideoError}
           onLoadedData={handleVideoLoaded}
           onLoadStart={() => setIsLoading(true)}