瀏覽代碼

Merge pull request #2985 from QuantumNous/refactor/async-task-merge

refactor: async task
Calcium-Ion 1 周之前
父節點
當前提交
902661df3f
共有 54 個文件被更改,包括 3056 次插入1421 次删除
  1. 9 1
      common/gin.go
  2. 8 9
      controller/midjourney.go
  3. 112 37
      controller/relay.go
  4. 33 215
      controller/task.go
  5. 0 313
      controller/task_video.go
  6. 30 81
      controller/video_proxy.go
  7. 4 4
      controller/video_proxy_gemini.go
  8. 0 32
      dto/suno.go
  9. 47 0
      dto/task.go
  10. 1 2
      logger/logger.go
  11. 10 0
      main.go
  12. 18 0
      middleware/auth.go
  13. 43 0
      model/log.go
  14. 13 0
      model/midjourney.go
  15. 105 60
      model/task.go
  16. 217 0
      model/task_cas_test.go
  17. 3 3
      model/token.go
  18. 28 2
      relay/channel/adapter.go
  19. 44 24
      relay/channel/task/ali/adaptor.go
  20. 15 16
      relay/channel/task/doubao/adaptor.go
  21. 17 33
      relay/channel/task/gemini/adaptor.go
  22. 13 12
      relay/channel/task/hailuo/adaptor.go
  23. 14 20
      relay/channel/task/jimeng/adaptor.go
  24. 17 36
      relay/channel/task/kling/adaptor.go
  25. 116 15
      relay/channel/task/sora/adaptor.go
  26. 17 19
      relay/channel/task/suno/adaptor.go
  27. 95 0
      relay/channel/task/taskcommon/helpers.go
  28. 47 46
      relay/channel/task/vertex/adaptor.go
  29. 15 35
      relay/channel/task/vidu/adaptor.go
  30. 17 3
      relay/common/relay_info.go
  31. 2 8
      relay/common/relay_utils.go
  32. 13 2
      relay/helper/price.go
  33. 324 280
      relay/relay_task.go
  34. 2 2
      router/relay-router.go
  35. 11 5
      router/video-router.go
  36. 5 0
      service/billing_session.go
  37. 13 0
      service/error.go
  38. 1 1
      service/log_info_generate.go
  39. 285 0
      service/task_billing.go
  40. 712 0
      service/task_billing_test.go
  41. 486 0
      service/task_polling.go
  42. 2 7
      types/price_data.go
  43. 18 27
      web/src/components/table/task-logs/TaskLogsColumnDefs.jsx
  44. 0 2
      web/src/components/table/task-logs/index.jsx
  45. 0 2
      web/src/components/table/task-logs/modals/ContentModal.jsx
  46. 23 7
      web/src/components/table/usage-logs/UsageLogsColumnDefs.jsx
  47. 1 0
      web/src/components/table/usage-logs/UsageLogsFilters.jsx
  48. 21 3
      web/src/hooks/usage-logs/useUsageLogsData.jsx
  49. 5 43
      web/src/i18n/locales/en.json
  50. 5 2
      web/src/i18n/locales/fr.json
  51. 5 2
      web/src/i18n/locales/ja.json
  52. 5 2
      web/src/i18n/locales/ru.json
  53. 4 2
      web/src/i18n/locales/vi.json
  54. 5 6
      web/src/i18n/locales/zh-CN.json

+ 9 - 1
common/gin.go

@@ -243,7 +243,15 @@ func ParseMultipartFormReusable(c *gin.Context) (*multipart.Form, error) {
 		return nil, err
 	}
 
-	contentType := c.Request.Header.Get("Content-Type")
+	// Use the original Content-Type saved on first call to avoid boundary
+	// mismatch when callers overwrite c.Request.Header after multipart rebuild.
+	var contentType string
+	if saved, ok := c.Get("_original_multipart_ct"); ok {
+		contentType = saved.(string)
+	} else {
+		contentType = c.Request.Header.Get("Content-Type")
+		c.Set("_original_multipart_ct", contentType)
+	}
 	boundary, err := parseBoundary(contentType)
 	if err != nil {
 		return nil, err

+ 8 - 9
controller/midjourney.go

@@ -130,6 +130,7 @@ func UpdateMidjourneyTaskBulk() {
 				if !checkMjTaskNeedUpdate(task, responseItem) {
 					continue
 				}
+				preStatus := task.Status
 				task.Code = 1
 				task.Progress = responseItem.Progress
 				task.PromptEn = responseItem.PromptEn
@@ -172,18 +173,16 @@ func UpdateMidjourneyTaskBulk() {
 						shouldReturnQuota = true
 					}
 				}
-				err = task.Update()
+				won, err := task.UpdateWithStatus(preStatus)
 				if err != nil {
 					logger.LogError(ctx, "UpdateMidjourneyTask task error: "+err.Error())
-				} else {
-					if shouldReturnQuota {
-						err = model.IncreaseUserQuota(task.UserId, task.Quota, false)
-						if err != nil {
-							logger.LogError(ctx, "fail to increase user quota: "+err.Error())
-						}
-						logContent := fmt.Sprintf("构图失败 %s,补偿 %s", task.MjId, logger.LogQuota(task.Quota))
-						model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
+				} else if won && shouldReturnQuota {
+					err = model.IncreaseUserQuota(task.UserId, task.Quota, false)
+					if err != nil {
+						logger.LogError(ctx, "fail to increase user quota: "+err.Error())
 					}
+					logContent := fmt.Sprintf("构图失败 %s,补偿 %s", task.MjId, logger.LogQuota(task.Quota))
+					model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
 				}
 			}
 		}

+ 112 - 37
controller/relay.go

@@ -450,72 +450,147 @@ func RelayNotFound(c *gin.Context) {
 	})
 }
 
