|
|
@@ -2,7 +2,6 @@ package relay
|
|
|
|
|
|
import (
|
|
|
"bytes"
|
|
|
- "encoding/json"
|
|
|
"errors"
|
|
|
"fmt"
|
|
|
"io"
|
|
|
@@ -15,29 +14,33 @@ import (
|
|
|
"github.com/QuantumNous/new-api/dto"
|
|
|
"github.com/QuantumNous/new-api/model"
|
|
|
"github.com/QuantumNous/new-api/relay/channel"
|
|
|
+ "github.com/QuantumNous/new-api/relay/channel/task/taskcommon"
|
|
|
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
|
|
relayconstant "github.com/QuantumNous/new-api/relay/constant"
|
|
|
+ "github.com/QuantumNous/new-api/relay/helper"
|
|
|
"github.com/QuantumNous/new-api/service"
|
|
|
- "github.com/QuantumNous/new-api/setting/ratio_setting"
|
|
|
-
|
|
|
"github.com/gin-gonic/gin"
|
|
|
)
|
|
|
|
|
|
-/*
|
|
|
-Task 任务通过平台、Action 区分任务
|
|
|
-*/
|
|
|
-func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) {
|
|
|
- info.InitChannelMeta(c)
|
|
|
- // ensure TaskRelayInfo is initialized to avoid nil dereference when accessing embedded fields
|
|
|
- if info.TaskRelayInfo == nil {
|
|
|
- info.TaskRelayInfo = &relaycommon.TaskRelayInfo{}
|
|
|
- }
|
|
|
+type TaskSubmitResult struct {
|
|
|
+ UpstreamTaskID string
|
|
|
+ TaskData []byte
|
|
|
+ Platform constant.TaskPlatform
|
|
|
+ ModelName string
|
|
|
+ Quota int
|
|
|
+ //PerCallPrice types.PriceData
|
|
|
+}
|
|
|
+
|
|
|
+// ResolveOriginTask 处理基于已有任务的提交(remix / continuation):
|
|
|
+// 查找原始任务、从中提取模型名称、将渠道锁定到原始任务的渠道(并通过
|
|
|
+// specific_channel_id 禁止重试),以及提取 OtherRatios(时长、分辨率)。
|
|
|
+// 该函数在控制器的重试循环之前调用一次,其结果通过 info 字段和上下文持久化。
|
|
|
+func ResolveOriginTask(c *gin.Context, info *relaycommon.RelayInfo) *dto.TaskError {
|
|
|
+ // 检测 remix action
|
|
|
path := c.Request.URL.Path
|
|
|
if strings.Contains(path, "/v1/videos/") && strings.HasSuffix(path, "/remix") {
|
|
|
info.Action = constant.TaskActionRemix
|
|
|
}
|
|
|
-
|
|
|
- // 提取 remix 任务的 video_id
|
|
|
if info.Action == constant.TaskActionRemix {
|
|
|
videoID := c.Param("video_id")
|
|
|
if strings.TrimSpace(videoID) == "" {
|
|
|
@@ -46,241 +49,164 @@ func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.
|
|
|
info.OriginTaskID = videoID
|
|
|
}
|
|
|
|
|
|
- platform := constant.TaskPlatform(c.GetString("platform"))
|
|
|
+ if info.OriginTaskID == "" {
|
|
|
+ return nil
|
|
|
+ }
|
|
|
|
|
|
- // 获取原始任务信息
|
|
|
- if info.OriginTaskID != "" {
|
|
|
- originTask, exist, err := model.GetByTaskId(info.UserId, info.OriginTaskID)
|
|
|
- if err != nil {
|
|
|
- taskErr = service.TaskErrorWrapper(err, "get_origin_task_failed", http.StatusInternalServerError)
|
|
|
- return
|
|
|
+ // 查找原始任务
|
|
|
+ originTask, exist, err := model.GetByTaskId(info.UserId, info.OriginTaskID)
|
|
|
+ if err != nil {
|
|
|
+ return service.TaskErrorWrapper(err, "get_origin_task_failed", http.StatusInternalServerError)
|
|
|
+ }
|
|
|
+ if !exist {
|
|
|
+ return service.TaskErrorWrapperLocal(errors.New("task_origin_not_exist"), "task_not_exist", http.StatusBadRequest)
|
|
|
+ }
|
|
|
+
|
|
|
+ // 从原始任务推导模型名称
|
|
|
+ if info.OriginModelName == "" {
|
|
|
+ if originTask.Properties.OriginModelName != "" {
|
|
|
+ info.OriginModelName = originTask.Properties.OriginModelName
|
|
|
+ } else if originTask.Properties.UpstreamModelName != "" {
|
|
|
+ info.OriginModelName = originTask.Properties.UpstreamModelName
|
|
|
+ } else {
|
|
|
+ var taskData map[string]interface{}
|
|
|
+ _ = common.Unmarshal(originTask.Data, &taskData)
|
|
|
+ if m, ok := taskData["model"].(string); ok && m != "" {
|
|
|
+ info.OriginModelName = m
|
|
|
+ }
|
|
|
}
|
|
|
- if !exist {
|
|
|
- taskErr = service.TaskErrorWrapperLocal(errors.New("task_origin_not_exist"), "task_not_exist", http.StatusBadRequest)
|
|
|
- return
|
|
|
+ }
|
|
|
+
|
|
|
+ // 锁定到原始任务的渠道(如果与当前选中的不同)
|
|
|
+ if originTask.ChannelId != info.ChannelId {
|
|
|
+ ch, err := model.GetChannelById(originTask.ChannelId, true)
|
|
|
+ if err != nil {
|
|
|
+ return service.TaskErrorWrapperLocal(err, "channel_not_found", http.StatusBadRequest)
|
|
|
}
|
|
|
- if info.OriginModelName == "" {
|
|
|
- if originTask.Properties.OriginModelName != "" {
|
|
|
- info.OriginModelName = originTask.Properties.OriginModelName
|
|
|
- } else if originTask.Properties.UpstreamModelName != "" {
|
|
|
- info.OriginModelName = originTask.Properties.UpstreamModelName
|
|
|
- } else {
|
|
|
- var taskData map[string]interface{}
|
|
|
- _ = json.Unmarshal(originTask.Data, &taskData)
|
|
|
- if m, ok := taskData["model"].(string); ok && m != "" {
|
|
|
- info.OriginModelName = m
|
|
|
- platform = originTask.Platform
|
|
|
- }
|
|
|
- }
|
|
|
+ if ch.Status != common.ChannelStatusEnabled {
|
|
|
+ return service.TaskErrorWrapperLocal(errors.New("the channel of the origin task is disabled"), "task_channel_disable", http.StatusBadRequest)
|
|
|
}
|
|
|
- if originTask.ChannelId != info.ChannelId {
|
|
|
- channel, err := model.GetChannelById(originTask.ChannelId, true)
|
|
|
- if err != nil {
|
|
|
- taskErr = service.TaskErrorWrapperLocal(err, "channel_not_found", http.StatusBadRequest)
|
|
|
- return
|
|
|
- }
|
|
|
- if channel.Status != common.ChannelStatusEnabled {
|
|
|
- taskErr = service.TaskErrorWrapperLocal(errors.New("the channel of the origin task is disabled"), "task_channel_disable", http.StatusBadRequest)
|
|
|
- return
|
|
|
- }
|
|
|
- key, _, newAPIError := channel.GetNextEnabledKey()
|
|
|
- if newAPIError != nil {
|
|
|
- taskErr = service.TaskErrorWrapper(newAPIError, "channel_no_available_key", newAPIError.StatusCode)
|
|
|
- return
|
|
|
- }
|
|
|
- common.SetContextKey(c, constant.ContextKeyChannelKey, key)
|
|
|
- common.SetContextKey(c, constant.ContextKeyChannelType, channel.Type)
|
|
|
- common.SetContextKey(c, constant.ContextKeyChannelBaseUrl, channel.GetBaseURL())
|
|
|
- common.SetContextKey(c, constant.ContextKeyChannelId, originTask.ChannelId)
|
|
|
-
|
|
|
- info.ChannelBaseUrl = channel.GetBaseURL()
|
|
|
- info.ChannelId = originTask.ChannelId
|
|
|
- info.ChannelType = channel.Type
|
|
|
- info.ApiKey = key
|
|
|
- platform = originTask.Platform
|
|
|
+ key, _, newAPIError := ch.GetNextEnabledKey()
|
|
|
+ if newAPIError != nil {
|
|
|
+ return service.TaskErrorWrapper(newAPIError, "channel_no_available_key", newAPIError.StatusCode)
|
|
|
}
|
|
|
+ common.SetContextKey(c, constant.ContextKeyChannelKey, key)
|
|
|
+ common.SetContextKey(c, constant.ContextKeyChannelType, ch.Type)
|
|
|
+ common.SetContextKey(c, constant.ContextKeyChannelBaseUrl, ch.GetBaseURL())
|
|
|
+ common.SetContextKey(c, constant.ContextKeyChannelId, originTask.ChannelId)
|
|
|
|
|
|
- // 使用原始任务的参数
|
|
|
- if info.Action == constant.TaskActionRemix {
|
|
|
- var taskData map[string]interface{}
|
|
|
- _ = json.Unmarshal(originTask.Data, &taskData)
|
|
|
- secondsStr, _ := taskData["seconds"].(string)
|
|
|
- seconds, _ := strconv.Atoi(secondsStr)
|
|
|
- if seconds <= 0 {
|
|
|
- seconds = 4
|
|
|
- }
|
|
|
- sizeStr, _ := taskData["size"].(string)
|
|
|
- if info.PriceData.OtherRatios == nil {
|
|
|
- info.PriceData.OtherRatios = map[string]float64{}
|
|
|
- }
|
|
|
- info.PriceData.OtherRatios["seconds"] = float64(seconds)
|
|
|
- info.PriceData.OtherRatios["size"] = 1
|
|
|
- if sizeStr == "1792x1024" || sizeStr == "1024x1792" {
|
|
|
- info.PriceData.OtherRatios["size"] = 1.666667
|
|
|
- }
|
|
|
+ info.ChannelBaseUrl = ch.GetBaseURL()
|
|
|
+ info.ChannelId = originTask.ChannelId
|
|
|
+ info.ChannelType = ch.Type
|
|
|
+ info.ApiKey = key
|
|
|
+ }
|
|
|
+
|
|
|
+ // 渠道已锁定到原始任务 → 禁止重试切换到其他渠道
|
|
|
+ c.Set("specific_channel_id", fmt.Sprintf("%d", originTask.ChannelId))
|
|
|
+
|
|
|
+ // 提取 remix 参数(时长、分辨率 → OtherRatios)
|
|
|
+ if info.Action == constant.TaskActionRemix {
|
|
|
+ var taskData map[string]interface{}
|
|
|
+ _ = common.Unmarshal(originTask.Data, &taskData)
|
|
|
+ secondsStr, _ := taskData["seconds"].(string)
|
|
|
+ seconds, _ := strconv.Atoi(secondsStr)
|
|
|
+ if seconds <= 0 {
|
|
|
+ seconds = 4
|
|
|
+ }
|
|
|
+ sizeStr, _ := taskData["size"].(string)
|
|
|
+ if info.PriceData.OtherRatios == nil {
|
|
|
+ info.PriceData.OtherRatios = map[string]float64{}
|
|
|
+ }
|
|
|
+ info.PriceData.OtherRatios["seconds"] = float64(seconds)
|
|
|
+ info.PriceData.OtherRatios["size"] = 1
|
|
|
+ if sizeStr == "1792x1024" || sizeStr == "1024x1792" {
|
|
|
+ info.PriceData.OtherRatios["size"] = 1.666667
|
|
|
}
|
|
|
}
|
|
|
+
|
|
|
+ return nil
|
|
|
+}
|
|
|
+
|
|
|
+// RelayTaskSubmit 完成 task 提交的全部流程(每次尝试调用一次):
|
|
|
+// 刷新渠道元数据 → 确定 platform/adaptor → 验证请求 → 计算价格 →
|
|
|
+// 预扣费(仅首次,通过 info.Billing==nil 守卫)→ 构建/发送/解析上游请求。
|
|
|
+// 控制器负责 defer Refund 和成功后 Settle。
|
|
|
+func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (*TaskSubmitResult, *dto.TaskError) {
|
|
|
+ info.InitChannelMeta(c)
|
|
|
+
|
|
|
+ // 1. 确定 platform → 创建适配器 → 验证请求
|
|
|
+ platform := constant.TaskPlatform(c.GetString("platform"))
|
|
|
if platform == "" {
|
|
|
platform = GetTaskPlatform(c)
|
|
|
}
|
|
|
-
|
|
|
- info.InitChannelMeta(c)
|
|
|
adaptor := GetTaskAdaptor(platform)
|
|
|
if adaptor == nil {
|
|
|
- return service.TaskErrorWrapperLocal(fmt.Errorf("invalid api platform: %s", platform), "invalid_api_platform", http.StatusBadRequest)
|
|
|
+ return nil, service.TaskErrorWrapperLocal(fmt.Errorf("invalid api platform: %s", platform), "invalid_api_platform", http.StatusBadRequest)
|
|
|
}
|
|
|
adaptor.Init(info)
|
|
|
- // get & validate taskRequest 获取并验证文本请求
|
|
|
- taskErr = adaptor.ValidateRequestAndSetAction(c, info)
|
|
|
- if taskErr != nil {
|
|
|
- return
|
|
|
+ if taskErr := adaptor.ValidateRequestAndSetAction(c, info); taskErr != nil {
|
|
|
+ return nil, taskErr
|
|
|
}
|
|
|
|
|
|
+ // 2. 确定模型名称
|
|
|
modelName := info.OriginModelName
|
|
|
if modelName == "" {
|
|
|
modelName = service.CoverTaskActionToModelName(platform, info.Action)
|
|
|
}
|
|
|
- modelPrice, success := ratio_setting.GetModelPrice(modelName, true)
|
|
|
- if !success {
|
|
|
- defaultPrice, ok := ratio_setting.GetDefaultModelPriceMap()[modelName]
|
|
|
- if !ok {
|
|
|
- modelPrice = float64(common.PreConsumedQuota) / common.QuotaPerUnit
|
|
|
- } else {
|
|
|
- modelPrice = defaultPrice
|
|
|
- }
|
|
|
- }
|
|
|
|
|
|
- // 处理 auto 分组:从 context 获取实际选中的分组
|
|
|
- // 当使用 auto 分组时,Distribute 中间件会将实际选中的分组存储在 ContextKeyAutoGroup 中
|
|
|
- if autoGroup, exists := common.GetContextKey(c, constant.ContextKeyAutoGroup); exists {
|
|
|
- if groupStr, ok := autoGroup.(string); ok && groupStr != "" {
|
|
|
- info.UsingGroup = groupStr
|
|
|
- }
|
|
|
+ // 3. 预生成公开 task ID(仅首次)
|
|
|
+ if info.PublicTaskID == "" {
|
|
|
+ info.PublicTaskID = model.GenerateTaskID()
|
|
|
}
|
|
|
|
|
|
- // 预扣
|
|
|
- groupRatio := ratio_setting.GetGroupRatio(info.UsingGroup)
|
|
|
- var ratio float64
|
|
|
- userGroupRatio, hasUserGroupRatio := ratio_setting.GetGroupGroupRatio(info.UserGroup, info.UsingGroup)
|
|
|
- if hasUserGroupRatio {
|
|
|
- ratio = modelPrice * userGroupRatio
|
|
|
- } else {
|
|
|
- ratio = modelPrice * groupRatio
|
|
|
- }
|
|
|
- // FIXME: 临时修补,支持任务仅按次计费
|
|
|
+ // 4. 价格计算
|
|
|
+ info.OriginModelName = modelName
|
|
|
+ info.PriceData = helper.ModelPriceHelperPerCall(c, info)
|
|
|
+
|
|
|
if !common.StringsContains(constant.TaskPricePatches, modelName) {
|
|
|
- if len(info.PriceData.OtherRatios) > 0 {
|
|
|
- for _, ra := range info.PriceData.OtherRatios {
|
|
|
- if 1.0 != ra {
|
|
|
- ratio *= ra
|
|
|
- }
|
|
|
+ for _, ra := range info.PriceData.OtherRatios {
|
|
|
+ if ra != 1.0 {
|
|
|
+ info.PriceData.Quota = int(float64(info.PriceData.Quota) * ra)
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
- println(fmt.Sprintf("model: %s, model_price: %.4f, group: %s, group_ratio: %.4f, final_ratio: %.4f", modelName, modelPrice, info.UsingGroup, groupRatio, ratio))
|
|
|
- userQuota, err := model.GetUserQuota(info.UserId, false)
|
|
|
- if err != nil {
|
|
|
- taskErr = service.TaskErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
|
|
|
- return
|
|
|
- }
|
|
|
- quota := int(ratio * common.QuotaPerUnit)
|
|
|
- if userQuota-quota < 0 {
|
|
|
- taskErr = service.TaskErrorWrapperLocal(errors.New("user quota is not enough"), "quota_not_enough", http.StatusForbidden)
|
|
|
- return
|
|
|
+
|
|
|
+ // 5. 预扣费(仅首次 — 重试时 info.Billing 已存在,跳过)
|
|
|
+ if info.Billing == nil && !info.PriceData.FreeModel {
|
|
|
+ info.ForcePreConsume = true
|
|
|
+ if apiErr := service.PreConsumeBilling(c, info.PriceData.Quota, info); apiErr != nil {
|
|
|
+ return nil, service.TaskErrorFromAPIError(apiErr)
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
- // build body
|
|
|
+ // 6. 构建请求体
|
|
|
requestBody, err := adaptor.BuildRequestBody(c, info)
|
|
|
if err != nil {
|
|
|
- taskErr = service.TaskErrorWrapper(err, "build_request_failed", http.StatusInternalServerError)
|
|
|
- return
|
|
|
+ return nil, service.TaskErrorWrapper(err, "build_request_failed", http.StatusInternalServerError)
|
|
|
}
|
|
|
- // do request
|
|
|
+
|
|
|
+ // 7. 发送请求
|
|
|
resp, err := adaptor.DoRequest(c, info, requestBody)
|
|
|
if err != nil {
|
|
|
- taskErr = service.TaskErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
|
|
|
- return
|
|
|
+ return nil, service.TaskErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
|
|
|
}
|
|
|
- // handle response
|
|
|
if resp != nil && resp.StatusCode != http.StatusOK {
|
|
|
responseBody, _ := io.ReadAll(resp.Body)
|
|
|
- taskErr = service.TaskErrorWrapper(fmt.Errorf("%s", string(responseBody)), "fail_to_fetch_task", resp.StatusCode)
|
|
|
- return
|
|
|
+ return nil, service.TaskErrorWrapper(fmt.Errorf("%s", string(responseBody)), "fail_to_fetch_task", resp.StatusCode)
|
|
|
}
|
|
|
|
|
|
- defer func() {
|
|
|
- // release quota
|
|
|
- if info.ConsumeQuota && taskErr == nil {
|
|
|
-
|
|
|
- err := service.PostConsumeQuota(info, quota, 0, true)
|
|
|
- if err != nil {
|
|
|
- common.SysLog("error consuming token remain quota: " + err.Error())
|
|
|
- }
|
|
|
- if quota != 0 {
|
|
|
- tokenName := c.GetString("token_name")
|
|
|
- //gRatio := groupRatio
|
|
|
- //if hasUserGroupRatio {
|
|
|
- // gRatio = userGroupRatio
|
|
|
- //}
|
|
|
- logContent := fmt.Sprintf("操作 %s", info.Action)
|
|
|
- // FIXME: 临时修补,支持任务仅按次计费
|
|
|
- if common.StringsContains(constant.TaskPricePatches, modelName) {
|
|
|
- logContent = fmt.Sprintf("%s,按次计费", logContent)
|
|
|
- } else {
|
|
|
- if len(info.PriceData.OtherRatios) > 0 {
|
|
|
- var contents []string
|
|
|
- for key, ra := range info.PriceData.OtherRatios {
|
|
|
- if 1.0 != ra {
|
|
|
- contents = append(contents, fmt.Sprintf("%s: %.2f", key, ra))
|
|
|
- }
|
|
|
- }
|
|
|
- if len(contents) > 0 {
|
|
|
- logContent = fmt.Sprintf("%s, 计算参数:%s", logContent, strings.Join(contents, ", "))
|
|
|
- }
|
|
|
- }
|
|
|
- }
|
|
|
- other := make(map[string]interface{})
|
|
|
- if c != nil && c.Request != nil && c.Request.URL != nil {
|
|
|
- other["request_path"] = c.Request.URL.Path
|
|
|
- }
|
|
|
- other["model_price"] = modelPrice
|
|
|
- other["group_ratio"] = groupRatio
|
|
|
- if hasUserGroupRatio {
|
|
|
- other["user_group_ratio"] = userGroupRatio
|
|
|
- }
|
|
|
- model.RecordConsumeLog(c, info.UserId, model.RecordConsumeLogParams{
|
|
|
- ChannelId: info.ChannelId,
|
|
|
- ModelName: modelName,
|
|
|
- TokenName: tokenName,
|
|
|
- Quota: quota,
|
|
|
- Content: logContent,
|
|
|
- TokenId: info.TokenId,
|
|
|
- Group: info.UsingGroup,
|
|
|
- Other: other,
|
|
|
- })
|
|
|
- model.UpdateUserUsedQuotaAndRequestCount(info.UserId, quota)
|
|
|
- model.UpdateChannelUsedQuota(info.ChannelId, quota)
|
|
|
- }
|
|
|
- }
|
|
|
- }()
|
|
|
-
|
|
|
- taskID, taskData, taskErr := adaptor.DoResponse(c, resp, info)
|
|
|
+ // 8. 解析响应
|
|
|
+ upstreamTaskID, taskData, taskErr := adaptor.DoResponse(c, resp, info)
|
|
|
if taskErr != nil {
|
|
|
- return
|
|
|
+ return nil, taskErr
|
|
|
}
|
|
|
- info.ConsumeQuota = true
|
|
|
- // insert task
|
|
|
- task := model.InitTask(platform, info)
|
|
|
- task.TaskID = taskID
|
|
|
- task.Quota = quota
|
|
|
- task.Data = taskData
|
|
|
- task.Action = info.Action
|
|
|
- err = task.Insert()
|
|
|
- if err != nil {
|
|
|
- taskErr = service.TaskErrorWrapper(err, "insert_task_failed", http.StatusInternalServerError)
|
|
|
- return
|
|
|
- }
|
|
|
- return nil
|
|
|
+
|
|
|
+ return &TaskSubmitResult{
|
|
|
+ UpstreamTaskID: upstreamTaskID,
|
|
|
+ TaskData: taskData,
|
|
|
+ Platform: platform,
|
|
|
+ ModelName: modelName,
|
|
|
+ }, nil
|
|
|
}
|
|
|
|
|
|
var fetchRespBuilders = map[int]func(c *gin.Context) (respBody []byte, taskResp *dto.TaskError){
|
|
|
@@ -336,7 +262,7 @@ func sunoFetchRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *dto.Ta
|
|
|
} else {
|
|
|
tasks = make([]any, 0)
|
|
|
}
|
|
|
- respBody, err = json.Marshal(dto.TaskResponse[[]any]{
|
|
|
+ respBody, err = common.Marshal(dto.TaskResponse[[]any]{
|
|
|
Code: "success",
|
|
|
Data: tasks,
|
|
|
})
|
|
|
@@ -357,7 +283,7 @@ func sunoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *dt
|
|
|
return
|
|
|
}
|
|
|
|
|
|
- respBody, err = json.Marshal(dto.TaskResponse[any]{
|
|
|
+ respBody, err = common.Marshal(dto.TaskResponse[any]{
|
|
|
Code: "success",
|
|
|
Data: TaskModel2Dto(originTask),
|
|
|
})
|
|
|
@@ -381,97 +307,16 @@ func videoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *d
|
|
|
return
|
|
|
}
|
|
|
|
|
|
- func() {
|
|
|
- channelModel, err2 := model.GetChannelById(originTask.ChannelId, true)
|
|
|
- if err2 != nil {
|
|
|
- return
|
|
|
- }
|
|
|
- if channelModel.Type != constant.ChannelTypeVertexAi && channelModel.Type != constant.ChannelTypeGemini {
|
|
|
- return
|
|
|
- }
|
|
|
- baseURL := constant.ChannelBaseURLs[channelModel.Type]
|
|
|
- if channelModel.GetBaseURL() != "" {
|
|
|
- baseURL = channelModel.GetBaseURL()
|
|
|
- }
|
|
|
- proxy := channelModel.GetSetting().Proxy
|
|
|
- adaptor := GetTaskAdaptor(constant.TaskPlatform(strconv.Itoa(channelModel.Type)))
|
|
|
- if adaptor == nil {
|
|
|
- return
|
|
|
- }
|
|
|
- resp, err2 := adaptor.FetchTask(baseURL, channelModel.Key, map[string]any{
|
|
|
- "task_id": originTask.TaskID,
|
|
|
- "action": originTask.Action,
|
|
|
- }, proxy)
|
|
|
- if err2 != nil || resp == nil {
|
|
|
- return
|
|
|
- }
|
|
|
- defer resp.Body.Close()
|
|
|
- body, err2 := io.ReadAll(resp.Body)
|
|
|
- if err2 != nil {
|
|
|
- return
|
|
|
- }
|
|
|
- ti, err2 := adaptor.ParseTaskResult(body)
|
|
|
- if err2 == nil && ti != nil {
|
|
|
- if ti.Status != "" {
|
|
|
- originTask.Status = model.TaskStatus(ti.Status)
|
|
|
- }
|
|
|
- if ti.Progress != "" {
|
|
|
- originTask.Progress = ti.Progress
|
|
|
- }
|
|
|
- if ti.Url != "" {
|
|
|
- if strings.HasPrefix(ti.Url, "data:") {
|
|
|
- } else {
|
|
|
- originTask.FailReason = ti.Url
|
|
|
- }
|
|
|
- }
|
|
|
- _ = originTask.Update()
|
|
|
- var raw map[string]any
|
|
|
- _ = json.Unmarshal(body, &raw)
|
|
|
- format := "mp4"
|
|
|
- if respObj, ok := raw["response"].(map[string]any); ok {
|
|
|
- if vids, ok := respObj["videos"].([]any); ok && len(vids) > 0 {
|
|
|
- if v0, ok := vids[0].(map[string]any); ok {
|
|
|
- if mt, ok := v0["mimeType"].(string); ok && mt != "" {
|
|
|
- if strings.Contains(mt, "mp4") {
|
|
|
- format = "mp4"
|
|
|
- } else {
|
|
|
- format = mt
|
|
|
- }
|
|
|
- }
|
|
|
- }
|
|
|
- }
|
|
|
- }
|
|
|
- status := "processing"
|
|
|
- switch originTask.Status {
|
|
|
- case model.TaskStatusSuccess:
|
|
|
- status = "succeeded"
|
|
|
- case model.TaskStatusFailure:
|
|
|
- status = "failed"
|
|
|
- case model.TaskStatusQueued, model.TaskStatusSubmitted:
|
|
|
- status = "queued"
|
|
|
- }
|
|
|
- if !strings.HasPrefix(c.Request.RequestURI, "/v1/videos/") {
|
|
|
- out := map[string]any{
|
|
|
- "error": nil,
|
|
|
- "format": format,
|
|
|
- "metadata": nil,
|
|
|
- "status": status,
|
|
|
- "task_id": originTask.TaskID,
|
|
|
- "url": originTask.FailReason,
|
|
|
- }
|
|
|
- respBody, _ = json.Marshal(dto.TaskResponse[any]{
|
|
|
- Code: "success",
|
|
|
- Data: out,
|
|
|
- })
|
|
|
- }
|
|
|
- }
|
|
|
- }()
|
|
|
+ isOpenAIVideoAPI := strings.HasPrefix(c.Request.RequestURI, "/v1/videos/")
|
|
|
|
|
|
- if len(respBody) != 0 {
|
|
|
+ // Gemini/Vertex 支持实时查询:用户 fetch 时直接从上游拉取最新状态
|
|
|
+ if realtimeResp := tryRealtimeFetch(originTask, isOpenAIVideoAPI); len(realtimeResp) > 0 {
|
|
|
+ respBody = realtimeResp
|
|
|
return
|
|
|
}
|
|
|
|
|
|
- if strings.HasPrefix(c.Request.RequestURI, "/v1/videos/") {
|
|
|
+ // OpenAI Video API 格式: 走各 adaptor 的 ConvertToOpenAIVideo
|
|
|
+ if isOpenAIVideoAPI {
|
|
|
adaptor := GetTaskAdaptor(originTask.Platform)
|
|
|
if adaptor == nil {
|
|
|
taskResp = service.TaskErrorWrapperLocal(fmt.Errorf("invalid channel id: %d", originTask.ChannelId), "invalid_channel_id", http.StatusBadRequest)
|
|
|
@@ -486,10 +331,12 @@ func videoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *d
|
|
|
respBody = openAIVideoData
|
|
|
return
|
|
|
}
|
|
|
- taskResp = service.TaskErrorWrapperLocal(errors.New(fmt.Sprintf("not_implemented:%s", originTask.Platform)), "not_implemented", http.StatusNotImplemented)
|
|
|
+ taskResp = service.TaskErrorWrapperLocal(fmt.Errorf("not_implemented:%s", originTask.Platform), "not_implemented", http.StatusNotImplemented)
|
|
|
return
|
|
|
}
|
|
|
- respBody, err = json.Marshal(dto.TaskResponse[any]{
|
|
|
+
|
|
|
+ // 通用 TaskDto 格式
|
|
|
+ respBody, err = common.Marshal(dto.TaskResponse[any]{
|
|
|
Code: "success",
|
|
|
Data: TaskModel2Dto(originTask),
|
|
|
})
|
|
|
@@ -499,16 +346,145 @@ func videoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *d
|
|
|
return
|
|
|
}
|
|
|
|
|
|
+// tryRealtimeFetch 尝试从上游实时拉取 Gemini/Vertex 任务状态。
|
|
|
+// 仅当渠道类型为 Gemini 或 Vertex 时触发;其他渠道或出错时返回 nil。
|
|
|
+// 当非 OpenAI Video API 时,还会构建自定义格式的响应体。
|
|
|
+func tryRealtimeFetch(task *model.Task, isOpenAIVideoAPI bool) []byte {
|
|
|
+ channelModel, err := model.GetChannelById(task.ChannelId, true)
|
|
|
+ if err != nil {
|
|
|
+ return nil
|
|
|
+ }
|
|
|
+ if channelModel.Type != constant.ChannelTypeVertexAi && channelModel.Type != constant.ChannelTypeGemini {
|
|
|
+ return nil
|
|
|
+ }
|
|
|
+
|
|
|
+ baseURL := constant.ChannelBaseURLs[channelModel.Type]
|
|
|
+ if channelModel.GetBaseURL() != "" {
|
|
|
+ baseURL = channelModel.GetBaseURL()
|
|
|
+ }
|
|
|
+ proxy := channelModel.GetSetting().Proxy
|
|
|
+ adaptor := GetTaskAdaptor(constant.TaskPlatform(strconv.Itoa(channelModel.Type)))
|
|
|
+ if adaptor == nil {
|
|
|
+ return nil
|
|
|
+ }
|
|
|
+
|
|
|
+ resp, err := adaptor.FetchTask(baseURL, channelModel.Key, map[string]any{
|
|
|
+ "task_id": task.GetUpstreamTaskID(),
|
|
|
+ "action": task.Action,
|
|
|
+ }, proxy)
|
|
|
+ if err != nil || resp == nil {
|
|
|
+ return nil
|
|
|
+ }
|
|
|
+ defer resp.Body.Close()
|
|
|
+ body, err := io.ReadAll(resp.Body)
|
|
|
+ if err != nil {
|
|
|
+ return nil
|
|
|
+ }
|
|
|
+
|
|
|
+ ti, err := adaptor.ParseTaskResult(body)
|
|
|
+ if err != nil || ti == nil {
|
|
|
+ return nil
|
|
|
+ }
|
|
|
+
|
|
|
+ // 将上游最新状态更新到 task
|
|
|
+ if ti.Status != "" {
|
|
|
+ task.Status = model.TaskStatus(ti.Status)
|
|
|
+ }
|
|
|
+ if ti.Progress != "" {
|
|
|
+ task.Progress = ti.Progress
|
|
|
+ }
|
|
|
+ if strings.HasPrefix(ti.Url, "data:") {
|
|
|
+ // data: URI — kept in Data, not ResultURL
|
|
|
+ } else if ti.Url != "" {
|
|
|
+ task.PrivateData.ResultURL = ti.Url
|
|
|
+ } else if task.Status == model.TaskStatusSuccess {
|
|
|
+ // No URL from adaptor — construct proxy URL using public task ID
|
|
|
+ task.PrivateData.ResultURL = taskcommon.BuildProxyURL(task.TaskID)
|
|
|
+ }
|
|
|
+ _ = task.Update()
|
|
|
+
|
|
|
+ // OpenAI Video API 由调用者的 ConvertToOpenAIVideo 分支处理
|
|
|
+ if isOpenAIVideoAPI {
|
|
|
+ return nil
|
|
|
+ }
|
|
|
+
|
|
|
+ // 非 OpenAI Video API: 构建自定义格式响应
|
|
|
+ format := detectVideoFormat(body)
|
|
|
+ out := map[string]any{
|
|
|
+ "error": nil,
|
|
|
+ "format": format,
|
|
|
+ "metadata": nil,
|
|
|
+ "status": mapTaskStatusToSimple(task.Status),
|
|
|
+ "task_id": task.TaskID,
|
|
|
+ "url": task.GetResultURL(),
|
|
|
+ }
|
|
|
+ respBody, _ := common.Marshal(dto.TaskResponse[any]{
|
|
|
+ Code: "success",
|
|
|
+ Data: out,
|
|
|
+ })
|
|
|
+ return respBody
|
|
|
+}
|
|
|
+
|
|
|
+// detectVideoFormat 从 Gemini/Vertex 原始响应中探测视频格式
|
|
|
+func detectVideoFormat(rawBody []byte) string {
|
|
|
+ var raw map[string]any
|
|
|
+ if err := common.Unmarshal(rawBody, &raw); err != nil {
|
|
|
+ return "mp4"
|
|
|
+ }
|
|
|
+ respObj, ok := raw["response"].(map[string]any)
|
|
|
+ if !ok {
|
|
|
+ return "mp4"
|
|
|
+ }
|
|
|
+ vids, ok := respObj["videos"].([]any)
|
|
|
+ if !ok || len(vids) == 0 {
|
|
|
+ return "mp4"
|
|
|
+ }
|
|
|
+ v0, ok := vids[0].(map[string]any)
|
|
|
+ if !ok {
|
|
|
+ return "mp4"
|
|
|
+ }
|
|
|
+ mt, ok := v0["mimeType"].(string)
|
|
|
+ if !ok || mt == "" || strings.Contains(mt, "mp4") {
|
|
|
+ return "mp4"
|
|
|
+ }
|
|
|
+ return mt
|
|
|
+}
|
|
|
+
|
|
|
+// mapTaskStatusToSimple 将内部 TaskStatus 映射为简化状态字符串
|
|
|
+func mapTaskStatusToSimple(status model.TaskStatus) string {
|
|
|
+ switch status {
|
|
|
+ case model.TaskStatusSuccess:
|
|
|
+ return "succeeded"
|
|
|
+ case model.TaskStatusFailure:
|
|
|
+ return "failed"
|
|
|
+ case model.TaskStatusQueued, model.TaskStatusSubmitted:
|
|
|
+ return "queued"
|
|
|
+ default:
|
|
|
+ return "processing"
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
func TaskModel2Dto(task *model.Task) *dto.TaskDto {
|
|
|
return &dto.TaskDto{
|
|
|
+ ID: task.ID,
|
|
|
+ CreatedAt: task.CreatedAt,
|
|
|
+ UpdatedAt: task.UpdatedAt,
|
|
|
TaskID: task.TaskID,
|
|
|
+ Platform: string(task.Platform),
|
|
|
+ UserId: task.UserId,
|
|
|
+ Group: task.Group,
|
|
|
+ ChannelId: task.ChannelId,
|
|
|
+ Quota: task.Quota,
|
|
|
Action: task.Action,
|
|
|
Status: string(task.Status),
|
|
|
FailReason: task.FailReason,
|
|
|
+ ResultURL: task.GetResultURL(),
|
|
|
SubmitTime: task.SubmitTime,
|
|
|
StartTime: task.StartTime,
|
|
|
FinishTime: task.FinishTime,
|
|
|
Progress: task.Progress,
|
|
|
+ Properties: task.Properties,
|
|
|
+ Username: task.Username,
|
|
|
Data: task.Data,
|
|
|
}
|
|
|
}
|