ソースを参照

feat(task): add adaptor billing interface and async settlement framework

Add three billing lifecycle methods to the TaskAdaptor interface:
- EstimateBilling: compute OtherRatios from user request before pricing
- AdjustBillingOnSubmit: adjust ratios from upstream submit response
- AdjustBillingOnComplete: determine final quota at task terminal state

Introduce BaseBilling as embeddable no-op default for adaptors without
custom billing. Move Sora/Ali OtherRatios logic from shared validation
into per-adaptor EstimateBilling implementations.

Add TaskBillingContext to persist pricing params (model_price, group_ratio,
other_ratios) in task private data for async polling settlement.

Extract RecalculateTaskQuota as a general-purpose delta settlement
function and unify polling billing via settleTaskBillingOnComplete
(adaptor-first, then token-based fallback).
CaIon 3 週間 前
コミット
d6e11fd2e1

+ 7 - 0
controller/relay.go

@@ -509,6 +509,13 @@ func RelayTask(c *gin.Context) {
 		task.PrivateData.BillingSource = relayInfo.BillingSource
 		task.PrivateData.SubscriptionId = relayInfo.SubscriptionId
 		task.PrivateData.TokenId = relayInfo.TokenId
+		task.PrivateData.BillingContext = &model.TaskBillingContext{
+			ModelPrice:  relayInfo.PriceData.ModelPrice,
+			GroupRatio:  relayInfo.PriceData.GroupRatioInfo.GroupRatio,
+			ModelRatio:  relayInfo.PriceData.ModelRatio,
+			OtherRatios: relayInfo.PriceData.OtherRatios,
+			ModelName:   result.ModelName,
+		}
 		task.Quota = result.Quota
 		task.Data = result.TaskData
 		task.Action = relayInfo.Action

+ 1 - 2
logger/logger.go

@@ -2,7 +2,6 @@ package logger
 
 import (
 	"context"
-	"encoding/json"
 	"fmt"
 	"io"
 	"log"
@@ -151,7 +150,7 @@ func FormatQuota(quota int) string {
 
 // LogJson 仅供测试使用 only for test
 func LogJson(ctx context.Context, msg string, obj any) {
-	jsonStr, err := json.Marshal(obj)
+	jsonStr, err := common.Marshal(obj)
 	if err != nil {
 		LogError(ctx, fmt.Sprintf("json marshal failed: %s", err.Error()))
 		return

+ 13 - 3
model/task.go

@@ -100,9 +100,19 @@ type TaskPrivateData struct {
 	UpstreamTaskID string `json:"upstream_task_id,omitempty"` // 上游真实 task ID
 	ResultURL      string `json:"result_url,omitempty"`       // 任务成功后的结果 URL(视频地址等)
 	// 计费上下文:用于异步退款/差额结算(轮询阶段读取)
-	BillingSource  string `json:"billing_source,omitempty"`  // "wallet" 或 "subscription"
-	SubscriptionId int    `json:"subscription_id,omitempty"` // 订阅 ID,用于订阅退款
-	TokenId        int    `json:"token_id,omitempty"`        // 令牌 ID,用于令牌额度退款
+	BillingSource  string              `json:"billing_source,omitempty"`  // "wallet" 或 "subscription"
+	SubscriptionId int                 `json:"subscription_id,omitempty"` // 订阅 ID,用于订阅退款
+	TokenId        int                 `json:"token_id,omitempty"`        // 令牌 ID,用于令牌额度退款
+	BillingContext *TaskBillingContext `json:"billing_context,omitempty"` // 计费参数快照(用于轮询阶段重新计算)
+}
+
+// TaskBillingContext 记录任务提交时的计费参数,以便轮询阶段可以重新计算额度。
+type TaskBillingContext struct {
+	ModelPrice  float64            `json:"model_price,omitempty"`  // 模型单价
+	GroupRatio  float64            `json:"group_ratio,omitempty"`  // 分组倍率
+	ModelRatio  float64            `json:"model_ratio,omitempty"`  // 模型倍率
+	OtherRatios map[string]float64 `json:"other_ratios,omitempty"` // 附加倍率(时长、分辨率等)
+	ModelName   string             `json:"model_name,omitempty"`   // 模型名称
 }
 
 // GetUpstreamTaskID 获取上游真实 task ID(用于与 provider 通信)

+ 28 - 2
relay/channel/adapter.go

@@ -36,6 +36,32 @@ type TaskAdaptor interface {
 
 	ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) *dto.TaskError
 
+	// ── Billing ──────────────────────────────────────────────────────
+
+	// EstimateBilling returns OtherRatios for pre-charge based on user request.
+	// Called after ValidateRequestAndSetAction, before price calculation.
+	// Adaptors should extract duration, resolution, etc. from the parsed request
+	// and return them as ratio multipliers (e.g. {"seconds": 5, "size": 1.666}).
+	// Return nil to use the base model price without extra ratios.
+	EstimateBilling(c *gin.Context, info *relaycommon.RelayInfo) map[string]float64
+
+	// AdjustBillingOnSubmit returns adjusted OtherRatios from the upstream
+	// submit response. Called after a successful DoResponse.
+	// If the upstream returned actual parameters that differ from the estimate
+	// (e.g. actual seconds), return updated ratios so the caller can recalculate
+	// the quota and settle the delta with the pre-charge.
+	// Return nil if no adjustment is needed.
+	AdjustBillingOnSubmit(info *relaycommon.RelayInfo, taskData []byte) map[string]float64
+
+	// AdjustBillingOnComplete returns the actual quota when a task reaches a
+	// terminal state (success/failure) during polling.
+	// Called by the polling loop after ParseTaskResult.
+	// Return a positive value to trigger delta settlement (supplement / refund).
+	// Return 0 to keep the pre-charged amount unchanged.
+	AdjustBillingOnComplete(task *model.Task, taskResult *relaycommon.TaskInfo) int
+
+	// ── Request / Response ───────────────────────────────────────────
+
 	BuildRequestURL(info *relaycommon.RelayInfo) (string, error)
 	BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error
 	BuildRequestBody(c *gin.Context, info *relaycommon.RelayInfo) (io.Reader, error)
@@ -46,9 +72,9 @@ type TaskAdaptor interface {
 	GetModelList() []string
 	GetChannelName() string
 
-	// FetchTask
-	FetchTask(baseUrl, key string, body map[string]any, proxy string) (*http.Response, error)
+	// ── Polling ──────────────────────────────────────────────────────
 
+	FetchTask(baseUrl, key string, body map[string]any, proxy string) (*http.Response, error)
 	ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error)
 }
 

+ 36 - 21
relay/channel/task/ali/adaptor.go

@@ -13,6 +13,7 @@ import (
 	"github.com/QuantumNous/new-api/logger"
 	"github.com/QuantumNous/new-api/model"
 	"github.com/QuantumNous/new-api/relay/channel"
+	"github.com/QuantumNous/new-api/relay/channel/task/taskcommon"
 	relaycommon "github.com/QuantumNous/new-api/relay/common"
 	"github.com/QuantumNous/new-api/service"
 	"github.com/samber/lo"
@@ -108,10 +109,10 @@ type AliMetadata struct {
 // ============================
 
 type TaskAdaptor struct {
+	taskcommon.BaseBilling
 	ChannelType int
 	apiKey      string
 	baseURL     string
-	aliReq      *AliVideoRequest
 }
 
 func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) {
@@ -121,17 +122,7 @@ func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) {
 }
 
 func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) {
-	// 阿里通义万相支持 JSON 格式,不使用 multipart
-	var taskReq relaycommon.TaskSubmitReq
-	if err := common.UnmarshalBodyReusable(c, &taskReq); err != nil {
-		return service.TaskErrorWrapper(err, "unmarshal_task_request_failed", http.StatusBadRequest)
-	}
-	aliReq, err := a.convertToAliRequest(info, taskReq)
-	if err != nil {
-		return service.TaskErrorWrapper(err, "convert_to_ali_request_failed", http.StatusInternalServerError)
-	}
-	a.aliReq = aliReq
-	logger.LogJson(c, "ali video request body", aliReq)
+	// ValidateMultipartDirect 负责解析并将原始 TaskSubmitReq 存入 context
 	return relaycommon.ValidateMultipartDirect(c, info)
 }
 
@@ -148,11 +139,21 @@ func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info
 }
 
 func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayInfo) (io.Reader, error) {
-	bodyBytes, err := common.Marshal(a.aliReq)
+	taskReq, err := relaycommon.GetTaskRequest(c)
 	if err != nil {
-		return nil, errors.Wrap(err, "marshal_ali_request_failed")
+		return nil, errors.Wrap(err, "get_task_request_failed")
+	}
+
+	aliReq, err := a.convertToAliRequest(info, taskReq)
+	if err != nil {
+		return nil, errors.Wrap(err, "convert_to_ali_request_failed")
 	}
+	logger.LogJson(c, "ali video request body", aliReq)
 
+	bodyBytes, err := common.Marshal(aliReq)
+	if err != nil {
+		return nil, errors.Wrap(err, "marshal_ali_request_failed")
+	}
 	return bytes.NewReader(bodyBytes), nil
 }
 
@@ -335,19 +336,33 @@ func (a *TaskAdaptor) convertToAliRequest(info *relaycommon.RelayInfo, req relay
 		return nil, errors.New("can't change model with metadata")
 	}
 
-	info.PriceData.OtherRatios = map[string]float64{
-		"seconds": float64(aliReq.Parameters.Duration),
+	return aliReq, nil
+}
+
+// EstimateBilling 根据用户请求参数计算 OtherRatios(时长、分辨率等)。
+// 在 ValidateRequestAndSetAction 之后、价格计算之前调用。
+func (a *TaskAdaptor) EstimateBilling(c *gin.Context, info *relaycommon.RelayInfo) map[string]float64 {
+	taskReq, err := relaycommon.GetTaskRequest(c)
+	if err != nil {
+		return nil
+	}
+
+	aliReq, err := a.convertToAliRequest(info, taskReq)
+	if err != nil {
+		return nil
 	}
 
+	otherRatios := map[string]float64{
+		"seconds": float64(aliReq.Parameters.Duration),
+	}
 	ratios, err := ProcessAliOtherRatios(aliReq)
 	if err != nil {
-		return nil, err
+		return otherRatios
 	}
-	for s, f := range ratios {
-		info.PriceData.OtherRatios[s] = f
+	for k, v := range ratios {
+		otherRatios[k] = v
 	}
-
-	return aliReq, nil
+	return otherRatios
 }
 
 // DoRequest delegates to common helper

+ 1 - 0
relay/channel/task/doubao/adaptor.go

@@ -89,6 +89,7 @@ type responseTask struct {
 // ============================
 
 type TaskAdaptor struct {
+	taskcommon.BaseBilling
 	ChannelType int
 	apiKey      string
 	baseURL     string

+ 1 - 0
relay/channel/task/gemini/adaptor.go

@@ -85,6 +85,7 @@ type operationResponse struct {
 // ============================
 
 type TaskAdaptor struct {
+	taskcommon.BaseBilling
 	ChannelType int
 	apiKey      string
 	baseURL     string

+ 2 - 0
relay/channel/task/hailuo/adaptor.go

@@ -17,12 +17,14 @@ import (
 	"github.com/QuantumNous/new-api/constant"
 	"github.com/QuantumNous/new-api/dto"
 	"github.com/QuantumNous/new-api/relay/channel"
+	taskcommon "github.com/QuantumNous/new-api/relay/channel/task/taskcommon"
 	relaycommon "github.com/QuantumNous/new-api/relay/common"
 	"github.com/QuantumNous/new-api/service"
 )
 
 // https://platform.minimaxi.com/docs/api-reference/video-generation-intro
 type TaskAdaptor struct {
+	taskcommon.BaseBilling
 	ChannelType int
 	apiKey      string
 	baseURL     string

+ 1 - 0
relay/channel/task/jimeng/adaptor.go

@@ -77,6 +77,7 @@ const (
 // ============================
 
 type TaskAdaptor struct {
+	taskcommon.BaseBilling
 	ChannelType int
 	accessKey   string
 	secretKey   string

+ 1 - 0
relay/channel/task/kling/adaptor.go

@@ -97,6 +97,7 @@ type responsePayload struct {
 // ============================
 
 type TaskAdaptor struct {
+	taskcommon.BaseBilling
 	ChannelType int
 	apiKey      string
 	baseURL     string

+ 41 - 3
relay/channel/task/sora/adaptor.go

@@ -4,6 +4,7 @@ import (
 	"fmt"
 	"io"
 	"net/http"
+	"strconv"
 	"strings"
 
 	"github.com/QuantumNous/new-api/common"
@@ -11,6 +12,7 @@ import (
 	"github.com/QuantumNous/new-api/dto"
 	"github.com/QuantumNous/new-api/model"
 	"github.com/QuantumNous/new-api/relay/channel"
+	taskcommon "github.com/QuantumNous/new-api/relay/channel/task/taskcommon"
 	relaycommon "github.com/QuantumNous/new-api/relay/common"
 	"github.com/QuantumNous/new-api/service"
 
@@ -56,6 +58,7 @@ type responseTask struct {
 // ============================
 
 type TaskAdaptor struct {
+	taskcommon.BaseBilling
 	ChannelType int
 	apiKey      string
 	baseURL     string
@@ -68,15 +71,15 @@ func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) {
 }
 
 func validateRemixRequest(c *gin.Context) *dto.TaskError {
-	var req struct {
-		Prompt string `json:"prompt"`
-	}
+	var req relaycommon.TaskSubmitReq
 	if err := common.UnmarshalBodyReusable(c, &req); err != nil {
 		return service.TaskErrorWrapperLocal(err, "invalid_request", http.StatusBadRequest)
 	}
 	if strings.TrimSpace(req.Prompt) == "" {
 		return service.TaskErrorWrapperLocal(fmt.Errorf("field prompt is required"), "invalid_request", http.StatusBadRequest)
 	}
+	// 存储原始请求到 context,与 ValidateMultipartDirect 路径保持一致
+	c.Set("task_request", req)
 	return nil
 }
 
@@ -87,6 +90,41 @@ func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycom
 	return relaycommon.ValidateMultipartDirect(c, info)
 }
 
+// EstimateBilling 根据用户请求的 seconds 和 size 计算 OtherRatios。
+func (a *TaskAdaptor) EstimateBilling(c *gin.Context, info *relaycommon.RelayInfo) map[string]float64 {
+	// remix 路径的 OtherRatios 已在 ResolveOriginTask 中设置
+	if info.Action == constant.TaskActionRemix {
+		return nil
+	}
+
+	req, err := relaycommon.GetTaskRequest(c)
+	if err != nil {
+		return nil
+	}
+
+	seconds, _ := strconv.Atoi(req.Seconds)
+	if seconds == 0 {
+		seconds = req.Duration
+	}
+	if seconds <= 0 {
+		seconds = 4
+	}
+
+	size := req.Size
+	if size == "" {
+		size = "720x1280"
+	}
+
+	ratios := map[string]float64{
+		"seconds": float64(seconds),
+		"size":    1,
+	}
+	if size == "1792x1024" || size == "1024x1792" {
+		ratios["size"] = 1.666667
+	}
+	return ratios
+}
+
 func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, error) {
 	if info.Action == constant.TaskActionRemix {
 		return fmt.Sprintf("%s/v1/videos/%s/remix", a.baseURL, info.OriginTaskID), nil

+ 3 - 4
relay/channel/task/suno/adaptor.go

@@ -13,6 +13,7 @@ import (
 	"github.com/QuantumNous/new-api/constant"
 	"github.com/QuantumNous/new-api/dto"
 	"github.com/QuantumNous/new-api/relay/channel"
+	taskcommon "github.com/QuantumNous/new-api/relay/channel/task/taskcommon"
 	relaycommon "github.com/QuantumNous/new-api/relay/common"
 	"github.com/QuantumNous/new-api/service"
 
@@ -20,6 +21,7 @@ import (
 )
 
 type TaskAdaptor struct {
+	taskcommon.BaseBilling
 	ChannelType int
 }
 
@@ -79,10 +81,7 @@ func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info
 func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayInfo) (io.Reader, error) {
 	sunoRequest, ok := c.Get("task_request")
 	if !ok {
-		err := common.UnmarshalBodyReusable(c, &sunoRequest)
-		if err != nil {
-			return nil, err
-		}
+		return nil, fmt.Errorf("task_request not found in context")
 	}
 	data, err := common.Marshal(sunoRequest)
 	if err != nil {

+ 25 - 0
relay/channel/task/taskcommon/helpers.go

@@ -5,7 +5,10 @@ import (
 	"fmt"
 
 	"github.com/QuantumNous/new-api/common"
+	"github.com/QuantumNous/new-api/model"
+	relaycommon "github.com/QuantumNous/new-api/relay/common"
 	"github.com/QuantumNous/new-api/setting/system_setting"
+	"github.com/gin-gonic/gin"
 )
 
 // UnmarshalMetadata converts a map[string]any metadata to a typed struct via JSON round-trip.
@@ -68,3 +71,25 @@ const (
 	ProgressInProgress = "30%"
 	ProgressComplete   = "100%"
 )
+
+// ---------------------------------------------------------------------------
+// BaseBilling — embeddable no-op implementations for TaskAdaptor billing methods.
+// Adaptors that do not need custom billing can embed this struct directly.
+// ---------------------------------------------------------------------------
+
+type BaseBilling struct{}
+
+// EstimateBilling returns nil (no extra ratios; use base model price).
+func (BaseBilling) EstimateBilling(_ *gin.Context, _ *relaycommon.RelayInfo) map[string]float64 {
+	return nil
+}
+
+// AdjustBillingOnSubmit returns nil (no submit-time adjustment).
+func (BaseBilling) AdjustBillingOnSubmit(_ *relaycommon.RelayInfo, _ []byte) map[string]float64 {
+	return nil
+}
+
+// AdjustBillingOnComplete returns 0 (keep pre-charged amount).
+func (BaseBilling) AdjustBillingOnComplete(_ *model.Task, _ *relaycommon.TaskInfo) int {
+	return 0
+}

+ 23 - 18
relay/channel/task/vertex/adaptor.go

@@ -62,6 +62,7 @@ type operationResponse struct {
 // ============================
 
 type TaskAdaptor struct {
+	taskcommon.BaseBilling
 	ChannelType int
 	apiKey      string
 	baseURL     string
@@ -133,6 +134,28 @@ func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info
 	return nil
 }
 
+// EstimateBilling 根据用户请求中的 sampleCount 计算 OtherRatios。
+func (a *TaskAdaptor) EstimateBilling(c *gin.Context, _ *relaycommon.RelayInfo) map[string]float64 {
+	sampleCount := 1
+	v, ok := c.Get("task_request")
+	if ok {
+		req := v.(relaycommon.TaskSubmitReq)
+		if req.Metadata != nil {
+			if sc, exists := req.Metadata["sampleCount"]; exists {
+				if i, ok := sc.(int); ok && i > 0 {
+					sampleCount = i
+				}
+				if f, ok := sc.(float64); ok && int(f) > 0 {
+					sampleCount = int(f)
+				}
+			}
+		}
+	}
+	return map[string]float64{
+		"sampleCount": float64(sampleCount),
+	}
+}
+
 // BuildRequestBody converts request into Vertex specific format.
 func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayInfo) (io.Reader, error) {
 	v, ok := c.Get("task_request")
@@ -166,24 +189,6 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn
 		return nil, fmt.Errorf("sampleCount must be greater than 0")
 	}
 
-	// if req.Duration > 0 {
-	// 	body.Parameters["durationSeconds"] = req.Duration
-	// } else if req.Seconds != "" {
-	// 	seconds, err := strconv.Atoi(req.Seconds)
-	// 	if err != nil {
-	// 		return nil, errors.Wrap(err, "convert seconds to int failed")
-	// 	}
-	// 	body.Parameters["durationSeconds"] = seconds
-	// }
-
-	info.PriceData.OtherRatios = map[string]float64{
-		"sampleCount": float64(body.Parameters["sampleCount"].(int)),
-	}
-
-	// if v, ok := body.Parameters["durationSeconds"]; ok {
-	// 	info.PriceData.OtherRatios["durationSeconds"] = float64(v.(int))
-	// }
-
 	data, err := common.Marshal(body)
 	if err != nil {
 		return nil, err

+ 1 - 0
relay/channel/task/vidu/adaptor.go

@@ -73,6 +73,7 @@ type creation struct {
 // ============================
 
 type TaskAdaptor struct {
+	taskcommon.BaseBilling
 	ChannelType int
 	baseURL     string
 }

+ 2 - 8
relay/common/relay_utils.go

@@ -173,16 +173,10 @@ func ValidateMultipartDirect(c *gin.Context, info *RelayInfo) *dto.TaskError {
 		if model == "sora-2-pro" && !lo.Contains([]string{"720x1280", "1280x720", "1792x1024", "1024x1792"}, size) {
 			return createTaskError(fmt.Errorf("sora-2 size is invalid"), "invalid_size", http.StatusBadRequest, true)
 		}
-		info.PriceData.OtherRatios = map[string]float64{
-			"seconds": float64(seconds),
-			"size":    1,
-		}
-		if lo.Contains([]string{"1792x1024", "1024x1792"}, size) {
-			info.PriceData.OtherRatios["size"] = 1.666667
-		}
+		// OtherRatios 已移到 Sora adaptor 的 EstimateBilling 中设置
 	}
 
-	info.Action = action
+	storeTaskRequest(c, info, action, req)
 
 	return nil
 }

+ 57 - 7
relay/relay_task.go

@@ -128,8 +128,9 @@ func ResolveOriginTask(c *gin.Context, info *relaycommon.RelayInfo) *dto.TaskErr
 }
 
 // RelayTaskSubmit 完成 task 提交的全部流程(每次尝试调用一次):
-// 刷新渠道元数据 → 确定 platform/adaptor → 验证请求 → 计算价格 →
-// 预扣费(仅首次,通过 info.Billing==nil 守卫)→ 构建/发送/解析上游请求。
+// 刷新渠道元数据 → 确定 platform/adaptor → 验证请求 →
+// 估算计费(EstimateBilling) → 计算价格 → 预扣费(仅首次)→
+// 构建/发送/解析上游请求 → 提交后计费调整(AdjustBillingOnSubmit)。
 // 控制器负责 defer Refund 和成功后 Settle。
 func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (*TaskSubmitResult, *dto.TaskError) {
 	info.InitChannelMeta(c)
@@ -159,10 +160,20 @@ func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (*TaskSubmitRe
 		info.PublicTaskID = model.GenerateTaskID()
 	}
 
-	// 4. 价格计算
+	// 4. 价格计算:基础模型价格
 	info.OriginModelName = modelName
 	info.PriceData = helper.ModelPriceHelperPerCall(c, info)
 
+	// 5. 计费估算:让适配器根据用户请求提供 OtherRatios(时长、分辨率等)
+	//    必须在 ModelPriceHelperPerCall 之后调用(它会重建 PriceData)。
+	//    ResolveOriginTask 可能已在 remix 路径中预设了 OtherRatios,此处合并。
+	if estimatedRatios := adaptor.EstimateBilling(c, info); len(estimatedRatios) > 0 {
+		for k, v := range estimatedRatios {
+			info.PriceData.AddOtherRatio(k, v)
+		}
+	}
+
+	// 6. 将 OtherRatios 应用到基础额度
 	if !common.StringsContains(constant.TaskPricePatches, modelName) {
 		for _, ra := range info.PriceData.OtherRatios {
 			if ra != 1.0 {
@@ -171,7 +182,7 @@ func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (*TaskSubmitRe
 		}
 	}
 
-	// 5. 预扣费(仅首次 — 重试时 info.Billing 已存在,跳过)
+	// 7. 预扣费(仅首次 — 重试时 info.Billing 已存在,跳过)
 	if info.Billing == nil && !info.PriceData.FreeModel {
 		info.ForcePreConsume = true
 		if apiErr := service.PreConsumeBilling(c, info.PriceData.Quota, info); apiErr != nil {
@@ -179,13 +190,13 @@ func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (*TaskSubmitRe
 		}
 	}
 
-	// 6. 构建请求体
+	// 8. 构建请求体
 	requestBody, err := adaptor.BuildRequestBody(c, info)
 	if err != nil {
 		return nil, service.TaskErrorWrapper(err, "build_request_failed", http.StatusInternalServerError)
 	}
 
-	// 7. 发送请求
+	// 9. 发送请求
 	resp, err := adaptor.DoRequest(c, info, requestBody)
 	if err != nil {
 		return nil, service.TaskErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
@@ -195,20 +206,59 @@ func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (*TaskSubmitRe
 		return nil, service.TaskErrorWrapper(fmt.Errorf("%s", string(responseBody)), "fail_to_fetch_task", resp.StatusCode)
 	}
 
-	// 8. 解析响应
+	// 10. 返回 OtherRatios 给下游(header 必须在 DoResponse 写 body 之前设置)
+	otherRatios := info.PriceData.OtherRatios
+	if otherRatios == nil {
+		otherRatios = map[string]float64{}
+	}
+	ratiosJSON, _ := common.Marshal(otherRatios)
+	c.Header("X-New-Api-Other-Ratios", string(ratiosJSON))
+
+	// 11. 解析响应
 	upstreamTaskID, taskData, taskErr := adaptor.DoResponse(c, resp, info)
 	if taskErr != nil {
 		return nil, taskErr
 	}
 
+	// 11. 提交后计费调整:让适配器根据上游实际返回调整 OtherRatios
+	finalQuota := info.PriceData.Quota
+	if adjustedRatios := adaptor.AdjustBillingOnSubmit(info, taskData); len(adjustedRatios) > 0 {
+		// 基于调整后的 ratios 重新计算 quota
+		finalQuota = recalcQuotaFromRatios(info, adjustedRatios)
+		info.PriceData.OtherRatios = adjustedRatios
+		info.PriceData.Quota = finalQuota
+	}
+
 	return &TaskSubmitResult{
 		UpstreamTaskID: upstreamTaskID,
 		TaskData:       taskData,
 		Platform:       platform,
 		ModelName:      modelName,
+		Quota:          finalQuota,
 	}, nil
 }
 
+// recalcQuotaFromRatios 根据 adjustedRatios 重新计算 quota。
+// 公式: baseQuota × ∏(ratio) — 其中 baseQuota 是不含 OtherRatios 的基础额度。
+func recalcQuotaFromRatios(info *relaycommon.RelayInfo, ratios map[string]float64) int {
+	// 从 PriceData 获取不含 OtherRatios 的基础价格
+	baseQuota := info.PriceData.Quota
+	// 先除掉原有的 OtherRatios 恢复基础额度
+	for _, ra := range info.PriceData.OtherRatios {
+		if ra != 1.0 && ra > 0 {
+			baseQuota = int(float64(baseQuota) / ra)
+		}
+	}
+	// 应用新的 ratios
+	result := float64(baseQuota)
+	for _, ra := range ratios {
+		if ra != 1.0 {
+			result *= ra
+		}
+	}
+	return int(result)
+}
+
 var fetchRespBuilders = map[int]func(c *gin.Context) (respBody []byte, taskResp *dto.TaskError){
 	relayconstant.RelayModeSunoFetchByID:  sunoFetchByIDRespBodyBuilder,
 	relayconstant.RelayModeSunoFetch:      sunoFetchRespBodyBuilder,

+ 54 - 44
service/task_billing.go

@@ -130,6 +130,58 @@ func RefundTaskQuota(ctx context.Context, task *model.Task, reason string) {
 	model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
 }
 
+// RecalculateTaskQuota 通用的异步差额结算。
+// actualQuota 是任务完成后的实际应扣额度,与预扣额度 (task.Quota) 做差额结算。
+// reason 用于日志记录(例如 "token重算" 或 "adaptor调整")。
+func RecalculateTaskQuota(ctx context.Context, task *model.Task, actualQuota int, reason string) {
+	if actualQuota <= 0 {
+		return
+	}
+	preConsumedQuota := task.Quota
+	quotaDelta := actualQuota - preConsumedQuota
+
+	if quotaDelta == 0 {
+		logger.LogInfo(ctx, fmt.Sprintf("任务 %s 预扣费准确(%s,%s)",
+			task.TaskID, logger.LogQuota(actualQuota), reason))
+		return
+	}
+
+	logger.LogInfo(ctx, fmt.Sprintf("任务 %s 差额结算:delta=%s(实际:%s,预扣:%s,%s)",
+		task.TaskID,
+		logger.LogQuota(quotaDelta),
+		logger.LogQuota(actualQuota),
+		logger.LogQuota(preConsumedQuota),
+		reason,
+	))
+
+	// 调整资金来源
+	if err := taskAdjustFunding(task, quotaDelta); err != nil {
+		logger.LogError(ctx, fmt.Sprintf("差额结算资金调整失败 task %s: %s", task.TaskID, err.Error()))
+		return
+	}
+
+	// 调整令牌额度
+	taskAdjustTokenQuota(ctx, task, quotaDelta)
+
+	// 更新统计(仅补扣时更新,退还不影响已用统计)
+	if quotaDelta > 0 {
+		model.UpdateUserUsedQuotaAndRequestCount(task.UserId, quotaDelta)
+		model.UpdateChannelUsedQuota(task.ChannelId, quotaDelta)
+	}
+	task.Quota = actualQuota
+
+	var action string
+	if quotaDelta > 0 {
+		action = "补扣费"
+	} else {
+		action = "退还"
+	}
+	logContent := fmt.Sprintf("异步任务成功%s,预扣费 %s,实际扣费 %s,原因:%s",
+		action,
+		logger.LogQuota(preConsumedQuota), logger.LogQuota(actualQuota), reason)
+	model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
+}
+
 // RecalculateTaskQuotaByTokens 根据实际 token 消耗重新计费(异步差额结算)。
 // 当任务成功且返回了 totalTokens 时,根据模型倍率和分组倍率重新计算实际扣费额度,
 // 与预扣费的差额进行补扣或退还。支持钱包和订阅计费来源。
@@ -180,48 +232,6 @@ func RecalculateTaskQuotaByTokens(ctx context.Context, task *model.Task, totalTo
 	// 计算实际应扣费额度: totalTokens * modelRatio * groupRatio
 	actualQuota := int(float64(totalTokens) * modelRatio * finalGroupRatio)
 
-	// 计算差额(正数=需要补扣,负数=需要退还)
-	preConsumedQuota := task.Quota
-	quotaDelta := actualQuota - preConsumedQuota
-
-	if quotaDelta == 0 {
-		logger.LogInfo(ctx, fmt.Sprintf("视频任务 %s 预扣费准确(%s,tokens:%d)",
-			task.TaskID, logger.LogQuota(actualQuota), totalTokens))
-		return
-	}
-
-	logger.LogInfo(ctx, fmt.Sprintf("视频任务 %s 差额结算:delta=%s(实际:%s,预扣:%s,tokens:%d)",
-		task.TaskID,
-		logger.LogQuota(quotaDelta),
-		logger.LogQuota(actualQuota),
-		logger.LogQuota(preConsumedQuota),
-		totalTokens,
-	))
-
-	// 调整资金来源
-	if err := taskAdjustFunding(task, quotaDelta); err != nil {
-		logger.LogError(ctx, fmt.Sprintf("差额结算资金调整失败 task %s: %s", task.TaskID, err.Error()))
-		return
-	}
-
-	// 调整令牌额度
-	taskAdjustTokenQuota(ctx, task, quotaDelta)
-
-	// 更新统计(仅补扣时更新,退还不影响已用统计)
-	if quotaDelta > 0 {
-		model.UpdateUserUsedQuotaAndRequestCount(task.UserId, quotaDelta)
-		model.UpdateChannelUsedQuota(task.ChannelId, quotaDelta)
-	}
-	task.Quota = actualQuota
-
-	var action string
-	if quotaDelta > 0 {
-		action = "补扣费"
-	} else {
-		action = "退还"
-	}
-	logContent := fmt.Sprintf("视频任务成功%s,模型倍率 %.2f,分组倍率 %.2f,tokens %d,预扣费 %s,实际扣费 %s",
-		action, modelRatio, finalGroupRatio, totalTokens,
-		logger.LogQuota(preConsumedQuota), logger.LogQuota(actualQuota))
-	model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
+	reason := fmt.Sprintf("token重算:tokens=%d, modelRatio=%.2f, groupRatio=%.2f", totalTokens, modelRatio, finalGroupRatio)
+	RecalculateTaskQuota(ctx, task, actualQuota, reason)
 }

+ 24 - 4
service/task_polling.go

@@ -26,6 +26,9 @@ type TaskPollingAdaptor interface {
 	Init(info *relaycommon.RelayInfo)
 	FetchTask(baseURL string, key string, body map[string]any, proxy string) (*http.Response, error)
 	ParseTaskResult(body []byte) (*relaycommon.TaskInfo, error)
+	// AdjustBillingOnComplete 在任务到达终态(成功/失败)时由轮询循环调用。
+	// 返回正数触发差额结算(补扣/退还),返回 0 保持预扣费金额不变。
+	AdjustBillingOnComplete(task *model.Task, taskResult *relaycommon.TaskInfo) int
 }
 
 // GetTaskAdaptorFunc 由 main 包注入,用于获取指定平台的任务适配器。
@@ -372,10 +375,8 @@ func updateVideoSingleTask(ctx context.Context, adaptor TaskPollingAdaptor, ch *
 			task.PrivateData.ResultURL = taskcommon.BuildProxyURL(task.TaskID)
 		}
 
-		// 如果返回了 total_tokens,根据模型倍率重新计费
-		if taskResult.TotalTokens > 0 {
-			RecalculateTaskQuotaByTokens(ctx, task, taskResult.TotalTokens)
-		}
+		// 完成时计费调整:优先由 adaptor 计算,回退到 token 重算
+		settleTaskBillingOnComplete(ctx, adaptor, task, taskResult)
 	case model.TaskStatusFailure:
 		logger.LogJson(ctx, fmt.Sprintf("Task %s failed", taskId), task)
 		task.Status = model.TaskStatusFailure
@@ -444,3 +445,22 @@ func truncateBase64(s string) string {
 	}
 	return s[:maxKeep] + "..."
 }
+
+// settleTaskBillingOnComplete 任务完成时的统一计费调整。
+// 优先级:1. adaptor.AdjustBillingOnComplete 返回正数 → 使用 adaptor 计算的额度
+//
+//  2. taskResult.TotalTokens > 0 → 按 token 重算
+//  3. 都不满足 → 保持预扣额度不变
+func settleTaskBillingOnComplete(ctx context.Context, adaptor TaskPollingAdaptor, task *model.Task, taskResult *relaycommon.TaskInfo) {
+	// 1. 优先让 adaptor 决定最终额度
+	if actualQuota := adaptor.AdjustBillingOnComplete(task, taskResult); actualQuota > 0 {
+		RecalculateTaskQuota(ctx, task, actualQuota, "adaptor计费调整")
+		return
+	}
+	// 2. 回退到 token 重算
+	if taskResult.TotalTokens > 0 {
+		RecalculateTaskQuotaByTokens(ctx, task, taskResult.TotalTokens)
+		return
+	}
+	// 3. 无调整,保持预扣额度
+}