|
|
@@ -128,8 +128,9 @@ func ResolveOriginTask(c *gin.Context, info *relaycommon.RelayInfo) *dto.TaskErr
|
|
|
}
|
|
|
|
|
|
// RelayTaskSubmit 完成 task 提交的全部流程(每次尝试调用一次):
|
|
|
-// 刷新渠道元数据 → 确定 platform/adaptor → 验证请求 → 计算价格 →
|
|
|
-// 预扣费(仅首次,通过 info.Billing==nil 守卫)→ 构建/发送/解析上游请求。
|
|
|
+// 刷新渠道元数据 → 确定 platform/adaptor → 验证请求 →
|
|
|
+// 估算计费(EstimateBilling) → 计算价格 → 预扣费(仅首次)→
|
|
|
+// 构建/发送/解析上游请求 → 提交后计费调整(AdjustBillingOnSubmit)。
|
|
|
// 控制器负责 defer Refund 和成功后 Settle。
|
|
|
func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (*TaskSubmitResult, *dto.TaskError) {
|
|
|
info.InitChannelMeta(c)
|
|
|
@@ -159,10 +160,20 @@ func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (*TaskSubmitRe
|
|
|
info.PublicTaskID = model.GenerateTaskID()
|
|
|
}
|
|
|
|
|
|
- // 4. 价格计算
|
|
|
+ // 4. 价格计算:基础模型价格
|
|
|
info.OriginModelName = modelName
|
|
|
info.PriceData = helper.ModelPriceHelperPerCall(c, info)
|
|
|
|
|
|
+ // 5. 计费估算:让适配器根据用户请求提供 OtherRatios(时长、分辨率等)
|
|
|
+ // 必须在 ModelPriceHelperPerCall 之后调用(它会重建 PriceData)。
|
|
|
+ // ResolveOriginTask 可能已在 remix 路径中预设了 OtherRatios,此处合并。
|
|
|
+ if estimatedRatios := adaptor.EstimateBilling(c, info); len(estimatedRatios) > 0 {
|
|
|
+ for k, v := range estimatedRatios {
|
|
|
+ info.PriceData.AddOtherRatio(k, v)
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ // 6. 将 OtherRatios 应用到基础额度
|
|
|
if !common.StringsContains(constant.TaskPricePatches, modelName) {
|
|
|
for _, ra := range info.PriceData.OtherRatios {
|
|
|
if ra != 1.0 {
|
|
|
@@ -171,7 +182,7 @@ func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (*TaskSubmitRe
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- // 5. 预扣费(仅首次 — 重试时 info.Billing 已存在,跳过)
|
|
|
+ // 7. 预扣费(仅首次 — 重试时 info.Billing 已存在,跳过)
|
|
|
if info.Billing == nil && !info.PriceData.FreeModel {
|
|
|
info.ForcePreConsume = true
|
|
|
if apiErr := service.PreConsumeBilling(c, info.PriceData.Quota, info); apiErr != nil {
|
|
|
@@ -179,13 +190,13 @@ func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (*TaskSubmitRe
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- // 6. 构建请求体
|
|
|
+ // 8. 构建请求体
|
|
|
requestBody, err := adaptor.BuildRequestBody(c, info)
|
|
|
if err != nil {
|
|
|
return nil, service.TaskErrorWrapper(err, "build_request_failed", http.StatusInternalServerError)
|
|
|
}
|
|
|
|
|
|
- // 7. 发送请求
|
|
|
+ // 9. 发送请求
|
|
|
resp, err := adaptor.DoRequest(c, info, requestBody)
|
|
|
if err != nil {
|
|
|
return nil, service.TaskErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
|
|
|
@@ -195,20 +206,59 @@ func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (*TaskSubmitRe
|
|
|
return nil, service.TaskErrorWrapper(fmt.Errorf("%s", string(responseBody)), "fail_to_fetch_task", resp.StatusCode)
|
|
|
}
|
|
|
|
|
|
- // 8. 解析响应
|
|
|
+ // 10. 返回 OtherRatios 给下游(header 必须在 DoResponse 写 body 之前设置)
|
|
|
+ otherRatios := info.PriceData.OtherRatios
|
|
|
+ if otherRatios == nil {
|
|
|
+ otherRatios = map[string]float64{}
|
|
|
+ }
|
|
|
+ ratiosJSON, _ := common.Marshal(otherRatios)
|
|
|
+ c.Header("X-New-Api-Other-Ratios", string(ratiosJSON))
|
|
|
+
|
|
|
+ // 11. 解析响应
|
|
|
upstreamTaskID, taskData, taskErr := adaptor.DoResponse(c, resp, info)
|
|
|
if taskErr != nil {
|
|
|
return nil, taskErr
|
|
|
}
|
|
|
|
|
|
+ // 11. 提交后计费调整:让适配器根据上游实际返回调整 OtherRatios
|
|
|
+ finalQuota := info.PriceData.Quota
|
|
|
+ if adjustedRatios := adaptor.AdjustBillingOnSubmit(info, taskData); len(adjustedRatios) > 0 {
|
|
|
+ // 基于调整后的 ratios 重新计算 quota
|
|
|
+ finalQuota = recalcQuotaFromRatios(info, adjustedRatios)
|
|
|
+ info.PriceData.OtherRatios = adjustedRatios
|
|
|
+ info.PriceData.Quota = finalQuota
|
|
|
+ }
|
|
|
+
|
|
|
return &TaskSubmitResult{
|
|
|
UpstreamTaskID: upstreamTaskID,
|
|
|
TaskData: taskData,
|
|
|
Platform: platform,
|
|
|
ModelName: modelName,
|
|
|
+ Quota: finalQuota,
|
|
|
}, nil
|
|
|
}
|
|
|
|
|
|
+// recalcQuotaFromRatios 根据 adjustedRatios 重新计算 quota。
|
|
|
+// 公式: baseQuota × ∏(ratio) — 其中 baseQuota 是不含 OtherRatios 的基础额度。
|
|
|
+func recalcQuotaFromRatios(info *relaycommon.RelayInfo, ratios map[string]float64) int {
|
|
|
+ // 从 PriceData 获取不含 OtherRatios 的基础价格
|
|
|
+ baseQuota := info.PriceData.Quota
|
|
|
+ // 先除掉原有的 OtherRatios 恢复基础额度
|
|
|
+ for _, ra := range info.PriceData.OtherRatios {
|
|
|
+ if ra != 1.0 && ra > 0 {
|
|
|
+ baseQuota = int(float64(baseQuota) / ra)
|
|
|
+ }
|
|
|
+ }
|
|
|
+ // 应用新的 ratios
|
|
|
+ result := float64(baseQuota)
|
|
|
+ for _, ra := range ratios {
|
|
|
+ if ra != 1.0 {
|
|
|
+ result *= ra
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return int(result)
|
|
|
+}
|
|
|
+
|
|
|
var fetchRespBuilders = map[int]func(c *gin.Context) (respBody []byte, taskResp *dto.TaskError){
|
|
|
relayconstant.RelayModeSunoFetchByID: sunoFetchByIDRespBodyBuilder,
|
|
|
relayconstant.RelayModeSunoFetch: sunoFetchRespBodyBuilder,
|