+func RelayTaskFetch(c *gin.Context) {
+	relayInfo, err := relaycommon.GenRelayInfo(c, types.RelayFormatTask, nil, nil)
+	if err != nil {
+		c.JSON(http.StatusInternalServerError, &dto.TaskError{
+			Code:       "gen_relay_info_failed",
+			Message:    err.Error(),
+			StatusCode: http.StatusInternalServerError,
+		})
+		return
+	}
+	if taskErr := relay.RelayTaskFetch(c, relayInfo.RelayMode); taskErr != nil {
+		respondTaskError(c, taskErr)
+	}
+}
+
 func RelayTask(c *gin.Context) {
-	retryTimes := common.RetryTimes
-	channelId := c.GetInt("channel_id")
-	c.Set("use_channel", []string{fmt.Sprintf("%d", channelId)})
 	relayInfo, err := relaycommon.GenRelayInfo(c, types.RelayFormatTask, nil, nil)
 	if err != nil {
+		c.JSON(http.StatusInternalServerError, &dto.TaskError{
+			Code:       "gen_relay_info_failed",
+			Message:    err.Error(),
+			StatusCode: http.StatusInternalServerError,
+		})
 		return
 	}
-	taskErr := taskRelayHandler(c, relayInfo)
-	if taskErr == nil {
-		retryTimes = 0
+
+	if taskErr := relay.ResolveOriginTask(c, relayInfo); taskErr != nil {
+		respondTaskError(c, taskErr)
+		return
 	}
+
+	var result *relay.TaskSubmitResult
+	var taskErr *dto.TaskError
+	defer func() {
+		if taskErr != nil && relayInfo.Billing != nil {
+			relayInfo.Billing.Refund(c)
+		}
+	}()
+
 	retryParam := &service.RetryParam{
 		Ctx:        c,
 		TokenGroup: relayInfo.TokenGroup,
 		ModelName:  relayInfo.OriginModelName,
 		Retry:      common.GetPointer(0),
 	}
-	for ; shouldRetryTaskRelay(c, channelId, taskErr, retryTimes) && retryParam.GetRetry() < retryTimes; retryParam.IncreaseRetry() {
-		channel, newAPIError := getChannel(c, relayInfo, retryParam)
-		if newAPIError != nil {
-			logger.LogError(c, fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", newAPIError.Error()))
-			taskErr = service.TaskErrorWrapperLocal(newAPIError.Err, "get_channel_failed", http.StatusInternalServerError)
-			break
+
+	for ; retryParam.GetRetry() <= common.RetryTimes; retryParam.IncreaseRetry() {
+		var channel *model.Channel
+
+		if lockedCh, ok := relayInfo.LockedChannel.(*model.Channel); ok && lockedCh != nil {
+			channel = lockedCh
+			if retryParam.GetRetry() > 0 {
+				if setupErr := middleware.SetupContextForSelectedChannel(c, channel, relayInfo.OriginModelName); setupErr != nil {
+					taskErr = service.TaskErrorWrapperLocal(setupErr.Err, "setup_locked_channel_failed", http.StatusInternalServerError)
+					break
+				}
+			}
+		} else {
+			var channelErr *types.NewAPIError
+			channel, channelErr = getChannel(c, relayInfo, retryParam)
+			if channelErr != nil {
+				logger.LogError(c, channelErr.Error())
+				taskErr = service.TaskErrorWrapperLocal(channelErr.Err, "get_channel_failed", http.StatusInternalServerError)
+				break
+			}
 		}
-		channelId = channel.Id
-		useChannel := c.GetStringSlice("use_channel")
-		useChannel = append(useChannel, fmt.Sprintf("%d", channelId))
-		c.Set("use_channel", useChannel)
-		logger.LogInfo(c, fmt.Sprintf("using channel #%d to retry (remain times %d)", channel.Id, retryParam.GetRetry()))
-		//middleware.SetupContextForSelectedChannel(c, channel, originalModel)
-
-		bodyStorage, err := common.GetBodyStorage(c)
-		if err != nil {
-			if common.IsRequestBodyTooLargeError(err) || errors.Is(err, common.ErrRequestBodyTooLarge) {
-				taskErr = service.TaskErrorWrapperLocal(err, "read_request_body_failed", http.StatusRequestEntityTooLarge)
+
+		addUsedChannel(c, channel.Id)
+		bodyStorage, bodyErr := common.GetBodyStorage(c)
+		if bodyErr != nil {
+			if common.IsRequestBodyTooLargeError(bodyErr) || errors.Is(bodyErr, common.ErrRequestBodyTooLarge) {
+				taskErr = service.TaskErrorWrapperLocal(bodyErr, "read_request_body_failed", http.StatusRequestEntityTooLarge)
 			} else {
-				taskErr = service.TaskErrorWrapperLocal(err, "read_request_body_failed", http.StatusBadRequest)
+				taskErr = service.TaskErrorWrapperLocal(bodyErr, "read_request_body_failed", http.StatusBadRequest)
 			}
 			break
 		}
 		c.Request.Body = io.NopCloser(bodyStorage)
-		taskErr = taskRelayHandler(c, relayInfo)
+
+		result, taskErr = relay.RelayTaskSubmit(c, relayInfo)
+		if taskErr == nil {
+			break
+		}
+
+		if !taskErr.LocalError {
+			processChannelError(c,
+				*types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey,
+					common.GetContextKeyString(c, constant.ContextKeyChannelKey), channel.GetAutoBan()),
+				types.NewOpenAIError(taskErr.Error, types.ErrorCodeBadResponseStatusCode, taskErr.StatusCode))
+		}
+
+		if !shouldRetryTaskRelay(c, channel.Id, taskErr, common.RetryTimes-retryParam.GetRetry()) {
+			break
+		}
 	}
+
 	useChannel := c.GetStringSlice("use_channel")
 	if len(useChannel) > 1 {
 		retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]"))
 		logger.LogInfo(c, retryLogStr)
 	}
-	if taskErr != nil {
-		if taskErr.StatusCode == http.StatusTooManyRequests {
-			taskErr.Message = "当前分组上游负载已饱和,请稍后再试"
+
+	// ── 成功:结算 + 日志 + 插入任务 ──
+	if taskErr == nil {
+		if settleErr := service.SettleBilling(c, relayInfo, result.Quota); settleErr != nil {
+			common.SysError("settle task billing error: " + settleErr.Error())
 		}
-		c.JSON(taskErr.StatusCode, taskErr)
+		service.LogTaskConsumption(c, relayInfo)
+
+		task := model.InitTask(result.Platform, relayInfo)
+		task.PrivateData.UpstreamTaskID = result.UpstreamTaskID
+		task.PrivateData.BillingSource = relayInfo.BillingSource
+		task.PrivateData.SubscriptionId = relayInfo.SubscriptionId
+		task.PrivateData.TokenId = relayInfo.TokenId
+		task.PrivateData.BillingContext = &model.TaskBillingContext{
+			ModelPrice:      relayInfo.PriceData.ModelPrice,
+			GroupRatio:      relayInfo.PriceData.GroupRatioInfo.GroupRatio,
+			ModelRatio:      relayInfo.PriceData.ModelRatio,
+			OtherRatios:     relayInfo.PriceData.OtherRatios,
+			OriginModelName: relayInfo.OriginModelName,
+			PerCallBilling:  common.StringsContains(constant.TaskPricePatches, relayInfo.OriginModelName),
+		}
+		task.Quota = result.Quota
+		task.Data = result.TaskData
+		task.Action = relayInfo.Action
+		if insertErr := task.Insert(); insertErr != nil {
+			common.SysError("insert task error: " + insertErr.Error())
+		}
+	}
+
+	if taskErr != nil {
+		respondTaskError(c, taskErr)
 	}
 }
 
-func taskRelayHandler(c *gin.Context, relayInfo *relaycommon.RelayInfo) *dto.TaskError {
-	var err *dto.TaskError
-	switch relayInfo.RelayMode {
-	case relayconstant.RelayModeSunoFetch, relayconstant.RelayModeSunoFetchByID, relayconstant.RelayModeVideoFetchByID:
-		err = relay.RelayTaskFetch(c, relayInfo.RelayMode)
-	default:
-		err = relay.RelayTaskSubmit(c, relayInfo)
+// respondTaskError 统一输出 Task 错误响应(含 429 限流提示改写)
+func respondTaskError(c *gin.Context, taskErr *dto.TaskError) {
+	if taskErr.StatusCode == http.StatusTooManyRequests {
+		taskErr.Message = "当前分组上游负载已饱和,请稍后再试"
 	}
-	return err
+	c.JSON(taskErr.StatusCode, taskErr)
 }
 
 func shouldRetryTaskRelay(c *gin.Context, channelId int, taskErr *dto.TaskError, retryTimes int) bool {

+ 33 - 215
controller/task.go

@@ -1,231 +1,22 @@
 package controller
 
 import (
-	"context"
-	"encoding/json"
-	"errors"
-	"fmt"
-	"io"
-	"net/http"
-	"sort"
 	"strconv"
-	"time"
 
 	"github.com/QuantumNous/new-api/common"
 	"github.com/QuantumNous/new-api/constant"
 	"github.com/QuantumNous/new-api/dto"
-	"github.com/QuantumNous/new-api/logger"
 	"github.com/QuantumNous/new-api/model"
 	"github.com/QuantumNous/new-api/relay"
+	"github.com/QuantumNous/new-api/service"
+	"github.com/QuantumNous/new-api/types"
 
 	"github.com/gin-gonic/gin"
-	"github.com/samber/lo"
 )
 
+// UpdateTaskBulk 薄入口,实际轮询逻辑在 service 层
 func UpdateTaskBulk() {
-	//revocer
-	//imageModel := "midjourney"
-	for {
-		time.Sleep(time.Duration(15) * time.Second)
-		common.SysLog("任务进度轮询开始")
-		ctx := context.TODO()
-		allTasks := model.GetAllUnFinishSyncTasks(constant.TaskQueryLimit)
-		platformTask := make(map[constant.TaskPlatform][]*model.Task)
-		for _, t := range allTasks {
-			platformTask[t.Platform] = append(platformTask[t.Platform], t)
-		}
-		for platform, tasks := range platformTask {
-			if len(tasks) == 0 {
-				continue
-			}
-			taskChannelM := make(map[int][]string)
-			taskM := make(map[string]*model.Task)
-			nullTaskIds := make([]int64, 0)
-			for _, task := range tasks {
-				if task.TaskID == "" {
-					// 统计失败的未完成任务
-					nullTaskIds = append(nullTaskIds, task.ID)
-					continue
-				}
-				taskM[task.TaskID] = task
-				taskChannelM[task.ChannelId] = append(taskChannelM[task.ChannelId], task.TaskID)
-			}
-			if len(nullTaskIds) > 0 {
-				err := model.TaskBulkUpdateByID(nullTaskIds, map[string]any{
-					"status":   "FAILURE",
-					"progress": "100%",
-				})
-				if err != nil {
-					logger.LogError(ctx, fmt.Sprintf("Fix null task_id task error: %v", err))
-				} else {
-					logger.LogInfo(ctx, fmt.Sprintf("Fix null task_id task success: %v", nullTaskIds))
-				}
-			}
-			if len(taskChannelM) == 0 {
-				continue
-			}
-
-			UpdateTaskByPlatform(platform, taskChannelM, taskM)
-		}
-		common.SysLog("任务进度轮询完成")
-	}
-}
-
-func UpdateTaskByPlatform(platform constant.TaskPlatform, taskChannelM map[int][]string, taskM map[string]*model.Task) {
-	switch platform {
-	case constant.TaskPlatformMidjourney:
-		//_ = UpdateMidjourneyTaskAll(context.Background(), tasks)
-	case constant.TaskPlatformSuno:
-		_ = UpdateSunoTaskAll(context.Background(), taskChannelM, taskM)
-	default:
-		if err := UpdateVideoTaskAll(context.Background(), platform, taskChannelM, taskM); err != nil {
-			common.SysLog(fmt.Sprintf("UpdateVideoTaskAll fail: %s", err))
-		}
-	}
-}
-
-func UpdateSunoTaskAll(ctx context.Context, taskChannelM map[int][]string, taskM map[string]*model.Task) error {
-	for channelId, taskIds := range taskChannelM {
-		err := updateSunoTaskAll(ctx, channelId, taskIds, taskM)
-		if err != nil {
-			logger.LogError(ctx, fmt.Sprintf("渠道 #%d 更新异步任务失败: %s", channelId, err.Error()))
-		}
-	}
-	return nil
-}
-
-func updateSunoTaskAll(ctx context.Context, channelId int, taskIds []string, taskM map[string]*model.Task) error {
-	logger.LogInfo(ctx, fmt.Sprintf("渠道 #%d 未完成的任务有: %d", channelId, len(taskIds)))
-	if len(taskIds) == 0 {
-		return nil
-	}
-	channel, err := model.CacheGetChannel(channelId)
-	if err != nil {
-		common.SysLog(fmt.Sprintf("CacheGetChannel: %v", err))
-		err = model.TaskBulkUpdate(taskIds, map[string]any{
-			"fail_reason": fmt.Sprintf("获取渠道信息失败,请联系管理员,渠道ID:%d", channelId),
-			"status":      "FAILURE",
-			"progress":    "100%",
-		})
-		if err != nil {
-			common.SysLog(fmt.Sprintf("UpdateMidjourneyTask error2: %v", err))
-		}
-		return err
-	}
-	adaptor := relay.GetTaskAdaptor(constant.TaskPlatformSuno)
-	if adaptor == nil {
-		return errors.New("adaptor not found")
-	}
-	proxy := channel.GetSetting().Proxy
-	resp, err := adaptor.FetchTask(*channel.BaseURL, channel.Key, map[string]any{
-		"ids": taskIds,
-	}, proxy)
-	if err != nil {
-		common.SysLog(fmt.Sprintf("Get Task Do req error: %v", err))
-		return err
-	}
-	if resp.StatusCode != http.StatusOK {
-		logger.LogError(ctx, fmt.Sprintf("Get Task status code: %d", resp.StatusCode))
-		return errors.New(fmt.Sprintf("Get Task status code: %d", resp.StatusCode))
-	}
-	defer resp.Body.Close()
-	responseBody, err := io.ReadAll(resp.Body)
-	if err != nil {
-		common.SysLog(fmt.Sprintf("Get Task parse body error: %v", err))
-		return err
-	}
-	var responseItems dto.TaskResponse[[]dto.SunoDataResponse]
-	err = json.Unmarshal(responseBody, &responseItems)
-	if err != nil {
-		logger.LogError(ctx, fmt.Sprintf("Get Task parse body error2: %v, body: %s", err, string(responseBody)))
-		return err
-	}
-	if !responseItems.IsSuccess() {
-		common.SysLog(fmt.Sprintf("渠道 #%d 未完成的任务有: %d, 成功获取到任务数: %s", channelId, len(taskIds), string(responseBody)))
-		return err
-	}
-
-	for _, responseItem := range responseItems.Data {
-		task := taskM[responseItem.TaskID]
-		if !checkTaskNeedUpdate(task, responseItem) {
-			continue
-		}
-
-		task.Status = lo.If(model.TaskStatus(responseItem.Status) != "", model.TaskStatus(responseItem.Status)).Else(task.Status)
-		task.FailReason = lo.If(responseItem.FailReason != "", responseItem.FailReason).Else(task.FailReason)
-		task.SubmitTime = lo.If(responseItem.SubmitTime != 0, responseItem.SubmitTime).Else(task.SubmitTime)
-		task.StartTime = lo.If(responseItem.StartTime != 0, responseItem.StartTime).Else(task.StartTime)
-		task.FinishTime = lo.If(responseItem.FinishTime != 0, responseItem.FinishTime).Else(task.FinishTime)
-		if responseItem.FailReason != "" || task.Status == model.TaskStatusFailure {
-			logger.LogInfo(ctx, task.TaskID+" 构建失败,"+task.FailReason)
-			task.Progress = "100%"
-			//err = model.CacheUpdateUserQuota(task.UserId) ?
-			if err != nil {
-				logger.LogError(ctx, "error update user quota cache: "+err.Error())
-			} else {
-				quota := task.Quota
-				if quota != 0 {
-					err = model.IncreaseUserQuota(task.UserId, quota, false)
-					if err != nil {
-						logger.LogError(ctx, "fail to increase user quota: "+err.Error())
-					}
-					logContent := fmt.Sprintf("异步任务执行失败 %s,补偿 %s", task.TaskID, logger.LogQuota(quota))
-					model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
-				}
-			}
-		}
-		if responseItem.Status == model.TaskStatusSuccess {
-			task.Progress = "100%"
-		}
-		task.Data = responseItem.Data
-
-		err = task.Update()
-		if err != nil {
-			common.SysLog("UpdateMidjourneyTask task error: " + err.Error())
-		}
-	}
-	return nil
-}
-
-func checkTaskNeedUpdate(oldTask *model.Task, newTask dto.SunoDataResponse) bool {
-
-	if oldTask.SubmitTime != newTask.SubmitTime {
-		return true
-	}
-	if oldTask.StartTime != newTask.StartTime {
-		return true
-	}
-	if oldTask.FinishTime != newTask.FinishTime {
-		return true
-	}
-	if string(oldTask.Status) != newTask.Status {
-		return true
-	}
-	if oldTask.FailReason != newTask.FailReason {
-		return true
-	}
-	if oldTask.FinishTime != newTask.FinishTime {
-		return true
-	}
-
-	if (oldTask.Status == model.TaskStatusFailure || oldTask.Status == model.TaskStatusSuccess) && oldTask.Progress != "100%" {
-		return true
-	}
-
-	oldData, _ := json.Marshal(oldTask.Data)
-	newData, _ := json.Marshal(newTask.Data)
-
-	sort.Slice(oldData, func(i, j int) bool {
-		return oldData[i] < oldData[j]
-	})
-	sort.Slice(newData, func(i, j int) bool {
-		return newData[i] < newData[j]
-	})
-
-	if string(oldData) != string(newData) {
-		return true
-	}
-	return false
+	service.TaskPollingLoop()
 }
 
 func GetAllTask(c *gin.Context) {
@@ -247,7 +38,7 @@ func GetAllTask(c *gin.Context) {
 	items := model.TaskGetAllTasks(pageInfo.GetStartIdx(), pageInfo.GetPageSize(), queryParams)
 	total := model.TaskCountAllTasks(queryParams)
 	pageInfo.SetTotal(int(total))
-	pageInfo.SetItems(items)
+	pageInfo.SetItems(tasksToDto(items, true))
 	common.ApiSuccess(c, pageInfo)
 }
 
@@ -271,6 +62,33 @@ func GetUserTask(c *gin.Context) {
 	items := model.TaskGetAllUserTask(userId, pageInfo.GetStartIdx(), pageInfo.GetPageSize(), queryParams)
 	total := model.TaskCountAllUserTask(userId, queryParams)
 	pageInfo.SetTotal(int(total))
-	pageInfo.SetItems(items)
+	pageInfo.SetItems(tasksToDto(items, false))
 	common.ApiSuccess(c, pageInfo)
 }
+
+func tasksToDto(tasks []*model.Task, fillUser bool) []*dto.TaskDto {
+	var userIdMap map[int]*model.UserBase
+	if fillUser {
+		userIdMap = make(map[int]*model.UserBase)
+		userIds := types.NewSet[int]()
+		for _, task := range tasks {
+			userIds.Add(task.UserId)
+		}
+		for _, userId := range userIds.Items() {
+			cacheUser, err := model.GetUserCache(userId)
+			if err == nil {
+				userIdMap[userId] = cacheUser
+			}
+		}
+	}
+	result := make([]*dto.TaskDto, len(tasks))
+	for i, task := range tasks {
+		if fillUser {
+			if user, ok := userIdMap[task.UserId]; ok {
+				task.Username = user.Username
+			}
+		}
+		result[i] = relay.TaskModel2Dto(task)
+	}
+	return result
+}

+ 0 - 313
controller/task_video.go

@@ -1,313 +0,0 @@
-package controller
-
-import (
-	"context"
-	"encoding/json"
-	"fmt"
-	"io"
-	"time"
-
-	"github.com/QuantumNous/new-api/common"
-	"github.com/QuantumNous/new-api/constant"
-	"github.com/QuantumNous/new-api/dto"
-	"github.com/QuantumNous/new-api/logger"
-	"github.com/QuantumNous/new-api/model"
-	"github.com/QuantumNous/new-api/relay"
-	"github.com/QuantumNous/new-api/relay/channel"
-	relaycommon "github.com/QuantumNous/new-api/relay/common"
-	"github.com/QuantumNous/new-api/setting/ratio_setting"
-)
-
-func UpdateVideoTaskAll(ctx context.Context, platform constant.TaskPlatform, taskChannelM map[int][]string, taskM map[string]*model.Task) error {
-	for channelId, taskIds := range taskChannelM {
-		if err := updateVideoTaskAll(ctx, platform, channelId, taskIds, taskM); err != nil {
-			logger.LogError(ctx, fmt.Sprintf("Channel #%d failed to update video async tasks: %s", channelId, err.Error()))
-		}
-	}
-	return nil
-}
-
-func updateVideoTaskAll(ctx context.Context, platform constant.TaskPlatform, channelId int, taskIds []string, taskM map[string]*model.Task) error {
-	logger.LogInfo(ctx, fmt.Sprintf("Channel #%d pending video tasks: %d", channelId, len(taskIds)))
-	if len(taskIds) == 0 {
-		return nil
-	}
-	cacheGetChannel, err := model.CacheGetChannel(channelId)
-	if err != nil {
-		errUpdate := model.TaskBulkUpdate(taskIds, map[string]any{
-			"fail_reason": fmt.Sprintf("Failed to get channel info, channel ID: %d", channelId),
-			"status":      "FAILURE",
-			"progress":    "100%",
-		})
-		if errUpdate != nil {
-			common.SysLog(fmt.Sprintf("UpdateVideoTask error: %v", errUpdate))
-		}
-		return fmt.Errorf("CacheGetChannel failed: %w", err)
-	}
-	adaptor := relay.GetTaskAdaptor(platform)
-	if adaptor == nil {
-		return fmt.Errorf("video adaptor not found")
-	}
-	info := &relaycommon.RelayInfo{}
-	info.ChannelMeta = &relaycommon.ChannelMeta{
-		ChannelBaseUrl: cacheGetChannel.GetBaseURL(),
-	}
-	info.ApiKey = cacheGetChannel.Key
-	adaptor.Init(info)
-	for _, taskId := range taskIds {
-		if err := updateVideoSingleTask(ctx, adaptor, cacheGetChannel, taskId, taskM); err != nil {
-			logger.LogError(ctx, fmt.Sprintf("Failed to update video task %s: %s", taskId, err.Error()))
-		}
-	}
-	return nil
-}
-
-func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, channel *model.Channel, taskId string, taskM map[string]*model.Task) error {
-	baseURL := constant.ChannelBaseURLs[channel.Type]
-	if channel.GetBaseURL() != "" {
-		baseURL = channel.GetBaseURL()
-	}
-	proxy := channel.GetSetting().Proxy
-
-	task := taskM[taskId]
-	if task == nil {
-		logger.LogError(ctx, fmt.Sprintf("Task %s not found in taskM", taskId))
-		return fmt.Errorf("task %s not found", taskId)
-	}
-	key := channel.Key
-
-	privateData := task.PrivateData
-	if privateData.Key != "" {
-		key = privateData.Key
-	}
-	resp, err := adaptor.FetchTask(baseURL, key, map[string]any{
-		"task_id": taskId,
-		"action":  task.Action,
-	}, proxy)
-	if err != nil {
-		return fmt.Errorf("fetchTask failed for task %s: %w", taskId, err)
-	}
-	//if resp.StatusCode != http.StatusOK {
-	//return fmt.Errorf("get Video Task status code: %d", resp.StatusCode)
-	//}
-	defer resp.Body.Close()
-	responseBody, err := io.ReadAll(resp.Body)
-	if err != nil {
-		return fmt.Errorf("readAll failed for task %s: %w", taskId, err)
-	}
-
-	logger.LogDebug(ctx, fmt.Sprintf("UpdateVideoSingleTask response: %s", string(responseBody)))
-
-	taskResult := &relaycommon.TaskInfo{}
-	// try parse as New API response format
-	var responseItems dto.TaskResponse[model.Task]
-	if err = common.Unmarshal(responseBody, &responseItems); err == nil && responseItems.IsSuccess() {
-		logger.LogDebug(ctx, fmt.Sprintf("UpdateVideoSingleTask parsed as new api response format: %+v", responseItems))
-		t := responseItems.Data
-		taskResult.TaskID = t.TaskID
-		taskResult.Status = string(t.Status)
-		taskResult.Url = t.FailReason
-		taskResult.Progress = t.Progress
-		taskResult.Reason = t.FailReason
-		task.Data = t.Data
-	} else if taskResult, err = adaptor.ParseTaskResult(responseBody); err != nil {
-		return fmt.Errorf("parseTaskResult failed for task %s: %w", taskId, err)
-	} else {
-		task.Data = redactVideoResponseBody(responseBody)
-	}
-
-	logger.LogDebug(ctx, fmt.Sprintf("UpdateVideoSingleTask taskResult: %+v", taskResult))
-
-	now := time.Now().Unix()
-	if taskResult.Status == "" {
-		//return fmt.Errorf("task %s status is empty", taskId)
-		taskResult = relaycommon.FailTaskInfo("upstream returned empty status")
-	}
-
-	// 记录原本的状态,防止重复退款
-	shouldRefund := false
-	quota := task.Quota
-	preStatus := task.Status
-
-	task.Status = model.TaskStatus(taskResult.Status)
-	switch taskResult.Status {
-	case model.TaskStatusSubmitted:
-		task.Progress = "10%"
-	case model.TaskStatusQueued:
-		task.Progress = "20%"
-	case model.TaskStatusInProgress:
-		task.Progress = "30%"
-		if task.StartTime == 0 {
-			task.StartTime = now
-		}
-	case model.TaskStatusSuccess:
-		task.Progress = "100%"
-		if task.FinishTime == 0 {
-			task.FinishTime = now
-		}
-		if !(len(taskResult.Url) > 5 && taskResult.Url[:5] == "data:") {
-			task.FailReason = taskResult.Url
-		}
-
-		// 如果返回了 total_tokens 并且配置了模型倍率(非固定价格),则重新计费
-		if taskResult.TotalTokens > 0 {
-			// 获取模型名称
-			var taskData map[string]interface{}
-			if err := json.Unmarshal(task.Data, &taskData); err == nil {
-				if modelName, ok := taskData["model"].(string); ok && modelName != "" {
-					// 获取模型价格和倍率
-					modelRatio, hasRatioSetting, _ := ratio_setting.GetModelRatio(modelName)
-					// 只有配置了倍率(非固定价格)时才按 token 重新计费
-					if hasRatioSetting && modelRatio > 0 {
-						// 获取用户和组的倍率信息
-						group := task.Group
-						if group == "" {
-							user, err := model.GetUserById(task.UserId, false)
-							if err == nil {
-								group = user.Group
-							}
-						}
-						if group != "" {
-							groupRatio := ratio_setting.GetGroupRatio(group)
-							userGroupRatio, hasUserGroupRatio := ratio_setting.GetGroupGroupRatio(group, group)
-
-							var finalGroupRatio float64
-							if hasUserGroupRatio {
-								finalGroupRatio = userGroupRatio
-							} else {
-								finalGroupRatio = groupRatio
-							}
-
-							// 计算实际应扣费额度: totalTokens * modelRatio * groupRatio
-							actualQuota := int(float64(taskResult.TotalTokens) * modelRatio * finalGroupRatio)
-
-							// 计算差额
-							preConsumedQuota := task.Quota
-							quotaDelta := actualQuota - preConsumedQuota
-
-							if quotaDelta > 0 {
-								// 需要补扣费
-								logger.LogInfo(ctx, fmt.Sprintf("视频任务 %s 预扣费后补扣费:%s(实际消耗:%s,预扣费:%s,tokens:%d)",
-									task.TaskID,
-									logger.LogQuota(quotaDelta),
-									logger.LogQuota(actualQuota),
-									logger.LogQuota(preConsumedQuota),
-									taskResult.TotalTokens,
-								))
-								if err := model.DecreaseUserQuota(task.UserId, quotaDelta); err != nil {
-									logger.LogError(ctx, fmt.Sprintf("补扣费失败: %s", err.Error()))
-								} else {
-									model.UpdateUserUsedQuotaAndRequestCount(task.UserId, quotaDelta)
-									model.UpdateChannelUsedQuota(task.ChannelId, quotaDelta)
-									task.Quota = actualQuota // 更新任务记录的实际扣费额度
-
-									// 记录消费日志
-									logContent := fmt.Sprintf("视频任务成功补扣费,模型倍率 %.2f,分组倍率 %.2f,tokens %d,预扣费 %s,实际扣费 %s,补扣费 %s",
-										modelRatio, finalGroupRatio, taskResult.TotalTokens,
-										logger.LogQuota(preConsumedQuota), logger.LogQuota(actualQuota), logger.LogQuota(quotaDelta))
-									model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
-								}
-							} else if quotaDelta < 0 {
-								// 需要退还多扣的费用
-								refundQuota := -quotaDelta
-								logger.LogInfo(ctx, fmt.Sprintf("视频任务 %s 预扣费后返还:%s(实际消耗:%s,预扣费:%s,tokens:%d)",
-									task.TaskID,
-									logger.LogQuota(refundQuota),
-									logger.LogQuota(actualQuota),
-									logger.LogQuota(preConsumedQuota),
-									taskResult.TotalTokens,
-								))
-								if err := model.IncreaseUserQuota(task.UserId, refundQuota, false); err != nil {
-									logger.LogError(ctx, fmt.Sprintf("退还预扣费失败: %s", err.Error()))
-								} else {
-									task.Quota = actualQuota // 更新任务记录的实际扣费额度
-
-									// 记录退款日志
-									logContent := fmt.Sprintf("视频任务成功退还多扣费用,模型倍率 %.2f,分组倍率 %.2f,tokens %d,预扣费 %s,实际扣费 %s,退还 %s",
-										modelRatio, finalGroupRatio, taskResult.TotalTokens,
-										logger.LogQuota(preConsumedQuota), logger.LogQuota(actualQuota), logger.LogQuota(refundQuota))
-									model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
-								}
-							} else {
-								// quotaDelta == 0, 预扣费刚好准确
-								logger.LogInfo(ctx, fmt.Sprintf("视频任务 %s 预扣费准确(%s,tokens:%d)",
-									task.TaskID, logger.LogQuota(actualQuota), taskResult.TotalTokens))
-							}
-						}
-					}
-				}
-			}
-		}
-	case model.TaskStatusFailure:
-		logger.LogJson(ctx, fmt.Sprintf("Task %s failed", taskId), task)
-		task.Status = model.TaskStatusFailure
-		task.Progress = "100%"
-		if task.FinishTime == 0 {
-			task.FinishTime = now
-		}
-		task.FailReason = taskResult.Reason
-		logger.LogInfo(ctx, fmt.Sprintf("Task %s failed: %s", task.TaskID, task.FailReason))
-		taskResult.Progress = "100%"
-		if quota != 0 {
-			if preStatus != model.TaskStatusFailure {
-				shouldRefund = true
-			} else {
-				logger.LogWarn(ctx, fmt.Sprintf("Task %s already in failure status, skip refund", task.TaskID))
-			}
-		}
-	default:
-		return fmt.Errorf("unknown task status %s for task %s", taskResult.Status, taskId)
-	}
-	if taskResult.Progress != "" {
-		task.Progress = taskResult.Progress
-	}
-	if err := task.Update(); err != nil {
-		common.SysLog("UpdateVideoTask task error: " + err.Error())
-		shouldRefund = false
-	}
-
-	if shouldRefund {
-		// 任务失败且之前状态不是失败才退还额度,防止重复退还
-		if err := model.IncreaseUserQuota(task.UserId, quota, false); err != nil {
-			logger.LogWarn(ctx, "Failed to increase user quota: "+err.Error())
-		}
-		logContent := fmt.Sprintf("Video async task failed %s, refund %s", task.TaskID, logger.LogQuota(quota))
-		model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
-	}
-
-	return nil
-}
-
-func redactVideoResponseBody(body []byte) []byte {
-	var m map[string]any
-	if err := json.Unmarshal(body, &m); err != nil {
-		return body
-	}
-	resp, _ := m["response"].(map[string]any)
-	if resp != nil {
-		delete(resp, "bytesBase64Encoded")
-		if v, ok := resp["video"].(string); ok {
-			resp["video"] = truncateBase64(v)
-		}
-		if vs, ok := resp["videos"].([]any); ok {
-			for i := range vs {
-				if vm, ok := vs[i].(map[string]any); ok {
-					delete(vm, "bytesBase64Encoded")
-				}
-			}
-		}
-	}
-	b, err := json.Marshal(m)
-	if err != nil {
-		return body
-	}
-	return b
-}
-
-func truncateBase64(s string) string {
-	const maxKeep = 256
-	if len(s) <= maxKeep {
-		return s
-	}
-	return s[:maxKeep] + "..."
-}

+ 30 - 81
controller/video_proxy.go

@@ -16,59 +16,44 @@ import (
 	"github.com/gin-gonic/gin"
 )
 
+// videoProxyError returns a standardized OpenAI-style error response.
+func videoProxyError(c *gin.Context, status int, errType, message string) {
+	c.JSON(status, gin.H{
+		"error": gin.H{
+			"message": message,
+			"type":    errType,
+		},
+	})
+}
+
 func VideoProxy(c *gin.Context) {
 	taskID := c.Param("task_id")
 	if taskID == "" {
-		c.JSON(http.StatusBadRequest, gin.H{
-			"error": gin.H{
-				"message": "task_id is required",
-				"type":    "invalid_request_error",
-			},
-		})
+		videoProxyError(c, http.StatusBadRequest, "invalid_request_error", "task_id is required")
 		return
 	}
 
 	task, exists, err := model.GetByOnlyTaskId(taskID)
 	if err != nil {
 		logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to query task %s: %s", taskID, err.Error()))
-		c.JSON(http.StatusInternalServerError, gin.H{
-			"error": gin.H{
-				"message": "Failed to query task",
-				"type":    "server_error",
-			},
-		})
+		videoProxyError(c, http.StatusInternalServerError, "server_error", "Failed to query task")
 		return
 	}
 	if !exists || task == nil {
-		logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to get task %s: %v", taskID, err))
-		c.JSON(http.StatusNotFound, gin.H{
-			"error": gin.H{
-				"message": "Task not found",
-				"type":    "invalid_request_error",
-			},
-		})
+		videoProxyError(c, http.StatusNotFound, "invalid_request_error", "Task not found")
 		return
 	}
 
 	if task.Status != model.TaskStatusSuccess {
-		c.JSON(http.StatusBadRequest, gin.H{
-			"error": gin.H{
-				"message": fmt.Sprintf("Task is not completed yet, current status: %s", task.Status),
-				"type":    "invalid_request_error",
-			},
-		})
+		videoProxyError(c, http.StatusBadRequest, "invalid_request_error",
+			fmt.Sprintf("Task is not completed yet, current status: %s", task.Status))
 		return
 	}
 
 	channel, err := model.CacheGetChannel(task.ChannelId)
 	if err != nil {
-		logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to get task %s: not found", taskID))
-		c.JSON(http.StatusInternalServerError, gin.H{
-			"error": gin.H{
-				"message": "Failed to retrieve channel information",
-				"type":    "server_error",
-			},
-		})
+		logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to get channel for task %s: %s", taskID, err.Error()))
+		videoProxyError(c, http.StatusInternalServerError, "server_error", "Failed to retrieve channel information")
 		return
 	}
 	baseURL := channel.GetBaseURL()
@@ -81,12 +66,7 @@ func VideoProxy(c *gin.Context) {
 	client, err := service.GetHttpClientWithProxy(proxy)
 	if err != nil {
 		logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to create proxy client for task %s: %s", taskID, err.Error()))
-		c.JSON(http.StatusInternalServerError, gin.H{
-			"error": gin.H{
-				"message": "Failed to create proxy client",
-				"type":    "server_error",
-			},
-		})
+		videoProxyError(c, http.StatusInternalServerError, "server_error", "Failed to create proxy client")
 		return
 	}
 
@@ -95,12 +75,7 @@ func VideoProxy(c *gin.Context) {
 	req, err := http.NewRequestWithContext(ctx, http.MethodGet, "", nil)
 	if err != nil {
 		logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to create request: %s", err.Error()))
-		c.JSON(http.StatusInternalServerError, gin.H{
-			"error": gin.H{
-				"message": "Failed to create proxy request",
-				"type":    "server_error",
-			},
-		})
+		videoProxyError(c, http.StatusInternalServerError, "server_error", "Failed to create proxy request")
 		return
 	}
 
@@ -109,68 +84,43 @@ func VideoProxy(c *gin.Context) {
 		apiKey := task.PrivateData.Key
 		if apiKey == "" {
 			logger.LogError(c.Request.Context(), fmt.Sprintf("Missing stored API key for Gemini task %s", taskID))
-			c.JSON(http.StatusInternalServerError, gin.H{
-				"error": gin.H{
-					"message": "API key not stored for task",
-					"type":    "server_error",
-				},
-			})
+			videoProxyError(c, http.StatusInternalServerError, "server_error", "API key not stored for task")
 			return
 		}
-
 		videoURL, err = getGeminiVideoURL(channel, task, apiKey)
 		if err != nil {
 			logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to resolve Gemini video URL for task %s: %s", taskID, err.Error()))
-			c.JSON(http.StatusBadGateway, gin.H{
-				"error": gin.H{
-					"message": "Failed to resolve Gemini video URL",
-					"type":    "server_error",
-				},
-			})
+			videoProxyError(c, http.StatusBadGateway, "server_error", "Failed to resolve Gemini video URL")
 			return
 		}
 		req.Header.Set("x-goog-api-key", apiKey)
 	case constant.ChannelTypeOpenAI, constant.ChannelTypeSora:
-		videoURL = fmt.Sprintf("%s/v1/videos/%s/content", baseURL, task.TaskID)
+		videoURL = fmt.Sprintf("%s/v1/videos/%s/content", baseURL, task.GetUpstreamTaskID())
 		req.Header.Set("Authorization", "Bearer "+channel.Key)
 	default:
-		// Video URL is directly in task.FailReason
-		videoURL = task.FailReason
+		// Video URL is stored in PrivateData.ResultURL (fallback to FailReason for old data)
+		videoURL = task.GetResultURL()
 	}
 
 	req.URL, err = url.Parse(videoURL)
 	if err != nil {
 		logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to parse URL %s: %s", videoURL, err.Error()))
-		c.JSON(http.StatusInternalServerError, gin.H{
-			"error": gin.H{
-				"message": "Failed to create proxy request",
-				"type":    "server_error",
-			},
-		})
+		videoProxyError(c, http.StatusInternalServerError, "server_error", "Failed to create proxy request")
 		return
 	}
 
 	resp, err := client.Do(req)
 	if err != nil {
 		logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to fetch video from %s: %s", videoURL, err.Error()))
-		c.JSON(http.StatusBadGateway, gin.H{
-			"error": gin.H{
-				"message": "Failed to fetch video content",
-				"type":    "server_error",
-			},
-		})
+		videoProxyError(c, http.StatusBadGateway, "server_error", "Failed to fetch video content")
 		return
 	}
 	defer resp.Body.Close()
 
 	if resp.StatusCode != http.StatusOK {
 		logger.LogError(c.Request.Context(), fmt.Sprintf("Upstream returned status %d for %s", resp.StatusCode, videoURL))
-		c.JSON(http.StatusBadGateway, gin.H{
-			"error": gin.H{
-				"message": fmt.Sprintf("Upstream service returned status %d", resp.StatusCode),
-				"type":    "server_error",
-			},
-		})
+		videoProxyError(c, http.StatusBadGateway, "server_error",
+			fmt.Sprintf("Upstream service returned status %d", resp.StatusCode))
 		return
 	}
 
@@ -180,10 +130,9 @@ func VideoProxy(c *gin.Context) {
 		}
 	}
 
-	c.Writer.Header().Set("Cache-Control", "public, max-age=86400") // Cache for 24 hours
+	c.Writer.Header().Set("Cache-Control", "public, max-age=86400")
 	c.Writer.WriteHeader(resp.StatusCode)
-	_, err = io.Copy(c.Writer, resp.Body)
-	if err != nil {
+	if _, err = io.Copy(c.Writer, resp.Body); err != nil {
 		logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to stream video content: %s", err.Error()))
 	}
 }

+ 4 - 4
controller/video_proxy_gemini.go

@@ -1,12 +1,12 @@
 package controller
 
 import (
-	"encoding/json"
 	"fmt"
 	"io"
 	"strconv"
 	"strings"
 
+	"github.com/QuantumNous/new-api/common"
 	"github.com/QuantumNous/new-api/constant"
 	"github.com/QuantumNous/new-api/model"
 	"github.com/QuantumNous/new-api/relay"
@@ -37,7 +37,7 @@ func getGeminiVideoURL(channel *model.Channel, task *model.Task, apiKey string)
 
 	proxy := channel.GetSetting().Proxy
 	resp, err := adaptor.FetchTask(baseURL, apiKey, map[string]any{
-		"task_id": task.TaskID,
+		"task_id": task.GetUpstreamTaskID(),
 		"action":  task.Action,
 	}, proxy)
 	if err != nil {
@@ -71,7 +71,7 @@ func extractGeminiVideoURLFromTaskData(task *model.Task) string {
 		return ""
 	}
 	var payload map[string]any
-	if err := json.Unmarshal(task.Data, &payload); err != nil {
+	if err := common.Unmarshal(task.Data, &payload); err != nil {
 		return ""
 	}
 	return extractGeminiVideoURLFromMap(payload)
@@ -79,7 +79,7 @@ func extractGeminiVideoURLFromTaskData(task *model.Task) string {
 
 func extractGeminiVideoURLFromPayload(body []byte) string {
 	var payload map[string]any
-	if err := json.Unmarshal(body, &payload); err != nil {
+	if err := common.Unmarshal(body, &payload); err != nil {
 		return ""
 	}
 	return extractGeminiVideoURLFromMap(payload)

+ 0 - 32
dto/suno.go

@@ -4,10 +4,6 @@ import (
 	"encoding/json"
 )
 
-type TaskData interface {
-	SunoDataResponse | []SunoDataResponse | string | any
-}
-
 type SunoSubmitReq struct {
 	GptDescriptionPrompt string  `json:"gpt_description_prompt,omitempty"`
 	Prompt               string  `json:"prompt,omitempty"`
@@ -20,10 +16,6 @@ type SunoSubmitReq struct {
 	MakeInstrumental     bool    `json:"make_instrumental"`
 }
 
-type FetchReq struct {
-	IDs []string `json:"ids"`
-}
-
 type SunoDataResponse struct {
 	TaskID     string          `json:"task_id" gorm:"type:varchar(50);index"`
 	Action     string          `json:"action" gorm:"type:varchar(40);index"` // 任务类型, song, lyrics, description-mode
@@ -66,30 +58,6 @@ type SunoLyrics struct {
 	Text   string `json:"text"`
 }
 
-const TaskSuccessCode = "success"
-
-type TaskResponse[T TaskData] struct {
-	Code    string `json:"code"`
-	Message string `json:"message"`
-	Data    T      `json:"data"`
-}
-
-func (t *TaskResponse[T]) IsSuccess() bool {
-	return t.Code == TaskSuccessCode
-}
-
-type TaskDto struct {
-	TaskID     string          `json:"task_id"` // 第三方id,不一定有/ song id\ Task id
-	Action     string          `json:"action"`  // 任务类型, song, lyrics, description-mode
-	Status     string          `json:"status"`  // 任务状态, submitted, queueing, processing, success, failed
-	FailReason string          `json:"fail_reason"`
-	SubmitTime int64           `json:"submit_time"`
-	StartTime  int64           `json:"start_time"`
-	FinishTime int64           `json:"finish_time"`
-	Progress   string          `json:"progress"`
-	Data       json.RawMessage `json:"data"`
-}
-
 type SunoGoAPISubmitReq struct {
 	CustomMode bool `json:"custom_mode"`
 

+ 47 - 0
dto/task.go

@@ -1,5 +1,9 @@
 package dto
 
+import (
+	"encoding/json"
+)
+
 type TaskError struct {
 	Code       string `json:"code"`
 	Message    string `json:"message"`
@@ -8,3 +12,46 @@ type TaskError struct {
 	LocalError bool   `json:"-"`
 	Error      error  `json:"-"`
 }
+
+type TaskData interface {
+	SunoDataResponse | []SunoDataResponse | string | any
+}
+
+const TaskSuccessCode = "success"
+
+type TaskResponse[T TaskData] struct {
+	Code    string `json:"code"`
+	Message string `json:"message"`
+	Data    T      `json:"data"`
+}
+
+func (t *TaskResponse[T]) IsSuccess() bool {
+	return t.Code == TaskSuccessCode
+}
+
+type TaskDto struct {
+	ID         int64           `json:"id"`
+	CreatedAt  int64           `json:"created_at"`
+	UpdatedAt  int64           `json:"updated_at"`
+	TaskID     string          `json:"task_id"`
+	Platform   string          `json:"platform"`
+	UserId     int             `json:"user_id"`
+	Group      string          `json:"group"`
+	ChannelId  int             `json:"channel_id"`
+	Quota      int             `json:"quota"`
+	Action     string          `json:"action"`
+	Status     string          `json:"status"`
+	FailReason string          `json:"fail_reason"`
+	ResultURL  string          `json:"result_url,omitempty"` // 任务结果 URL(视频地址等)
+	SubmitTime int64           `json:"submit_time"`
+	StartTime  int64           `json:"start_time"`
+	FinishTime int64           `json:"finish_time"`
+	Progress   string          `json:"progress"`
+	Properties any             `json:"properties"`
+	Username   string          `json:"username,omitempty"`
+	Data       json.RawMessage `json:"data"`
+}
+
+type FetchReq struct {
+	IDs []string `json:"ids"`
+}

+ 1 - 2
logger/logger.go

@@ -2,7 +2,6 @@ package logger
 
 import (
 	"context"
-	"encoding/json"
 	"fmt"
 	"io"
 	"log"
@@ -151,7 +150,7 @@ func FormatQuota(quota int) string {
 
 // LogJson 仅供测试使用 only for test
 func LogJson(ctx context.Context, msg string, obj any) {
-	jsonStr, err := json.Marshal(obj)
+	jsonStr, err := common.Marshal(obj)
 	if err != nil {
 		LogError(ctx, fmt.Sprintf("json marshal failed: %s", err.Error()))
 		return

+ 10 - 0
main.go

@@ -19,6 +19,7 @@ import (
 	"github.com/QuantumNous/new-api/middleware"
 	"github.com/QuantumNous/new-api/model"
 	"github.com/QuantumNous/new-api/oauth"
+	"github.com/QuantumNous/new-api/relay"
 	"github.com/QuantumNous/new-api/router"
 	"github.com/QuantumNous/new-api/service"
 	_ "github.com/QuantumNous/new-api/setting/performance_setting"
@@ -111,6 +112,15 @@ func main() {
 	// Subscription quota reset task (daily/weekly/monthly/custom)
 	service.StartSubscriptionQuotaResetTask()
 
+	// Wire task polling adaptor factory (breaks service -> relay import cycle)
+	service.GetTaskAdaptorFunc = func(platform constant.TaskPlatform) service.TaskPollingAdaptor {
+		a := relay.GetTaskAdaptor(platform)
+		if a == nil {
+			return nil
+		}
+		return a
+	}
+
 	if common.IsMasterNode && constant.UpdateTask {
 		gopool.Go(func() {
 			controller.UpdateMidjourneyTaskBulk()

+ 18 - 0
middleware/auth.go

@@ -170,6 +170,24 @@ func WssAuth(c *gin.Context) {
 
 }
 
+// TokenOrUserAuth allows either session-based user auth or API token auth.
+// Used for endpoints that need to be accessible from both the dashboard and API clients.
+func TokenOrUserAuth() func(c *gin.Context) {
+	return func(c *gin.Context) {
+		// Try session auth first (dashboard users)
+		session := sessions.Default(c)
+		if id := session.Get("id"); id != nil {
+			if status, ok := session.Get("status").(int); ok && status == common.UserStatusEnabled {
+				c.Set("id", id)
+				c.Next()
+				return
+			}
+		}
+		// Fall back to token auth (API clients)
+		TokenAuth()(c)
+	}
+}
+
 // TokenAuthReadOnly 宽松版本的令牌认证中间件,用于只读查询接口。
 // 只验证令牌 key 是否存在,不检查令牌状态、过期时间和额度。
 // 即使令牌已过期、已耗尽或已禁用,也允许访问。

+ 43 - 0
model/log.go

@@ -199,6 +199,49 @@ func RecordConsumeLog(c *gin.Context, userId int, params RecordConsumeLogParams)
 	}
 }
 
+type RecordTaskBillingLogParams struct {
+	UserId    int
+	LogType   int
+	Content   string
+	ChannelId int
+	ModelName string
+	Quota     int
+	TokenId   int
+	Group     string
+	Other     map[string]interface{}
+}
+
+func RecordTaskBillingLog(params RecordTaskBillingLogParams) {
+	if params.LogType == LogTypeConsume && !common.LogConsumeEnabled {
+		return
+	}
+	username, _ := GetUsernameById(params.UserId, false)
+	tokenName := ""
+	if params.TokenId > 0 {
+		if token, err := GetTokenById(params.TokenId); err == nil {
+			tokenName = token.Name
+		}
+	}
+	log := &Log{
+		UserId:    params.UserId,
+		Username:  username,
+		CreatedAt: common.GetTimestamp(),
+		Type:      params.LogType,
+		Content:   params.Content,
+		TokenName: tokenName,
+		ModelName: params.ModelName,
+		Quota:     params.Quota,
+		ChannelId: params.ChannelId,
+		TokenId:   params.TokenId,
+		Group:     params.Group,
+		Other:     common.MapToJsonStr(params.Other),
+	}
+	err := LOG_DB.Create(log).Error
+	if err != nil {
+		common.SysLog("failed to record task billing log: " + err.Error())
+	}
+}
+
 func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, startIdx int, num int, channel int, group string, requestId string) (logs []*Log, total int64, err error) {
 	var tx *gorm.DB
 	if logType == LogTypeUnknown {

+ 13 - 0
model/midjourney.go

@@ -157,6 +157,19 @@ func (midjourney *Midjourney) Update() error {
 	return err
 }
 
+// UpdateWithStatus performs a conditional UPDATE guarded by fromStatus (CAS).
+// Returns (true, nil) if this caller won the update, (false, nil) if
+// another process already moved the task out of fromStatus.
+// UpdateWithStatus performs a conditional UPDATE guarded by fromStatus (CAS).
+// Uses Model().Select("*").Updates() to avoid GORM Save()'s INSERT fallback.
+func (midjourney *Midjourney) UpdateWithStatus(fromStatus string) (bool, error) {
+	result := DB.Model(midjourney).Where("status = ?", fromStatus).Select("*").Updates(midjourney)
+	if result.Error != nil {
+		return false, result.Error
+	}
+	return result.RowsAffected > 0, nil
+}
+
 func MjBulkUpdate(mjIds []string, params map[string]any) error {
 	return DB.Model(&Midjourney{}).
 		Where("mj_id in (?)", mjIds).

+ 105 - 60
model/task.go

@@ -1,10 +1,12 @@
 package model
 
 import (
+	"bytes"
 	"database/sql/driver"
 	"encoding/json"
 	"time"
 
+	"github.com/QuantumNous/new-api/common"
 	"github.com/QuantumNous/new-api/constant"
 	"github.com/QuantumNous/new-api/dto"
 	commonRelay "github.com/QuantumNous/new-api/relay/common"
@@ -64,13 +66,12 @@ type Task struct {
 }
 
 func (t *Task) SetData(data any) {
-	b, _ := json.Marshal(data)
+	b, _ := common.Marshal(data)
 	t.Data = json.RawMessage(b)
 }
 
 func (t *Task) GetData(v any) error {
-	err := json.Unmarshal(t.Data, &v)
-	return err
+	return common.Unmarshal(t.Data, &v)
 }
 
 type Properties struct {
@@ -85,18 +86,59 @@ func (m *Properties) Scan(val interface{}) error {
 		*m = Properties{}
 		return nil
 	}
-	return json.Unmarshal(bytesValue, m)
+	return common.Unmarshal(bytesValue, m)
 }
 
 func (m Properties) Value() (driver.Value, error) {
 	if m == (Properties{}) {
 		return nil, nil
 	}
-	return json.Marshal(m)
+	return common.Marshal(m)
 }
 
 type TaskPrivateData struct {
-	Key string `json:"key,omitempty"`
+	Key            string `json:"key,omitempty"`
+	UpstreamTaskID string `json:"upstream_task_id,omitempty"` // 上游真实 task ID
+	ResultURL      string `json:"result_url,omitempty"`       // 任务成功后的结果 URL(视频地址等)
+	// 计费上下文:用于异步退款/差额结算(轮询阶段读取)
+	BillingSource  string              `json:"billing_source,omitempty"`  // "wallet" 或 "subscription"
+	SubscriptionId int                 `json:"subscription_id,omitempty"` // 订阅 ID,用于订阅退款
+	TokenId        int                 `json:"token_id,omitempty"`        // 令牌 ID,用于令牌额度退款
+	BillingContext *TaskBillingContext `json:"billing_context,omitempty"` // 计费参数快照(用于轮询阶段重新计算)
+}
+
+// TaskBillingContext 记录任务提交时的计费参数,以便轮询阶段可以重新计算额度。
+type TaskBillingContext struct {
+	ModelPrice      float64            `json:"model_price,omitempty"`       // 模型单价
+	GroupRatio      float64            `json:"group_ratio,omitempty"`       // 分组倍率
+	ModelRatio      float64            `json:"model_ratio,omitempty"`       // 模型倍率
+	OtherRatios     map[string]float64 `json:"other_ratios,omitempty"`      // 附加倍率(时长、分辨率等)
+	OriginModelName string             `json:"origin_model_name,omitempty"` // 模型名称,必须为OriginModelName
+	PerCallBilling  bool               `json:"per_call_billing,omitempty"`  // 按次计费:跳过轮询阶段的差额结算
+}
+
+// GetUpstreamTaskID 获取上游真实 task ID(用于与 provider 通信)
+// 旧数据没有 UpstreamTaskID 时,TaskID 本身就是上游 ID
+func (t *Task) GetUpstreamTaskID() string {
+	if t.PrivateData.UpstreamTaskID != "" {
+		return t.PrivateData.UpstreamTaskID
+	}
+	return t.TaskID
+}
+
+// GetResultURL 获取任务结果 URL(视频地址等)
+// 新数据存在 PrivateData.ResultURL 中;旧数据回退到 FailReason(历史兼容)
+func (t *Task) GetResultURL() string {
+	if t.PrivateData.ResultURL != "" {
+		return t.PrivateData.ResultURL
+	}
+	return t.FailReason
+}
+
+// GenerateTaskID 生成对外暴露的 task_xxxx 格式 ID
+func GenerateTaskID() string {
+	key, _ := common.GenerateRandomCharsKey(32)
+	return "task_" + key
 }
 
 func (p *TaskPrivateData) Scan(val interface{}) error {
@@ -104,14 +146,14 @@ func (p *TaskPrivateData) Scan(val interface{}) error {
 	if len(bytesValue) == 0 {
 		return nil
 	}
-	return json.Unmarshal(bytesValue, p)
+	return common.Unmarshal(bytesValue, p)
 }
 
 func (p TaskPrivateData) Value() (driver.Value, error) {
 	if (p == TaskPrivateData{}) {
 		return nil, nil
 	}
-	return json.Marshal(p)
+	return common.Marshal(p)
 }
 
 // SyncTaskQueryParams 用于包含所有搜索条件的结构体,可以根据需求添加更多字段
@@ -142,7 +184,16 @@ func InitTask(platform constant.TaskPlatform, relayInfo *commonRelay.RelayInfo)
 		}
 	}
 
+	// 使用预生成的公开 ID(如果有),否则新生成
+	taskID := ""
+	if relayInfo.TaskRelayInfo != nil && relayInfo.TaskRelayInfo.PublicTaskID != "" {
+		taskID = relayInfo.TaskRelayInfo.PublicTaskID
+	} else {
+		taskID = GenerateTaskID()
+	}
+
 	t := &Task{
+		TaskID:      taskID,
 		UserId:      relayInfo.UserId,
 		Group:       relayInfo.UsingGroup,
 		SubmitTime:  time.Now().Unix(),
@@ -291,38 +342,63 @@ func GetByTaskIds(userId int, taskIds []any) ([]*Task, error) {
 	return task, nil
 }
 
-func TaskUpdateProgress(id int64, progress string) error {
-	return DB.Model(&Task{}).Where("id = ?", id).Update("progress", progress).Error
-}
-
 func (Task *Task) Insert() error {
 	var err error
 	err = DB.Create(Task).Error
 	return err
 }
 
+type taskSnapshot struct {
+	Status     TaskStatus
+	Progress   string
+	StartTime  int64
+	FinishTime int64
+	FailReason string
+	ResultURL  string
+	Data       json.RawMessage
+}
+
+func (s taskSnapshot) Equal(other taskSnapshot) bool {
+	return s.Status == other.Status &&
+		s.Progress == other.Progress &&
+		s.StartTime == other.StartTime &&
+		s.FinishTime == other.FinishTime &&
+		s.FailReason == other.FailReason &&
+		s.ResultURL == other.ResultURL &&
+		bytes.Equal(s.Data, other.Data)
+}
+
+func (t *Task) Snapshot() taskSnapshot {
+	return taskSnapshot{
+		Status:     t.Status,
+		Progress:   t.Progress,
+		StartTime:  t.StartTime,
+		FinishTime: t.FinishTime,
+		FailReason: t.FailReason,
+		ResultURL:  t.PrivateData.ResultURL,
+		Data:       t.Data,
+	}
+}
+
 func (Task *Task) Update() error {
 	var err error
 	err = DB.Save(Task).Error
 	return err
 }
 
-func TaskBulkUpdate(TaskIds []string, params map[string]any) error {
-	if len(TaskIds) == 0 {
-		return nil
-	}
-	return DB.Model(&Task{}).
-		Where("task_id in (?)", TaskIds).
-		Updates(params).Error
-}
-
-func TaskBulkUpdateByTaskIds(taskIDs []int64, params map[string]any) error {
-	if len(taskIDs) == 0 {
-		return nil
-	}
-	return DB.Model(&Task{}).
-		Where("id in (?)", taskIDs).
-		Updates(params).Error
+// UpdateWithStatus performs a conditional UPDATE guarded by fromStatus (CAS).
+// Returns (true, nil) if this caller won the update, (false, nil) if
+// another process already moved the task out of fromStatus.
+//
+// Uses Model().Select("*").Updates() instead of Save() because GORM's Save
+// falls back to INSERT ON CONFLICT when the WHERE-guarded UPDATE matches
+// zero rows, which silently bypasses the CAS guard.
+func (t *Task) UpdateWithStatus(fromStatus TaskStatus) (bool, error) {
+	result := DB.Model(t).Where("status = ?", fromStatus).Select("*").Updates(t)
+	if result.Error != nil {
+		return false, result.Error
+	}
+	return result.RowsAffected > 0, nil
 }
 
 func TaskBulkUpdateByID(ids []int64, params map[string]any) error {
@@ -339,37 +415,6 @@ type TaskQuotaUsage struct {
 	Count float64 `json:"count"`
 }
 
-func SumUsedTaskQuota(queryParams SyncTaskQueryParams) (stat []TaskQuotaUsage, err error) {
-	query := DB.Model(Task{})
-	// 添加过滤条件
-	if queryParams.ChannelID != "" {
-		query = query.Where("channel_id = ?", queryParams.ChannelID)
-	}
-	if queryParams.UserID != "" {
-		query = query.Where("user_id = ?", queryParams.UserID)
-	}
-	if len(queryParams.UserIDs) != 0 {
-		query = query.Where("user_id in (?)", queryParams.UserIDs)
-	}
-	if queryParams.TaskID != "" {
-		query = query.Where("task_id = ?", queryParams.TaskID)
-	}
-	if queryParams.Action != "" {
-		query = query.Where("action = ?", queryParams.Action)
-	}
-	if queryParams.Status != "" {
-		query = query.Where("status = ?", queryParams.Status)
-	}
-	if queryParams.StartTimestamp != 0 {
-		query = query.Where("submit_time >= ?", queryParams.StartTimestamp)
-	}
-	if queryParams.EndTimestamp != 0 {
-		query = query.Where("submit_time <= ?", queryParams.EndTimestamp)
-	}
-	err = query.Select("mode, sum(quota) as count").Group("mode").Find(&stat).Error
-	return stat, err
-}
-
 // TaskCountAllTasks returns total tasks that match the given query params (admin usage)
 func TaskCountAllTasks(queryParams SyncTaskQueryParams) int64 {
 	var total int64
@@ -438,6 +483,6 @@ func (t *Task) ToOpenAIVideo() *dto.OpenAIVideo {
 	openAIVideo.SetProgressStr(t.Progress)
 	openAIVideo.CreatedAt = t.CreatedAt
 	openAIVideo.CompletedAt = t.UpdatedAt
-	openAIVideo.SetMetadata("url", t.FailReason)
+	openAIVideo.SetMetadata("url", t.GetResultURL())
 	return openAIVideo
 }

+ 217 - 0
model/task_cas_test.go

@@ -0,0 +1,217 @@
+package model
+
+import (
+	"encoding/json"
+	"os"
+	"sync"
+	"testing"
+	"time"
+
+	"github.com/QuantumNous/new-api/common"
+	"github.com/glebarez/sqlite"
+	"github.com/stretchr/testify/assert"
+	"github.com/stretchr/testify/require"
+	"gorm.io/gorm"
+)
+
+func TestMain(m *testing.M) {
+	db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
+	if err != nil {
+		panic("failed to open test db: " + err.Error())
+	}
+	DB = db
+	LOG_DB = db
+
+	common.UsingSQLite = true
+	common.RedisEnabled = false
+	common.BatchUpdateEnabled = false
+	common.LogConsumeEnabled = true
+
+	sqlDB, err := db.DB()
+	if err != nil {
+		panic("failed to get sql.DB: " + err.Error())
+	}
+	sqlDB.SetMaxOpenConns(1)
+
+	if err := db.AutoMigrate(&Task{}, &User{}, &Token{}, &Log{}, &Channel{}); err != nil {
+		panic("failed to migrate: " + err.Error())
+	}
+
+	os.Exit(m.Run())
+}
+
+func truncateTables(t *testing.T) {
+	t.Helper()
+	t.Cleanup(func() {
+		DB.Exec("DELETE FROM tasks")
+		DB.Exec("DELETE FROM users")
+		DB.Exec("DELETE FROM tokens")
+		DB.Exec("DELETE FROM logs")
+		DB.Exec("DELETE FROM channels")
+	})
+}
+
+func insertTask(t *testing.T, task *Task) {
+	t.Helper()
+	task.CreatedAt = time.Now().Unix()
+	task.UpdatedAt = time.Now().Unix()
+	require.NoError(t, DB.Create(task).Error)
+}
+
+// ---------------------------------------------------------------------------
+// Snapshot / Equal — pure logic tests (no DB)
+// ---------------------------------------------------------------------------
+
+func TestSnapshotEqual_Same(t *testing.T) {
+	s := taskSnapshot{
+		Status:     TaskStatusInProgress,
+		Progress:   "50%",
+		StartTime:  1000,
+		FinishTime: 0,
+		FailReason: "",
+		ResultURL:  "",
+		Data:       json.RawMessage(`{"key":"value"}`),
+	}
+	assert.True(t, s.Equal(s))
+}
+
+func TestSnapshotEqual_DifferentStatus(t *testing.T) {
+	a := taskSnapshot{Status: TaskStatusInProgress, Data: json.RawMessage(`{}`)}
+	b := taskSnapshot{Status: TaskStatusSuccess, Data: json.RawMessage(`{}`)}
+	assert.False(t, a.Equal(b))
+}
+
+func TestSnapshotEqual_DifferentProgress(t *testing.T) {
+	a := taskSnapshot{Status: TaskStatusInProgress, Progress: "30%", Data: json.RawMessage(`{}`)}
+	b := taskSnapshot{Status: TaskStatusInProgress, Progress: "60%", Data: json.RawMessage(`{}`)}
+	assert.False(t, a.Equal(b))
+}
+
+func TestSnapshotEqual_DifferentData(t *testing.T) {
+	a := taskSnapshot{Status: TaskStatusInProgress, Data: json.RawMessage(`{"a":1}`)}
+	b := taskSnapshot{Status: TaskStatusInProgress, Data: json.RawMessage(`{"a":2}`)}
+	assert.False(t, a.Equal(b))
+}
+
+func TestSnapshotEqual_NilVsEmpty(t *testing.T) {
+	a := taskSnapshot{Status: TaskStatusInProgress, Data: nil}
+	b := taskSnapshot{Status: TaskStatusInProgress, Data: json.RawMessage{}}
+	// bytes.Equal(nil, []byte{}) == true
+	assert.True(t, a.Equal(b))
+}
+
+func TestSnapshot_Roundtrip(t *testing.T) {
+	task := &Task{
+		Status:     TaskStatusInProgress,
+		Progress:   "42%",
+		StartTime:  1234,
+		FinishTime: 5678,
+		FailReason: "timeout",
+		PrivateData: TaskPrivateData{
+			ResultURL: "https://example.com/result.mp4",
+		},
+		Data: json.RawMessage(`{"model":"test-model"}`),
+	}
+	snap := task.Snapshot()
+	assert.Equal(t, task.Status, snap.Status)
+	assert.Equal(t, task.Progress, snap.Progress)
+	assert.Equal(t, task.StartTime, snap.StartTime)
+	assert.Equal(t, task.FinishTime, snap.FinishTime)
+	assert.Equal(t, task.FailReason, snap.FailReason)
+	assert.Equal(t, task.PrivateData.ResultURL, snap.ResultURL)
+	assert.JSONEq(t, string(task.Data), string(snap.Data))
+}
+
+// ---------------------------------------------------------------------------
+// UpdateWithStatus CAS — DB integration tests
+// ---------------------------------------------------------------------------
+
+func TestUpdateWithStatus_Win(t *testing.T) {
+	truncateTables(t)
+
+	task := &Task{
+		TaskID:   "task_cas_win",
+		Status:   TaskStatusInProgress,
+		Progress: "50%",
+		Data:     json.RawMessage(`{}`),
+	}
+	insertTask(t, task)
+
+	task.Status = TaskStatusSuccess
+	task.Progress = "100%"
+	won, err := task.UpdateWithStatus(TaskStatusInProgress)
+	require.NoError(t, err)
+	assert.True(t, won)
+
+	var reloaded Task
+	require.NoError(t, DB.First(&reloaded, task.ID).Error)
+	assert.EqualValues(t, TaskStatusSuccess, reloaded.Status)
+	assert.Equal(t, "100%", reloaded.Progress)
+}
+
+func TestUpdateWithStatus_Lose(t *testing.T) {
+	truncateTables(t)
+
+	task := &Task{
+		TaskID: "task_cas_lose",
+		Status: TaskStatusFailure,
+		Data:   json.RawMessage(`{}`),
+	}
+	insertTask(t, task)
+
+	task.Status = TaskStatusSuccess
+	won, err := task.UpdateWithStatus(TaskStatusInProgress) // wrong fromStatus
+	require.NoError(t, err)
+	assert.False(t, won)
+
+	var reloaded Task
+	require.NoError(t, DB.First(&reloaded, task.ID).Error)
+	assert.EqualValues(t, TaskStatusFailure, reloaded.Status) // unchanged
+}
+
+func TestUpdateWithStatus_ConcurrentWinner(t *testing.T) {
+	truncateTables(t)
+
+	task := &Task{
+		TaskID: "task_cas_race",
+		Status: TaskStatusInProgress,
+		Quota:  1000,
+		Data:   json.RawMessage(`{}`),
+	}
+	insertTask(t, task)
+
+	const goroutines = 5
+	wins := make([]bool, goroutines)
+	var wg sync.WaitGroup
+	wg.Add(goroutines)
+
+	for i := 0; i < goroutines; i++ {
+		go func(idx int) {
+			defer wg.Done()
+			t := &Task{}
+			*t = Task{
+				ID:       task.ID,
+				TaskID:   task.TaskID,
+				Status:   TaskStatusSuccess,
+				Progress: "100%",
+				Quota:    task.Quota,
+				Data:     json.RawMessage(`{}`),
+			}
+			t.CreatedAt = task.CreatedAt
+			t.UpdatedAt = time.Now().Unix()
+			won, err := t.UpdateWithStatus(TaskStatusInProgress)
+			if err == nil {
+				wins[idx] = won
+			}
+		}(i)
+	}
+	wg.Wait()
+
+	winCount := 0
+	for _, w := range wins {
+		if w {
+			winCount++
+		}
+	}
+	assert.Equal(t, 1, winCount, "exactly one goroutine should win the CAS")
+}

+ 3 - 3
model/token.go

@@ -360,7 +360,7 @@ func DeleteTokenById(id int, userId int) (err error) {
 	return token.Delete()
 }
 
-func IncreaseTokenQuota(id int, key string, quota int) (err error) {
+func IncreaseTokenQuota(tokenId int, key string, quota int) (err error) {
 	if quota < 0 {
 		return errors.New("quota 不能为负数!")
 	}
@@ -373,10 +373,10 @@ func IncreaseTokenQuota(id int, key string, quota int) (err error) {
 		})
 	}
 	if common.BatchUpdateEnabled {
-		addNewRecord(BatchUpdateTypeTokenQuota, id, quota)
+		addNewRecord(BatchUpdateTypeTokenQuota, tokenId, quota)
 		return nil
 	}
-	return increaseTokenQuota(id, quota)
+	return increaseTokenQuota(tokenId, quota)
 }
 
 func increaseTokenQuota(id int, quota int) (err error) {

+ 28 - 2
relay/channel/adapter.go

@@ -36,6 +36,32 @@ type TaskAdaptor interface {
 
 	ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) *dto.TaskError
 
+	// ── Billing ──────────────────────────────────────────────────────
+
+	// EstimateBilling returns OtherRatios for pre-charge based on user request.
+	// Called after ValidateRequestAndSetAction, before price calculation.
+	// Adaptors should extract duration, resolution, etc. from the parsed request
+	// and return them as ratio multipliers (e.g. {"seconds": 5, "size": 1.666}).
+	// Return nil to use the base model price without extra ratios.
+	EstimateBilling(c *gin.Context, info *relaycommon.RelayInfo) map[string]float64
+
+	// AdjustBillingOnSubmit returns adjusted OtherRatios from the upstream
+	// submit response. Called after a successful DoResponse.
+	// If the upstream returned actual parameters that differ from the estimate
+	// (e.g. actual seconds), return updated ratios so the caller can recalculate
+	// the quota and settle the delta with the pre-charge.
+	// Return nil if no adjustment is needed.
+	AdjustBillingOnSubmit(info *relaycommon.RelayInfo, taskData []byte) map[string]float64
+
+	// AdjustBillingOnComplete returns the actual quota when a task reaches a
+	// terminal state (success/failure) during polling.
+	// Called by the polling loop after ParseTaskResult.
+	// Return a positive value to trigger delta settlement (supplement / refund).
+	// Return 0 to keep the pre-charged amount unchanged.
+	AdjustBillingOnComplete(task *model.Task, taskResult *relaycommon.TaskInfo) int
+
+	// ── Request / Response ───────────────────────────────────────────
+
 	BuildRequestURL(info *relaycommon.RelayInfo) (string, error)
 	BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error
 	BuildRequestBody(c *gin.Context, info *relaycommon.RelayInfo) (io.Reader, error)
@@ -46,9 +72,9 @@ type TaskAdaptor interface {
 	GetModelList() []string
 	GetChannelName() string
 
-	// FetchTask
-	FetchTask(baseUrl, key string, body map[string]any, proxy string) (*http.Response, error)
+	// ── Polling ──────────────────────────────────────────────────────
 
+	FetchTask(baseUrl, key string, body map[string]any, proxy string) (*http.Response, error)
 	ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error)
 }
 

+ 44 - 24
relay/channel/task/ali/adaptor.go

@@ -13,6 +13,7 @@ import (
 	"github.com/QuantumNous/new-api/logger"
 	"github.com/QuantumNous/new-api/model"
 	"github.com/QuantumNous/new-api/relay/channel"
+	"github.com/QuantumNous/new-api/relay/channel/task/taskcommon"
 	relaycommon "github.com/QuantumNous/new-api/relay/common"
 	"github.com/QuantumNous/new-api/service"
 	"github.com/samber/lo"
@@ -108,10 +109,10 @@ type AliMetadata struct {
 // ============================
 
 type TaskAdaptor struct {
+	taskcommon.BaseBilling
 	ChannelType int
 	apiKey      string
 	baseURL     string
-	aliReq      *AliVideoRequest
 }
 
 func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) {
@@ -121,17 +122,7 @@ func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) {
 }
 
 func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) {
-	// 阿里通义万相支持 JSON 格式,不使用 multipart
-	var taskReq relaycommon.TaskSubmitReq
-	if err := common.UnmarshalBodyReusable(c, &taskReq); err != nil {
-		return service.TaskErrorWrapper(err, "unmarshal_task_request_failed", http.StatusBadRequest)
-	}
-	aliReq, err := a.convertToAliRequest(info, taskReq)
-	if err != nil {
-		return service.TaskErrorWrapper(err, "convert_to_ali_request_failed", http.StatusInternalServerError)
-	}
-	a.aliReq = aliReq
-	logger.LogJson(c, "ali video request body", aliReq)
+	// ValidateMultipartDirect 负责解析并将原始 TaskSubmitReq 存入 context
 	return relaycommon.ValidateMultipartDirect(c, info)
 }
 
@@ -148,11 +139,21 @@ func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info
 }
 
 func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayInfo) (io.Reader, error) {
-	bodyBytes, err := common.Marshal(a.aliReq)
+	taskReq, err := relaycommon.GetTaskRequest(c)
 	if err != nil {
-		return nil, errors.Wrap(err, "marshal_ali_request_failed")
+		return nil, errors.Wrap(err, "get_task_request_failed")
 	}
 
+	aliReq, err := a.convertToAliRequest(info, taskReq)
+	if err != nil {
+		return nil, errors.Wrap(err, "convert_to_ali_request_failed")
+	}
+	logger.LogJson(c, "ali video request body", aliReq)
+
+	bodyBytes, err := common.Marshal(aliReq)
+	if err != nil {
+		return nil, errors.Wrap(err, "marshal_ali_request_failed")
+	}
 	return bytes.NewReader(bodyBytes), nil
 }
 
@@ -252,8 +253,12 @@ func ProcessAliOtherRatios(aliReq *AliVideoRequest) (map[string]float64, error)
 }
 
 func (a *TaskAdaptor) convertToAliRequest(info *relaycommon.RelayInfo, req relaycommon.TaskSubmitReq) (*AliVideoRequest, error) {
+	upstreamModel := req.Model
+	if info.IsModelMapped {
+		upstreamModel = info.UpstreamModelName
+	}
 	aliReq := &AliVideoRequest{
-		Model: req.Model,
+		Model: upstreamModel,
 		Input: AliVideoInput{
 			Prompt: req.Prompt,
 			ImgURL: req.InputReference,
@@ -331,23 +336,37 @@ func (a *TaskAdaptor) convertToAliRequest(info *relaycommon.RelayInfo, req relay
 		}
 	}
 
-	if aliReq.Model != req.Model {
+	if aliReq.Model != upstreamModel {
 		return nil, errors.New("can't change model with metadata")
 	}
 
-	info.PriceData.OtherRatios = map[string]float64{
-		"seconds": float64(aliReq.Parameters.Duration),
+	return aliReq, nil
+}
+
+// EstimateBilling 根据用户请求参数计算 OtherRatios(时长、分辨率等)。
+// 在 ValidateRequestAndSetAction 之后、价格计算之前调用。
+func (a *TaskAdaptor) EstimateBilling(c *gin.Context, info *relaycommon.RelayInfo) map[string]float64 {
+	taskReq, err := relaycommon.GetTaskRequest(c)
+	if err != nil {
+		return nil
+	}
+
+	aliReq, err := a.convertToAliRequest(info, taskReq)
+	if err != nil {
+		return nil
 	}
 
+	otherRatios := map[string]float64{
+		"seconds": float64(aliReq.Parameters.Duration),
+	}
 	ratios, err := ProcessAliOtherRatios(aliReq)
 	if err != nil {
-		return nil, err
+		return otherRatios
 	}
-	for s, f := range ratios {
-		info.PriceData.OtherRatios[s] = f
+	for k, v := range ratios {
+		otherRatios[k] = v
 	}
-
-	return aliReq, nil
+	return otherRatios
 }
 
 // DoRequest delegates to common helper
@@ -384,7 +403,8 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
 
 	// 转换为 OpenAI 格式响应
 	openAIResp := dto.NewOpenAIVideo()
-	openAIResp.ID = aliResp.Output.TaskID
+	openAIResp.ID = info.PublicTaskID
+	openAIResp.TaskID = info.PublicTaskID
 	openAIResp.Model = c.GetString("model")
 	if openAIResp.Model == "" && info != nil {
 		openAIResp.Model = info.OriginModelName

+ 15 - 16
relay/channel/task/doubao/adaptor.go

@@ -2,7 +2,6 @@ package doubao
 
 import (
 	"bytes"
-	"encoding/json"
 	"fmt"
 	"io"
 	"net/http"
@@ -14,6 +13,7 @@ import (
 	"github.com/QuantumNous/new-api/dto"
 	"github.com/QuantumNous/new-api/model"
 	"github.com/QuantumNous/new-api/relay/channel"
+	taskcommon "github.com/QuantumNous/new-api/relay/channel/task/taskcommon"
 	relaycommon "github.com/QuantumNous/new-api/relay/common"
 	"github.com/QuantumNous/new-api/service"
 
@@ -89,6 +89,7 @@ type responseTask struct {
 // ============================
 
 type TaskAdaptor struct {
+	taskcommon.BaseBilling
 	ChannelType int
 	apiKey      string
 	baseURL     string
@@ -130,8 +131,12 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn
 	if err != nil {
 		return nil, errors.Wrap(err, "convert request payload failed")
 	}
-	info.UpstreamModelName = body.Model
-	data, err := json.Marshal(body)
+	if info.IsModelMapped {
+		body.Model = info.UpstreamModelName
+	} else {
+		info.UpstreamModelName = body.Model
+	}
+	data, err := common.Marshal(body)
 	if err != nil {
 		return nil, err
 	}
@@ -154,7 +159,7 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
 
 	// Parse Doubao response
 	var dResp responsePayload
-	if err := json.Unmarshal(responseBody, &dResp); err != nil {
+	if err := common.Unmarshal(responseBody, &dResp); err != nil {
 		taskErr = service.TaskErrorWrapper(errors.Wrapf(err, "body: %s", responseBody), "unmarshal_response_body_failed", http.StatusInternalServerError)
 		return
 	}
@@ -165,8 +170,8 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
 	}
 
 	ov := dto.NewOpenAIVideo()
-	ov.ID = dResp.ID
-	ov.TaskID = dResp.ID
+	ov.ID = info.PublicTaskID
+	ov.TaskID = info.PublicTaskID
 	ov.CreatedAt = time.Now().Unix()
 	ov.Model = info.OriginModelName
 
@@ -234,12 +239,7 @@ func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*
 	}
 
 	metadata := req.Metadata
-	medaBytes, err := json.Marshal(metadata)
-	if err != nil {
-		return nil, errors.Wrap(err, "metadata marshal metadata failed")
-	}
-	err = json.Unmarshal(medaBytes, &r)
-	if err != nil {
+	if err := taskcommon.UnmarshalMetadata(metadata, &r); err != nil {
 		return nil, errors.Wrap(err, "unmarshal metadata failed")
 	}
 
@@ -248,7 +248,7 @@ func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*
 
 func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) {
 	resTask := responseTask{}
-	if err := json.Unmarshal(respBody, &resTask); err != nil {
+	if err := common.Unmarshal(respBody, &resTask); err != nil {
 		return nil, errors.Wrap(err, "unmarshal task result failed")
 	}
 
@@ -286,7 +286,7 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e
 
 func (a *TaskAdaptor) ConvertToOpenAIVideo(originTask *model.Task) ([]byte, error) {
 	var dResp responseTask
-	if err := json.Unmarshal(originTask.Data, &dResp); err != nil {
+	if err := common.Unmarshal(originTask.Data, &dResp); err != nil {
 		return nil, errors.Wrap(err, "unmarshal doubao task data failed")
 	}
 
@@ -307,6 +307,5 @@ func (a *TaskAdaptor) ConvertToOpenAIVideo(originTask *model.Task) ([]byte, erro
 		}
 	}
 
-	jsonData, _ := common.Marshal(openAIVideo)
-	return jsonData, nil
+	return common.Marshal(openAIVideo)
 }

+ 17 - 33
relay/channel/task/gemini/adaptor.go

@@ -2,8 +2,6 @@ package gemini
 
 import (
 	"bytes"
-	"encoding/base64"
-	"encoding/json"
 	"fmt"
 	"io"
 	"net/http"
@@ -16,10 +14,10 @@ import (
 	"github.com/QuantumNous/new-api/dto"
 	"github.com/QuantumNous/new-api/model"
 	"github.com/QuantumNous/new-api/relay/channel"
+	taskcommon "github.com/QuantumNous/new-api/relay/channel/task/taskcommon"
 	relaycommon "github.com/QuantumNous/new-api/relay/common"
 	"github.com/QuantumNous/new-api/service"
 	"github.com/QuantumNous/new-api/setting/model_setting"
-	"github.com/QuantumNous/new-api/setting/system_setting"
 	"github.com/gin-gonic/gin"
 	"github.com/pkg/errors"
 )
@@ -87,6 +85,7 @@ type operationResponse struct {
 // ============================
 
 type TaskAdaptor struct {
+	taskcommon.BaseBilling
 	ChannelType int
 	apiKey      string
 	baseURL     string
@@ -106,7 +105,7 @@ func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycom
 
 // BuildRequestURL constructs the upstream URL.
 func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, error) {
-	modelName := info.OriginModelName
+	modelName := info.UpstreamModelName
 	version := model_setting.GetGeminiVersionSetting(modelName)
 
 	return fmt.Sprintf(
@@ -145,16 +144,11 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn
 	}
 
 	metadata := req.Metadata
-	medaBytes, err := json.Marshal(metadata)
-	if err != nil {
-		return nil, errors.Wrap(err, "metadata marshal metadata failed")
-	}
-	err = json.Unmarshal(medaBytes, &body.Parameters)
-	if err != nil {
+	if err := taskcommon.UnmarshalMetadata(metadata, &body.Parameters); err != nil {
 		return nil, errors.Wrap(err, "unmarshal metadata failed")
 	}
 
-	data, err := json.Marshal(body)
+	data, err := common.Marshal(body)
 	if err != nil {
 		return nil, err
 	}
@@ -175,16 +169,16 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
 	_ = resp.Body.Close()
 
 	var s submitResponse
-	if err := json.Unmarshal(responseBody, &s); err != nil {
+	if err := common.Unmarshal(responseBody, &s); err != nil {
 		return "", nil, service.TaskErrorWrapper(err, "unmarshal_response_failed", http.StatusInternalServerError)
 	}
 	if strings.TrimSpace(s.Name) == "" {
 		return "", nil, service.TaskErrorWrapper(fmt.Errorf("missing operation name"), "invalid_response", http.StatusInternalServerError)
 	}
-	taskID = encodeLocalTaskID(s.Name)
+	taskID = taskcommon.EncodeLocalTaskID(s.Name)
 	ov := dto.NewOpenAIVideo()
-	ov.ID = taskID
-	ov.TaskID = taskID
+	ov.ID = info.PublicTaskID
+	ov.TaskID = info.PublicTaskID
 	ov.CreatedAt = time.Now().Unix()
 	ov.Model = info.OriginModelName
 	c.JSON(http.StatusOK, ov)
@@ -206,7 +200,7 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy
 		return nil, fmt.Errorf("invalid task_id")
 	}
 
-	upstreamName, err := decodeLocalTaskID(taskID)
+	upstreamName, err := taskcommon.DecodeLocalTaskID(taskID)
 	if err != nil {
 		return nil, fmt.Errorf("decode task_id failed: %w", err)
 	}
@@ -232,7 +226,7 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy
 
 func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) {
 	var op operationResponse
-	if err := json.Unmarshal(respBody, &op); err != nil {
+	if err := common.Unmarshal(respBody, &op); err != nil {
 		return nil, fmt.Errorf("unmarshal operation response failed: %w", err)
 	}
 
@@ -254,9 +248,8 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e
 	ti.Status = model.TaskStatusSuccess
 	ti.Progress = "100%"
 
-	taskID := encodeLocalTaskID(op.Name)
-	ti.TaskID = taskID
-	ti.Url = fmt.Sprintf("%s/v1/videos/%s/content", system_setting.ServerAddress, taskID)
+	ti.TaskID = taskcommon.EncodeLocalTaskID(op.Name)
+	// Url intentionally left empty — the caller constructs the proxy URL using the public task ID
 
 	// Extract URL from generateVideoResponse if available
 	if len(op.Response.GenerateVideoResponse.GeneratedSamples) > 0 {
@@ -269,7 +262,10 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e
 }
 
 func (a *TaskAdaptor) ConvertToOpenAIVideo(task *model.Task) ([]byte, error) {
-	upstreamName, err := decodeLocalTaskID(task.TaskID)
+	// Use GetUpstreamTaskID() to get the real upstream operation name for model extraction.
+	// task.TaskID is now a public task_xxxx ID, no longer a base64-encoded upstream name.
+	upstreamTaskID := task.GetUpstreamTaskID()
+	upstreamName, err := taskcommon.DecodeLocalTaskID(upstreamTaskID)
 	if err != nil {
 		upstreamName = ""
 	}
@@ -297,18 +293,6 @@ func (a *TaskAdaptor) ConvertToOpenAIVideo(task *model.Task) ([]byte, error) {
 // helpers
 // ============================
 
-func encodeLocalTaskID(name string) string {
-	return base64.RawURLEncoding.EncodeToString([]byte(name))
-}
-
-func decodeLocalTaskID(local string) (string, error) {
-	b, err := base64.RawURLEncoding.DecodeString(local)
-	if err != nil {
-		return "", err
-	}
-	return string(b), nil
-}
-
 var modelRe = regexp.MustCompile(`models/([^/]+)/operations/`)
 
 func extractModelFromOperationName(name string) string {

+ 13 - 12
relay/channel/task/hailuo/adaptor.go

@@ -2,7 +2,6 @@ package hailuo
 
 import (
 	"bytes"
-	"encoding/json"
 	"fmt"
 	"io"
 	"net/http"
@@ -18,12 +17,14 @@ import (
 	"github.com/QuantumNous/new-api/constant"
 	"github.com/QuantumNous/new-api/dto"
 	"github.com/QuantumNous/new-api/relay/channel"
+	taskcommon "github.com/QuantumNous/new-api/relay/channel/task/taskcommon"
 	relaycommon "github.com/QuantumNous/new-api/relay/common"
 	"github.com/QuantumNous/new-api/service"
 )
 
 // https://platform.minimaxi.com/docs/api-reference/video-generation-intro
 type TaskAdaptor struct {
+	taskcommon.BaseBilling
 	ChannelType int
 	apiKey      string
 	baseURL     string
@@ -60,12 +61,12 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn
 		return nil, fmt.Errorf("invalid request type in context")
 	}
 
-	body, err := a.convertToRequestPayload(&req)
+	body, err := a.convertToRequestPayload(&req, info)
 	if err != nil {
 		return nil, errors.Wrap(err, "convert request payload failed")
 	}
 
-	data, err := json.Marshal(body)
+	data, err := common.Marshal(body)
 	if err != nil {
 		return nil, err
 	}
@@ -86,7 +87,7 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
 	_ = resp.Body.Close()
 
 	var hResp VideoResponse
-	if err := json.Unmarshal(responseBody, &hResp); err != nil {
+	if err := common.Unmarshal(responseBody, &hResp); err != nil {
 		taskErr = service.TaskErrorWrapper(errors.Wrapf(err, "body: %s", responseBody), "unmarshal_response_body_failed", http.StatusInternalServerError)
 		return
 	}
@@ -101,8 +102,8 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
 	}
 
 	ov := dto.NewOpenAIVideo()
-	ov.ID = hResp.TaskID
-	ov.TaskID = hResp.TaskID
+	ov.ID = info.PublicTaskID
+	ov.TaskID = info.PublicTaskID
 	ov.CreatedAt = time.Now().Unix()
 	ov.Model = info.OriginModelName
 
@@ -141,8 +142,8 @@ func (a *TaskAdaptor) GetChannelName() string {
 	return ChannelName
 }
 
-func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*VideoRequest, error) {
-	modelConfig := GetModelConfig(req.Model)
+func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq, info *relaycommon.RelayInfo) (*VideoRequest, error) {
+	modelConfig := GetModelConfig(info.UpstreamModelName)
 	duration := DefaultDuration
 	if req.Duration > 0 {
 		duration = req.Duration
@@ -153,7 +154,7 @@ func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*
 	}
 
 	videoRequest := &VideoRequest{
-		Model:      req.Model,
+		Model:      info.UpstreamModelName,
 		Prompt:     req.Prompt,
 		Duration:   &duration,
 		Resolution: resolution,
@@ -182,7 +183,7 @@ func (a *TaskAdaptor) parseResolutionFromSize(size string, modelConfig ModelConf
 
 func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) {
 	resTask := QueryTaskResponse{}
-	if err := json.Unmarshal(respBody, &resTask); err != nil {
+	if err := common.Unmarshal(respBody, &resTask); err != nil {
 		return nil, errors.Wrap(err, "unmarshal task result failed")
 	}
 
@@ -224,7 +225,7 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e
 
 func (a *TaskAdaptor) ConvertToOpenAIVideo(originTask *model.Task) ([]byte, error) {
 	var hailuoResp QueryTaskResponse
-	if err := json.Unmarshal(originTask.Data, &hailuoResp); err != nil {
+	if err := common.Unmarshal(originTask.Data, &hailuoResp); err != nil {
 		return nil, errors.Wrap(err, "unmarshal hailuo task data failed")
 	}
 
@@ -271,7 +272,7 @@ func (a *TaskAdaptor) buildVideoURL(_, fileID string) string {
 	}
 
 	var retrieveResp RetrieveFileResponse
-	if err := json.Unmarshal(responseBody, &retrieveResp); err != nil {
+	if err := common.Unmarshal(responseBody, &retrieveResp); err != nil {
 		return ""
 	}
 

+ 14 - 20
relay/channel/task/jimeng/adaptor.go

@@ -6,7 +6,6 @@ import (
 	"crypto/sha256"
 	"encoding/base64"
 	"encoding/hex"
-	"encoding/json"
 	"fmt"
 	"io"
 	"net/http"
@@ -25,6 +24,7 @@ import (
 	"github.com/QuantumNous/new-api/constant"
 	"github.com/QuantumNous/new-api/dto"
 	"github.com/QuantumNous/new-api/relay/channel"
+	taskcommon "github.com/QuantumNous/new-api/relay/channel/task/taskcommon"
 	relaycommon "github.com/QuantumNous/new-api/relay/common"
 	"github.com/QuantumNous/new-api/service"
 )
@@ -77,6 +77,7 @@ const (
 // ============================
 
 type TaskAdaptor struct {
+	taskcommon.BaseBilling
 	ChannelType int
 	accessKey   string
 	secretKey   string
@@ -164,11 +165,11 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn
 		}
 	}
 
-	body, err := a.convertToRequestPayload(&req)
+	body, err := a.convertToRequestPayload(&req, info)
 	if err != nil {
 		return nil, errors.Wrap(err, "convert request payload failed")
 	}
-	data, err := json.Marshal(body)
+	data, err := common.Marshal(body)
 	if err != nil {
 		return nil, err
 	}
@@ -191,7 +192,7 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
 
 	// Parse Jimeng response
 	var jResp responsePayload
-	if err := json.Unmarshal(responseBody, &jResp); err != nil {
+	if err := common.Unmarshal(responseBody, &jResp); err != nil {
 		taskErr = service.TaskErrorWrapper(errors.Wrapf(err, "body: %s", responseBody), "unmarshal_response_body_failed", http.StatusInternalServerError)
 		return
 	}
@@ -202,8 +203,8 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
 	}
 
 	ov := dto.NewOpenAIVideo()
-	ov.ID = jResp.Data.TaskID
-	ov.TaskID = jResp.Data.TaskID
+	ov.ID = info.PublicTaskID
+	ov.TaskID = info.PublicTaskID
 	ov.CreatedAt = time.Now().Unix()
 	ov.Model = info.OriginModelName
 	c.JSON(http.StatusOK, ov)
@@ -225,7 +226,7 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy
 		"req_key": "jimeng_vgfm_t2v_l20", // This is fixed value from doc: https://www.volcengine.com/docs/85621/1544774
 		"task_id": taskID,
 	}
-	payloadBytes, err := json.Marshal(payload)
+	payloadBytes, err := common.Marshal(payload)
 	if err != nil {
 		return nil, errors.Wrap(err, "marshal fetch task payload failed")
 	}
@@ -377,9 +378,9 @@ func hmacSHA256(key []byte, data []byte) []byte {
 	return h.Sum(nil)
 }
 
-func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*requestPayload, error) {
+func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq, info *relaycommon.RelayInfo) (*requestPayload, error) {
 	r := requestPayload{
-		ReqKey: req.Model,
+		ReqKey: info.UpstreamModelName,
 		Prompt: req.Prompt,
 	}
 
@@ -398,13 +399,7 @@ func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*
 			r.BinaryDataBase64 = req.Images
 		}
 	}
-	metadata := req.Metadata
-	medaBytes, err := json.Marshal(metadata)
-	if err != nil {
-		return nil, errors.Wrap(err, "metadata marshal metadata failed")
-	}
-	err = json.Unmarshal(medaBytes, &r)
-	if err != nil {
+	if err := taskcommon.UnmarshalMetadata(req.Metadata, &r); err != nil {
 		return nil, errors.Wrap(err, "unmarshal metadata failed")
 	}
 
@@ -432,7 +427,7 @@ func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*
 
 func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) {
 	resTask := responseTask{}
-	if err := json.Unmarshal(respBody, &resTask); err != nil {
+	if err := common.Unmarshal(respBody, &resTask); err != nil {
 		return nil, errors.Wrap(err, "unmarshal task result failed")
 	}
 	taskResult := relaycommon.TaskInfo{}
@@ -458,7 +453,7 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e
 
 func (a *TaskAdaptor) ConvertToOpenAIVideo(originTask *model.Task) ([]byte, error) {
 	var jimengResp responseTask
-	if err := json.Unmarshal(originTask.Data, &jimengResp); err != nil {
+	if err := common.Unmarshal(originTask.Data, &jimengResp); err != nil {
 		return nil, errors.Wrap(err, "unmarshal jimeng task data failed")
 	}
 
@@ -477,8 +472,7 @@ func (a *TaskAdaptor) ConvertToOpenAIVideo(originTask *model.Task) ([]byte, erro
 		}
 	}
 
-	jsonData, _ := common.Marshal(openAIVideo)
-	return jsonData, nil
+	return common.Marshal(openAIVideo)
 }
 
 func isNewAPIRelay(apiKey string) bool {

+ 17 - 36
relay/channel/task/kling/adaptor.go

@@ -2,7 +2,6 @@ package kling
 
 import (
 	"bytes"
-	"encoding/json"
 	"fmt"
 	"io"
 	"net/http"
@@ -21,6 +20,7 @@ import (
 	"github.com/QuantumNous/new-api/constant"
 	"github.com/QuantumNous/new-api/dto"
 	"github.com/QuantumNous/new-api/relay/channel"
+	taskcommon "github.com/QuantumNous/new-api/relay/channel/task/taskcommon"
 	relaycommon "github.com/QuantumNous/new-api/relay/common"
 	"github.com/QuantumNous/new-api/service"
 )
@@ -97,6 +97,7 @@ type responsePayload struct {
 // ============================
 
 type TaskAdaptor struct {
+	taskcommon.BaseBilling
 	ChannelType int
 	apiKey      string
 	baseURL     string
@@ -149,14 +150,14 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn
 	}
 	req := v.(relaycommon.TaskSubmitReq)
 
-	body, err := a.convertToRequestPayload(&req)
+	body, err := a.convertToRequestPayload(&req, info)
 	if err != nil {
 		return nil, err
 	}
 	if body.Image == "" && body.ImageTail == "" {
 		c.Set("action", constant.TaskActionTextGenerate)
 	}
-	data, err := json.Marshal(body)
+	data, err := common.Marshal(body)
 	if err != nil {
 		return nil, err
 	}
@@ -180,7 +181,7 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
 	}
 
 	var kResp responsePayload
-	err = json.Unmarshal(responseBody, &kResp)
+	err = common.Unmarshal(responseBody, &kResp)
 	if err != nil {
 		taskErr = service.TaskErrorWrapper(err, "unmarshal_response_failed", http.StatusInternalServerError)
 		return
@@ -190,8 +191,8 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
 		return
 	}
 	ov := dto.NewOpenAIVideo()
-	ov.ID = kResp.Data.TaskId
-	ov.TaskID = kResp.Data.TaskId
+	ov.ID = info.PublicTaskID
+	ov.TaskID = info.PublicTaskID
 	ov.CreatedAt = time.Now().Unix()
 	ov.Model = info.OriginModelName
 	c.JSON(http.StatusOK, ov)
@@ -247,15 +248,15 @@ func (a *TaskAdaptor) GetChannelName() string {
 // helpers
 // ============================
 
-func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*requestPayload, error) {
+func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq, info *relaycommon.RelayInfo) (*requestPayload, error) {
 	r := requestPayload{
 		Prompt:         req.Prompt,
 		Image:          req.Image,
-		Mode:           defaultString(req.Mode, "std"),
-		Duration:       fmt.Sprintf("%d", defaultInt(req.Duration, 5)),
+		Mode:           taskcommon.DefaultString(req.Mode, "std"),
+		Duration:       fmt.Sprintf("%d", taskcommon.DefaultInt(req.Duration, 5)),
 		AspectRatio:    a.getAspectRatio(req.Size),
-		ModelName:      req.Model,
-		Model:          req.Model, // Keep consistent with model_name, double writing improves compatibility
+		ModelName:      info.UpstreamModelName,
+		Model:          info.UpstreamModelName,
 		CfgScale:       0.5,
 		StaticMask:     "",
 		DynamicMasks:   []DynamicMask{},
@@ -265,14 +266,9 @@ func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*
 	}
 	if r.ModelName == "" {
 		r.ModelName = "kling-v1"
+		r.Model = "kling-v1"
 	}
-	metadata := req.Metadata
-	medaBytes, err := json.Marshal(metadata)
-	if err != nil {
-		return nil, errors.Wrap(err, "metadata marshal metadata failed")
-	}
-	err = json.Unmarshal(medaBytes, &r)
-	if err != nil {
+	if err := taskcommon.UnmarshalMetadata(req.Metadata, &r); err != nil {
 		return nil, errors.Wrap(err, "unmarshal metadata failed")
 	}
 	return &r, nil
@@ -291,20 +287,6 @@ func (a *TaskAdaptor) getAspectRatio(size string) string {
 	}
 }
 
-func defaultString(s, def string) string {
-	if strings.TrimSpace(s) == "" {
-		return def
-	}
-	return s
-}
-
-func defaultInt(v int, def int) int {
-	if v == 0 {
-		return def
-	}
-	return v
-}
-
 // ============================
 // JWT helpers
 // ============================
@@ -340,7 +322,7 @@ func (a *TaskAdaptor) createJWTTokenWithKey(apiKey string) (string, error) {
 func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) {
 	taskInfo := &relaycommon.TaskInfo{}
 	resPayload := responsePayload{}
-	err := json.Unmarshal(respBody, &resPayload)
+	err := common.Unmarshal(respBody, &resPayload)
 	if err != nil {
 		return nil, errors.Wrap(err, "failed to unmarshal response body")
 	}
@@ -374,7 +356,7 @@ func isNewAPIRelay(apiKey string) bool {
 
 func (a *TaskAdaptor) ConvertToOpenAIVideo(originTask *model.Task) ([]byte, error) {
 	var klingResp responsePayload
-	if err := json.Unmarshal(originTask.Data, &klingResp); err != nil {
+	if err := common.Unmarshal(originTask.Data, &klingResp); err != nil {
 		return nil, errors.Wrap(err, "unmarshal kling task data failed")
 	}
 
@@ -401,6 +383,5 @@ func (a *TaskAdaptor) ConvertToOpenAIVideo(originTask *model.Task) ([]byte, erro
 			Code:    fmt.Sprintf("%d", klingResp.Code),
 		}
 	}
-	jsonData, _ := common.Marshal(openAIVideo)
-	return jsonData, nil
+	return common.Marshal(openAIVideo)
 }

+ 116 - 15
relay/channel/task/sora/adaptor.go

@@ -1,9 +1,12 @@
 package sora
 
 import (
+	"bytes"
 	"fmt"
 	"io"
+	"mime/multipart"
 	"net/http"
+	"strconv"
 	"strings"
 
 	"github.com/QuantumNous/new-api/common"
@@ -11,12 +14,13 @@ import (
 	"github.com/QuantumNous/new-api/dto"
 	"github.com/QuantumNous/new-api/model"
 	"github.com/QuantumNous/new-api/relay/channel"
+	taskcommon "github.com/QuantumNous/new-api/relay/channel/task/taskcommon"
 	relaycommon "github.com/QuantumNous/new-api/relay/common"
 	"github.com/QuantumNous/new-api/service"
-	"github.com/QuantumNous/new-api/setting/system_setting"
 
 	"github.com/gin-gonic/gin"
 	"github.com/pkg/errors"
+	"github.com/tidwall/sjson"
 )
 
 // ============================
@@ -57,6 +61,7 @@ type responseTask struct {
 // ============================
 
 type TaskAdaptor struct {
+	taskcommon.BaseBilling
 	ChannelType int
 	apiKey      string
 	baseURL     string
@@ -69,15 +74,15 @@ func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) {
 }
 
 func validateRemixRequest(c *gin.Context) *dto.TaskError {
-	var req struct {
-		Prompt string `json:"prompt"`
-	}
+	var req relaycommon.TaskSubmitReq
 	if err := common.UnmarshalBodyReusable(c, &req); err != nil {
 		return service.TaskErrorWrapperLocal(err, "invalid_request", http.StatusBadRequest)
 	}
 	if strings.TrimSpace(req.Prompt) == "" {
 		return service.TaskErrorWrapperLocal(fmt.Errorf("field prompt is required"), "invalid_request", http.StatusBadRequest)
 	}
+	// 存储原始请求到 context,与 ValidateMultipartDirect 路径保持一致
+	c.Set("task_request", req)
 	return nil
 }
 
@@ -88,6 +93,41 @@ func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycom
 	return relaycommon.ValidateMultipartDirect(c, info)
 }
 
+// EstimateBilling 根据用户请求的 seconds 和 size 计算 OtherRatios。
+func (a *TaskAdaptor) EstimateBilling(c *gin.Context, info *relaycommon.RelayInfo) map[string]float64 {
+	// remix 路径的 OtherRatios 已在 ResolveOriginTask 中设置
+	if info.Action == constant.TaskActionRemix {
+		return nil
+	}
+
+	req, err := relaycommon.GetTaskRequest(c)
+	if err != nil {
+		return nil
+	}
+
+	seconds, _ := strconv.Atoi(req.Seconds)
+	if seconds == 0 {
+		seconds = req.Duration
+	}
+	if seconds <= 0 {
+		seconds = 4
+	}
+
+	size := req.Size
+	if size == "" {
+		size = "720x1280"
+	}
+
+	ratios := map[string]float64{
+		"seconds": float64(seconds),
+		"size":    1,
+	}
+	if size == "1792x1024" || size == "1024x1792" {
+		ratios["size"] = 1.666667
+	}
+	return ratios
+}
+
 func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, error) {
 	if info.Action == constant.TaskActionRemix {
 		return fmt.Sprintf("%s/v1/videos/%s/remix", a.baseURL, info.OriginTaskID), nil
@@ -107,6 +147,59 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn
 	if err != nil {
 		return nil, errors.Wrap(err, "get_request_body_failed")
 	}
+	cachedBody, err := storage.Bytes()
+	if err != nil {
+		return nil, errors.Wrap(err, "read_body_bytes_failed")
+	}
+	contentType := c.GetHeader("Content-Type")
+
+	if strings.HasPrefix(contentType, "application/json") {
+		var bodyMap map[string]interface{}
+		if err := common.Unmarshal(cachedBody, &bodyMap); err == nil {
+			bodyMap["model"] = info.UpstreamModelName
+			if newBody, err := common.Marshal(bodyMap); err == nil {
+				return bytes.NewReader(newBody), nil
+			}
+		}
+		return bytes.NewReader(cachedBody), nil
+	}
+
+	if strings.Contains(contentType, "multipart/form-data") {
+		formData, err := common.ParseMultipartFormReusable(c)
+		if err != nil {
+			return bytes.NewReader(cachedBody), nil
+		}
+		var buf bytes.Buffer
+		writer := multipart.NewWriter(&buf)
+		writer.WriteField("model", info.UpstreamModelName)
+		for key, values := range formData.Value {
+			if key == "model" {
+				continue
+			}
+			for _, v := range values {
+				writer.WriteField(key, v)
+			}
+		}
+		for fieldName, fileHeaders := range formData.File {
+			for _, fh := range fileHeaders {
+				f, err := fh.Open()
+				if err != nil {
+					continue
+				}
+				part, err := writer.CreateFormFile(fieldName, fh.Filename)
+				if err != nil {
+					f.Close()
+					continue
+				}
+				io.Copy(part, f)
+				f.Close()
+			}
+		}
+		writer.Close()
+		c.Request.Header.Set("Content-Type", writer.FormDataContentType())
+		return &buf, nil
+	}
+
 	return common.ReaderOnly(storage), nil
 }
 
@@ -116,7 +209,7 @@ func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, req
 }
 
 // DoResponse handles upstream response, returns taskID etc.
-func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, _ *relaycommon.RelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) {
+func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) {
 	responseBody, err := io.ReadAll(resp.Body)
 	if err != nil {
 		taskErr = service.TaskErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
@@ -131,17 +224,20 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, _ *relayco
 		return
 	}
 
-	if dResp.ID == "" {
-		if dResp.TaskID == "" {
-			taskErr = service.TaskErrorWrapper(fmt.Errorf("task_id is empty"), "invalid_response", http.StatusInternalServerError)
-			return
-		}
-		dResp.ID = dResp.TaskID
-		dResp.TaskID = ""
+	upstreamID := dResp.ID
+	if upstreamID == "" {
+		upstreamID = dResp.TaskID
+	}
+	if upstreamID == "" {
+		taskErr = service.TaskErrorWrapper(fmt.Errorf("task_id is empty"), "invalid_response", http.StatusInternalServerError)
+		return
 	}
 
+	// 使用公开 task_xxxx ID 返回给客户端
+	dResp.ID = info.PublicTaskID
+	dResp.TaskID = info.PublicTaskID
 	c.JSON(http.StatusOK, dResp)
-	return dResp.ID, responseBody, nil
+	return upstreamID, responseBody, nil
 }
 
 // FetchTask fetch task status
@@ -192,7 +288,7 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e
 		taskResult.Status = model.TaskStatusInProgress
 	case "completed":
 		taskResult.Status = model.TaskStatusSuccess
-		taskResult.Url = fmt.Sprintf("%s/v1/videos/%s/content", system_setting.ServerAddress, resTask.ID)
+		// Url intentionally left empty — the caller constructs the proxy URL using the public task ID
 	case "failed", "cancelled":
 		taskResult.Status = model.TaskStatusFailure
 		if resTask.Error != nil {
@@ -210,5 +306,10 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e
 }
 
 func (a *TaskAdaptor) ConvertToOpenAIVideo(task *model.Task) ([]byte, error) {
-	return task.Data, nil
+	data := task.Data
+	var err error
+	if data, err = sjson.SetBytes(data, "id", task.TaskID); err != nil {
+		return nil, errors.Wrap(err, "set id failed")
+	}
+	return data, nil
 }

+ 17 - 19
relay/channel/task/suno/adaptor.go

@@ -3,7 +3,6 @@ package suno
 import (
 	"bytes"
 	"context"
-	"encoding/json"
 	"fmt"
 	"io"
 	"net/http"
@@ -14,6 +13,7 @@ import (
 	"github.com/QuantumNous/new-api/constant"
 	"github.com/QuantumNous/new-api/dto"
 	"github.com/QuantumNous/new-api/relay/channel"
+	taskcommon "github.com/QuantumNous/new-api/relay/channel/task/taskcommon"
 	relaycommon "github.com/QuantumNous/new-api/relay/common"
 	"github.com/QuantumNous/new-api/service"
 
@@ -21,11 +21,16 @@ import (
 )
 
 type TaskAdaptor struct {
+	taskcommon.BaseBilling
 	ChannelType int
 }
 
+// ParseTaskResult is not used for Suno tasks.
+// Suno polling uses a dedicated batch-fetch path (service.UpdateSunoTasks) that
+// receives dto.TaskResponse[[]dto.SunoDataResponse] from the upstream /fetch API.
+// This differs from the per-task polling used by video adaptors.
 func (a *TaskAdaptor) ParseTaskResult([]byte) (*relaycommon.TaskInfo, error) {
-	return nil, fmt.Errorf("not implement") // todo implement this method if needed
+	return nil, fmt.Errorf("suno uses batch polling via UpdateSunoTasks, ParseTaskResult is not applicable")
 }
 
 func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) {
@@ -76,12 +81,9 @@ func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info
 func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayInfo) (io.Reader, error) {
 	sunoRequest, ok := c.Get("task_request")
 	if !ok {
-		err := common.UnmarshalBodyReusable(c, &sunoRequest)
-		if err != nil {
-			return nil, err
-		}
+		return nil, fmt.Errorf("task_request not found in context")
 	}
-	data, err := json.Marshal(sunoRequest)
+	data, err := common.Marshal(sunoRequest)
 	if err != nil {
 		return nil, err
 	}
@@ -99,7 +101,7 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
 		return
 	}
 	var sunoResponse dto.TaskResponse[string]
-	err = json.Unmarshal(responseBody, &sunoResponse)
+	err = common.Unmarshal(responseBody, &sunoResponse)
 	if err != nil {
 		taskErr = service.TaskErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
 		return
@@ -109,17 +111,13 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
 		return
 	}
 
-	for k, v := range resp.Header {
-		c.Writer.Header().Set(k, v[0])
-	}
-	c.Writer.Header().Set("Content-Type", "application/json")
-	c.Writer.WriteHeader(resp.StatusCode)
-
-	_, err = io.Copy(c.Writer, bytes.NewBuffer(responseBody))
-	if err != nil {
-		taskErr = service.TaskErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError)
-		return
+	// 使用公开 task_xxxx ID 替换上游 ID 返回给客户端
+	publicResponse := dto.TaskResponse[string]{
+		Code:    sunoResponse.Code,
+		Message: sunoResponse.Message,
+		Data:    info.PublicTaskID,
 	}
+	c.JSON(http.StatusOK, publicResponse)
 
 	return sunoResponse.Data, nil, nil
 }
@@ -134,7 +132,7 @@ func (a *TaskAdaptor) GetChannelName() string {
 
 func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy string) (*http.Response, error) {
 	requestUrl := fmt.Sprintf("%s/suno/fetch", baseUrl)
-	byteBody, err := json.Marshal(body)
+	byteBody, err := common.Marshal(body)
 	if err != nil {
 		return nil, err
 	}

+ 95 - 0
relay/channel/task/taskcommon/helpers.go

@@ -0,0 +1,95 @@
+package taskcommon
+
+import (
+	"encoding/base64"
+	"fmt"
+
+	"github.com/QuantumNous/new-api/common"
+	"github.com/QuantumNous/new-api/model"
+	relaycommon "github.com/QuantumNous/new-api/relay/common"
+	"github.com/QuantumNous/new-api/setting/system_setting"
+	"github.com/gin-gonic/gin"
+)
+
+// UnmarshalMetadata converts a map[string]any metadata to a typed struct via JSON round-trip.
+// This replaces the repeated pattern: json.Marshal(metadata) → json.Unmarshal(bytes, &target).
+func UnmarshalMetadata(metadata map[string]any, target any) error {
+	if metadata == nil {
+		return nil
+	}
+	metaBytes, err := common.Marshal(metadata)
+	if err != nil {
+		return fmt.Errorf("marshal metadata failed: %w", err)
+	}
+	if err := common.Unmarshal(metaBytes, target); err != nil {
+		return fmt.Errorf("unmarshal metadata failed: %w", err)
+	}
+	return nil
+}
+
+// DefaultString returns val if non-empty, otherwise fallback.
+func DefaultString(val, fallback string) string {
+	if val == "" {
+		return fallback
+	}
+	return val
+}
+
+// DefaultInt returns val if non-zero, otherwise fallback.
+func DefaultInt(val, fallback int) int {
+	if val == 0 {
+		return fallback
+	}
+	return val
+}
+
+// EncodeLocalTaskID encodes an upstream operation name to a URL-safe base64 string.
+// Used by Gemini/Vertex to store upstream names as task IDs.
+func EncodeLocalTaskID(name string) string {
+	return base64.RawURLEncoding.EncodeToString([]byte(name))
+}
+
+// DecodeLocalTaskID decodes a base64-encoded upstream operation name.
+func DecodeLocalTaskID(id string) (string, error) {
+	b, err := base64.RawURLEncoding.DecodeString(id)
+	if err != nil {
+		return "", err
+	}
+	return string(b), nil
+}
+
+// BuildProxyURL constructs the video proxy URL using the public task ID.
+// e.g., "https://your-server.com/v1/videos/task_xxxx/content"
+func BuildProxyURL(taskID string) string {
+	return fmt.Sprintf("%s/v1/videos/%s/content", system_setting.ServerAddress, taskID)
+}
+
+// Status-to-progress mapping constants for polling updates.
+const (
+	ProgressSubmitted  = "10%"
+	ProgressQueued     = "20%"
+	ProgressInProgress = "30%"
+	ProgressComplete   = "100%"
+)
+
+// ---------------------------------------------------------------------------
+// BaseBilling — embeddable no-op implementations for TaskAdaptor billing methods.
+// Adaptors that do not need custom billing can embed this struct directly.
+// ---------------------------------------------------------------------------
+
+type BaseBilling struct{}
+
+// EstimateBilling returns nil (no extra ratios; use base model price).
+func (BaseBilling) EstimateBilling(_ *gin.Context, _ *relaycommon.RelayInfo) map[string]float64 {
+	return nil
+}
+
+// AdjustBillingOnSubmit returns nil (no submit-time adjustment).
+func (BaseBilling) AdjustBillingOnSubmit(_ *relaycommon.RelayInfo, _ []byte) map[string]float64 {
+	return nil
+}
+
+// AdjustBillingOnComplete returns 0 (keep pre-charged amount).
+func (BaseBilling) AdjustBillingOnComplete(_ *model.Task, _ *relaycommon.TaskInfo) int {
+	return 0
+}

+ 47 - 46
relay/channel/task/vertex/adaptor.go

@@ -2,13 +2,12 @@ package vertex
 
 import (
 	"bytes"
-	"encoding/base64"
-	"encoding/json"
 	"fmt"
 	"io"
 	"net/http"
 	"regexp"
 	"strings"
+	"time"
 
 	"github.com/QuantumNous/new-api/common"
 	"github.com/QuantumNous/new-api/model"
@@ -17,6 +16,7 @@ import (
 	"github.com/QuantumNous/new-api/constant"
 	"github.com/QuantumNous/new-api/dto"
 	"github.com/QuantumNous/new-api/relay/channel"
+	taskcommon "github.com/QuantumNous/new-api/relay/channel/task/taskcommon"
 	vertexcore "github.com/QuantumNous/new-api/relay/channel/vertex"
 	relaycommon "github.com/QuantumNous/new-api/relay/common"
 	"github.com/QuantumNous/new-api/service"
@@ -62,6 +62,7 @@ type operationResponse struct {
 // ============================
 
 type TaskAdaptor struct {
+	taskcommon.BaseBilling
 	ChannelType int
 	apiKey      string
 	baseURL     string
@@ -82,10 +83,10 @@ func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycom
 // BuildRequestURL constructs the upstream URL.
 func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, error) {
 	adc := &vertexcore.Credentials{}
-	if err := json.Unmarshal([]byte(a.apiKey), adc); err != nil {
+	if err := common.Unmarshal([]byte(a.apiKey), adc); err != nil {
 		return "", fmt.Errorf("failed to decode credentials: %w", err)
 	}
-	modelName := info.OriginModelName
+	modelName := info.UpstreamModelName
 	if modelName == "" {
 		modelName = "veo-3.0-generate-001"
 	}
@@ -116,7 +117,7 @@ func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info
 	req.Header.Set("Accept", "application/json")
 
 	adc := &vertexcore.Credentials{}
-	if err := json.Unmarshal([]byte(a.apiKey), adc); err != nil {
+	if err := common.Unmarshal([]byte(a.apiKey), adc); err != nil {
 		return fmt.Errorf("failed to decode credentials: %w", err)
 	}
 
@@ -133,6 +134,28 @@ func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info
 	return nil
 }
 
+// EstimateBilling 根据用户请求中的 sampleCount 计算 OtherRatios。
+func (a *TaskAdaptor) EstimateBilling(c *gin.Context, _ *relaycommon.RelayInfo) map[string]float64 {
+	sampleCount := 1
+	v, ok := c.Get("task_request")
+	if ok {
+		req := v.(relaycommon.TaskSubmitReq)
+		if req.Metadata != nil {
+			if sc, exists := req.Metadata["sampleCount"]; exists {
+				if i, ok := sc.(int); ok && i > 0 {
+					sampleCount = i
+				}
+				if f, ok := sc.(float64); ok && int(f) > 0 {
+					sampleCount = int(f)
+				}
+			}
+		}
+	}
+	return map[string]float64{
+		"sampleCount": float64(sampleCount),
+	}
+}
+
 // BuildRequestBody converts request into Vertex specific format.
 func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayInfo) (io.Reader, error) {
 	v, ok := c.Get("task_request")
@@ -166,25 +189,7 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn
 		return nil, fmt.Errorf("sampleCount must be greater than 0")
 	}
 
-	// if req.Duration > 0 {
-	// 	body.Parameters["durationSeconds"] = req.Duration
-	// } else if req.Seconds != "" {
-	// 	seconds, err := strconv.Atoi(req.Seconds)
-	// 	if err != nil {
-	// 		return nil, errors.Wrap(err, "convert seconds to int failed")
-	// 	}
-	// 	body.Parameters["durationSeconds"] = seconds
-	// }
-
-	info.PriceData.OtherRatios = map[string]float64{
-		"sampleCount": float64(body.Parameters["sampleCount"].(int)),
-	}
-
-	// if v, ok := body.Parameters["durationSeconds"]; ok {
-	// 	info.PriceData.OtherRatios["durationSeconds"] = float64(v.(int))
-	// }
-
-	data, err := json.Marshal(body)
+	data, err := common.Marshal(body)
 	if err != nil {
 		return nil, err
 	}
@@ -205,14 +210,19 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
 	_ = resp.Body.Close()
 
 	var s submitResponse
-	if err := json.Unmarshal(responseBody, &s); err != nil {
+	if err := common.Unmarshal(responseBody, &s); err != nil {
 		return "", nil, service.TaskErrorWrapper(err, "unmarshal_response_failed", http.StatusInternalServerError)
 	}
 	if strings.TrimSpace(s.Name) == "" {
 		return "", nil, service.TaskErrorWrapper(fmt.Errorf("missing operation name"), "invalid_response", http.StatusInternalServerError)
 	}
-	localID := encodeLocalTaskID(s.Name)
-	c.JSON(http.StatusOK, gin.H{"task_id": localID})
+	localID := taskcommon.EncodeLocalTaskID(s.Name)
+	ov := dto.NewOpenAIVideo()
+	ov.ID = info.PublicTaskID
+	ov.TaskID = info.PublicTaskID
+	ov.CreatedAt = time.Now().Unix()
+	ov.Model = info.OriginModelName
+	c.JSON(http.StatusOK, ov)
 	return localID, responseBody, nil
 }
 
@@ -225,7 +235,7 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy
 	if !ok {
 		return nil, fmt.Errorf("invalid task_id")
 	}
-	upstreamName, err := decodeLocalTaskID(taskID)
+	upstreamName, err := taskcommon.DecodeLocalTaskID(taskID)
 	if err != nil {
 		return nil, fmt.Errorf("decode task_id failed: %w", err)
 	}
@@ -245,12 +255,12 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy
 		url = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:fetchPredictOperation", region, project, region, modelName)
 	}
 	payload := map[string]string{"operationName": upstreamName}
-	data, err := json.Marshal(payload)
+	data, err := common.Marshal(payload)
 	if err != nil {
 		return nil, err
 	}
 	adc := &vertexcore.Credentials{}
-	if err := json.Unmarshal([]byte(key), adc); err != nil {
+	if err := common.Unmarshal([]byte(key), adc); err != nil {
 		return nil, fmt.Errorf("failed to decode credentials: %w", err)
 	}
 	token, err := vertexcore.AcquireAccessToken(*adc, proxy)
@@ -274,7 +284,7 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy
 
 func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) {
 	var op operationResponse
-	if err := json.Unmarshal(respBody, &op); err != nil {
+	if err := common.Unmarshal(respBody, &op); err != nil {
 		return nil, fmt.Errorf("unmarshal operation response failed: %w", err)
 	}
 	ti := &relaycommon.TaskInfo{}
@@ -338,7 +348,10 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e
 }
 
 func (a *TaskAdaptor) ConvertToOpenAIVideo(task *model.Task) ([]byte, error) {
-	upstreamName, err := decodeLocalTaskID(task.TaskID)
+	// Use GetUpstreamTaskID() to get the real upstream operation name for model extraction.
+	// task.TaskID is now a public task_xxxx ID, no longer a base64-encoded upstream name.
+	upstreamTaskID := task.GetUpstreamTaskID()
+	upstreamName, err := taskcommon.DecodeLocalTaskID(upstreamTaskID)
 	if err != nil {
 		upstreamName = ""
 	}
@@ -353,8 +366,8 @@ func (a *TaskAdaptor) ConvertToOpenAIVideo(task *model.Task) ([]byte, error) {
 	v.SetProgressStr(task.Progress)
 	v.CreatedAt = task.CreatedAt
 	v.CompletedAt = task.UpdatedAt
-	if strings.HasPrefix(task.FailReason, "data:") && len(task.FailReason) > 0 {
-		v.SetMetadata("url", task.FailReason)
+	if resultURL := task.GetResultURL(); strings.HasPrefix(resultURL, "data:") && len(resultURL) > 0 {
+		v.SetMetadata("url", resultURL)
 	}
 
 	return common.Marshal(v)
@@ -364,18 +377,6 @@ func (a *TaskAdaptor) ConvertToOpenAIVideo(task *model.Task) ([]byte, error) {
 // helpers
 // ============================
 
-func encodeLocalTaskID(name string) string {
-	return base64.RawURLEncoding.EncodeToString([]byte(name))
-}
-
-func decodeLocalTaskID(local string) (string, error) {
-	b, err := base64.RawURLEncoding.DecodeString(local)
-	if err != nil {
-		return "", err
-	}
-	return string(b), nil
-}
-
 var regionRe = regexp.MustCompile(`locations/([a-z0-9-]+)/`)
 
 func extractRegionFromOperationName(name string) string {

+ 15 - 35
relay/channel/task/vidu/adaptor.go

@@ -2,7 +2,6 @@ package vidu
 
 import (
 	"bytes"
-	"encoding/json"
 	"fmt"
 	"io"
 	"net/http"
@@ -16,6 +15,7 @@ import (
 	"github.com/QuantumNous/new-api/dto"
 	"github.com/QuantumNous/new-api/model"
 	"github.com/QuantumNous/new-api/relay/channel"
+	taskcommon "github.com/QuantumNous/new-api/relay/channel/task/taskcommon"
 	relaycommon "github.com/QuantumNous/new-api/relay/common"
 	"github.com/QuantumNous/new-api/service"
 
@@ -73,6 +73,7 @@ type creation struct {
 // ============================
 
 type TaskAdaptor struct {
+	taskcommon.BaseBilling
 	ChannelType int
 	baseURL     string
 }
@@ -115,7 +116,7 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn
 	}
 	req := v.(relaycommon.TaskSubmitReq)
 
-	body, err := a.convertToRequestPayload(&req)
+	body, err := a.convertToRequestPayload(&req, info)
 	if err != nil {
 		return nil, err
 	}
@@ -127,7 +128,7 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn
 		}
 	}
 
-	data, err := json.Marshal(body)
+	data, err := common.Marshal(body)
 	if err != nil {
 		return nil, err
 	}
@@ -168,7 +169,7 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
 	}
 
 	var vResp responsePayload
-	err = json.Unmarshal(responseBody, &vResp)
+	err = common.Unmarshal(responseBody, &vResp)
 	if err != nil {
 		taskErr = service.TaskErrorWrapper(errors.Wrap(err, fmt.Sprintf("%s", responseBody)), "unmarshal_response_failed", http.StatusInternalServerError)
 		return
@@ -180,8 +181,8 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
 	}
 
 	ov := dto.NewOpenAIVideo()
-	ov.ID = vResp.TaskId
-	ov.TaskID = vResp.TaskId
+	ov.ID = info.PublicTaskID
+	ov.TaskID = info.PublicTaskID
 	ov.CreatedAt = time.Now().Unix()
 	ov.Model = info.OriginModelName
 	c.JSON(http.StatusOK, ov)
@@ -223,47 +224,27 @@ func (a *TaskAdaptor) GetChannelName() string {
 // helpers
 // ============================
 
-func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*requestPayload, error) {
+func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq, info *relaycommon.RelayInfo) (*requestPayload, error) {
 	r := requestPayload{
-		Model:             defaultString(req.Model, "viduq1"),
+		Model:             taskcommon.DefaultString(info.UpstreamModelName, "viduq1"),
 		Images:            req.Images,
 		Prompt:            req.Prompt,
-		Duration:          defaultInt(req.Duration, 5),
-		Resolution:        defaultString(req.Size, "1080p"),
+		Duration:          taskcommon.DefaultInt(req.Duration, 5),
+		Resolution:        taskcommon.DefaultString(req.Size, "1080p"),
 		MovementAmplitude: "auto",
 		Bgm:               false,
 	}
-	metadata := req.Metadata
-	medaBytes, err := json.Marshal(metadata)
-	if err != nil {
-		return nil, errors.Wrap(err, "metadata marshal metadata failed")
-	}
-	err = json.Unmarshal(medaBytes, &r)
-	if err != nil {
+	if err := taskcommon.UnmarshalMetadata(req.Metadata, &r); err != nil {
 		return nil, errors.Wrap(err, "unmarshal metadata failed")
 	}
 	return &r, nil
 }
 
-func defaultString(value, defaultValue string) string {
-	if value == "" {
-		return defaultValue
-	}
-	return value
-}
-
-func defaultInt(value, defaultValue int) int {
-	if value == 0 {
-		return defaultValue
-	}
-	return value
-}
-
 func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) {
 	taskInfo := &relaycommon.TaskInfo{}
 
 	var taskResp taskResultResponse
-	err := json.Unmarshal(respBody, &taskResp)
+	err := common.Unmarshal(respBody, &taskResp)
 	if err != nil {
 		return nil, errors.Wrap(err, "failed to unmarshal response body")
 	}
@@ -293,7 +274,7 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e
 
 func (a *TaskAdaptor) ConvertToOpenAIVideo(originTask *model.Task) ([]byte, error) {
 	var viduResp taskResultResponse
-	if err := json.Unmarshal(originTask.Data, &viduResp); err != nil {
+	if err := common.Unmarshal(originTask.Data, &viduResp); err != nil {
 		return nil, errors.Wrap(err, "unmarshal vidu task data failed")
 	}
 
@@ -315,6 +296,5 @@ func (a *TaskAdaptor) ConvertToOpenAIVideo(originTask *model.Task) ([]byte, erro
 		}
 	}
 
-	jsonData, _ := common.Marshal(openAIVideo)
-	return jsonData, nil
+	return common.Marshal(openAIVideo)
 }

+ 17 - 3
relay/common/relay_info.go

@@ -118,8 +118,12 @@ type RelayInfo struct {
 	SendResponseCount      int
 	ReceivedResponseCount  int
 	FinalPreConsumedQuota  int // 最终预消耗的配额
+	// ForcePreConsume 为 true 时禁用 BillingSession 的信任额度旁路,
+	// 强制预扣全额。用于异步任务(视频/音乐生成等),因为请求返回后任务仍在运行,
+	// 必须在提交前锁定全额。
+	ForcePreConsume bool
 	// Billing 是计费会话,封装了预扣费/结算/退款的统一生命周期。
-	// 免费模型和按次计费(MJ/Task)时为 nil。
+	// 免费模型时为 nil。
 	Billing BillingSettler
 	// BillingSource indicates whether this request is billed from wallet quota or subscription.
 	// "" or "wallet" => wallet; "subscription" => subscription
@@ -525,8 +529,10 @@ func GenRelayInfo(c *gin.Context, relayFormat types.RelayFormat, request dto.Req
 		return nil, errors.New("request is not a OpenAIResponsesCompactionRequest")
 	case types.RelayFormatTask:
 		info = genBaseRelayInfo(c, nil)
+		info.TaskRelayInfo = &TaskRelayInfo{}
 	case types.RelayFormatMjProxy:
 		info = genBaseRelayInfo(c, nil)
+		info.TaskRelayInfo = &TaskRelayInfo{}
 	default:
 		err = errors.New("invalid relay format")
 	}
@@ -608,8 +614,16 @@ func (info *RelayInfo) HasSendResponse() bool {
 type TaskRelayInfo struct {
 	Action       string
 	OriginTaskID string
+	// PublicTaskID 是提交时预生成的 task_xxxx 格式公开 ID,
+	// 供 DoResponse 在返回给客户端时使用(避免暴露上游真实 ID)。
+	PublicTaskID string
 
 	ConsumeQuota bool
+
+	// LockedChannel holds the full channel object when the request is bound to
+	// a specific channel (e.g., remix on origin task's channel). Stored as any
+	// to avoid an import cycle with model; callers type-assert to *model.Channel.
+	LockedChannel any
 }
 
 type TaskSubmitReq struct {
@@ -667,11 +681,11 @@ func (t *TaskSubmitReq) UnmarshalJSON(data []byte) error {
 func (t *TaskSubmitReq) UnmarshalMetadata(v any) error {
 	metadata := t.Metadata
 	if metadata != nil {
-		metadataBytes, err := json.Marshal(metadata)
+		metadataBytes, err := common.Marshal(metadata)
 		if err != nil {
 			return fmt.Errorf("marshal metadata failed: %w", err)
 		}
-		err = json.Unmarshal(metadataBytes, v)
+		err = common.Unmarshal(metadataBytes, v)
 		if err != nil {
 			return fmt.Errorf("unmarshal metadata to target failed: %w", err)
 		}

+ 2 - 8
relay/common/relay_utils.go

@@ -173,16 +173,10 @@ func ValidateMultipartDirect(c *gin.Context, info *RelayInfo) *dto.TaskError {
 		if model == "sora-2-pro" && !lo.Contains([]string{"720x1280", "1280x720", "1792x1024", "1024x1792"}, size) {
 			return createTaskError(fmt.Errorf("sora-2 size is invalid"), "invalid_size", http.StatusBadRequest, true)
 		}
-		info.PriceData.OtherRatios = map[string]float64{
-			"seconds": float64(seconds),
-			"size":    1,
-		}
-		if lo.Contains([]string{"1792x1024", "1024x1792"}, size) {
-			info.PriceData.OtherRatios["size"] = 1.666667
-		}
+		// OtherRatios 已移到 Sora adaptor 的 EstimateBilling 中设置
 	}
 
-	info.Action = action
+	storeTaskRequest(c, info, action, req)
 
 	return nil
 }

+ 13 - 2
relay/helper/price.go

@@ -140,7 +140,7 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens
 }
 
 // ModelPriceHelperPerCall 按次计费的 PriceHelper (MJ、Task)
-func ModelPriceHelperPerCall(c *gin.Context, info *relaycommon.RelayInfo) types.PerCallPriceData {
+func ModelPriceHelperPerCall(c *gin.Context, info *relaycommon.RelayInfo) types.PriceData {
 	groupRatioInfo := HandleGroupRatio(c, info)
 
 	modelPrice, success := ratio_setting.GetModelPrice(info.OriginModelName, true)
@@ -154,7 +154,18 @@ func ModelPriceHelperPerCall(c *gin.Context, info *relaycommon.RelayInfo) types.
 		}
 	}
 	quota := int(modelPrice * common.QuotaPerUnit * groupRatioInfo.GroupRatio)
-	priceData := types.PerCallPriceData{
+
+	// 免费模型检测(与 ModelPriceHelper 对齐)
+	freeModel := false
+	if !operation_setting.GetQuotaSetting().EnableFreeModelPreConsume {
+		if groupRatioInfo.GroupRatio == 0 || modelPrice == 0 {
+			quota = 0
+			freeModel = true
+		}
+	}
+
+	priceData := types.PriceData{
+		FreeModel:      freeModel,
 		ModelPrice:     modelPrice,
 		Quota:          quota,
 		GroupRatioInfo: groupRatioInfo,

+ 324 - 280
relay/relay_task.go

@@ -2,7 +2,6 @@ package relay
 
 import (
 	"bytes"
-	"encoding/json"
 	"errors"
 	"fmt"
 	"io"
@@ -15,29 +14,33 @@ import (
 	"github.com/QuantumNous/new-api/dto"
 	"github.com/QuantumNous/new-api/model"
 	"github.com/QuantumNous/new-api/relay/channel"
+	"github.com/QuantumNous/new-api/relay/channel/task/taskcommon"
 	relaycommon "github.com/QuantumNous/new-api/relay/common"
 	relayconstant "github.com/QuantumNous/new-api/relay/constant"
+	"github.com/QuantumNous/new-api/relay/helper"
 	"github.com/QuantumNous/new-api/service"
-	"github.com/QuantumNous/new-api/setting/ratio_setting"
-
 	"github.com/gin-gonic/gin"
 )
 
-/*
-Task 任务通过平台、Action 区分任务
-*/
-func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) {
-	info.InitChannelMeta(c)
-	// ensure TaskRelayInfo is initialized to avoid nil dereference when accessing embedded fields
-	if info.TaskRelayInfo == nil {
-		info.TaskRelayInfo = &relaycommon.TaskRelayInfo{}
-	}
+type TaskSubmitResult struct {
+	UpstreamTaskID string
+	TaskData       []byte
+	Platform       constant.TaskPlatform
+	Quota          int
+	//PerCallPrice   types.PriceData
+}
+
+// ResolveOriginTask 处理基于已有任务的提交(remix / continuation):
+// 查找原始任务、从中提取模型名称、将渠道锁定到原始任务的渠道
+// (通过 info.LockedChannel,重试时复用同一渠道并轮换 key),
+// 以及提取 OtherRatios(时长、分辨率)。
+// 该函数在控制器的重试循环之前调用一次,其结果通过 info 字段和上下文持久化。
+func ResolveOriginTask(c *gin.Context, info *relaycommon.RelayInfo) *dto.TaskError {
+	// 检测 remix action
 	path := c.Request.URL.Path
 	if strings.Contains(path, "/v1/videos/") && strings.HasSuffix(path, "/remix") {
 		info.Action = constant.TaskActionRemix
 	}
-
-	// 提取 remix 任务的 video_id
 	if info.Action == constant.TaskActionRemix {
 		videoID := c.Param("video_id")
 		if strings.TrimSpace(videoID) == "" {
@@ -46,64 +49,71 @@ func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.
 		info.OriginTaskID = videoID
 	}
 
-	platform := constant.TaskPlatform(c.GetString("platform"))
+	if info.OriginTaskID == "" {
+		return nil
+	}
 
-	// 获取原始任务信息
-	if info.OriginTaskID != "" {
-		originTask, exist, err := model.GetByTaskId(info.UserId, info.OriginTaskID)
-		if err != nil {
-			taskErr = service.TaskErrorWrapper(err, "get_origin_task_failed", http.StatusInternalServerError)
-			return
-		}
-		if !exist {
-			taskErr = service.TaskErrorWrapperLocal(errors.New("task_origin_not_exist"), "task_not_exist", http.StatusBadRequest)
-			return
-		}
-		if info.OriginModelName == "" {
-			if originTask.Properties.OriginModelName != "" {
-				info.OriginModelName = originTask.Properties.OriginModelName
-			} else if originTask.Properties.UpstreamModelName != "" {
-				info.OriginModelName = originTask.Properties.UpstreamModelName
-			} else {
-				var taskData map[string]interface{}
-				_ = json.Unmarshal(originTask.Data, &taskData)
-				if m, ok := taskData["model"].(string); ok && m != "" {
-					info.OriginModelName = m
-					platform = originTask.Platform
-				}
+	// 查找原始任务
+	originTask, exist, err := model.GetByTaskId(info.UserId, info.OriginTaskID)
+	if err != nil {
+		return service.TaskErrorWrapper(err, "get_origin_task_failed", http.StatusInternalServerError)
+	}
+	if !exist {
+		return service.TaskErrorWrapperLocal(errors.New("task_origin_not_exist"), "task_not_exist", http.StatusBadRequest)
+	}
+
+	// 从原始任务推导模型名称
+	if info.OriginModelName == "" {
+		if originTask.Properties.OriginModelName != "" {
+			info.OriginModelName = originTask.Properties.OriginModelName
+		} else if originTask.Properties.UpstreamModelName != "" {
+			info.OriginModelName = originTask.Properties.UpstreamModelName
+		} else {
+			var taskData map[string]interface{}
+			_ = common.Unmarshal(originTask.Data, &taskData)
+			if m, ok := taskData["model"].(string); ok && m != "" {
+				info.OriginModelName = m
 			}
 		}
-		if originTask.ChannelId != info.ChannelId {
-			channel, err := model.GetChannelById(originTask.ChannelId, true)
-			if err != nil {
-				taskErr = service.TaskErrorWrapperLocal(err, "channel_not_found", http.StatusBadRequest)
-				return
-			}
-			if channel.Status != common.ChannelStatusEnabled {
-				taskErr = service.TaskErrorWrapperLocal(errors.New("the channel of the origin task is disabled"), "task_channel_disable", http.StatusBadRequest)
-				return
-			}
-			key, _, newAPIError := channel.GetNextEnabledKey()
-			if newAPIError != nil {
-				taskErr = service.TaskErrorWrapper(newAPIError, "channel_no_available_key", newAPIError.StatusCode)
-				return
-			}
-			common.SetContextKey(c, constant.ContextKeyChannelKey, key)
-			common.SetContextKey(c, constant.ContextKeyChannelType, channel.Type)
-			common.SetContextKey(c, constant.ContextKeyChannelBaseUrl, channel.GetBaseURL())
-			common.SetContextKey(c, constant.ContextKeyChannelId, originTask.ChannelId)
-
-			info.ChannelBaseUrl = channel.GetBaseURL()
-			info.ChannelId = originTask.ChannelId
-			info.ChannelType = channel.Type
-			info.ApiKey = key
-			platform = originTask.Platform
+	}
+
+	// 锁定到原始任务的渠道(重试时复用同一渠道,轮换 key)
+	ch, err := model.GetChannelById(originTask.ChannelId, true)
+	if err != nil {
+		return service.TaskErrorWrapperLocal(err, "channel_not_found", http.StatusBadRequest)
+	}
+	if ch.Status != common.ChannelStatusEnabled {
+		return service.TaskErrorWrapperLocal(errors.New("the channel of the origin task is disabled"), "task_channel_disable", http.StatusBadRequest)
+	}
+	info.LockedChannel = ch
+
+	if originTask.ChannelId != info.ChannelId {
+		key, _, newAPIError := ch.GetNextEnabledKey()
+		if newAPIError != nil {
+			return service.TaskErrorWrapper(newAPIError, "channel_no_available_key", newAPIError.StatusCode)
 		}
+		common.SetContextKey(c, constant.ContextKeyChannelKey, key)
+		common.SetContextKey(c, constant.ContextKeyChannelType, ch.Type)
+		common.SetContextKey(c, constant.ContextKeyChannelBaseUrl, ch.GetBaseURL())
+		common.SetContextKey(c, constant.ContextKeyChannelId, originTask.ChannelId)
+
+		info.ChannelBaseUrl = ch.GetBaseURL()
+		info.ChannelId = originTask.ChannelId
+		info.ChannelType = ch.Type
+		info.ApiKey = key
+	}
 
-		// 使用原始任务的参数
-		if info.Action == constant.TaskActionRemix {
+	// 提取 remix 参数(时长、分辨率 → OtherRatios)
+	if info.Action == constant.TaskActionRemix {
+		if originTask.PrivateData.BillingContext != nil {
+			// 新的 remix 逻辑:直接从原始任务的 BillingContext 中提取 OtherRatios(如果存在)
+			for s, f := range originTask.PrivateData.BillingContext.OtherRatios {
+				info.PriceData.AddOtherRatio(s, f)
+			}
+		} else {
+			// 旧的 remix 逻辑:直接从 task data 解析 seconds 和 size(如果存在)
 			var taskData map[string]interface{}
-			_ = json.Unmarshal(originTask.Data, &taskData)
+			_ = common.Unmarshal(originTask.Data, &taskData)
 			secondsStr, _ := taskData["seconds"].(string)
 			seconds, _ := strconv.Atoi(secondsStr)
 			if seconds <= 0 {
@@ -120,167 +130,146 @@ func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.
 			}
 		}
 	}
+
+	return nil
+}
+
+// RelayTaskSubmit 完成 task 提交的全部流程(每次尝试调用一次):
+// 刷新渠道元数据 → 确定 platform/adaptor → 验证请求 →
+// 估算计费(EstimateBilling) → 计算价格 → 预扣费(仅首次)→
+// 构建/发送/解析上游请求 → 提交后计费调整(AdjustBillingOnSubmit)。
+// 控制器负责 defer Refund 和成功后 Settle。
+func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (*TaskSubmitResult, *dto.TaskError) {
+	info.InitChannelMeta(c)
+
+	// 1. 确定 platform → 创建适配器 → 验证请求
+	platform := constant.TaskPlatform(c.GetString("platform"))
 	if platform == "" {
 		platform = GetTaskPlatform(c)
 	}
-
-	info.InitChannelMeta(c)
 	adaptor := GetTaskAdaptor(platform)
 	if adaptor == nil {
-		return service.TaskErrorWrapperLocal(fmt.Errorf("invalid api platform: %s", platform), "invalid_api_platform", http.StatusBadRequest)
+		return nil, service.TaskErrorWrapperLocal(fmt.Errorf("invalid api platform: %s", platform), "invalid_api_platform", http.StatusBadRequest)
 	}
 	adaptor.Init(info)
-	// get & validate taskRequest 获取并验证文本请求
-	taskErr = adaptor.ValidateRequestAndSetAction(c, info)
-	if taskErr != nil {
-		return
+	if taskErr := adaptor.ValidateRequestAndSetAction(c, info); taskErr != nil {
+		return nil, taskErr
 	}
 
+	// 2. 确定模型名称
 	modelName := info.OriginModelName
 	if modelName == "" {
 		modelName = service.CoverTaskActionToModelName(platform, info.Action)
 	}
-	modelPrice, success := ratio_setting.GetModelPrice(modelName, true)
-	if !success {
-		defaultPrice, ok := ratio_setting.GetDefaultModelPriceMap()[modelName]
-		if !ok {
-			modelPrice = float64(common.PreConsumedQuota) / common.QuotaPerUnit
-		} else {
-			modelPrice = defaultPrice
-		}
+
+	// 2.5 应用渠道的模型映射(与同步任务对齐)
+	info.OriginModelName = modelName
+	info.UpstreamModelName = modelName
+	if err := helper.ModelMappedHelper(c, info, nil); err != nil {
+		return nil, service.TaskErrorWrapperLocal(err, "model_mapping_failed", http.StatusBadRequest)
 	}
 
-	// 处理 auto 分组:从 context 获取实际选中的分组
-	// 当使用 auto 分组时,Distribute 中间件会将实际选中的分组存储在 ContextKeyAutoGroup 中
-	if autoGroup, exists := common.GetContextKey(c, constant.ContextKeyAutoGroup); exists {
-		if groupStr, ok := autoGroup.(string); ok && groupStr != "" {
-			info.UsingGroup = groupStr
-		}
+	// 3. 预生成公开 task ID(仅首次)
+	if info.PublicTaskID == "" {
+		info.PublicTaskID = model.GenerateTaskID()
 	}
 
-	// 预扣
-	groupRatio := ratio_setting.GetGroupRatio(info.UsingGroup)
-	var ratio float64
-	userGroupRatio, hasUserGroupRatio := ratio_setting.GetGroupGroupRatio(info.UserGroup, info.UsingGroup)
-	if hasUserGroupRatio {
-		ratio = modelPrice * userGroupRatio
-	} else {
-		ratio = modelPrice * groupRatio
+	// 4. 价格计算:基础模型价格
+	info.OriginModelName = modelName
+	info.PriceData = helper.ModelPriceHelperPerCall(c, info)
+
+	// 5. 计费估算:让适配器根据用户请求提供 OtherRatios(时长、分辨率等)
+	//    必须在 ModelPriceHelperPerCall 之后调用(它会重建 PriceData)。
+	//    ResolveOriginTask 可能已在 remix 路径中预设了 OtherRatios,此处合并。
+	if estimatedRatios := adaptor.EstimateBilling(c, info); len(estimatedRatios) > 0 {
+		for k, v := range estimatedRatios {
+			info.PriceData.AddOtherRatio(k, v)
+		}
 	}
-	// FIXME: 临时修补,支持任务仅按次计费
+
+	// 6. 将 OtherRatios 应用到基础额度
 	if !common.StringsContains(constant.TaskPricePatches, modelName) {
-		if len(info.PriceData.OtherRatios) > 0 {
-			for _, ra := range info.PriceData.OtherRatios {
-				if 1.0 != ra {
-					ratio *= ra
-				}
+		for _, ra := range info.PriceData.OtherRatios {
+			if ra != 1.0 {
+				info.PriceData.Quota = int(float64(info.PriceData.Quota) * ra)
 			}
 		}
 	}
-	println(fmt.Sprintf("model: %s, model_price: %.4f, group: %s, group_ratio: %.4f, final_ratio: %.4f", modelName, modelPrice, info.UsingGroup, groupRatio, ratio))
-	userQuota, err := model.GetUserQuota(info.UserId, false)
-	if err != nil {
-		taskErr = service.TaskErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
-		return
-	}
-	quota := int(ratio * common.QuotaPerUnit)
-	if userQuota-quota < 0 {
-		taskErr = service.TaskErrorWrapperLocal(errors.New("user quota is not enough"), "quota_not_enough", http.StatusForbidden)
-		return
+
+	// 7. 预扣费(仅首次 — 重试时 info.Billing 已存在,跳过)
+	if info.Billing == nil && !info.PriceData.FreeModel {
+		info.ForcePreConsume = true
+		if apiErr := service.PreConsumeBilling(c, info.PriceData.Quota, info); apiErr != nil {
+			return nil, service.TaskErrorFromAPIError(apiErr)
+		}
 	}
 
-	// build body
+	// 8. 构建请求体
 	requestBody, err := adaptor.BuildRequestBody(c, info)
 	if err != nil {
-		taskErr = service.TaskErrorWrapper(err, "build_request_failed", http.StatusInternalServerError)
-		return
+		return nil, service.TaskErrorWrapper(err, "build_request_failed", http.StatusInternalServerError)
 	}
-	// do request
+
+	// 9. 发送请求
 	resp, err := adaptor.DoRequest(c, info, requestBody)
 	if err != nil {
-		taskErr = service.TaskErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
-		return
+		return nil, service.TaskErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
 	}
-	// handle response
 	if resp != nil && resp.StatusCode != http.StatusOK {
 		responseBody, _ := io.ReadAll(resp.Body)
-		taskErr = service.TaskErrorWrapper(fmt.Errorf("%s", string(responseBody)), "fail_to_fetch_task", resp.StatusCode)
-		return
+		return nil, service.TaskErrorWrapper(fmt.Errorf("%s", string(responseBody)), "fail_to_fetch_task", resp.StatusCode)
 	}
 
-	defer func() {
-		// release quota
-		if info.ConsumeQuota && taskErr == nil {
-
-			err := service.PostConsumeQuota(info, quota, 0, true)
-			if err != nil {
-				common.SysLog("error consuming token remain quota: " + err.Error())
-			}
-			if quota != 0 {
-				tokenName := c.GetString("token_name")
-				//gRatio := groupRatio
-				//if hasUserGroupRatio {
-				//	gRatio = userGroupRatio
-				//}
-				logContent := fmt.Sprintf("操作 %s", info.Action)
-				// FIXME: 临时修补,支持任务仅按次计费
-				if common.StringsContains(constant.TaskPricePatches, modelName) {
-					logContent = fmt.Sprintf("%s,按次计费", logContent)
-				} else {
-					if len(info.PriceData.OtherRatios) > 0 {
-						var contents []string
-						for key, ra := range info.PriceData.OtherRatios {
-							if 1.0 != ra {
-								contents = append(contents, fmt.Sprintf("%s: %.2f", key, ra))
-							}
-						}
-						if len(contents) > 0 {
-							logContent = fmt.Sprintf("%s, 计算参数:%s", logContent, strings.Join(contents, ", "))
-						}
-					}
-				}
-				other := make(map[string]interface{})
-				if c != nil && c.Request != nil && c.Request.URL != nil {
-					other["request_path"] = c.Request.URL.Path
-				}
-				other["model_price"] = modelPrice
-				other["group_ratio"] = groupRatio
-				if hasUserGroupRatio {
-					other["user_group_ratio"] = userGroupRatio
-				}
-				model.RecordConsumeLog(c, info.UserId, model.RecordConsumeLogParams{
-					ChannelId: info.ChannelId,
-					ModelName: modelName,
-					TokenName: tokenName,
-					Quota:     quota,
-					Content:   logContent,
-					TokenId:   info.TokenId,
-					Group:     info.UsingGroup,
-					Other:     other,
-				})
-				model.UpdateUserUsedQuotaAndRequestCount(info.UserId, quota)
-				model.UpdateChannelUsedQuota(info.ChannelId, quota)
-			}
-		}
-	}()
+	// 10. 返回 OtherRatios 给下游(header 必须在 DoResponse 写 body 之前设置)
+	otherRatios := info.PriceData.OtherRatios
+	if otherRatios == nil {
+		otherRatios = map[string]float64{}
+	}
+	ratiosJSON, _ := common.Marshal(otherRatios)
+	c.Header("X-New-Api-Other-Ratios", string(ratiosJSON))
 
-	taskID, taskData, taskErr := adaptor.DoResponse(c, resp, info)
+	// 11. 解析响应
+	upstreamTaskID, taskData, taskErr := adaptor.DoResponse(c, resp, info)
 	if taskErr != nil {
-		return
+		return nil, taskErr
 	}
-	info.ConsumeQuota = true
-	// insert task
-	task := model.InitTask(platform, info)
-	task.TaskID = taskID
-	task.Quota = quota
-	task.Data = taskData
-	task.Action = info.Action
-	err = task.Insert()
-	if err != nil {
-		taskErr = service.TaskErrorWrapper(err, "insert_task_failed", http.StatusInternalServerError)
-		return
+
+	// 11. 提交后计费调整:让适配器根据上游实际返回调整 OtherRatios
+	finalQuota := info.PriceData.Quota
+	if adjustedRatios := adaptor.AdjustBillingOnSubmit(info, taskData); len(adjustedRatios) > 0 {
+		// 基于调整后的 ratios 重新计算 quota
+		finalQuota = recalcQuotaFromRatios(info, adjustedRatios)
+		info.PriceData.OtherRatios = adjustedRatios
+		info.PriceData.Quota = finalQuota
 	}
-	return nil
+
+	return &TaskSubmitResult{
+		UpstreamTaskID: upstreamTaskID,
+		TaskData:       taskData,
+		Platform:       platform,
+		Quota:          finalQuota,
+	}, nil
+}
+
+// recalcQuotaFromRatios 根据 adjustedRatios 重新计算 quota。
+// 公式: baseQuota × ∏(ratio) — 其中 baseQuota 是不含 OtherRatios 的基础额度。
+func recalcQuotaFromRatios(info *relaycommon.RelayInfo, ratios map[string]float64) int {
+	// 从 PriceData 获取不含 OtherRatios 的基础价格
+	baseQuota := info.PriceData.Quota
+	// 先除掉原有的 OtherRatios 恢复基础额度
+	for _, ra := range info.PriceData.OtherRatios {
+		if ra != 1.0 && ra > 0 {
+			baseQuota = int(float64(baseQuota) / ra)
+		}
+	}
+	// 应用新的 ratios
+	result := float64(baseQuota)
+	for _, ra := range ratios {
+		if ra != 1.0 {
+			result *= ra
+		}
+	}
+	return int(result)
 }
 
 var fetchRespBuilders = map[int]func(c *gin.Context) (respBody []byte, taskResp *dto.TaskError){
@@ -336,7 +325,7 @@ func sunoFetchRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *dto.Ta
 	} else {
 		tasks = make([]any, 0)
 	}
-	respBody, err = json.Marshal(dto.TaskResponse[[]any]{
+	respBody, err = common.Marshal(dto.TaskResponse[[]any]{
 		Code: "success",
 		Data: tasks,
 	})
@@ -357,7 +346,7 @@ func sunoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *dt
 		return
 	}
 
-	respBody, err = json.Marshal(dto.TaskResponse[any]{
+	respBody, err = common.Marshal(dto.TaskResponse[any]{
 		Code: "success",
 		Data: TaskModel2Dto(originTask),
 	})
@@ -381,97 +370,16 @@ func videoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *d
 		return
 	}
 
-	func() {
-		channelModel, err2 := model.GetChannelById(originTask.ChannelId, true)
-		if err2 != nil {
-			return
-		}
-		if channelModel.Type != constant.ChannelTypeVertexAi && channelModel.Type != constant.ChannelTypeGemini {
-			return
-		}
-		baseURL := constant.ChannelBaseURLs[channelModel.Type]
-		if channelModel.GetBaseURL() != "" {
-			baseURL = channelModel.GetBaseURL()
-		}
-		proxy := channelModel.GetSetting().Proxy
-		adaptor := GetTaskAdaptor(constant.TaskPlatform(strconv.Itoa(channelModel.Type)))
-		if adaptor == nil {
-			return
-		}
-		resp, err2 := adaptor.FetchTask(baseURL, channelModel.Key, map[string]any{
-			"task_id": originTask.TaskID,
-			"action":  originTask.Action,
-		}, proxy)
-		if err2 != nil || resp == nil {
-			return
-		}
-		defer resp.Body.Close()
-		body, err2 := io.ReadAll(resp.Body)
-		if err2 != nil {
-			return
-		}
-		ti, err2 := adaptor.ParseTaskResult(body)
-		if err2 == nil && ti != nil {
-			if ti.Status != "" {
-				originTask.Status = model.TaskStatus(ti.Status)
-			}
-			if ti.Progress != "" {
-				originTask.Progress = ti.Progress
-			}
-			if ti.Url != "" {
-				if strings.HasPrefix(ti.Url, "data:") {
-				} else {
-					originTask.FailReason = ti.Url
-				}
-			}
-			_ = originTask.Update()
-			var raw map[string]any
-			_ = json.Unmarshal(body, &raw)
-			format := "mp4"
-			if respObj, ok := raw["response"].(map[string]any); ok {
-				if vids, ok := respObj["videos"].([]any); ok && len(vids) > 0 {
-					if v0, ok := vids[0].(map[string]any); ok {
-						if mt, ok := v0["mimeType"].(string); ok && mt != "" {
-							if strings.Contains(mt, "mp4") {
-								format = "mp4"
-							} else {
-								format = mt
-							}
-						}
-					}
-				}
-			}
-			status := "processing"
-			switch originTask.Status {
-			case model.TaskStatusSuccess:
-				status = "succeeded"
-			case model.TaskStatusFailure:
-				status = "failed"
-			case model.TaskStatusQueued, model.TaskStatusSubmitted:
-				status = "queued"
-			}
-			if !strings.HasPrefix(c.Request.RequestURI, "/v1/videos/") {
-				out := map[string]any{
-					"error":    nil,
-					"format":   format,
-					"metadata": nil,
-					"status":   status,
-					"task_id":  originTask.TaskID,
-					"url":      originTask.FailReason,
-				}
-				respBody, _ = json.Marshal(dto.TaskResponse[any]{
-					Code: "success",
-					Data: out,
-				})
-			}
-		}
-	}()
+	isOpenAIVideoAPI := strings.HasPrefix(c.Request.RequestURI, "/v1/videos/")
 
-	if len(respBody) != 0 {
+	// Gemini/Vertex 支持实时查询:用户 fetch 时直接从上游拉取最新状态
+	if realtimeResp := tryRealtimeFetch(originTask, isOpenAIVideoAPI); len(realtimeResp) > 0 {
+		respBody = realtimeResp
 		return
 	}
 
-	if strings.HasPrefix(c.Request.RequestURI, "/v1/videos/") {
+	// OpenAI Video API 格式: 走各 adaptor 的 ConvertToOpenAIVideo
+	if isOpenAIVideoAPI {
 		adaptor := GetTaskAdaptor(originTask.Platform)
 		if adaptor == nil {
 			taskResp = service.TaskErrorWrapperLocal(fmt.Errorf("invalid channel id: %d", originTask.ChannelId), "invalid_channel_id", http.StatusBadRequest)
@@ -486,10 +394,12 @@ func videoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *d
 			respBody = openAIVideoData
 			return
 		}
-		taskResp = service.TaskErrorWrapperLocal(errors.New(fmt.Sprintf("not_implemented:%s", originTask.Platform)), "not_implemented", http.StatusNotImplemented)
+		taskResp = service.TaskErrorWrapperLocal(fmt.Errorf("not_implemented:%s", originTask.Platform), "not_implemented", http.StatusNotImplemented)
 		return
 	}
-	respBody, err = json.Marshal(dto.TaskResponse[any]{
+
+	// 通用 TaskDto 格式
+	respBody, err = common.Marshal(dto.TaskResponse[any]{
 		Code: "success",
 		Data: TaskModel2Dto(originTask),
 	})
@@ -499,16 +409,150 @@ func videoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *d
 	return
 }
 
+// tryRealtimeFetch 尝试从上游实时拉取 Gemini/Vertex 任务状态。
+// 仅当渠道类型为 Gemini 或 Vertex 时触发;其他渠道或出错时返回 nil。
+// 当非 OpenAI Video API 时,还会构建自定义格式的响应体。
+func tryRealtimeFetch(task *model.Task, isOpenAIVideoAPI bool) []byte {
+	channelModel, err := model.GetChannelById(task.ChannelId, true)
+	if err != nil {
+		return nil
+	}
+	if channelModel.Type != constant.ChannelTypeVertexAi && channelModel.Type != constant.ChannelTypeGemini {
+		return nil
+	}
+
+	baseURL := constant.ChannelBaseURLs[channelModel.Type]
+	if channelModel.GetBaseURL() != "" {
+		baseURL = channelModel.GetBaseURL()
+	}
+	proxy := channelModel.GetSetting().Proxy
+	adaptor := GetTaskAdaptor(constant.TaskPlatform(strconv.Itoa(channelModel.Type)))
+	if adaptor == nil {
+		return nil
+	}
+
+	resp, err := adaptor.FetchTask(baseURL, channelModel.Key, map[string]any{
+		"task_id": task.GetUpstreamTaskID(),
+		"action":  task.Action,
+	}, proxy)
+	if err != nil || resp == nil {
+		return nil
+	}
+	defer resp.Body.Close()
+	body, err := io.ReadAll(resp.Body)
+	if err != nil {
+		return nil
+	}
+
+	ti, err := adaptor.ParseTaskResult(body)
+	if err != nil || ti == nil {
+		return nil
+	}
+
+	snap := task.Snapshot()
+
+	// 将上游最新状态更新到 task
+	if ti.Status != "" {
+		task.Status = model.TaskStatus(ti.Status)
+	}
+	if ti.Progress != "" {
+		task.Progress = ti.Progress
+	}
+	if strings.HasPrefix(ti.Url, "data:") {
+		// data: URI — kept in Data, not ResultURL
+	} else if ti.Url != "" {
+		task.PrivateData.ResultURL = ti.Url
+	} else if task.Status == model.TaskStatusSuccess {
+		// No URL from adaptor — construct proxy URL using public task ID
+		task.PrivateData.ResultURL = taskcommon.BuildProxyURL(task.TaskID)
+	}
+
+	if !snap.Equal(task.Snapshot()) {
+		_, _ = task.UpdateWithStatus(snap.Status)
+	}
+
+	// OpenAI Video API 由调用者的 ConvertToOpenAIVideo 分支处理
+	if isOpenAIVideoAPI {
+		return nil
+	}
+
+	// 非 OpenAI Video API: 构建自定义格式响应
+	format := detectVideoFormat(body)
+	out := map[string]any{
+		"error":    nil,
+		"format":   format,
+		"metadata": nil,
+		"status":   mapTaskStatusToSimple(task.Status),
+		"task_id":  task.TaskID,
+		"url":      task.GetResultURL(),
+	}
+	respBody, _ := common.Marshal(dto.TaskResponse[any]{
+		Code: "success",
+		Data: out,
+	})
+	return respBody
+}
+
+// detectVideoFormat 从 Gemini/Vertex 原始响应中探测视频格式
+func detectVideoFormat(rawBody []byte) string {
+	var raw map[string]any
+	if err := common.Unmarshal(rawBody, &raw); err != nil {
+		return "mp4"
+	}
+	respObj, ok := raw["response"].(map[string]any)
+	if !ok {
+		return "mp4"
+	}
+	vids, ok := respObj["videos"].([]any)
+	if !ok || len(vids) == 0 {
+		return "mp4"
+	}
+	v0, ok := vids[0].(map[string]any)
+	if !ok {
+		return "mp4"
+	}
+	mt, ok := v0["mimeType"].(string)
+	if !ok || mt == "" || strings.Contains(mt, "mp4") {
+		return "mp4"
+	}
+	return mt
+}
+
+// mapTaskStatusToSimple 将内部 TaskStatus 映射为简化状态字符串
+func mapTaskStatusToSimple(status model.TaskStatus) string {
+	switch status {
+	case model.TaskStatusSuccess:
+		return "succeeded"
+	case model.TaskStatusFailure:
+		return "failed"
+	case model.TaskStatusQueued, model.TaskStatusSubmitted:
+		return "queued"
+	default:
+		return "processing"
+	}
+}
+
 func TaskModel2Dto(task *model.Task) *dto.TaskDto {
 	return &dto.TaskDto{
+		ID:         task.ID,
+		CreatedAt:  task.CreatedAt,
+		UpdatedAt:  task.UpdatedAt,
 		TaskID:     task.TaskID,
+		Platform:   string(task.Platform),
+		UserId:     task.UserId,
+		Group:      task.Group,
+		ChannelId:  task.ChannelId,
+		Quota:      task.Quota,
 		Action:     task.Action,
 		Status:     string(task.Status),
 		FailReason: task.FailReason,
+		ResultURL:  task.GetResultURL(),
 		SubmitTime: task.SubmitTime,
 		StartTime:  task.StartTime,
 		FinishTime: task.FinishTime,
 		Progress:   task.Progress,
+		Properties: task.Properties,
+		Username:   task.Username,
 		Data:       task.Data,
 	}
 }

+ 2 - 2
router/relay-router.go

@@ -174,8 +174,8 @@ func SetRelayRouter(router *gin.Engine) {
 	relaySunoRouter.Use(middleware.TokenAuth(), middleware.Distribute())
 	{
 		relaySunoRouter.POST("/submit/:action", controller.RelayTask)
-		relaySunoRouter.POST("/fetch", controller.RelayTask)
-		relaySunoRouter.GET("/fetch/:id", controller.RelayTask)
+		relaySunoRouter.POST("/fetch", controller.RelayTaskFetch)
+		relaySunoRouter.GET("/fetch/:id", controller.RelayTaskFetch)
 	}
 
 	relayGeminiRouter := router.Group("/v1beta")

+ 11 - 5
router/video-router.go

@@ -8,19 +8,25 @@ import (
 )
 
 func SetVideoRouter(router *gin.Engine) {
+	// Video proxy: accepts either session auth (dashboard) or token auth (API clients)
+	videoProxyRouter := router.Group("/v1")
+	videoProxyRouter.Use(middleware.TokenOrUserAuth())
+	{
+		videoProxyRouter.GET("/videos/:task_id/content", controller.VideoProxy)
+	}
+
 	videoV1Router := router.Group("/v1")
 	videoV1Router.Use(middleware.TokenAuth(), middleware.Distribute())
 	{
-		videoV1Router.GET("/videos/:task_id/content", controller.VideoProxy)
 		videoV1Router.POST("/video/generations", controller.RelayTask)
-		videoV1Router.GET("/video/generations/:task_id", controller.RelayTask)
+		videoV1Router.GET("/video/generations/:task_id", controller.RelayTaskFetch)
 		videoV1Router.POST("/videos/:video_id/remix", controller.RelayTask)
 	}
 	// openai compatible API video routes
 	// docs: https://platform.openai.com/docs/api-reference/videos/create
 	{
 		videoV1Router.POST("/videos", controller.RelayTask)
-		videoV1Router.GET("/videos/:task_id", controller.RelayTask)
+		videoV1Router.GET("/videos/:task_id", controller.RelayTaskFetch)
 	}
 
 	klingV1Router := router.Group("/kling/v1")
@@ -28,8 +34,8 @@ func SetVideoRouter(router *gin.Engine) {
 	{
 		klingV1Router.POST("/videos/text2video", controller.RelayTask)
 		klingV1Router.POST("/videos/image2video", controller.RelayTask)
-		klingV1Router.GET("/videos/text2video/:task_id", controller.RelayTask)
-		klingV1Router.GET("/videos/image2video/:task_id", controller.RelayTask)
+		klingV1Router.GET("/videos/text2video/:task_id", controller.RelayTaskFetch)
+		klingV1Router.GET("/videos/image2video/:task_id", controller.RelayTaskFetch)
 	}
 
 	// Jimeng official API routes - direct mapping to official API format

+ 5 - 0
service/billing_session.go

@@ -193,6 +193,11 @@ func (s *BillingSession) preConsume(c *gin.Context, quota int) *types.NewAPIErro
 
 // shouldTrust 统一信任额度检查,适用于钱包和订阅。
 func (s *BillingSession) shouldTrust(c *gin.Context) bool {
+	// 异步任务(ForcePreConsume=true)必须预扣全额,不允许信任旁路
+	if s.relayInfo.ForcePreConsume {
+		return false
+	}
+
 	trustQuota := common.GetTrustQuota()
 	if trustQuota <= 0 {
 		return false

+ 13 - 0
service/error.go

@@ -206,3 +206,16 @@ func TaskErrorWrapper(err error, code string, statusCode int) *dto.TaskError {
 
 	return taskError
 }
+
+// TaskErrorFromAPIError 将 PreConsumeBilling 返回的 NewAPIError 转换为 TaskError。
+func TaskErrorFromAPIError(apiErr *types.NewAPIError) *dto.TaskError {
+	if apiErr == nil {
+		return nil
+	}
+	return &dto.TaskError{
+		Code:       string(apiErr.GetErrorCode()),
+		Message:    apiErr.Err.Error(),
+		StatusCode: apiErr.StatusCode,
+		Error:      apiErr.Err,
+	}
+}

+ 1 - 1
service/log_info_generate.go

@@ -204,7 +204,7 @@ func GenerateClaudeOtherInfo(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
 	return info
 }
 
-func GenerateMjOtherInfo(relayInfo *relaycommon.RelayInfo, priceData types.PerCallPriceData) map[string]interface{} {
+func GenerateMjOtherInfo(relayInfo *relaycommon.RelayInfo, priceData types.PriceData) map[string]interface{} {
 	other := make(map[string]interface{})
 	other["model_price"] = priceData.ModelPrice
 	other["group_ratio"] = priceData.GroupRatioInfo.GroupRatio

+ 285 - 0
service/task_billing.go

@@ -0,0 +1,285 @@
+package service
+
+import (
+	"context"
+	"fmt"
+	"strings"
+
+	"github.com/QuantumNous/new-api/common"
+	"github.com/QuantumNous/new-api/constant"
+	"github.com/QuantumNous/new-api/logger"
+	"github.com/QuantumNous/new-api/model"
+	relaycommon "github.com/QuantumNous/new-api/relay/common"
+	"github.com/QuantumNous/new-api/setting/ratio_setting"
+	"github.com/gin-gonic/gin"
+)
+
+// LogTaskConsumption 记录任务消费日志和统计信息(仅记录,不涉及实际扣费)。
+// 实际扣费已由 BillingSession(PreConsumeBilling + SettleBilling)完成。
+func LogTaskConsumption(c *gin.Context, info *relaycommon.RelayInfo) {
+	tokenName := c.GetString("token_name")
+	logContent := fmt.Sprintf("操作 %s", info.Action)
+	// 支持任务仅按次计费
+	if common.StringsContains(constant.TaskPricePatches, info.OriginModelName) {
+		logContent = fmt.Sprintf("%s,按次计费", logContent)
+	} else {
+		if len(info.PriceData.OtherRatios) > 0 {
+			var contents []string
+			for key, ra := range info.PriceData.OtherRatios {
+				if 1.0 != ra {
+					contents = append(contents, fmt.Sprintf("%s: %.2f", key, ra))
+				}
+			}
+			if len(contents) > 0 {
+				logContent = fmt.Sprintf("%s, 计算参数:%s", logContent, strings.Join(contents, ", "))
+			}
+		}
+	}
+	other := make(map[string]interface{})
+	other["request_path"] = c.Request.URL.Path
+	other["model_price"] = info.PriceData.ModelPrice
+	other["group_ratio"] = info.PriceData.GroupRatioInfo.GroupRatio
+	if info.PriceData.GroupRatioInfo.HasSpecialRatio {
+		other["user_group_ratio"] = info.PriceData.GroupRatioInfo.GroupSpecialRatio
+	}
+	if info.IsModelMapped {
+		other["is_model_mapped"] = true
+		other["upstream_model_name"] = info.UpstreamModelName
+	}
+	model.RecordConsumeLog(c, info.UserId, model.RecordConsumeLogParams{
+		ChannelId: info.ChannelId,
+		ModelName: info.OriginModelName,
+		TokenName: tokenName,
+		Quota:     info.PriceData.Quota,
+		Content:   logContent,
+		TokenId:   info.TokenId,
+		Group:     info.UsingGroup,
+		Other:     other,
+	})
+	model.UpdateUserUsedQuotaAndRequestCount(info.UserId, info.PriceData.Quota)
+	model.UpdateChannelUsedQuota(info.ChannelId, info.PriceData.Quota)
+}
+
+// ---------------------------------------------------------------------------
+// 异步任务计费辅助函数
+// ---------------------------------------------------------------------------
+
+// resolveTokenKey 通过 TokenId 运行时获取令牌 Key(用于 Redis 缓存操作)。
+// 如果令牌已被删除或查询失败,返回空字符串。
+func resolveTokenKey(ctx context.Context, tokenId int, taskID string) string {
+	token, err := model.GetTokenById(tokenId)
+	if err != nil {
+		logger.LogWarn(ctx, fmt.Sprintf("获取令牌 key 失败 (tokenId=%d, task=%s): %s", tokenId, taskID, err.Error()))
+		return ""
+	}
+	return token.Key
+}
+
+// taskIsSubscription 判断任务是否通过订阅计费。
+func taskIsSubscription(task *model.Task) bool {
+	return task.PrivateData.BillingSource == BillingSourceSubscription && task.PrivateData.SubscriptionId > 0
+}
+
+// taskAdjustFunding 调整任务的资金来源(钱包或订阅),delta > 0 表示扣费,delta < 0 表示退还。
+func taskAdjustFunding(task *model.Task, delta int) error {
+	if taskIsSubscription(task) {
+		return model.PostConsumeUserSubscriptionDelta(task.PrivateData.SubscriptionId, int64(delta))
+	}
+	if delta > 0 {
+		return model.DecreaseUserQuota(task.UserId, delta)
+	}
+	return model.IncreaseUserQuota(task.UserId, -delta, false)
+}
+
+// taskAdjustTokenQuota 调整任务的令牌额度,delta > 0 表示扣费,delta < 0 表示退还。
+// 需要通过 resolveTokenKey 运行时获取 key(不从 PrivateData 中读取)。
+func taskAdjustTokenQuota(ctx context.Context, task *model.Task, delta int) {
+	if task.PrivateData.TokenId <= 0 || delta == 0 {
+		return
+	}
+	tokenKey := resolveTokenKey(ctx, task.PrivateData.TokenId, task.TaskID)
+	if tokenKey == "" {
+		return
+	}
+	var err error
+	if delta > 0 {
+		err = model.DecreaseTokenQuota(task.PrivateData.TokenId, tokenKey, delta)
+	} else {
+		err = model.IncreaseTokenQuota(task.PrivateData.TokenId, tokenKey, -delta)
+	}
+	if err != nil {
+		logger.LogWarn(ctx, fmt.Sprintf("调整令牌额度失败 (delta=%d, task=%s): %s", delta, task.TaskID, err.Error()))
+	}
+}
+
+// taskBillingOther 从 task 的 BillingContext 构建日志 Other 字段。
+func taskBillingOther(task *model.Task) map[string]interface{} {
+	other := make(map[string]interface{})
+	if bc := task.PrivateData.BillingContext; bc != nil {
+		other["model_price"] = bc.ModelPrice
+		other["group_ratio"] = bc.GroupRatio
+		if len(bc.OtherRatios) > 0 {
+			for k, v := range bc.OtherRatios {
+				other[k] = v
+			}
+		}
+	}
+	props := task.Properties
+	if props.UpstreamModelName != "" && props.UpstreamModelName != props.OriginModelName {
+		other["is_model_mapped"] = true
+		other["upstream_model_name"] = props.UpstreamModelName
+	}
+	return other
+}
+
+// taskModelName 从 BillingContext 或 Properties 中获取模型名称。
+func taskModelName(task *model.Task) string {
+	if bc := task.PrivateData.BillingContext; bc != nil && bc.OriginModelName != "" {
+		return bc.OriginModelName
+	}
+	return task.Properties.OriginModelName
+}
+
+// RefundTaskQuota 统一的任务失败退款逻辑。
+// 当异步任务失败时,将预扣的 quota 退还给用户(支持钱包和订阅),并退还令牌额度。
+func RefundTaskQuota(ctx context.Context, task *model.Task, reason string) {
+	quota := task.Quota
+	if quota == 0 {
+		return
+	}
+
+	// 1. 退还资金来源(钱包或订阅)
+	if err := taskAdjustFunding(task, -quota); err != nil {
+		logger.LogWarn(ctx, fmt.Sprintf("退还资金来源失败 task %s: %s", task.TaskID, err.Error()))
+		return
+	}
+
+	// 2. 退还令牌额度
+	taskAdjustTokenQuota(ctx, task, -quota)
+
+	// 3. 记录日志
+	other := taskBillingOther(task)
+	other["task_id"] = task.TaskID
+	other["reason"] = reason
+	model.RecordTaskBillingLog(model.RecordTaskBillingLogParams{
+		UserId:    task.UserId,
+		LogType:   model.LogTypeRefund,
+		Content:   "",
+		ChannelId: task.ChannelId,
+		ModelName: taskModelName(task),
+		Quota:     quota,
+		TokenId:   task.PrivateData.TokenId,
+		Group:     task.Group,
+		Other:     other,
+	})
+}
+
+// RecalculateTaskQuota 通用的异步差额结算。
+// actualQuota 是任务完成后的实际应扣额度,与预扣额度 (task.Quota) 做差额结算。
+// reason 用于日志记录(例如 "token重算" 或 "adaptor调整")。
+func RecalculateTaskQuota(ctx context.Context, task *model.Task, actualQuota int, reason string) {
+	if actualQuota <= 0 {
+		return
+	}
+	preConsumedQuota := task.Quota
+	quotaDelta := actualQuota - preConsumedQuota
+
+	if quotaDelta == 0 {
+		logger.LogInfo(ctx, fmt.Sprintf("任务 %s 预扣费准确(%s,%s)",
+			task.TaskID, logger.LogQuota(actualQuota), reason))
+		return
+	}
+
+	logger.LogInfo(ctx, fmt.Sprintf("任务 %s 差额结算:delta=%s(实际:%s,预扣:%s,%s)",
+		task.TaskID,
+		logger.LogQuota(quotaDelta),
+		logger.LogQuota(actualQuota),
+		logger.LogQuota(preConsumedQuota),
+		reason,
+	))
+
+	// 调整资金来源
+	if err := taskAdjustFunding(task, quotaDelta); err != nil {
+		logger.LogError(ctx, fmt.Sprintf("差额结算资金调整失败 task %s: %s", task.TaskID, err.Error()))
+		return
+	}
+
+	// 调整令牌额度
+	taskAdjustTokenQuota(ctx, task, quotaDelta)
+
+	task.Quota = actualQuota
+
+	var logType int
+	var logQuota int
+	if quotaDelta > 0 {
+		logType = model.LogTypeConsume
+		logQuota = quotaDelta
+		model.UpdateUserUsedQuotaAndRequestCount(task.UserId, quotaDelta)
+		model.UpdateChannelUsedQuota(task.ChannelId, quotaDelta)
+	} else {
+		logType = model.LogTypeRefund
+		logQuota = -quotaDelta
+	}
+	other := taskBillingOther(task)
+	other["task_id"] = task.TaskID
+	other["reason"] = reason
+	other["pre_consumed_quota"] = preConsumedQuota
+	other["actual_quota"] = actualQuota
+	model.RecordTaskBillingLog(model.RecordTaskBillingLogParams{
+		UserId:    task.UserId,
+		LogType:   logType,
+		Content:   "",
+		ChannelId: task.ChannelId,
+		ModelName: taskModelName(task),
+		Quota:     logQuota,
+		TokenId:   task.PrivateData.TokenId,
+		Group:     task.Group,
+		Other:     other,
+	})
+}
+
+// RecalculateTaskQuotaByTokens 根据实际 token 消耗重新计费(异步差额结算)。
+// 当任务成功且返回了 totalTokens 时,根据模型倍率和分组倍率重新计算实际扣费额度,
+// 与预扣费的差额进行补扣或退还。支持钱包和订阅计费来源。
+func RecalculateTaskQuotaByTokens(ctx context.Context, task *model.Task, totalTokens int) {
+	if totalTokens <= 0 {
+		return
+	}
+
+	modelName := taskModelName(task)
+
+	// 获取模型价格和倍率
+	modelRatio, hasRatioSetting, _ := ratio_setting.GetModelRatio(modelName)
+	// 只有配置了倍率(非固定价格)时才按 token 重新计费
+	if !hasRatioSetting || modelRatio <= 0 {
+		return
+	}
+
+	// 获取用户和组的倍率信息
+	group := task.Group
+	if group == "" {
+		user, err := model.GetUserById(task.UserId, false)
+		if err == nil {
+			group = user.Group
+		}
+	}
+	if group == "" {
+		return
+	}
+
+	groupRatio := ratio_setting.GetGroupRatio(group)
+	userGroupRatio, hasUserGroupRatio := ratio_setting.GetGroupGroupRatio(group, group)
+
+	var finalGroupRatio float64
+	if hasUserGroupRatio {
+		finalGroupRatio = userGroupRatio
+	} else {
+		finalGroupRatio = groupRatio
+	}
+
+	// 计算实际应扣费额度: totalTokens * modelRatio * groupRatio
+	actualQuota := int(float64(totalTokens) * modelRatio * finalGroupRatio)
+
+	reason := fmt.Sprintf("token重算:tokens=%d, modelRatio=%.2f, groupRatio=%.2f", totalTokens, modelRatio, finalGroupRatio)
+	RecalculateTaskQuota(ctx, task, actualQuota, reason)
+}

+ 712 - 0
service/task_billing_test.go

@@ -0,0 +1,712 @@
+package service
+
+import (
+	"context"
+	"encoding/json"
+	"net/http"
+	"os"
+	"testing"
+	"time"
+
+	"github.com/QuantumNous/new-api/common"
+	"github.com/QuantumNous/new-api/model"
+	relaycommon "github.com/QuantumNous/new-api/relay/common"
+	"github.com/glebarez/sqlite"
+	"github.com/stretchr/testify/assert"
+	"github.com/stretchr/testify/require"
+	"gorm.io/gorm"
+)
+
+func TestMain(m *testing.M) {
+	db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
+	if err != nil {
+		panic("failed to open test db: " + err.Error())
+	}
+	sqlDB, err := db.DB()
+	if err != nil {
+		panic("failed to get sql.DB: " + err.Error())
+	}
+	sqlDB.SetMaxOpenConns(1)
+
+	model.DB = db
+	model.LOG_DB = db
+
+	common.UsingSQLite = true
+	common.RedisEnabled = false
+	common.BatchUpdateEnabled = false
+	common.LogConsumeEnabled = true
+
+	if err := db.AutoMigrate(
+		&model.Task{},
+		&model.User{},
+		&model.Token{},
+		&model.Log{},
+		&model.Channel{},
+		&model.UserSubscription{},
+	); err != nil {
+		panic("failed to migrate: " + err.Error())
+	}
+
+	os.Exit(m.Run())
+}
+
+// ---------------------------------------------------------------------------
+// Seed helpers
+// ---------------------------------------------------------------------------
+
+func truncate(t *testing.T) {
+	t.Helper()
+	t.Cleanup(func() {
+		model.DB.Exec("DELETE FROM tasks")
+		model.DB.Exec("DELETE FROM users")
+		model.DB.Exec("DELETE FROM tokens")
+		model.DB.Exec("DELETE FROM logs")
+		model.DB.Exec("DELETE FROM channels")
+		model.DB.Exec("DELETE FROM user_subscriptions")
+	})
+}
+
+func seedUser(t *testing.T, id int, quota int) {
+	t.Helper()
+	user := &model.User{Id: id, Username: "test_user", Quota: quota, Status: common.UserStatusEnabled}
+	require.NoError(t, model.DB.Create(user).Error)
+}
+
+func seedToken(t *testing.T, id int, userId int, key string, remainQuota int) {
+	t.Helper()
+	token := &model.Token{
+		Id:          id,
+		UserId:      userId,
+		Key:         key,
+		Name:        "test_token",
+		Status:      common.TokenStatusEnabled,
+		RemainQuota: remainQuota,
+		UsedQuota:   0,
+	}
+	require.NoError(t, model.DB.Create(token).Error)
+}
+
+func seedSubscription(t *testing.T, id int, userId int, amountTotal int64, amountUsed int64) {
+	t.Helper()
+	sub := &model.UserSubscription{
+		Id:          id,
+		UserId:      userId,
+		AmountTotal: amountTotal,
+		AmountUsed:  amountUsed,
+		Status:      "active",
+		StartTime:   time.Now().Unix(),
+		EndTime:     time.Now().Add(30 * 24 * time.Hour).Unix(),
+	}
+	require.NoError(t, model.DB.Create(sub).Error)
+}
+
+func seedChannel(t *testing.T, id int) {
+	t.Helper()
+	ch := &model.Channel{Id: id, Name: "test_channel", Key: "sk-test", Status: common.ChannelStatusEnabled}
+	require.NoError(t, model.DB.Create(ch).Error)
+}
+
+func makeTask(userId, channelId, quota, tokenId int, billingSource string, subscriptionId int) *model.Task {
+	return &model.Task{
+		TaskID:    "task_" + time.Now().Format("150405.000"),
+		UserId:    userId,
+		ChannelId: channelId,
+		Quota:     quota,
+		Status:    model.TaskStatus(model.TaskStatusInProgress),
+		Group:     "default",
+		Data:      json.RawMessage(`{}`),
+		CreatedAt: time.Now().Unix(),
+		UpdatedAt: time.Now().Unix(),
+		Properties: model.Properties{
+			OriginModelName: "test-model",
+		},
+		PrivateData: model.TaskPrivateData{
+			BillingSource:  billingSource,
+			SubscriptionId: subscriptionId,
+			TokenId:        tokenId,
+			BillingContext: &model.TaskBillingContext{
+				ModelPrice: 0.02,
+				GroupRatio: 1.0,
+				OriginModelName: "test-model",
+			},
+		},
+	}
+}
+
+// ---------------------------------------------------------------------------
+// Read-back helpers
+// ---------------------------------------------------------------------------
+
+func getUserQuota(t *testing.T, id int) int {
+	t.Helper()
+	var user model.User
+	require.NoError(t, model.DB.Select("quota").Where("id = ?", id).First(&user).Error)
+	return user.Quota
+}
+
+func getTokenRemainQuota(t *testing.T, id int) int {
+	t.Helper()
+	var token model.Token
+	require.NoError(t, model.DB.Select("remain_quota").Where("id = ?", id).First(&token).Error)
+	return token.RemainQuota
+}
+
+func getTokenUsedQuota(t *testing.T, id int) int {
+	t.Helper()
+	var token model.Token
+	require.NoError(t, model.DB.Select("used_quota").Where("id = ?", id).First(&token).Error)
+	return token.UsedQuota
+}
+
+func getSubscriptionUsed(t *testing.T, id int) int64 {
+	t.Helper()
+	var sub model.UserSubscription
+	require.NoError(t, model.DB.Select("amount_used").Where("id = ?", id).First(&sub).Error)
+	return sub.AmountUsed
+}
+
+func getLastLog(t *testing.T) *model.Log {
+	t.Helper()
+	var log model.Log
+	err := model.LOG_DB.Order("id desc").First(&log).Error
+	if err != nil {
+		return nil
+	}
+	return &log
+}
+
+func countLogs(t *testing.T) int64 {
+	t.Helper()
+	var count int64
+	model.LOG_DB.Model(&model.Log{}).Count(&count)
+	return count
+}
+
+// ===========================================================================
+// RefundTaskQuota tests
+// ===========================================================================
+
+func TestRefundTaskQuota_Wallet(t *testing.T) {
+	truncate(t)
+	ctx := context.Background()
+
+	const userID, tokenID, channelID = 1, 1, 1
+	const initQuota, preConsumed = 10000, 3000
+	const tokenRemain = 5000
+
+	seedUser(t, userID, initQuota)
+	seedToken(t, tokenID, userID, "sk-test-key", tokenRemain)
+	seedChannel(t, channelID)
+
+	task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceWallet, 0)
+
+	RefundTaskQuota(ctx, task, "task failed: upstream error")
+
+	// User quota should increase by preConsumed
+	assert.Equal(t, initQuota+preConsumed, getUserQuota(t, userID))
+
+	// Token remain_quota should increase, used_quota should decrease
+	assert.Equal(t, tokenRemain+preConsumed, getTokenRemainQuota(t, tokenID))
+	assert.Equal(t, -preConsumed, getTokenUsedQuota(t, tokenID))
+
+	// A refund log should be created
+	log := getLastLog(t)
+	require.NotNil(t, log)
+	assert.Equal(t, model.LogTypeRefund, log.Type)
+	assert.Equal(t, preConsumed, log.Quota)
+	assert.Equal(t, "test-model", log.ModelName)
+}
+
+func TestRefundTaskQuota_Subscription(t *testing.T) {
+	truncate(t)
+	ctx := context.Background()
+
+	const userID, tokenID, channelID, subID = 2, 2, 2, 1
+	const preConsumed = 2000
+	const subTotal, subUsed int64 = 100000, 50000
+	const tokenRemain = 8000
+
+	seedUser(t, userID, 0)
+	seedToken(t, tokenID, userID, "sk-sub-key", tokenRemain)
+	seedChannel(t, channelID)
+	seedSubscription(t, subID, userID, subTotal, subUsed)
+
+	task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceSubscription, subID)
+
+	RefundTaskQuota(ctx, task, "subscription task failed")
+
+	// Subscription used should decrease by preConsumed
+	assert.Equal(t, subUsed-int64(preConsumed), getSubscriptionUsed(t, subID))
+
+	// Token should also be refunded
+	assert.Equal(t, tokenRemain+preConsumed, getTokenRemainQuota(t, tokenID))
+
+	log := getLastLog(t)
+	require.NotNil(t, log)
+	assert.Equal(t, model.LogTypeRefund, log.Type)
+}
+
+func TestRefundTaskQuota_ZeroQuota(t *testing.T) {
+	truncate(t)
+	ctx := context.Background()
+
+	const userID = 3
+	seedUser(t, userID, 5000)
+
+	task := makeTask(userID, 0, 0, 0, BillingSourceWallet, 0)
+
+	RefundTaskQuota(ctx, task, "zero quota task")
+
+	// No change to user quota
+	assert.Equal(t, 5000, getUserQuota(t, userID))
+
+	// No log created
+	assert.Equal(t, int64(0), countLogs(t))
+}
+
+func TestRefundTaskQuota_NoToken(t *testing.T) {
+	truncate(t)
+	ctx := context.Background()
+
+	const userID, channelID = 4, 4
+	const initQuota, preConsumed = 10000, 1500
+
+	seedUser(t, userID, initQuota)
+	seedChannel(t, channelID)
+
+	task := makeTask(userID, channelID, preConsumed, 0, BillingSourceWallet, 0) // TokenId=0
+
+	RefundTaskQuota(ctx, task, "no token task failed")
+
+	// User quota refunded
+	assert.Equal(t, initQuota+preConsumed, getUserQuota(t, userID))
+
+	// Log created
+	log := getLastLog(t)
+	require.NotNil(t, log)
+	assert.Equal(t, model.LogTypeRefund, log.Type)
+}
+
+// ===========================================================================
+// RecalculateTaskQuota tests
+// ===========================================================================
+
+func TestRecalculate_PositiveDelta(t *testing.T) {
+	truncate(t)
+	ctx := context.Background()
+
+	const userID, tokenID, channelID = 10, 10, 10
+	const initQuota, preConsumed = 10000, 2000
+	const actualQuota = 3000 // under-charged by 1000
+	const tokenRemain = 5000
+
+	seedUser(t, userID, initQuota)
+	seedToken(t, tokenID, userID, "sk-recalc-pos", tokenRemain)
+	seedChannel(t, channelID)
+
+	task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceWallet, 0)
+
+	RecalculateTaskQuota(ctx, task, actualQuota, "adaptor adjustment")
+
+	// User quota should decrease by the delta (1000 additional charge)
+	assert.Equal(t, initQuota-(actualQuota-preConsumed), getUserQuota(t, userID))
+
+	// Token should also be charged the delta
+	assert.Equal(t, tokenRemain-(actualQuota-preConsumed), getTokenRemainQuota(t, tokenID))
+
+	// task.Quota should be updated to actualQuota
+	assert.Equal(t, actualQuota, task.Quota)
+
+	// Log type should be Consume (additional charge)
+	log := getLastLog(t)
+	require.NotNil(t, log)
+	assert.Equal(t, model.LogTypeConsume, log.Type)
+	assert.Equal(t, actualQuota-preConsumed, log.Quota)
+}
+
+func TestRecalculate_NegativeDelta(t *testing.T) {
+	truncate(t)
+	ctx := context.Background()
+
+	const userID, tokenID, channelID = 11, 11, 11
+	const initQuota, preConsumed = 10000, 5000
+	const actualQuota = 3000 // over-charged by 2000
+	const tokenRemain = 5000
+
+	seedUser(t, userID, initQuota)
+	seedToken(t, tokenID, userID, "sk-recalc-neg", tokenRemain)
+	seedChannel(t, channelID)
+
+	task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceWallet, 0)
+
+	RecalculateTaskQuota(ctx, task, actualQuota, "adaptor adjustment")
+
+	// User quota should increase by abs(delta) = 2000 (refund overpayment)
+	assert.Equal(t, initQuota+(preConsumed-actualQuota), getUserQuota(t, userID))
+
+	// Token should be refunded the difference
+	assert.Equal(t, tokenRemain+(preConsumed-actualQuota), getTokenRemainQuota(t, tokenID))
+
+	// task.Quota updated
+	assert.Equal(t, actualQuota, task.Quota)
+
+	// Log type should be Refund
+	log := getLastLog(t)
+	require.NotNil(t, log)
+	assert.Equal(t, model.LogTypeRefund, log.Type)
+	assert.Equal(t, preConsumed-actualQuota, log.Quota)
+}
+
+func TestRecalculate_ZeroDelta(t *testing.T) {
+	truncate(t)
+	ctx := context.Background()
+
+	const userID = 12
+	const initQuota, preConsumed = 10000, 3000
+
+	seedUser(t, userID, initQuota)
+
+	task := makeTask(userID, 0, preConsumed, 0, BillingSourceWallet, 0)
+
+	RecalculateTaskQuota(ctx, task, preConsumed, "exact match")
+
+	// No change to user quota
+	assert.Equal(t, initQuota, getUserQuota(t, userID))
+
+	// No log created (delta is zero)
+	assert.Equal(t, int64(0), countLogs(t))
+}
+
+func TestRecalculate_ActualQuotaZero(t *testing.T) {
+	truncate(t)
+	ctx := context.Background()
+
+	const userID = 13
+	const initQuota = 10000
+
+	seedUser(t, userID, initQuota)
+
+	task := makeTask(userID, 0, 5000, 0, BillingSourceWallet, 0)
+
+	RecalculateTaskQuota(ctx, task, 0, "zero actual")
+
+	// No change (early return)
+	assert.Equal(t, initQuota, getUserQuota(t, userID))
+	assert.Equal(t, int64(0), countLogs(t))
+}
+
+func TestRecalculate_Subscription_NegativeDelta(t *testing.T) {
+	truncate(t)
+	ctx := context.Background()
+
+	const userID, tokenID, channelID, subID = 14, 14, 14, 2
+	const preConsumed = 5000
+	const actualQuota = 2000 // over-charged by 3000
+	const subTotal, subUsed int64 = 100000, 50000
+	const tokenRemain = 8000
+
+	seedUser(t, userID, 0)
+	seedToken(t, tokenID, userID, "sk-sub-recalc", tokenRemain)
+	seedChannel(t, channelID)
+	seedSubscription(t, subID, userID, subTotal, subUsed)
+
+	task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceSubscription, subID)
+
+	RecalculateTaskQuota(ctx, task, actualQuota, "subscription over-charge")
+
+	// Subscription used should decrease by delta (refund 3000)
+	assert.Equal(t, subUsed-int64(preConsumed-actualQuota), getSubscriptionUsed(t, subID))
+
+	// Token refunded
+	assert.Equal(t, tokenRemain+(preConsumed-actualQuota), getTokenRemainQuota(t, tokenID))
+
+	assert.Equal(t, actualQuota, task.Quota)
+
+	log := getLastLog(t)
+	require.NotNil(t, log)
+	assert.Equal(t, model.LogTypeRefund, log.Type)
+}
+
+// ===========================================================================
+// CAS + Billing integration tests
+// Simulates the flow in updateVideoSingleTask (service/task_polling.go)
+// ===========================================================================
+
+// simulatePollBilling reproduces the CAS + billing logic from updateVideoSingleTask.
+// It takes a persisted task (already in DB), applies the new status, and performs
+// the conditional update + billing exactly as the polling loop does.
+func simulatePollBilling(ctx context.Context, task *model.Task, newStatus model.TaskStatus, actualQuota int) {
+	snap := task.Snapshot()
+
+	shouldRefund := false
+	shouldSettle := false
+	quota := task.Quota
+
+	task.Status = newStatus
+	switch string(newStatus) {
+	case model.TaskStatusSuccess:
+		task.Progress = "100%"
+		task.FinishTime = 9999
+		shouldSettle = true
+	case model.TaskStatusFailure:
+		task.Progress = "100%"
+		task.FinishTime = 9999
+		task.FailReason = "upstream error"
+		if quota != 0 {
+			shouldRefund = true
+		}
+	default:
+		task.Progress = "50%"
+	}
+
+	isDone := task.Status == model.TaskStatus(model.TaskStatusSuccess) || task.Status == model.TaskStatus(model.TaskStatusFailure)
+	if isDone && snap.Status != task.Status {
+		won, err := task.UpdateWithStatus(snap.Status)
+		if err != nil {
+			shouldRefund = false
+			shouldSettle = false
+		} else if !won {
+			shouldRefund = false
+			shouldSettle = false
+		}
+	} else if !snap.Equal(task.Snapshot()) {
+		_, _ = task.UpdateWithStatus(snap.Status)
+	}
+
+	if shouldSettle && actualQuota > 0 {
+		RecalculateTaskQuota(ctx, task, actualQuota, "test settle")
+	}
+	if shouldRefund {
+		RefundTaskQuota(ctx, task, task.FailReason)
+	}
+}
+
+func TestCASGuardedRefund_Win(t *testing.T) {
+	truncate(t)
+	ctx := context.Background()
+
+	const userID, tokenID, channelID = 20, 20, 20
+	const initQuota, preConsumed = 10000, 4000
+	const tokenRemain = 6000
+
+	seedUser(t, userID, initQuota)
+	seedToken(t, tokenID, userID, "sk-cas-refund-win", tokenRemain)
+	seedChannel(t, channelID)
+
+	task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceWallet, 0)
+	task.Status = model.TaskStatus(model.TaskStatusInProgress)
+	require.NoError(t, model.DB.Create(task).Error)
+
+	simulatePollBilling(ctx, task, model.TaskStatus(model.TaskStatusFailure), 0)
+
+	// CAS wins: task in DB should now be FAILURE
+	var reloaded model.Task
+	require.NoError(t, model.DB.First(&reloaded, task.ID).Error)
+	assert.EqualValues(t, model.TaskStatusFailure, reloaded.Status)
+
+	// Refund should have happened
+	assert.Equal(t, initQuota+preConsumed, getUserQuota(t, userID))
+	assert.Equal(t, tokenRemain+preConsumed, getTokenRemainQuota(t, tokenID))
+
+	log := getLastLog(t)
+	require.NotNil(t, log)
+	assert.Equal(t, model.LogTypeRefund, log.Type)
+}
+
+func TestCASGuardedRefund_Lose(t *testing.T) {
+	truncate(t)
+	ctx := context.Background()
+
+	const userID, tokenID, channelID = 21, 21, 21
+	const initQuota, preConsumed = 10000, 4000
+	const tokenRemain = 6000
+
+	seedUser(t, userID, initQuota)
+	seedToken(t, tokenID, userID, "sk-cas-refund-lose", tokenRemain)
+	seedChannel(t, channelID)
+
+	// Create task with IN_PROGRESS in DB
+	task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceWallet, 0)
+	task.Status = model.TaskStatus(model.TaskStatusInProgress)
+	require.NoError(t, model.DB.Create(task).Error)
+
+	// Simulate another process already transitioning to FAILURE
+	model.DB.Model(&model.Task{}).Where("id = ?", task.ID).Update("status", model.TaskStatusFailure)
+
+	// Our process still has the old in-memory state (IN_PROGRESS) and tries to transition
+	// task.Status is still IN_PROGRESS in the snapshot
+	simulatePollBilling(ctx, task, model.TaskStatus(model.TaskStatusFailure), 0)
+
+	// CAS lost: user quota should NOT change (no double refund)
+	assert.Equal(t, initQuota, getUserQuota(t, userID))
+	assert.Equal(t, tokenRemain, getTokenRemainQuota(t, tokenID))
+
+	// No billing log should be created
+	assert.Equal(t, int64(0), countLogs(t))
+}
+
+func TestCASGuardedSettle_Win(t *testing.T) {
+	truncate(t)
+	ctx := context.Background()
+
+	const userID, tokenID, channelID = 22, 22, 22
+	const initQuota, preConsumed = 10000, 5000
+	const actualQuota = 3000 // over-charged, should get partial refund
+	const tokenRemain = 8000
+
+	seedUser(t, userID, initQuota)
+	seedToken(t, tokenID, userID, "sk-cas-settle-win", tokenRemain)
+	seedChannel(t, channelID)
+
+	task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceWallet, 0)
+	task.Status = model.TaskStatus(model.TaskStatusInProgress)
+	require.NoError(t, model.DB.Create(task).Error)
+
+	simulatePollBilling(ctx, task, model.TaskStatus(model.TaskStatusSuccess), actualQuota)
+
+	// CAS wins: task should be SUCCESS
+	var reloaded model.Task
+	require.NoError(t, model.DB.First(&reloaded, task.ID).Error)
+	assert.EqualValues(t, model.TaskStatusSuccess, reloaded.Status)
+
+	// Settlement should refund the over-charge (5000 - 3000 = 2000 back to user)
+	assert.Equal(t, initQuota+(preConsumed-actualQuota), getUserQuota(t, userID))
+	assert.Equal(t, tokenRemain+(preConsumed-actualQuota), getTokenRemainQuota(t, tokenID))
+
+	// task.Quota should be updated to actualQuota
+	assert.Equal(t, actualQuota, task.Quota)
+}
+
+func TestNonTerminalUpdate_NoBilling(t *testing.T) {
+	truncate(t)
+	ctx := context.Background()
+
+	const userID, channelID = 23, 23
+	const initQuota, preConsumed = 10000, 3000
+
+	seedUser(t, userID, initQuota)
+	seedChannel(t, channelID)
+
+	task := makeTask(userID, channelID, preConsumed, 0, BillingSourceWallet, 0)
+	task.Status = model.TaskStatus(model.TaskStatusInProgress)
+	task.Progress = "20%"
+	require.NoError(t, model.DB.Create(task).Error)
+
+	// Simulate a non-terminal poll update (still IN_PROGRESS, progress changed)
+	simulatePollBilling(ctx, task, model.TaskStatus(model.TaskStatusInProgress), 0)
+
+	// User quota should NOT change
+	assert.Equal(t, initQuota, getUserQuota(t, userID))
+
+	// No billing log
+	assert.Equal(t, int64(0), countLogs(t))
+
+	// Task progress should be updated in DB
+	var reloaded model.Task
+	require.NoError(t, model.DB.First(&reloaded, task.ID).Error)
+	assert.Equal(t, "50%", reloaded.Progress)
+}
+
+// ===========================================================================
+// Mock adaptor for settleTaskBillingOnComplete tests
+// ===========================================================================
+
+type mockAdaptor struct {
+	adjustReturn int
+}
+
+func (m *mockAdaptor) Init(_ *relaycommon.RelayInfo)                                            {}
+func (m *mockAdaptor) FetchTask(string, string, map[string]any, string) (*http.Response, error)  { return nil, nil }
+func (m *mockAdaptor) ParseTaskResult([]byte) (*relaycommon.TaskInfo, error)                     { return nil, nil }
+func (m *mockAdaptor) AdjustBillingOnComplete(_ *model.Task, _ *relaycommon.TaskInfo) int {
+	return m.adjustReturn
+}
+
+// ===========================================================================
+// PerCallBilling tests — settleTaskBillingOnComplete
+// ===========================================================================
+
+func TestSettle_PerCallBilling_SkipsAdaptorAdjust(t *testing.T) {
+	truncate(t)
+	ctx := context.Background()
+
+	const userID, tokenID, channelID = 30, 30, 30
+	const initQuota, preConsumed = 10000, 5000
+	const tokenRemain = 8000
+
+	seedUser(t, userID, initQuota)
+	seedToken(t, tokenID, userID, "sk-percall-adaptor", tokenRemain)
+	seedChannel(t, channelID)
+
+	task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceWallet, 0)
+	task.PrivateData.BillingContext.PerCallBilling = true
+
+	adaptor := &mockAdaptor{adjustReturn: 2000}
+	taskResult := &relaycommon.TaskInfo{Status: model.TaskStatusSuccess}
+
+	settleTaskBillingOnComplete(ctx, adaptor, task, taskResult)
+
+	// Per-call: no adjustment despite adaptor returning 2000
+	assert.Equal(t, initQuota, getUserQuota(t, userID))
+	assert.Equal(t, tokenRemain, getTokenRemainQuota(t, tokenID))
+	assert.Equal(t, preConsumed, task.Quota)
+	assert.Equal(t, int64(0), countLogs(t))
+}
+
+func TestSettle_PerCallBilling_SkipsTotalTokens(t *testing.T) {
+	truncate(t)
+	ctx := context.Background()
+
+	const userID, tokenID, channelID = 31, 31, 31
+	const initQuota, preConsumed = 10000, 4000
+	const tokenRemain = 7000
+
+	seedUser(t, userID, initQuota)
+	seedToken(t, tokenID, userID, "sk-percall-tokens", tokenRemain)
+	seedChannel(t, channelID)
+
+	task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceWallet, 0)
+	task.PrivateData.BillingContext.PerCallBilling = true
+
+	adaptor := &mockAdaptor{adjustReturn: 0}
+	taskResult := &relaycommon.TaskInfo{Status: model.TaskStatusSuccess, TotalTokens: 9999}
+
+	settleTaskBillingOnComplete(ctx, adaptor, task, taskResult)
+
+	// Per-call: no recalculation by tokens
+	assert.Equal(t, initQuota, getUserQuota(t, userID))
+	assert.Equal(t, tokenRemain, getTokenRemainQuota(t, tokenID))
+	assert.Equal(t, preConsumed, task.Quota)
+	assert.Equal(t, int64(0), countLogs(t))
+}
+
+func TestSettle_NonPerCall_AdaptorAdjustWorks(t *testing.T) {
+	truncate(t)
+	ctx := context.Background()
+
+	const userID, tokenID, channelID = 32, 32, 32
+	const initQuota, preConsumed = 10000, 5000
+	const adaptorQuota = 3000
+	const tokenRemain = 8000
+
+	seedUser(t, userID, initQuota)
+	seedToken(t, tokenID, userID, "sk-nonpercall-adj", tokenRemain)
+	seedChannel(t, channelID)
+
+	task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceWallet, 0)
+	// PerCallBilling defaults to false
+
+	adaptor := &mockAdaptor{adjustReturn: adaptorQuota}
+	taskResult := &relaycommon.TaskInfo{Status: model.TaskStatusSuccess}
+
+	settleTaskBillingOnComplete(ctx, adaptor, task, taskResult)
+
+	// Non-per-call: adaptor adjustment applies (refund 2000)
+	assert.Equal(t, initQuota+(preConsumed-adaptorQuota), getUserQuota(t, userID))
+	assert.Equal(t, tokenRemain+(preConsumed-adaptorQuota), getTokenRemainQuota(t, tokenID))
+	assert.Equal(t, adaptorQuota, task.Quota)
+
+	log := getLastLog(t)
+	require.NotNil(t, log)
+	assert.Equal(t, model.LogTypeRefund, log.Type)
+}

+ 486 - 0
service/task_polling.go

@@ -0,0 +1,486 @@
+package service
+
+import (
+	"context"
+	"errors"
+	"fmt"
+	"io"
+	"net/http"
+	"sort"
+	"strings"
+	"time"
+
+	"github.com/QuantumNous/new-api/common"
+	"github.com/QuantumNous/new-api/constant"
+	"github.com/QuantumNous/new-api/dto"
+	"github.com/QuantumNous/new-api/logger"
+	"github.com/QuantumNous/new-api/model"
+	"github.com/QuantumNous/new-api/relay/channel/task/taskcommon"
+	relaycommon "github.com/QuantumNous/new-api/relay/common"
+
+	"github.com/samber/lo"
+)
+
+// TaskPollingAdaptor 定义轮询所需的最小适配器接口,避免 service -> relay 的循环依赖
+type TaskPollingAdaptor interface {
+	Init(info *relaycommon.RelayInfo)
+	FetchTask(baseURL string, key string, body map[string]any, proxy string) (*http.Response, error)
+	ParseTaskResult(body []byte) (*relaycommon.TaskInfo, error)
+	// AdjustBillingOnComplete 在任务到达终态(成功/失败)时由轮询循环调用。
+	// 返回正数触发差额结算(补扣/退还),返回 0 保持预扣费金额不变。
+	AdjustBillingOnComplete(task *model.Task, taskResult *relaycommon.TaskInfo) int
+}
+
+// GetTaskAdaptorFunc 由 main 包注入,用于获取指定平台的任务适配器。
+// 打破 service -> relay -> relay/channel -> service 的循环依赖。
+var GetTaskAdaptorFunc func(platform constant.TaskPlatform) TaskPollingAdaptor
+
+// TaskPollingLoop 主轮询循环,每 15 秒检查一次未完成的任务
+func TaskPollingLoop() {
+	for {
+		time.Sleep(time.Duration(15) * time.Second)
+		common.SysLog("任务进度轮询开始")
+		ctx := context.TODO()
+		allTasks := model.GetAllUnFinishSyncTasks(constant.TaskQueryLimit)
+		platformTask := make(map[constant.TaskPlatform][]*model.Task)
+		for _, t := range allTasks {
+			platformTask[t.Platform] = append(platformTask[t.Platform], t)
+		}
+		for platform, tasks := range platformTask {
+			if len(tasks) == 0 {
+				continue
+			}
+			taskChannelM := make(map[int][]string)
+			taskM := make(map[string]*model.Task)
+			nullTaskIds := make([]int64, 0)
+			for _, task := range tasks {
+				upstreamID := task.GetUpstreamTaskID()
+				if upstreamID == "" {
+					// 统计失败的未完成任务
+					nullTaskIds = append(nullTaskIds, task.ID)
+					continue
+				}
+				taskM[upstreamID] = task
+				taskChannelM[task.ChannelId] = append(taskChannelM[task.ChannelId], upstreamID)
+			}
+			if len(nullTaskIds) > 0 {
+				err := model.TaskBulkUpdateByID(nullTaskIds, map[string]any{
+					"status":   "FAILURE",
+					"progress": "100%",
+				})
+				if err != nil {
+					logger.LogError(ctx, fmt.Sprintf("Fix null task_id task error: %v", err))
+				} else {
+					logger.LogInfo(ctx, fmt.Sprintf("Fix null task_id task success: %v", nullTaskIds))
+				}
+			}
+			if len(taskChannelM) == 0 {
+				continue
+			}
+
+			DispatchPlatformUpdate(platform, taskChannelM, taskM)
+		}
+		common.SysLog("任务进度轮询完成")
+	}
+}
+
+// DispatchPlatformUpdate 按平台分发轮询更新
+func DispatchPlatformUpdate(platform constant.TaskPlatform, taskChannelM map[int][]string, taskM map[string]*model.Task) {
+	switch platform {
+	case constant.TaskPlatformMidjourney:
+		// MJ 轮询由其自身处理,这里预留入口
+	case constant.TaskPlatformSuno:
+		_ = UpdateSunoTasks(context.Background(), taskChannelM, taskM)
+	default:
+		if err := UpdateVideoTasks(context.Background(), platform, taskChannelM, taskM); err != nil {
+			common.SysLog(fmt.Sprintf("UpdateVideoTasks fail: %s", err))
+		}
+	}
+}
+
+// UpdateSunoTasks 按渠道更新所有 Suno 任务
+func UpdateSunoTasks(ctx context.Context, taskChannelM map[int][]string, taskM map[string]*model.Task) error {
+	for channelId, taskIds := range taskChannelM {
+		err := updateSunoTasks(ctx, channelId, taskIds, taskM)
+		if err != nil {
+			logger.LogError(ctx, fmt.Sprintf("渠道 #%d 更新异步任务失败: %s", channelId, err.Error()))
+		}
+	}
+	return nil
+}
+
+func updateSunoTasks(ctx context.Context, channelId int, taskIds []string, taskM map[string]*model.Task) error {
+	logger.LogInfo(ctx, fmt.Sprintf("渠道 #%d 未完成的任务有: %d", channelId, len(taskIds)))
+	if len(taskIds) == 0 {
+		return nil
+	}
+	ch, err := model.CacheGetChannel(channelId)
+	if err != nil {
+		common.SysLog(fmt.Sprintf("CacheGetChannel: %v", err))
+		// Collect DB primary key IDs for bulk update (taskIds are upstream IDs, not task_id column values)
+		var failedIDs []int64
+		for _, upstreamID := range taskIds {
+			if t, ok := taskM[upstreamID]; ok {
+				failedIDs = append(failedIDs, t.ID)
+			}
+		}
+		err = model.TaskBulkUpdateByID(failedIDs, map[string]any{
+			"fail_reason": fmt.Sprintf("获取渠道信息失败,请联系管理员,渠道ID:%d", channelId),
+			"status":      "FAILURE",
+			"progress":    "100%",
+		})
+		if err != nil {
+			common.SysLog(fmt.Sprintf("UpdateSunoTask error: %v", err))
+		}
+		return err
+	}
+	adaptor := GetTaskAdaptorFunc(constant.TaskPlatformSuno)
+	if adaptor == nil {
+		return errors.New("adaptor not found")
+	}
+	proxy := ch.GetSetting().Proxy
+	resp, err := adaptor.FetchTask(*ch.BaseURL, ch.Key, map[string]any{
+		"ids": taskIds,
+	}, proxy)
+	if err != nil {
+		common.SysLog(fmt.Sprintf("Get Task Do req error: %v", err))
+		return err
+	}
+	if resp.StatusCode != http.StatusOK {
+		logger.LogError(ctx, fmt.Sprintf("Get Task status code: %d", resp.StatusCode))
+		return fmt.Errorf("Get Task status code: %d", resp.StatusCode)
+	}
+	defer resp.Body.Close()
+	responseBody, err := io.ReadAll(resp.Body)
+	if err != nil {
+		common.SysLog(fmt.Sprintf("Get Task parse body error: %v", err))
+		return err
+	}
+	var responseItems dto.TaskResponse[[]dto.SunoDataResponse]
+	err = common.Unmarshal(responseBody, &responseItems)
+	if err != nil {
+		logger.LogError(ctx, fmt.Sprintf("Get Task parse body error2: %v, body: %s", err, string(responseBody)))
+		return err
+	}
+	if !responseItems.IsSuccess() {
+		common.SysLog(fmt.Sprintf("渠道 #%d 未完成的任务有: %d, 成功获取到任务数: %s", channelId, len(taskIds), string(responseBody)))
+		return err
+	}
+
+	for _, responseItem := range responseItems.Data {
+		task := taskM[responseItem.TaskID]
+		if !taskNeedsUpdate(task, responseItem) {
+			continue
+		}
+
+		task.Status = lo.If(model.TaskStatus(responseItem.Status) != "", model.TaskStatus(responseItem.Status)).Else(task.Status)
+		task.FailReason = lo.If(responseItem.FailReason != "", responseItem.FailReason).Else(task.FailReason)
+		task.SubmitTime = lo.If(responseItem.SubmitTime != 0, responseItem.SubmitTime).Else(task.SubmitTime)
+		task.StartTime = lo.If(responseItem.StartTime != 0, responseItem.StartTime).Else(task.StartTime)
+		task.FinishTime = lo.If(responseItem.FinishTime != 0, responseItem.FinishTime).Else(task.FinishTime)
+		if responseItem.FailReason != "" || task.Status == model.TaskStatusFailure {
+			logger.LogInfo(ctx, task.TaskID+" 构建失败,"+task.FailReason)
+			task.Progress = "100%"
+			RefundTaskQuota(ctx, task, task.FailReason)
+		}
+		if responseItem.Status == model.TaskStatusSuccess {
+			task.Progress = "100%"
+		}
+		task.Data = responseItem.Data
+
+		err = task.Update()
+		if err != nil {
+			common.SysLog("UpdateSunoTask task error: " + err.Error())
+		}
+	}
+	return nil
+}
+
+// taskNeedsUpdate 检查 Suno 任务是否需要更新
+func taskNeedsUpdate(oldTask *model.Task, newTask dto.SunoDataResponse) bool {
+	if oldTask.SubmitTime != newTask.SubmitTime {
+		return true
+	}
+	if oldTask.StartTime != newTask.StartTime {
+		return true
+	}
+	if oldTask.FinishTime != newTask.FinishTime {
+		return true
+	}
+	if string(oldTask.Status) != newTask.Status {
+		return true
+	}
+	if oldTask.FailReason != newTask.FailReason {
+		return true
+	}
+
+	if (oldTask.Status == model.TaskStatusFailure || oldTask.Status == model.TaskStatusSuccess) && oldTask.Progress != "100%" {
+		return true
+	}
+
+	oldData, _ := common.Marshal(oldTask.Data)
+	newData, _ := common.Marshal(newTask.Data)
+
+	sort.Slice(oldData, func(i, j int) bool {
+		return oldData[i] < oldData[j]
+	})
+	sort.Slice(newData, func(i, j int) bool {
+		return newData[i] < newData[j]
+	})
+
+	if string(oldData) != string(newData) {
+		return true
+	}
+	return false
+}
+
+// UpdateVideoTasks 按渠道更新所有视频任务
+func UpdateVideoTasks(ctx context.Context, platform constant.TaskPlatform, taskChannelM map[int][]string, taskM map[string]*model.Task) error {
+	for channelId, taskIds := range taskChannelM {
+		if err := updateVideoTasks(ctx, platform, channelId, taskIds, taskM); err != nil {
+			logger.LogError(ctx, fmt.Sprintf("Channel #%d failed to update video async tasks: %s", channelId, err.Error()))
+		}
+	}
+	return nil
+}
+
+func updateVideoTasks(ctx context.Context, platform constant.TaskPlatform, channelId int, taskIds []string, taskM map[string]*model.Task) error {
+	logger.LogInfo(ctx, fmt.Sprintf("Channel #%d pending video tasks: %d", channelId, len(taskIds)))
+	if len(taskIds) == 0 {
+		return nil
+	}
+	cacheGetChannel, err := model.CacheGetChannel(channelId)
+	if err != nil {
+		// Collect DB primary key IDs for bulk update (taskIds are upstream IDs, not task_id column values)
+		var failedIDs []int64
+		for _, upstreamID := range taskIds {
+			if t, ok := taskM[upstreamID]; ok {
+				failedIDs = append(failedIDs, t.ID)
+			}
+		}
+		errUpdate := model.TaskBulkUpdateByID(failedIDs, map[string]any{
+			"fail_reason": fmt.Sprintf("Failed to get channel info, channel ID: %d", channelId),
+			"status":      "FAILURE",
+			"progress":    "100%",
+		})
+		if errUpdate != nil {
+			common.SysLog(fmt.Sprintf("UpdateVideoTask error: %v", errUpdate))
+		}
+		return fmt.Errorf("CacheGetChannel failed: %w", err)
+	}
+	adaptor := GetTaskAdaptorFunc(platform)
+	if adaptor == nil {
+		return fmt.Errorf("video adaptor not found")
+	}
+	info := &relaycommon.RelayInfo{}
+	info.ChannelMeta = &relaycommon.ChannelMeta{
+		ChannelBaseUrl: cacheGetChannel.GetBaseURL(),
+	}
+	info.ApiKey = cacheGetChannel.Key
+	adaptor.Init(info)
+	for _, taskId := range taskIds {
+		if err := updateVideoSingleTask(ctx, adaptor, cacheGetChannel, taskId, taskM); err != nil {
+			logger.LogError(ctx, fmt.Sprintf("Failed to update video task %s: %s", taskId, err.Error()))
+		}
+	}
+	return nil
+}
+
+func updateVideoSingleTask(ctx context.Context, adaptor TaskPollingAdaptor, ch *model.Channel, taskId string, taskM map[string]*model.Task) error {
+	baseURL := constant.ChannelBaseURLs[ch.Type]
+	if ch.GetBaseURL() != "" {
+		baseURL = ch.GetBaseURL()
+	}
+	proxy := ch.GetSetting().Proxy
+
+	task := taskM[taskId]
+	if task == nil {
+		logger.LogError(ctx, fmt.Sprintf("Task %s not found in taskM", taskId))
+		return fmt.Errorf("task %s not found", taskId)
+	}
+	key := ch.Key
+
+	privateData := task.PrivateData
+	if privateData.Key != "" {
+		key = privateData.Key
+	}
+	resp, err := adaptor.FetchTask(baseURL, key, map[string]any{
+		"task_id": task.GetUpstreamTaskID(),
+		"action":  task.Action,
+	}, proxy)
+	if err != nil {
+		return fmt.Errorf("fetchTask failed for task %s: %w", taskId, err)
+	}
+	defer resp.Body.Close()
+	responseBody, err := io.ReadAll(resp.Body)
+	if err != nil {
+		return fmt.Errorf("readAll failed for task %s: %w", taskId, err)
+	}
+
+	logger.LogDebug(ctx, fmt.Sprintf("updateVideoSingleTask response: %s", string(responseBody)))
+
+	snap := task.Snapshot()
+
+	taskResult := &relaycommon.TaskInfo{}
+	// try parse as New API response format
+	var responseItems dto.TaskResponse[model.Task]
+	if err = common.Unmarshal(responseBody, &responseItems); err == nil && responseItems.IsSuccess() {
+		logger.LogDebug(ctx, fmt.Sprintf("updateVideoSingleTask parsed as new api response format: %+v", responseItems))
+		t := responseItems.Data
+		taskResult.TaskID = t.TaskID
+		taskResult.Status = string(t.Status)
+		taskResult.Url = t.GetResultURL()
+		taskResult.Progress = t.Progress
+		taskResult.Reason = t.FailReason
+		task.Data = t.Data
+	} else if taskResult, err = adaptor.ParseTaskResult(responseBody); err != nil {
+		return fmt.Errorf("parseTaskResult failed for task %s: %w", taskId, err)
+	} else {
+		task.Data = redactVideoResponseBody(responseBody)
+	}
+
+	logger.LogDebug(ctx, fmt.Sprintf("updateVideoSingleTask taskResult: %+v", taskResult))
+
+	now := time.Now().Unix()
+	if taskResult.Status == "" {
+		taskResult = relaycommon.FailTaskInfo("upstream returned empty status")
+	}
+
+	shouldRefund := false
+	shouldSettle := false
+	quota := task.Quota
+
+	task.Status = model.TaskStatus(taskResult.Status)
+	switch taskResult.Status {
+	case model.TaskStatusSubmitted:
+		task.Progress = taskcommon.ProgressSubmitted
+	case model.TaskStatusQueued:
+		task.Progress = taskcommon.ProgressQueued
+	case model.TaskStatusInProgress:
+		task.Progress = taskcommon.ProgressInProgress
+		if task.StartTime == 0 {
+			task.StartTime = now
+		}
+	case model.TaskStatusSuccess:
+		task.Progress = taskcommon.ProgressComplete
+		if task.FinishTime == 0 {
+			task.FinishTime = now
+		}
+		if strings.HasPrefix(taskResult.Url, "data:") {
+			// data: URI (e.g. Vertex base64 encoded video) — keep in Data, not in ResultURL
+		} else if taskResult.Url != "" {
+			// Direct upstream URL (e.g. Kling, Ali, Doubao, etc.)
+			task.PrivateData.ResultURL = taskResult.Url
+		} else {
+			// No URL from adaptor — construct proxy URL using public task ID
+			task.PrivateData.ResultURL = taskcommon.BuildProxyURL(task.TaskID)
+		}
+		shouldSettle = true
+	case model.TaskStatusFailure:
+		logger.LogJson(ctx, fmt.Sprintf("Task %s failed", taskId), task)
+		task.Status = model.TaskStatusFailure
+		task.Progress = taskcommon.ProgressComplete
+		if task.FinishTime == 0 {
+			task.FinishTime = now
+		}
+		task.FailReason = taskResult.Reason
+		logger.LogInfo(ctx, fmt.Sprintf("Task %s failed: %s", task.TaskID, task.FailReason))
+		taskResult.Progress = taskcommon.ProgressComplete
+		if quota != 0 {
+			shouldRefund = true
+		}
+	default:
+		return fmt.Errorf("unknown task status %s for task %s", taskResult.Status, task.TaskID)
+	}
+	if taskResult.Progress != "" {
+		task.Progress = taskResult.Progress
+	}
+
+	isDone := task.Status == model.TaskStatusSuccess || task.Status == model.TaskStatusFailure
+	if isDone && snap.Status != task.Status {
+		won, err := task.UpdateWithStatus(snap.Status)
+		if err != nil {
+			logger.LogError(ctx, fmt.Sprintf("UpdateWithStatus failed for task %s: %s", task.TaskID, err.Error()))
+			shouldRefund = false
+			shouldSettle = false
+		} else if !won {
+			logger.LogWarn(ctx, fmt.Sprintf("Task %s already transitioned by another process, skip billing", task.TaskID))
+			shouldRefund = false
+			shouldSettle = false
+		}
+	} else if !snap.Equal(task.Snapshot()) {
+		if _, err := task.UpdateWithStatus(snap.Status); err != nil {
+			logger.LogError(ctx, fmt.Sprintf("Failed to update task %s: %s", task.TaskID, err.Error()))
+		}
+	} else {
+		// No changes, skip update
+		logger.LogDebug(ctx, fmt.Sprintf("No update needed for task %s", task.TaskID))
+	}
+
+	if shouldSettle {
+		settleTaskBillingOnComplete(ctx, adaptor, task, taskResult)
+	}
+	if shouldRefund {
+		RefundTaskQuota(ctx, task, task.FailReason)
+	}
+
+	return nil
+}
+
+func redactVideoResponseBody(body []byte) []byte {
+	var m map[string]any
+	if err := common.Unmarshal(body, &m); err != nil {
+		return body
+	}
+	resp, _ := m["response"].(map[string]any)
+	if resp != nil {
+		delete(resp, "bytesBase64Encoded")
+		if v, ok := resp["video"].(string); ok {
+			resp["video"] = truncateBase64(v)
+		}
+		if vs, ok := resp["videos"].([]any); ok {
+			for i := range vs {
+				if vm, ok := vs[i].(map[string]any); ok {
+					delete(vm, "bytesBase64Encoded")
+				}
+			}
+		}
+	}
+	b, err := common.Marshal(m)
+	if err != nil {
+		return body
+	}
+	return b
+}
+
+func truncateBase64(s string) string {
+	const maxKeep = 256
+	if len(s) <= maxKeep {
+		return s
+	}
+	return s[:maxKeep] + "..."
+}
+
+// settleTaskBillingOnComplete 任务完成时的统一计费调整。
+// 优先级:1. adaptor.AdjustBillingOnComplete 返回正数 → 使用 adaptor 计算的额度
+//
+//  2. taskResult.TotalTokens > 0 → 按 token 重算
+//  3. 都不满足 → 保持预扣额度不变
+func settleTaskBillingOnComplete(ctx context.Context, adaptor TaskPollingAdaptor, task *model.Task, taskResult *relaycommon.TaskInfo) {
+	// 0. 按次计费的任务不做差额结算
+	if bc := task.PrivateData.BillingContext; bc != nil && bc.PerCallBilling {
+		logger.LogInfo(ctx, fmt.Sprintf("任务 %s 按次计费,跳过差额结算", task.TaskID))
+		return
+	}
+	// 1. 优先让 adaptor 决定最终额度
+	if actualQuota := adaptor.AdjustBillingOnComplete(task, taskResult); actualQuota > 0 {
+		RecalculateTaskQuota(ctx, task, actualQuota, "adaptor计费调整")
+		return
+	}
+	// 2. 回退到 token 重算
+	if taskResult.TotalTokens > 0 {
+		RecalculateTaskQuotaByTokens(ctx, task, taskResult.TotalTokens)
+		return
+	}
+	// 3. 无调整,保持预扣额度
+}

+ 2 - 7
types/price_data.go

@@ -22,7 +22,8 @@ type PriceData struct {
 	AudioCompletionRatio float64
 	OtherRatios          map[string]float64
 	UsePrice             bool
-	QuotaToPreConsume    int // 预消耗额度
+	Quota                int // 按次计费的最终额度(MJ / Task)
+	QuotaToPreConsume    int // 按量计费的预消耗额度
 	GroupRatioInfo       GroupRatioInfo
 }
 
@@ -36,12 +37,6 @@ func (p *PriceData) AddOtherRatio(key string, ratio float64) {
 	p.OtherRatios[key] = ratio
 }
 
-type PerCallPriceData struct {
-	ModelPrice     float64
-	Quota          int
-	GroupRatioInfo GroupRatioInfo
-}
-
 func (p *PriceData) ToSetting() string {
 	return fmt.Sprintf("ModelPrice: %f, ModelRatio: %f, CompletionRatio: %f, CacheRatio: %f, GroupRatio: %f, UsePrice: %t, CacheCreationRatio: %f, CacheCreation5mRatio: %f, CacheCreation1hRatio: %f, QuotaToPreConsume: %d, ImageRatio: %f, AudioRatio: %f, AudioCompletionRatio: %f", p.ModelPrice, p.ModelRatio, p.CompletionRatio, p.CacheRatio, p.GroupRatioInfo.GroupRatio, p.UsePrice, p.CacheCreationRatio, p.CacheCreation5mRatio, p.CacheCreation1hRatio, p.QuotaToPreConsume, p.ImageRatio, p.AudioRatio, p.AudioCompletionRatio)
 }

+ 18 - 27
web/src/components/table/task-logs/TaskLogsColumnDefs.jsx

@@ -84,8 +84,8 @@ function renderDuration(submit_time, finishTime) {
 
   // 返回带有样式的颜色标签
   return (
-    <Tag color={color} shape='circle' prefixIcon={<Clock size={14} />}>
-      {durationSec} 
+    <Tag color={color} shape='circle'>
+      {durationSec} s
     </Tag>
   );
 }
@@ -149,7 +149,7 @@ const renderPlatform = (platform, t) => {
   );
   if (option) {
     return (
-      <Tag color={option.color} shape='circle' prefixIcon={<Video size={14} />}>
+      <Tag color={option.color} shape='circle'>
         {option.label}
       </Tag>
     );
@@ -157,13 +157,13 @@ const renderPlatform = (platform, t) => {
   switch (platform) {
     case 'suno':
       return (
-        <Tag color='green' shape='circle' prefixIcon={<Music size={14} />}>
+        <Tag color='green' shape='circle'>
           Suno
         </Tag>
       );
     default:
       return (
-        <Tag color='white' shape='circle' prefixIcon={<HelpCircle size={14} />}>
+        <Tag color='white' shape='circle'>
           {t('未知')}
         </Tag>
       );
@@ -240,7 +240,6 @@ export const getTaskLogsColumns = ({
   openContentModal,
   isAdminUser,
   openVideoModal,
-  showUserInfoFunc,
 }) => {
   return [
     {
@@ -278,7 +277,6 @@ export const getTaskLogsColumns = ({
               color={colors[parseInt(text) % colors.length]}
               size='large'
               shape='circle'
-              prefixIcon={<Hash size={14} />}
               onClick={() => {
                 copyText(text);
               }}
@@ -294,7 +292,7 @@ export const getTaskLogsColumns = ({
     {
       key: COLUMN_KEYS.USERNAME,
       title: t('用户'),
-      dataIndex: 'user_id',
+      dataIndex: 'username',
       render: (userId, record, index) => {
         if (!isAdminUser) {
           return <></>;
@@ -302,22 +300,14 @@ export const getTaskLogsColumns = ({
         const displayText = String(record.username || userId || '?');
         return (
           <Space>
-            <Tooltip content={displayText}>
-              <Avatar
-                size='extra-small'
-                color={stringToColor(displayText)}
-                style={{ cursor: 'pointer' }}
-                onClick={() => showUserInfoFunc && showUserInfoFunc(userId)}
-              >
-                {displayText.slice(0, 1)}
-              </Avatar>
-            </Tooltip>
-            <Typography.Text
-              ellipsis={{ showTooltip: true }}
-              style={{ cursor: 'pointer', color: 'var(--semi-color-primary)' }}
-              onClick={() => showUserInfoFunc && showUserInfoFunc(userId)}
+            <Avatar
+              size='extra-small'
+              color={stringToColor(displayText)}
             >
-              {userId}
+              {displayText.slice(0, 1)}
+            </Avatar>
+            <Typography.Text>
+              {displayText}
             </Typography.Text>
           </Space>
         );
@@ -396,7 +386,7 @@ export const getTaskLogsColumns = ({
       dataIndex: 'fail_reason',
       fixed: 'right',
       render: (text, record, index) => {
-        // 仅当为视频生成任务且成功,且 fail_reason 是 URL 时显示可点击链接
+        // 视频预览:优先使用 result_url,兼容旧数据 fail_reason 中的 URL
         const isVideoTask =
           record.action === TASK_ACTION_GENERATE ||
           record.action === TASK_ACTION_TEXT_GENERATE ||
@@ -404,14 +394,15 @@ export const getTaskLogsColumns = ({
           record.action === TASK_ACTION_REFERENCE_GENERATE ||
           record.action === TASK_ACTION_REMIX_GENERATE;
         const isSuccess = record.status === 'SUCCESS';
-        const isUrl = typeof text === 'string' && /^https?:\/\//.test(text);
-        if (isSuccess && isVideoTask && isUrl) {
+        const resultUrl = record.result_url;
+        const hasResultUrl = typeof resultUrl === 'string' && /^https?:\/\//.test(resultUrl);
+        if (isSuccess && isVideoTask && hasResultUrl) {
           return (
             <a
               href='#'
               onClick={(e) => {
                 e.preventDefault();
-                openVideoModal(text);
+                openVideoModal(resultUrl);
               }}
             >
               {t('点击预览视频')}

+ 0 - 2
web/src/components/table/task-logs/index.jsx

@@ -25,7 +25,6 @@ import TaskLogsActions from './TaskLogsActions';
 import TaskLogsFilters from './TaskLogsFilters';
 import ColumnSelectorModal from './modals/ColumnSelectorModal';
 import ContentModal from './modals/ContentModal';
-import UserInfoModal from '../usage-logs/modals/UserInfoModal';
 import { useTaskLogsData } from '../../../hooks/task-logs/useTaskLogsData';
 import { useIsMobile } from '../../../hooks/common/useIsMobile';
 import { createCardProPagination } from '../../../helpers/utils';
@@ -46,7 +45,6 @@ const TaskLogsPage = () => {
         modalContent={taskLogsData.videoUrl}
         isVideo={true}
       />
-      <UserInfoModal {...taskLogsData} />
 
       <Layout>
         <CardPro

+ 0 - 2
web/src/components/table/task-logs/modals/ContentModal.jsx

@@ -144,8 +144,6 @@ const ContentModal = ({
             maxHeight: '100%',
             objectFit: 'contain',
           }}
-          autoPlay
-          crossOrigin='anonymous'
           onError={handleVideoError}
           onLoadedData={handleVideoLoaded}
           onLoadStart={() => setIsLoading(true)}

+ 23 - 7
web/src/components/table/usage-logs/UsageLogsColumnDefs.jsx

@@ -133,6 +133,12 @@ function renderType(type, t) {
           {t('错误')}
         </Tag>
       );
+    case 6:
+      return (
+        <Tag color='teal' shape='circle'>
+          {t('退款')}
+        </Tag>
+      );
     default:
       return (
         <Tag color='grey' shape='circle'>
@@ -368,7 +374,7 @@ export const getLogsColumns = ({
         }
 
         return isAdminUser &&
-          (record.type === 0 || record.type === 2 || record.type === 5) ? (
+          (record.type === 0 || record.type === 2 || record.type === 5 || record.type === 6) ? (
           <Space>
             <span style={{ position: 'relative', display: 'inline-block' }}>
               <Tooltip content={record.channel_name || t('未知渠道')}>
@@ -459,7 +465,7 @@ export const getLogsColumns = ({
       title: t('令牌'),
       dataIndex: 'token_name',
       render: (text, record, index) => {
-        return record.type === 0 || record.type === 2 || record.type === 5 ? (
+        return record.type === 0 || record.type === 2 || record.type === 5 || record.type === 6 ? (
           <div>
             <Tag
               color='grey'
@@ -482,7 +488,7 @@ export const getLogsColumns = ({
       title: t('分组'),
       dataIndex: 'group',
       render: (text, record, index) => {
-        if (record.type === 0 || record.type === 2 || record.type === 5) {
+        if (record.type === 0 || record.type === 2 || record.type === 5 || record.type === 6) {
           if (record.group) {
             return <>{renderGroup(record.group)}</>;
           } else {
@@ -522,7 +528,7 @@ export const getLogsColumns = ({
       title: t('模型'),
       dataIndex: 'model_name',
       render: (text, record, index) => {
-        return record.type === 0 || record.type === 2 || record.type === 5 ? (
+        return record.type === 0 || record.type === 2 || record.type === 5 || record.type === 6 ? (
           <>{renderModelName(record, copyText, t)}</>
         ) : (
           <></>
@@ -589,7 +595,7 @@ export const getLogsColumns = ({
           cacheText = `${t('缓存写')} ${formatTokenCount(cacheSummary.cacheWriteTokens)}`;
         }
 
-        return record.type === 0 || record.type === 2 || record.type === 5 ? (
+        return record.type === 0 || record.type === 2 || record.type === 5 || record.type === 6 ? (
           <div
             style={{
               display: 'inline-flex',
@@ -623,7 +629,7 @@ export const getLogsColumns = ({
       dataIndex: 'completion_tokens',
       render: (text, record, index) => {
         return parseInt(text) > 0 &&
-          (record.type === 0 || record.type === 2 || record.type === 5) ? (
+          (record.type === 0 || record.type === 2 || record.type === 5 || record.type === 6) ? (
           <>{<span> {text} </span>}</>
         ) : (
           <></>
@@ -635,7 +641,7 @@ export const getLogsColumns = ({
       title: t('花费'),
       dataIndex: 'quota',
       render: (text, record, index) => {
-        if (!(record.type === 0 || record.type === 2 || record.type === 5)) {
+        if (!(record.type === 0 || record.type === 2 || record.type === 5 || record.type === 6)) {
           return <></>;
         }
         const other = getLogOther(record.other);
@@ -722,6 +728,16 @@ export const getLogsColumns = ({
       fixed: 'right',
       render: (text, record, index) => {
         let other = getLogOther(record.other);
+        if (record.type === 6) {
+          return (
+            <Typography.Paragraph
+              ellipsis={{ rows: 2 }}
+              style={{ maxWidth: 240 }}
+            >
+              {t('异步任务退款')}
+            </Typography.Paragraph>
+          );
+        }
         if (other == null || record.type !== 2) {
           return (
             <Typography.Paragraph

+ 1 - 0
web/src/components/table/usage-logs/UsageLogsFilters.jsx

@@ -148,6 +148,7 @@ const LogsFilters = ({
               <Form.Select.Option value='3'>{t('管理')}</Form.Select.Option>
               <Form.Select.Option value='4'>{t('系统')}</Form.Select.Option>
               <Form.Select.Option value='5'>{t('错误')}</Form.Select.Option>
+              <Form.Select.Option value='6'>{t('退款')}</Form.Select.Option>
             </Form.Select>
           </div>
 

+ 21 - 3
web/src/hooks/usage-logs/useUsageLogsData.jsx

@@ -344,7 +344,7 @@ export const useLogsData = () => {
       let other = getLogOther(logs[i].other);
       let expandDataLocal = [];
 
-      if (isAdminUser && (logs[i].type === 0 || logs[i].type === 2)) {
+      if (isAdminUser && (logs[i].type === 0 || logs[i].type === 2 || logs[i].type === 6)) {
         expandDataLocal.push({
           key: t('渠道信息'),
           value: `${logs[i].channel} - ${logs[i].channel_name || '[未知]'}`,
@@ -535,6 +535,24 @@ export const useLogsData = () => {
           });
         }
       }
+      if (logs[i].type === 6) {
+        if (other?.task_id) {
+          expandDataLocal.push({
+            key: t('任务ID'),
+            value: other.task_id,
+          });
+        }
+        if (other?.reason) {
+          expandDataLocal.push({
+            key: t('失败原因'),
+            value: (
+              <div style={{ maxWidth: 600, whiteSpace: 'normal', wordBreak: 'break-word', lineHeight: 1.6 }}>
+                {other.reason}
+              </div>
+            ),
+          });
+        }
+      }
       if (other?.request_path) {
         expandDataLocal.push({
           key: t('请求路径'),
@@ -590,13 +608,13 @@ export const useLogsData = () => {
           ),
         });
       }
-      if (isAdminUser) {
+      if (isAdminUser && logs[i].type !== 6) {
         expandDataLocal.push({
           key: t('请求转换'),
           value: requestConversionDisplayValue(other?.request_conversion),
         });
       }
-      if (isAdminUser) {
+      if (isAdminUser && logs[i].type !== 6) {
         let localCountMode = '';
         if (other?.admin_info?.local_count_tokens) {
           localCountMode = t('本地计费');

+ 5 - 43
web/src/i18n/locales/en.json

@@ -302,7 +302,6 @@
     "价格重新计算中...": "Recalculating price...",
     "价格预估": "Price Estimate",
     "任务 ID": "Task ID",
-    "任务ID": "Task ID",
     "任务日志": "Task Logs",
     "任务状态": "Status",
     "任务记录": "Task Records",
@@ -544,7 +543,6 @@
     "创建": "Create",
     "创建令牌默认选择auto分组,初始令牌也将设为auto(否则留空,为用户默认分组)": "Create token with auto group by default, initial token will also be set to auto (otherwise leave blank for user default group)",
     "创建失败": "Creation failed",
-    "创建成功": "Creation successful",
     "创建或选择密钥时,将 Project 设置为 io.cloud": "When creating or selecting a key, set Project to io.cloud",
     "创建新用户账户": "Create new user account",
     "创建新的令牌": "Create New Token",
@@ -787,7 +785,6 @@
     "天": "day",
     "天前": "days ago",
     "失败": "Failed",
-    "失败原因": "Failure reason",
     "失败时自动禁用通道": "Automatically disable channel on failure",
     "失败重试次数": "Failed retry times",
     "奖励说明": "Reward description",
@@ -1336,7 +1333,6 @@
     "更新失败,请检查输入信息": "Update failed, please check the input information",
     "更新容器配置": "Update Container Configuration",
     "更新容器配置可能会导致容器重启,请确保在合适的时间进行此操作。": "Updating container configuration may cause the container to restart, please ensure you perform this operation at an appropriate time.",
-    "更新成功": "Update successful",
     "更新所有已启用通道余额": "Update balance for all enabled channels",
     "更新支付设置": "Update payment settings",
     "更新时间": "Update time",
@@ -1767,7 +1763,6 @@
     "确认清理不活跃的磁盘缓存?": "Confirm cleanup of inactive disk cache?",
     "确认禁用": "Confirm disable",
     "确认补单": "Confirm Order Completion",
-    "确认解绑": "Confirm Unbind",
     "确认解绑 Passkey": "Confirm Unbind Passkey",
     "确认设置并完成初始化": "Confirm settings and complete initialization",
     "确认重置 Passkey": "Confirm Passkey Reset",
@@ -1945,7 +1940,6 @@
     "自动分组auto,从第一个开始选择": "Auto grouping auto, select from the first one",
     "自动刷新": "Auto Refresh",
     "自动刷新中": "Auto refreshing",
-    "自动检测": "Auto Detect",
     "自动模式": "Auto Mode",
     "自动测试所有通道间隔时间": "Auto test interval for all channels",
     "自动禁用": "Auto disabled",
@@ -2343,46 +2337,9 @@
     "输入验证码完成设置": "Enter verification code to complete setup",
     "输出": "Output",
     "输出 {{completion}} tokens / 1M tokens * {{symbol}}{{compPrice}}) * {{ratioType}} {{ratio}}": "Output {{completion}} tokens / 1M tokens * {{symbol}}{{compPrice}} * {{ratioType}} {{ratio}}",
-    "磁盘缓存设置(磁盘换内存)": "Disk Cache Settings (Disk Swap Memory)",
-    "启用磁盘缓存后,大请求体将临时存储到磁盘而非内存,可显著降低内存占用,适用于处理包含大量图片/文件的请求。建议在 SSD 环境下使用。": "When enabled, large request bodies are temporarily stored on disk instead of memory, significantly reducing memory usage. Suitable for requests with large images/files. SSD recommended.",
-    "启用磁盘缓存": "Enable Disk Cache",
-    "将大请求体临时存储到磁盘": "Store large request bodies temporarily on disk",
-    "磁盘缓存阈值 (MB)": "Disk Cache Threshold (MB)",
-    "请求体超过此大小时使用磁盘缓存": "Use disk cache when request body exceeds this size",
-    "磁盘缓存最大总量 (MB)": "Max Disk Cache Size (MB)",
-    "可用空间: {{free}} / 总空间: {{total}}": "Free: {{free}} / Total: {{total}}",
-    "磁盘缓存占用的最大空间": "Maximum space occupied by disk cache",
-    "留空使用系统临时目录": "Leave empty to use system temp directory",
-    "例如 /var/cache/new-api": "e.g. /var/cache/new-api",
-    "性能监控": "Performance Monitor",
-    "刷新统计": "Refresh Stats",
-    "重置统计": "Reset Stats",
-    "执行 GC": "Run GC",
-    "请求体磁盘缓存": "Request Body Disk Cache",
-    "活跃文件": "Active Files",
-    "磁盘命中": "Disk Hits",
-    "请求体内存缓存": "Request Body Memory Cache",
-    "当前缓存大小": "Current Cache Size",
-    "活跃缓存数": "Active Cache Count",
-    "内存命中": "Memory Hits",
-    "缓存目录磁盘空间": "Cache Directory Disk Space",
-    "磁盘可用空间小于缓存最大总量设置": "Disk free space is less than max cache size setting",
-    "已分配内存": "Allocated Memory",
-    "总分配内存": "Total Allocated Memory",
-    "系统内存": "System Memory",
-    "GC 次数": "GC Count",
-    "Goroutine 数": "Goroutine Count",
-    "目录文件数": "Directory File Count",
-    "目录总大小": "Directory Total Size",
-    "磁盘缓存已清理": "Disk cache cleared",
-    "清理失败": "Cleanup failed",
-    "统计已重置": "Statistics reset",
-    "重置失败": "Reset failed",
-    "GC 已执行": "GC executed",
     "GC 执行失败": "GC execution failed",
     "缓存目录": "Cache Directory",
     "可用": "Available",
-    "输出价格": "Output Price",
     "输出价格:{{symbol}}{{price}} * {{completionRatio}} = {{symbol}}{{total}} / 1M tokens (补全倍率: {{completionRatio}})": "Output price: {{symbol}}{{price}} * {{completionRatio}} = {{symbol}}{{total}} / 1M tokens (Completion ratio: {{completionRatio}})",
     "输出倍率 {{completionRatio}}": "Output ratio {{completionRatio}}",
     "边栏设置": "Sidebar Settings",
@@ -2545,6 +2502,11 @@
     "销毁容器": "Destroy Container",
     "销毁容器失败": "Failed to destroy container",
     "错误": "errors",
+    "退款": "Refund",
+    "错误详情": "Error Details",
+    "异步任务退款": "Async Task Refund",
+    "任务ID": "Task ID",
+    "失败原因": "Failure Reason",
     "键为分组名称,值为另一个 JSON 对象,键为分组名称,值为该分组的用户的特殊分组倍率,例如:{\"vip\": {\"default\": 0.5, \"test\": 1}},表示 vip 分组的用户在使用default分组的令牌时倍率为0.5,使用test分组时倍率为1": "The key is the group name, and the value is another JSON object. The key is the group name, and the value is the special group ratio for users in that group. For example: {\"vip\": {\"default\": 0.5, \"test\": 1}} means that users in the vip group have a ratio of 0.5 when using tokens from the default group, and a ratio of 1 when using tokens from the test group",
     "键为原状态码,值为要复写的状态码,仅影响本地判断": "The key is the original status code, and the value is the status code to override, only affects local judgment",
     "键为用户分组名称,值为操作映射对象。内层键以\"+:\"开头表示添加指定分组(键值为分组名称,值为描述),以\"-:\"开头表示移除指定分组(键值为分组名称),不带前缀的键直接添加该分组。例如:{\"vip\": {\"+:premium\": \"高级分组\", \"special\": \"特殊分组\", \"-:default\": \"默认分组\"}},表示 vip 分组的用户可以使用 premium 和 special 分组,同时移除 default 分组的访问权限": "Keys are user group names and values are operation mappings. Inner keys prefixed with \"+:\" add the specified group (key is the group name, value is the description); keys prefixed with \"-:\" remove the specified group; keys without a prefix add that group directly. Example: {\"vip\": {\"+:premium\": \"Advanced group\", \"special\": \"Special group\", \"-:default\": \"Default group\"}} means vip users can access the premium and special groups while removing access to the default group.",

+ 5 - 2
web/src/i18n/locales/fr.json

@@ -304,7 +304,6 @@
     "价格重新计算中...": "Recalculating price...",
     "价格预估": "Price Estimate",
     "任务 ID": "ID de la tâche",
-    "任务ID": "ID de la tâche",
     "任务日志": "Tâches",
     "任务状态": "Statut de la tâche",
     "任务记录": "Tâches",
@@ -792,7 +791,6 @@
     "天": "Jour",
     "天前": "il y a des jours",
     "失败": "Échec",
-    "失败原因": "Raison de l'échec",
     "失败时自动禁用通道": "Désactiver automatiquement le canal en cas d'échec",
     "失败重试次数": "Nombre de tentatives en cas d'échec",
     "奖励说明": "Description de la récompense",
@@ -2508,6 +2506,11 @@
     "销毁容器": "Destroy Container",
     "销毁容器失败": "Failed to destroy container",
     "错误": "Erreur",
+    "退款": "Remboursement",
+    "错误详情": "Détails de l'erreur",
+    "异步任务退款": "Remboursement de tâche asynchrone",
+    "任务ID": "ID de tâche",
+    "失败原因": "Raison de l'échec",
     "键为分组名称,值为另一个 JSON 对象,键为分组名称,值为该分组的用户的特殊分组倍率,例如:{\"vip\": {\"default\": 0.5, \"test\": 1}},表示 vip 分组的用户在使用default分组的令牌时倍率为0.5,使用test分组时倍率为1": "La clé est le nom du groupe, la valeur est un autre objet JSON, la clé est le nom du groupe, la valeur est le ratio de groupe spécial des utilisateurs de ce groupe, par exemple : {\"vip\": {\"default\": 0.5, \"test\": 1}}, ce qui signifie que les utilisateurs du groupe vip ont un ratio de 0.5 lors de l'utilisation de jetons du groupe default et un ratio de 1 lors de l'utilisation du groupe test",
     "键为原状态码,值为要复写的状态码,仅影响本地判断": "La clé est le code d'état d'origine, la valeur est le code d'état à réécrire, n'affecte que le jugement local",
     "键为用户分组名称,值为操作映射对象。内层键以\"+:\"开头表示添加指定分组(键值为分组名称,值为描述),以\"-:\"开头表示移除指定分组(键值为分组名称),不带前缀的键直接添加该分组。例如:{\"vip\": {\"+:premium\": \"高级分组\", \"special\": \"特殊分组\", \"-:default\": \"默认分组\"}},表示 vip 分组的用户可以使用 premium 和 special 分组,同时移除 default 分组的访问权限": "La clé correspond au nom du groupe d'utilisateurs et la valeur à un objet de mappage des opérations. Les clés internes commençant par \"+:\" ajoutent le groupe indiqué (clé = nom du groupe, valeur = description), celles commençant par \"-:\" retirent le groupe indiqué, et les clés sans préfixe ajoutent directement ce groupe. Exemple : {\"vip\": {\"+:premium\": \"Groupe avancé\", \"special\": \"Groupe spécial\", \"-:default\": \"Groupe par défaut\"}} signifie que les utilisateurs du groupe vip peuvent accéder aux groupes premium et special tout en perdant l'accès au groupe default.",

+ 5 - 2
web/src/i18n/locales/ja.json

@@ -300,7 +300,6 @@
     "价格重新计算中...": "Recalculating price...",
     "价格预估": "Price Estimate",
     "任务 ID": "タスクID",
-    "任务ID": "タスクID",
     "任务日志": "タスク履歴",
     "任务状态": "タスクステータス",
     "任务记录": "タスク履歴",
@@ -783,7 +782,6 @@
     "天": "日",
     "天前": "日前",
     "失败": "失敗",
-    "失败原因": "失敗理由",
     "失败时自动禁用通道": "失敗時にチャネルを自動的に無効にする",
     "失败重试次数": "再試行回数",
     "奖励说明": "特典説明",
@@ -2491,6 +2489,11 @@
     "销毁容器": "Destroy Container",
     "销毁容器失败": "Failed to destroy container",
     "错误": "エラー",
+    "退款": "返金",
+    "错误详情": "エラー詳細",
+    "异步任务退款": "非同期タスク返金",
+    "任务ID": "タスクID",
+    "失败原因": "失敗の原因",
     "键为分组名称,值为另一个 JSON 对象,键为分组名称,值为该分组的用户的特殊分组倍率,例如:{\"vip\": {\"default\": 0.5, \"test\": 1}},表示 vip 分组的用户在使用default分组的令牌时倍率为0.5,使用test分组时倍率为1": "キーはグループ名、値は別のJSONオブジェクトです。このオブジェクトのキーには、利用するトークンが属するグループ名を指定し、値にはそのユーザーグループに適用される特別な倍率を指定します。例:{\"vip\": {\"default\": 0.5, \"test\": 1}} は、vipグループのユーザーがdefaultグループのトークンを利用する際の倍率が0.5、testグループのトークンを利用する際の倍率が1になることを示します",
     "键为原状态码,值为要复写的状态码,仅影响本地判断": "キーは元のステータスコード、値は上書きするステータスコードで、ローカルでの判断にのみ影響します",
     "键为用户分组名称,值为操作映射对象。内层键以\"+:\"开头表示添加指定分组(键值为分组名称,值为描述),以\"-:\"开头表示移除指定分组(键值为分组名称),不带前缀的键直接添加该分组。例如:{\"vip\": {\"+:premium\": \"高级分组\", \"special\": \"特殊分组\", \"-:default\": \"默认分组\"}},表示 vip 分组的用户可以使用 premium 和 special 分组,同时移除 default 分组的访问权限": "Keys are user group names and values are operation mappings. Inner keys prefixed with \"+:\" add the specified group (key is the group name, value is the description); keys prefixed with \"-:\" remove the specified group; keys without a prefix add that group directly. Example: {\"vip\": {\"+:premium\": \"Advanced group\", \"special\": \"Special group\", \"-:default\": \"Default group\"}} means vip users can access the premium and special groups while removing access to the default group.",

+ 5 - 2
web/src/i18n/locales/ru.json

@@ -307,7 +307,6 @@
     "价格重新计算中...": "Recalculating price...",
     "价格预估": "Price Estimate",
     "任务 ID": "ID задачи",
-    "任务ID": "ID задачи",
     "任务日志": "Журнал задач",
     "任务状态": "Статус задачи",
     "任务记录": "Записи задач",
@@ -798,7 +797,6 @@
     "天": "день",
     "天前": "дней назад",
     "失败": "Неудача",
-    "失败原因": "Причина неудачи",
     "失败时自动禁用通道": "Автоматически отключать канал при неудаче",
     "失败重试次数": "Количество повторных попыток при неудаче",
     "奖励说明": "Описание награды",
@@ -2521,6 +2519,11 @@
     "销毁容器": "Destroy Container",
     "销毁容器失败": "Failed to destroy container",
     "错误": "Ошибка",
+    "退款": "Возврат",
+    "错误详情": "Детали ошибки",
+    "异步任务退款": "Возврат асинхронной задачи",
+    "任务ID": "ID задачи",
+    "失败原因": "Причина ошибки",
     "键为分组名称,值为另一个 JSON 对象,键为分组名称,值为该分组的用户的特殊分组倍率,例如:{\"vip\": {\"default\": 0.5, \"test\": 1}},表示 vip 分组的用户在使用default分组的令牌时倍率为0.5,使用test分组时倍率为1": "Ключ - это имя группы, значение - другой JSON объект, ключ - имя группы, значение - специальный групповой коэффициент для пользователей этой группы, например: {\"vip\": {\"default\": 0.5, \"test\": 1}}, означает, что пользователи группы vip при использовании токенов группы default имеют коэффициент 0.5, при использовании группы test - коэффициент 1",
     "键为原状态码,值为要复写的状态码,仅影响本地判断": "Ключ - исходный код состояния, значение - код состояния для перезаписи, влияет только на локальную проверку",
     "键为用户分组名称,值为操作映射对象。内层键以\"+:\"开头表示添加指定分组(键值为分组名称,值为描述),以\"-:\"开头表示移除指定分组(键值为分组名称),不带前缀的键直接添加该分组。例如:{\"vip\": {\"+:premium\": \"高级分组\", \"special\": \"特殊分组\", \"-:default\": \"默认分组\"}},表示 vip 分组的用户可以使用 premium 和 special 分组,同时移除 default 分组的访问权限": "Ключ — это название группы пользователей, значение — объект сопоставления операций. Внутренние ключи с префиксом \"+:\" добавляют указанные группы (ключ — название группы, значение — описание), с префиксом \"-:\" удаляют указанные группы, без префикса — сразу добавляют эту группу. Пример: {\"vip\": {\"+:premium\": \"Продвинутая группа\", \"special\": \"Особая группа\", \"-:default\": \"Группа по умолчанию\"}} означает, что пользователи группы vip могут использовать группы premium и special, одновременно теряя доступ к группе default.",

+ 4 - 2
web/src/i18n/locales/vi.json

@@ -301,7 +301,6 @@
     "价格重新计算中...": "Recalculating price...",
     "价格预估": "Price Estimate",
     "任务 ID": "ID tác vụ",
-    "任务ID": "ID tác vụ",
     "任务日志": "Nhật ký tác vụ",
     "任务状态": "Trạng thái",
     "任务记录": "Hồ sơ tác vụ",
@@ -784,7 +783,6 @@
     "天": "ngày",
     "天前": "ngày trước",
     "失败": "Thất bại",
-    "失败原因": "Lý do thất bại",
     "失败时自动禁用通道": "Tự động vô hiệu hóa kênh khi thất bại",
     "失败重试次数": "Số lần thử lại thất bại",
     "奖励说明": "Mô tả phần thưởng",
@@ -3060,10 +3058,14 @@
     "销毁容器失败": "Failed to destroy container",
     "锁定": "Khóa",
     "错误": "Lỗi",
+    "退款": "Hoàn tiền",
     "错误信息": "Thông tin lỗi",
     "错误日志": "Nhật ký lỗi",
     "错误码": "Mã lỗi",
     "错误详情": "Chi tiết lỗi",
+    "异步任务退款": "Hoàn tiền tác vụ bất đồng bộ",
+    "任务ID": "ID tác vụ",
+    "失败原因": "Nguyên nhân thất bại",
     "键为分组名称,值为另一个 JSON 对象,键为分组名称,值为该分组的用户的特殊分组倍率,例如:{\"vip\": {\"default\": 0.5, \"test\": 1}},表示 vip 分组的用户在使用default分组的令牌时倍率为0.5,使用test分组时倍率为1": "Khóa là tên nhóm và giá trị là một đối tượng JSON khác. Khóa là tên nhóm và giá trị là tỷ lệ nhóm đặc biệt cho người dùng trong nhóm đó. Ví dụ: {\"vip\": {\"default\": 0.5, \"test\": 1}} có nghĩa là người dùng trong nhóm vip có tỷ lệ 0.5 khi sử dụng mã thông báo từ nhóm default và tỷ lệ 1 khi sử dụng mã thông báo từ nhóm test.",
     "键为原状态码,值为要复写的状态码,仅影响本地判断": "Khóa là mã trạng thái gốc và giá trị là mã trạng thái cần ghi đè, chỉ ảnh hưởng đến phán đoán cục bộ",
     "键为用户分组名称,值为操作映射对象。内层键以\"+:\"开头表示添加指定分组(键值为分组名称,值为描述),以\"-:\"开头表示移除指定分组(键值为分组名称),不带前缀的键直接添加该分组。例如:{\"vip\": {\"+:premium\": \"高级分组\", \"special\": \"特殊分组\", \"-:default\": \"默认分组\"}},表示 vip 分组的用户可以使用 premium 和 special 分组,同时移除 default 分组的访问权限": "Keys are user group names and values are operation mappings. Inner keys prefixed with \"+:\" add the specified group (key is the group name, value is the description); keys prefixed with \"-:\" remove the specified group; keys without a prefix add that group directly. Example: {\"vip\": {\"+:premium\": \"Advanced group\", \"special\": \"Special group\", \"-:default\": \"Default group\"}} means vip users can access the premium and special groups while removing access to the default group.",

+ 5 - 6
web/src/i18n/locales/zh-CN.json

@@ -298,7 +298,6 @@
     "价格重新计算中...": "价格重新计算中...",
     "价格预估": "价格预估",
     "任务 ID": "任务 ID",
-    "任务ID": "任务ID",
     "任务日志": "任务日志",
     "任务状态": "任务状态",
     "任务记录": "任务记录",
@@ -539,7 +538,6 @@
     "创建": "创建",
     "创建令牌默认选择auto分组,初始令牌也将设为auto(否则留空,为用户默认分组)": "创建令牌默认选择auto分组,初始令牌也将设为auto(否则留空,为用户默认分组)",
     "创建失败": "创建失败",
-    "创建成功": "创建成功",
     "创建或选择密钥时,将 Project 设置为 io.cloud": "创建或选择密钥时,将 Project 设置为 io.cloud",
     "创建新用户账户": "创建新用户账户",
     "创建新的令牌": "创建新的令牌",
@@ -782,7 +780,6 @@
     "天": "天",
     "天前": "天前",
     "失败": "失败",
-    "失败原因": "失败原因",
     "失败时自动禁用通道": "失败时自动禁用通道",
     "失败重试次数": "失败重试次数",
     "奖励说明": "奖励说明",
@@ -1326,7 +1323,6 @@
     "更新失败,请检查输入信息": "更新失败,请检查输入信息",
     "更新容器配置": "更新容器配置",
     "更新容器配置可能会导致容器重启,请确保在合适的时间进行此操作。": "更新容器配置可能会导致容器重启,请确保在合适的时间进行此操作。",
-    "更新成功": "更新成功",
     "更新所有已启用通道余额": "更新所有已启用通道余额",
     "更新支付设置": "更新支付设置",
     "更新时间": "更新时间",
@@ -1754,7 +1750,6 @@
     "确认清除历史日志": "确认清除历史日志",
     "确认禁用": "确认禁用",
     "确认补单": "确认补单",
-    "确认解绑": "确认解绑",
     "确认解绑 Passkey": "确认解绑 Passkey",
     "确认设置并完成初始化": "确认设置并完成初始化",
     "确认重置 Passkey": "确认重置 Passkey",
@@ -1932,7 +1927,6 @@
     "自动分组auto,从第一个开始选择": "自动分组auto,从第一个开始选择",
     "自动刷新": "自动刷新",
     "自动刷新中": "自动刷新中",
-    "自动检测": "自动检测",
     "自动模式": "自动模式",
     "自动测试所有通道间隔时间": "自动测试所有通道间隔时间",
     "自动禁用": "自动禁用",
@@ -2531,6 +2525,11 @@
     "销毁容器": "销毁容器",
     "销毁容器失败": "销毁容器失败",
     "错误": "错误",
+    "退款": "退款",
+    "错误详情": "错误详情",
+    "异步任务退款": "异步任务退款",
+    "任务ID": "任务ID",
+    "失败原因": "失败原因",
     "键为分组名称,值为另一个 JSON 对象,键为分组名称,值为该分组的用户的特殊分组倍率,例如:{\"vip\": {\"default\": 0.5, \"test\": 1}},表示 vip 分组的用户在使用default分组的令牌时倍率为0.5,使用test分组时倍率为1": "键为分组名称,值为另一个 JSON 对象,键为分组名称,值为该分组的用户的特殊分组倍率,例如:{\"vip\": {\"default\": 0.5, \"test\": 1}},表示 vip 分组的用户在使用default分组的令牌时倍率为0.5,使用test分组时倍率为1",
     "键为原状态码,值为要复写的状态码,仅影响本地判断": "键为原状态码,值为要复写的状态码,仅影响本地判断",
     "键为用户分组名称,值为操作映射对象。内层键以\"+:\"开头表示添加指定分组(键值为分组名称,值为描述),以\"-:\"开头表示移除指定分组(键值为分组名称),不带前缀的键直接添加该分组。例如:{\"vip\": {\"+:premium\": \"高级分组\", \"special\": \"特殊分组\", \"-:default\": \"默认分组\"}},表示 vip 分组的用户可以使用 premium 和 special 分组,同时移除 default 分组的访问权限": "键为用户分组名称,值为操作映射对象。内层键以\"+:\"开头表示添加指定分组(键值为分组名称,值为描述),以\"-:\"开头表示移除指定分组(键值为分组名称),不带前缀的键直接添加该分组。例如:{\"vip\": {\"+:premium\": \"高级分组\", \"special\": \"特殊分组\", \"-:default\": \"默认分组\"}},表示 vip 分组的用户可以使用 premium 和 special 分组,同时移除 default 分组的访问权限",