| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564 |
- package relay
- import (
- "bytes"
- "errors"
- "fmt"
- "io"
- "net/http"
- "strconv"
- "strings"
- "github.com/QuantumNous/new-api/common"
- "github.com/QuantumNous/new-api/constant"
- "github.com/QuantumNous/new-api/dto"
- "github.com/QuantumNous/new-api/model"
- "github.com/QuantumNous/new-api/relay/channel"
- "github.com/QuantumNous/new-api/relay/channel/task/taskcommon"
- relaycommon "github.com/QuantumNous/new-api/relay/common"
- relayconstant "github.com/QuantumNous/new-api/relay/constant"
- "github.com/QuantumNous/new-api/relay/helper"
- "github.com/QuantumNous/new-api/service"
- "github.com/gin-gonic/gin"
- )
- type TaskSubmitResult struct {
- UpstreamTaskID string
- TaskData []byte
- Platform constant.TaskPlatform
- Quota int
- //PerCallPrice types.PriceData
- }
- // ResolveOriginTask 处理基于已有任务的提交(remix / continuation):
- // 查找原始任务、从中提取模型名称、将渠道锁定到原始任务的渠道
- // (通过 info.LockedChannel,重试时复用同一渠道并轮换 key),
- // 以及提取 OtherRatios(时长、分辨率)。
- // 该函数在控制器的重试循环之前调用一次,其结果通过 info 字段和上下文持久化。
- func ResolveOriginTask(c *gin.Context, info *relaycommon.RelayInfo) *dto.TaskError {
- // 检测 remix action
- path := c.Request.URL.Path
- if strings.Contains(path, "/v1/videos/") && strings.HasSuffix(path, "/remix") {
- info.Action = constant.TaskActionRemix
- }
- // 提取 remix 任务的 video_id
- if info.Action == constant.TaskActionRemix {
- videoID := c.Param("video_id")
- if strings.TrimSpace(videoID) == "" {
- return service.TaskErrorWrapperLocal(fmt.Errorf("video_id is required"), "invalid_request", http.StatusBadRequest)
- }
- info.OriginTaskID = videoID
- }
- if info.OriginTaskID == "" {
- return nil
- }
- // 查找原始任务
- originTask, exist, err := model.GetByTaskId(info.UserId, info.OriginTaskID)
- if err != nil {
- return service.TaskErrorWrapper(err, "get_origin_task_failed", http.StatusInternalServerError)
- }
- if !exist {
- return service.TaskErrorWrapperLocal(errors.New("task_origin_not_exist"), "task_not_exist", http.StatusBadRequest)
- }
- // 从原始任务推导模型名称
- if info.OriginModelName == "" {
- if originTask.Properties.OriginModelName != "" {
- info.OriginModelName = originTask.Properties.OriginModelName
- } else if originTask.Properties.UpstreamModelName != "" {
- info.OriginModelName = originTask.Properties.UpstreamModelName
- } else {
- var taskData map[string]interface{}
- _ = common.Unmarshal(originTask.Data, &taskData)
- if m, ok := taskData["model"].(string); ok && m != "" {
- info.OriginModelName = m
- }
- }
- }
- // 锁定到原始任务的渠道(重试时复用同一渠道,轮换 key)
- ch, err := model.GetChannelById(originTask.ChannelId, true)
- if err != nil {
- return service.TaskErrorWrapperLocal(err, "channel_not_found", http.StatusBadRequest)
- }
- if ch.Status != common.ChannelStatusEnabled {
- return service.TaskErrorWrapperLocal(errors.New("the channel of the origin task is disabled"), "task_channel_disable", http.StatusBadRequest)
- }
- info.LockedChannel = ch
- if originTask.ChannelId != info.ChannelId {
- key, _, newAPIError := ch.GetNextEnabledKey()
- if newAPIError != nil {
- return service.TaskErrorWrapper(newAPIError, "channel_no_available_key", newAPIError.StatusCode)
- }
- common.SetContextKey(c, constant.ContextKeyChannelKey, key)
- common.SetContextKey(c, constant.ContextKeyChannelType, ch.Type)
- common.SetContextKey(c, constant.ContextKeyChannelBaseUrl, ch.GetBaseURL())
- common.SetContextKey(c, constant.ContextKeyChannelId, originTask.ChannelId)
- info.ChannelBaseUrl = ch.GetBaseURL()
- info.ChannelId = originTask.ChannelId
- info.ChannelType = ch.Type
- info.ApiKey = key
- }
- // 提取 remix 参数(时长、分辨率 → OtherRatios)
- if info.Action == constant.TaskActionRemix {
- if originTask.PrivateData.BillingContext != nil {
- // 新的 remix 逻辑:直接从原始任务的 BillingContext 中提取 OtherRatios(如果存在)
- for s, f := range originTask.PrivateData.BillingContext.OtherRatios {
- info.PriceData.AddOtherRatio(s, f)
- }
- } else {
- // 旧的 remix 逻辑:直接从 task data 解析 seconds 和 size(如果存在)
- var taskData map[string]interface{}
- _ = common.Unmarshal(originTask.Data, &taskData)
- secondsStr, _ := taskData["seconds"].(string)
- seconds, _ := strconv.Atoi(secondsStr)
- if seconds <= 0 {
- seconds = 4
- }
- sizeStr, _ := taskData["size"].(string)
- if info.PriceData.OtherRatios == nil {
- info.PriceData.OtherRatios = map[string]float64{}
- }
- info.PriceData.OtherRatios["seconds"] = float64(seconds)
- info.PriceData.OtherRatios["size"] = 1
- if sizeStr == "1792x1024" || sizeStr == "1024x1792" {
- info.PriceData.OtherRatios["size"] = 1.666667
- }
- }
- }
- return nil
- }
- // RelayTaskSubmit 完成 task 提交的全部流程(每次尝试调用一次):
- // 刷新渠道元数据 → 确定 platform/adaptor → 验证请求 →
- // 估算计费(EstimateBilling) → 计算价格 → 预扣费(仅首次)→
- // 构建/发送/解析上游请求 → 提交后计费调整(AdjustBillingOnSubmit)。
- // 控制器负责 defer Refund 和成功后 Settle。
- func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (*TaskSubmitResult, *dto.TaskError) {
- info.InitChannelMeta(c)
- // 1. 确定 platform → 创建适配器 → 验证请求
- platform := constant.TaskPlatform(c.GetString("platform"))
- if platform == "" {
- platform = GetTaskPlatform(c)
- }
- adaptor := GetTaskAdaptor(platform)
- if adaptor == nil {
- return nil, service.TaskErrorWrapperLocal(fmt.Errorf("invalid api platform: %s", platform), "invalid_api_platform", http.StatusBadRequest)
- }
- adaptor.Init(info)
- if taskErr := adaptor.ValidateRequestAndSetAction(c, info); taskErr != nil {
- return nil, taskErr
- }
- // 2. 确定模型名称
- modelName := info.OriginModelName
- if modelName == "" {
- modelName = service.CoverTaskActionToModelName(platform, info.Action)
- }
- // 2.5 应用渠道的模型映射(与同步任务对齐)
- info.OriginModelName = modelName
- info.UpstreamModelName = modelName
- if err := helper.ModelMappedHelper(c, info, nil); err != nil {
- return nil, service.TaskErrorWrapperLocal(err, "model_mapping_failed", http.StatusBadRequest)
- }
- // 3. 预生成公开 task ID(仅首次)
- if info.PublicTaskID == "" {
- info.PublicTaskID = model.GenerateTaskID()
- }
- // 4. 价格计算:基础模型价格
- info.OriginModelName = modelName
- priceData, err := helper.ModelPriceHelperPerCall(c, info)
- if err != nil {
- return nil, service.TaskErrorWrapper(err, "model_price_error", http.StatusBadRequest)
- }
- info.PriceData = priceData
- // 5. 计费估算:让适配器根据用户请求提供 OtherRatios(时长、分辨率等)
- // 必须在 ModelPriceHelperPerCall 之后调用(它会重建 PriceData)。
- // ResolveOriginTask 可能已在 remix 路径中预设了 OtherRatios,此处合并。
- if estimatedRatios := adaptor.EstimateBilling(c, info); len(estimatedRatios) > 0 {
- for k, v := range estimatedRatios {
- info.PriceData.AddOtherRatio(k, v)
- }
- }
- // 6. 将 OtherRatios 应用到基础额度
- if !common.StringsContains(constant.TaskPricePatches, modelName) {
- for _, ra := range info.PriceData.OtherRatios {
- if ra != 1.0 {
- info.PriceData.Quota = int(float64(info.PriceData.Quota) * ra)
- }
- }
- }
- // 7. 预扣费(仅首次 — 重试时 info.Billing 已存在,跳过)
- if info.Billing == nil && !info.PriceData.FreeModel {
- info.ForcePreConsume = true
- if apiErr := service.PreConsumeBilling(c, info.PriceData.Quota, info); apiErr != nil {
- return nil, service.TaskErrorFromAPIError(apiErr)
- }
- }
- // 8. 构建请求体
- requestBody, err := adaptor.BuildRequestBody(c, info)
- if err != nil {
- return nil, service.TaskErrorWrapper(err, "build_request_failed", http.StatusInternalServerError)
- }
- // 9. 发送请求
- resp, err := adaptor.DoRequest(c, info, requestBody)
- if err != nil {
- return nil, service.TaskErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
- }
- if resp != nil && resp.StatusCode != http.StatusOK {
- responseBody, _ := io.ReadAll(resp.Body)
- return nil, service.TaskErrorWrapper(fmt.Errorf("%s", string(responseBody)), "fail_to_fetch_task", resp.StatusCode)
- }
- // 10. 返回 OtherRatios 给下游(header 必须在 DoResponse 写 body 之前设置)
- otherRatios := info.PriceData.OtherRatios
- if otherRatios == nil {
- otherRatios = map[string]float64{}
- }
- ratiosJSON, _ := common.Marshal(otherRatios)
- c.Header("X-New-Api-Other-Ratios", string(ratiosJSON))
- // 11. 解析响应
- upstreamTaskID, taskData, taskErr := adaptor.DoResponse(c, resp, info)
- if taskErr != nil {
- return nil, taskErr
- }
- // 11. 提交后计费调整:让适配器根据上游实际返回调整 OtherRatios
- finalQuota := info.PriceData.Quota
- if adjustedRatios := adaptor.AdjustBillingOnSubmit(info, taskData); len(adjustedRatios) > 0 {
- // 基于调整后的 ratios 重新计算 quota
- finalQuota = recalcQuotaFromRatios(info, adjustedRatios)
- info.PriceData.OtherRatios = adjustedRatios
- info.PriceData.Quota = finalQuota
- }
- return &TaskSubmitResult{
- UpstreamTaskID: upstreamTaskID,
- TaskData: taskData,
- Platform: platform,
- Quota: finalQuota,
- }, nil
- }
- // recalcQuotaFromRatios 根据 adjustedRatios 重新计算 quota。
- // 公式: baseQuota × ∏(ratio) — 其中 baseQuota 是不含 OtherRatios 的基础额度。
- func recalcQuotaFromRatios(info *relaycommon.RelayInfo, ratios map[string]float64) int {
- // 从 PriceData 获取不含 OtherRatios 的基础价格
- baseQuota := info.PriceData.Quota
- // 先除掉原有的 OtherRatios 恢复基础额度
- for _, ra := range info.PriceData.OtherRatios {
- if ra != 1.0 && ra > 0 {
- baseQuota = int(float64(baseQuota) / ra)
- }
- }
- // 应用新的 ratios
- result := float64(baseQuota)
- for _, ra := range ratios {
- if ra != 1.0 {
- result *= ra
- }
- }
- return int(result)
- }
- var fetchRespBuilders = map[int]func(c *gin.Context) (respBody []byte, taskResp *dto.TaskError){
- relayconstant.RelayModeSunoFetchByID: sunoFetchByIDRespBodyBuilder,
- relayconstant.RelayModeSunoFetch: sunoFetchRespBodyBuilder,
- relayconstant.RelayModeVideoFetchByID: videoFetchByIDRespBodyBuilder,
- }
- func RelayTaskFetch(c *gin.Context, relayMode int) (taskResp *dto.TaskError) {
- respBuilder, ok := fetchRespBuilders[relayMode]
- if !ok {
- taskResp = service.TaskErrorWrapperLocal(errors.New("invalid_relay_mode"), "invalid_relay_mode", http.StatusBadRequest)
- }
- respBody, taskErr := respBuilder(c)
- if taskErr != nil {
- return taskErr
- }
- if len(respBody) == 0 {
- respBody = []byte("{\"code\":\"success\",\"data\":null}")
- }
- c.Writer.Header().Set("Content-Type", "application/json")
- _, err := io.Copy(c.Writer, bytes.NewBuffer(respBody))
- if err != nil {
- taskResp = service.TaskErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError)
- return
- }
- return
- }
- func sunoFetchRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *dto.TaskError) {
- userId := c.GetInt("id")
- var condition = struct {
- IDs []any `json:"ids"`
- Action string `json:"action"`
- }{}
- err := c.BindJSON(&condition)
- if err != nil {
- taskResp = service.TaskErrorWrapper(err, "invalid_request", http.StatusBadRequest)
- return
- }
- var tasks []any
- if len(condition.IDs) > 0 {
- taskModels, err := model.GetByTaskIds(userId, condition.IDs)
- if err != nil {
- taskResp = service.TaskErrorWrapper(err, "get_tasks_failed", http.StatusInternalServerError)
- return
- }
- for _, task := range taskModels {
- tasks = append(tasks, TaskModel2Dto(task))
- }
- } else {
- tasks = make([]any, 0)
- }
- respBody, err = common.Marshal(dto.TaskResponse[[]any]{
- Code: "success",
- Data: tasks,
- })
- return
- }
- func sunoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *dto.TaskError) {
- taskId := c.Param("id")
- userId := c.GetInt("id")
- originTask, exist, err := model.GetByTaskId(userId, taskId)
- if err != nil {
- taskResp = service.TaskErrorWrapper(err, "get_task_failed", http.StatusInternalServerError)
- return
- }
- if !exist {
- taskResp = service.TaskErrorWrapperLocal(errors.New("task_not_exist"), "task_not_exist", http.StatusBadRequest)
- return
- }
- respBody, err = common.Marshal(dto.TaskResponse[any]{
- Code: "success",
- Data: TaskModel2Dto(originTask),
- })
- return
- }
- func videoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *dto.TaskError) {
- taskId := c.Param("task_id")
- if taskId == "" {
- taskId = c.GetString("task_id")
- }
- userId := c.GetInt("id")
- originTask, exist, err := model.GetByTaskId(userId, taskId)
- if err != nil {
- taskResp = service.TaskErrorWrapper(err, "get_task_failed", http.StatusInternalServerError)
- return
- }
- if !exist {
- taskResp = service.TaskErrorWrapperLocal(errors.New("task_not_exist"), "task_not_exist", http.StatusBadRequest)
- return
- }
- isOpenAIVideoAPI := strings.HasPrefix(c.Request.RequestURI, "/v1/videos/")
- // Gemini/Vertex 支持实时查询:用户 fetch 时直接从上游拉取最新状态
- if realtimeResp := tryRealtimeFetch(originTask, isOpenAIVideoAPI); len(realtimeResp) > 0 {
- respBody = realtimeResp
- return
- }
- // OpenAI Video API 格式: 走各 adaptor 的 ConvertToOpenAIVideo
- if isOpenAIVideoAPI {
- adaptor := GetTaskAdaptor(originTask.Platform)
- if adaptor == nil {
- taskResp = service.TaskErrorWrapperLocal(fmt.Errorf("invalid channel id: %d", originTask.ChannelId), "invalid_channel_id", http.StatusBadRequest)
- return
- }
- if converter, ok := adaptor.(channel.OpenAIVideoConverter); ok {
- openAIVideoData, err := converter.ConvertToOpenAIVideo(originTask)
- if err != nil {
- taskResp = service.TaskErrorWrapper(err, "convert_to_openai_video_failed", http.StatusInternalServerError)
- return
- }
- respBody = openAIVideoData
- return
- }
- taskResp = service.TaskErrorWrapperLocal(fmt.Errorf("not_implemented:%s", originTask.Platform), "not_implemented", http.StatusNotImplemented)
- return
- }
- // 通用 TaskDto 格式
- respBody, err = common.Marshal(dto.TaskResponse[any]{
- Code: "success",
- Data: TaskModel2Dto(originTask),
- })
- if err != nil {
- taskResp = service.TaskErrorWrapper(err, "marshal_response_failed", http.StatusInternalServerError)
- }
- return
- }
- // tryRealtimeFetch 尝试从上游实时拉取 Gemini/Vertex 任务状态。
- // 仅当渠道类型为 Gemini 或 Vertex 时触发;其他渠道或出错时返回 nil。
- // 当非 OpenAI Video API 时,还会构建自定义格式的响应体。
- func tryRealtimeFetch(task *model.Task, isOpenAIVideoAPI bool) []byte {
- channelModel, err := model.GetChannelById(task.ChannelId, true)
- if err != nil {
- return nil
- }
- if channelModel.Type != constant.ChannelTypeVertexAi && channelModel.Type != constant.ChannelTypeGemini {
- return nil
- }
- baseURL := constant.ChannelBaseURLs[channelModel.Type]
- if channelModel.GetBaseURL() != "" {
- baseURL = channelModel.GetBaseURL()
- }
- proxy := channelModel.GetSetting().Proxy
- adaptor := GetTaskAdaptor(constant.TaskPlatform(strconv.Itoa(channelModel.Type)))
- if adaptor == nil {
- return nil
- }
- resp, err := adaptor.FetchTask(baseURL, channelModel.Key, map[string]any{
- "task_id": task.GetUpstreamTaskID(),
- "action": task.Action,
- }, proxy)
- if err != nil || resp == nil {
- return nil
- }
- defer resp.Body.Close()
- body, err := io.ReadAll(resp.Body)
- if err != nil {
- return nil
- }
- ti, err := adaptor.ParseTaskResult(body)
- if err != nil || ti == nil {
- return nil
- }
- snap := task.Snapshot()
- // 将上游最新状态更新到 task
- if ti.Status != "" {
- task.Status = model.TaskStatus(ti.Status)
- }
- if ti.Progress != "" {
- task.Progress = ti.Progress
- }
- if strings.HasPrefix(ti.Url, "data:") {
- // data: URI — kept in Data, not ResultURL
- } else if ti.Url != "" {
- task.PrivateData.ResultURL = ti.Url
- } else if task.Status == model.TaskStatusSuccess {
- // No URL from adaptor — construct proxy URL using public task ID
- task.PrivateData.ResultURL = taskcommon.BuildProxyURL(task.TaskID)
- }
- if !snap.Equal(task.Snapshot()) {
- _, _ = task.UpdateWithStatus(snap.Status)
- }
- // OpenAI Video API 由调用者的 ConvertToOpenAIVideo 分支处理
- if isOpenAIVideoAPI {
- return nil
- }
- // 非 OpenAI Video API: 构建自定义格式响应
- format := detectVideoFormat(body)
- out := map[string]any{
- "error": nil,
- "format": format,
- "metadata": nil,
- "status": mapTaskStatusToSimple(task.Status),
- "task_id": task.TaskID,
- "url": task.GetResultURL(),
- }
- respBody, _ := common.Marshal(dto.TaskResponse[any]{
- Code: "success",
- Data: out,
- })
- return respBody
- }
- // detectVideoFormat 从 Gemini/Vertex 原始响应中探测视频格式
- func detectVideoFormat(rawBody []byte) string {
- var raw map[string]any
- if err := common.Unmarshal(rawBody, &raw); err != nil {
- return "mp4"
- }
- respObj, ok := raw["response"].(map[string]any)
- if !ok {
- return "mp4"
- }
- vids, ok := respObj["videos"].([]any)
- if !ok || len(vids) == 0 {
- return "mp4"
- }
- v0, ok := vids[0].(map[string]any)
- if !ok {
- return "mp4"
- }
- mt, ok := v0["mimeType"].(string)
- if !ok || mt == "" || strings.Contains(mt, "mp4") {
- return "mp4"
- }
- return mt
- }
- // mapTaskStatusToSimple 将内部 TaskStatus 映射为简化状态字符串
- func mapTaskStatusToSimple(status model.TaskStatus) string {
- switch status {
- case model.TaskStatusSuccess:
- return "succeeded"
- case model.TaskStatusFailure:
- return "failed"
- case model.TaskStatusQueued, model.TaskStatusSubmitted:
- return "queued"
- default:
- return "processing"
- }
- }
- func TaskModel2Dto(task *model.Task) *dto.TaskDto {
- return &dto.TaskDto{
- ID: task.ID,
- CreatedAt: task.CreatedAt,
- UpdatedAt: task.UpdatedAt,
- TaskID: task.TaskID,
- Platform: string(task.Platform),
- UserId: task.UserId,
- Group: task.Group,
- ChannelId: task.ChannelId,
- Quota: task.Quota,
- Action: task.Action,
- Status: string(task.Status),
- FailReason: task.FailReason,
- ResultURL: task.GetResultURL(),
- SubmitTime: task.SubmitTime,
- StartTime: task.StartTime,
- FinishTime: task.FinishTime,
- Progress: task.Progress,
- Properties: task.Properties,
- Username: task.Username,
- Data: task.Data,
- }
- }
|