| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539 |
- package service
- import (
- "context"
- "errors"
- "fmt"
- "io"
- "net/http"
- "sort"
- "strings"
- "time"
- "github.com/QuantumNous/new-api/common"
- "github.com/QuantumNous/new-api/constant"
- "github.com/QuantumNous/new-api/dto"
- "github.com/QuantumNous/new-api/logger"
- "github.com/QuantumNous/new-api/model"
- "github.com/QuantumNous/new-api/relay/channel/task/taskcommon"
- relaycommon "github.com/QuantumNous/new-api/relay/common"
- "github.com/samber/lo"
- )
- // TaskPollingAdaptor 定义轮询所需的最小适配器接口,避免 service -> relay 的循环依赖
- type TaskPollingAdaptor interface {
- Init(info *relaycommon.RelayInfo)
- FetchTask(baseURL string, key string, body map[string]any, proxy string) (*http.Response, error)
- ParseTaskResult(body []byte) (*relaycommon.TaskInfo, error)
- // AdjustBillingOnComplete 在任务到达终态(成功/失败)时由轮询循环调用。
- // 返回正数触发差额结算(补扣/退还),返回 0 保持预扣费金额不变。
- AdjustBillingOnComplete(task *model.Task, taskResult *relaycommon.TaskInfo) int
- }
- // GetTaskAdaptorFunc 由 main 包注入,用于获取指定平台的任务适配器。
- // 打破 service -> relay -> relay/channel -> service 的循环依赖。
- var GetTaskAdaptorFunc func(platform constant.TaskPlatform) TaskPollingAdaptor
- // sweepTimedOutTasks 在主轮询之前独立清理超时任务。
- // 每次最多处理 100 条,剩余的下个周期继续处理。
- // 使用 per-task CAS (UpdateWithStatus) 防止覆盖被正常轮询已推进的任务。
- func sweepTimedOutTasks(ctx context.Context) {
- if constant.TaskTimeoutMinutes <= 0 {
- return
- }
- cutoff := time.Now().Unix() - int64(constant.TaskTimeoutMinutes)*60
- tasks := model.GetTimedOutUnfinishedTasks(cutoff, 100)
- if len(tasks) == 0 {
- return
- }
- const legacyTaskCutoff int64 = 1740182400 // 2026-02-22 00:00:00 UTC
- reason := fmt.Sprintf("任务超时(%d分钟)", constant.TaskTimeoutMinutes)
- legacyReason := "任务超时(旧系统遗留任务,不进行退款,请联系管理员)"
- now := time.Now().Unix()
- timedOutCount := 0
- for _, task := range tasks {
- isLegacy := task.SubmitTime > 0 && task.SubmitTime < legacyTaskCutoff
- oldStatus := task.Status
- task.Status = model.TaskStatusFailure
- task.Progress = "100%"
- task.FinishTime = now
- if isLegacy {
- task.FailReason = legacyReason
- } else {
- task.FailReason = reason
- }
- won, err := task.UpdateWithStatus(oldStatus)
- if err != nil {
- logger.LogError(ctx, fmt.Sprintf("sweepTimedOutTasks CAS update error for task %s: %v", task.TaskID, err))
- continue
- }
- if !won {
- logger.LogInfo(ctx, fmt.Sprintf("sweepTimedOutTasks: task %s already transitioned, skip", task.TaskID))
- continue
- }
- timedOutCount++
- if !isLegacy && task.Quota != 0 {
- RefundTaskQuota(ctx, task, reason)
- }
- }
- if timedOutCount > 0 {
- logger.LogInfo(ctx, fmt.Sprintf("sweepTimedOutTasks: timed out %d tasks", timedOutCount))
- }
- }
- // TaskPollingLoop 主轮询循环,每 15 秒检查一次未完成的任务
- func TaskPollingLoop() {
- for {
- time.Sleep(time.Duration(15) * time.Second)
- common.SysLog("任务进度轮询开始")
- ctx := context.TODO()
- sweepTimedOutTasks(ctx)
- allTasks := model.GetAllUnFinishSyncTasks(constant.TaskQueryLimit)
- platformTask := make(map[constant.TaskPlatform][]*model.Task)
- for _, t := range allTasks {
- platformTask[t.Platform] = append(platformTask[t.Platform], t)
- }
- for platform, tasks := range platformTask {
- if len(tasks) == 0 {
- continue
- }
- taskChannelM := make(map[int][]string)
- taskM := make(map[string]*model.Task)
- nullTaskIds := make([]int64, 0)
- for _, task := range tasks {
- upstreamID := task.GetUpstreamTaskID()
- if upstreamID == "" {
- // 统计失败的未完成任务
- nullTaskIds = append(nullTaskIds, task.ID)
- continue
- }
- taskM[upstreamID] = task
- taskChannelM[task.ChannelId] = append(taskChannelM[task.ChannelId], upstreamID)
- }
- if len(nullTaskIds) > 0 {
- err := model.TaskBulkUpdateByID(nullTaskIds, map[string]any{
- "status": "FAILURE",
- "progress": "100%",
- })
- if err != nil {
- logger.LogError(ctx, fmt.Sprintf("Fix null task_id task error: %v", err))
- } else {
- logger.LogInfo(ctx, fmt.Sprintf("Fix null task_id task success: %v", nullTaskIds))
- }
- }
- if len(taskChannelM) == 0 {
- continue
- }
- DispatchPlatformUpdate(platform, taskChannelM, taskM)
- }
- common.SysLog("任务进度轮询完成")
- }
- }
- // DispatchPlatformUpdate 按平台分发轮询更新
- func DispatchPlatformUpdate(platform constant.TaskPlatform, taskChannelM map[int][]string, taskM map[string]*model.Task) {
- switch platform {
- case constant.TaskPlatformMidjourney:
- // MJ 轮询由其自身处理,这里预留入口
- case constant.TaskPlatformSuno:
- _ = UpdateSunoTasks(context.Background(), taskChannelM, taskM)
- default:
- if err := UpdateVideoTasks(context.Background(), platform, taskChannelM, taskM); err != nil {
- common.SysLog(fmt.Sprintf("UpdateVideoTasks fail: %s", err))
- }
- }
- }
- // UpdateSunoTasks 按渠道更新所有 Suno 任务
- func UpdateSunoTasks(ctx context.Context, taskChannelM map[int][]string, taskM map[string]*model.Task) error {
- for channelId, taskIds := range taskChannelM {
- err := updateSunoTasks(ctx, channelId, taskIds, taskM)
- if err != nil {
- logger.LogError(ctx, fmt.Sprintf("渠道 #%d 更新异步任务失败: %s", channelId, err.Error()))
- }
- }
- return nil
- }
- func updateSunoTasks(ctx context.Context, channelId int, taskIds []string, taskM map[string]*model.Task) error {
- logger.LogInfo(ctx, fmt.Sprintf("渠道 #%d 未完成的任务有: %d", channelId, len(taskIds)))
- if len(taskIds) == 0 {
- return nil
- }
- ch, err := model.CacheGetChannel(channelId)
- if err != nil {
- common.SysLog(fmt.Sprintf("CacheGetChannel: %v", err))
- // Collect DB primary key IDs for bulk update (taskIds are upstream IDs, not task_id column values)
- var failedIDs []int64
- for _, upstreamID := range taskIds {
- if t, ok := taskM[upstreamID]; ok {
- failedIDs = append(failedIDs, t.ID)
- }
- }
- err = model.TaskBulkUpdateByID(failedIDs, map[string]any{
- "fail_reason": fmt.Sprintf("获取渠道信息失败,请联系管理员,渠道ID:%d", channelId),
- "status": "FAILURE",
- "progress": "100%",
- })
- if err != nil {
- common.SysLog(fmt.Sprintf("UpdateSunoTask error: %v", err))
- }
- return err
- }
- adaptor := GetTaskAdaptorFunc(constant.TaskPlatformSuno)
- if adaptor == nil {
- return errors.New("adaptor not found")
- }
- proxy := ch.GetSetting().Proxy
- resp, err := adaptor.FetchTask(*ch.BaseURL, ch.Key, map[string]any{
- "ids": taskIds,
- }, proxy)
- if err != nil {
- common.SysLog(fmt.Sprintf("Get Task Do req error: %v", err))
- return err
- }
- if resp.StatusCode != http.StatusOK {
- logger.LogError(ctx, fmt.Sprintf("Get Task status code: %d", resp.StatusCode))
- return fmt.Errorf("Get Task status code: %d", resp.StatusCode)
- }
- defer resp.Body.Close()
- responseBody, err := io.ReadAll(resp.Body)
- if err != nil {
- common.SysLog(fmt.Sprintf("Get Task parse body error: %v", err))
- return err
- }
- var responseItems dto.TaskResponse[[]dto.SunoDataResponse]
- err = common.Unmarshal(responseBody, &responseItems)
- if err != nil {
- logger.LogError(ctx, fmt.Sprintf("Get Task parse body error2: %v, body: %s", err, string(responseBody)))
- return err
- }
- if !responseItems.IsSuccess() {
- common.SysLog(fmt.Sprintf("渠道 #%d 未完成的任务有: %d, 成功获取到任务数: %s", channelId, len(taskIds), string(responseBody)))
- return err
- }
- for _, responseItem := range responseItems.Data {
- task := taskM[responseItem.TaskID]
- if !taskNeedsUpdate(task, responseItem) {
- continue
- }
- task.Status = lo.If(model.TaskStatus(responseItem.Status) != "", model.TaskStatus(responseItem.Status)).Else(task.Status)
- task.FailReason = lo.If(responseItem.FailReason != "", responseItem.FailReason).Else(task.FailReason)
- task.SubmitTime = lo.If(responseItem.SubmitTime != 0, responseItem.SubmitTime).Else(task.SubmitTime)
- task.StartTime = lo.If(responseItem.StartTime != 0, responseItem.StartTime).Else(task.StartTime)
- task.FinishTime = lo.If(responseItem.FinishTime != 0, responseItem.FinishTime).Else(task.FinishTime)
- if responseItem.FailReason != "" || task.Status == model.TaskStatusFailure {
- logger.LogInfo(ctx, task.TaskID+" 构建失败,"+task.FailReason)
- task.Progress = "100%"
- RefundTaskQuota(ctx, task, task.FailReason)
- }
- if responseItem.Status == model.TaskStatusSuccess {
- task.Progress = "100%"
- }
- task.Data = responseItem.Data
- err = task.Update()
- if err != nil {
- common.SysLog("UpdateSunoTask task error: " + err.Error())
- }
- }
- return nil
- }
- // taskNeedsUpdate 检查 Suno 任务是否需要更新
- func taskNeedsUpdate(oldTask *model.Task, newTask dto.SunoDataResponse) bool {
- if oldTask.SubmitTime != newTask.SubmitTime {
- return true
- }
- if oldTask.StartTime != newTask.StartTime {
- return true
- }
- if oldTask.FinishTime != newTask.FinishTime {
- return true
- }
- if string(oldTask.Status) != newTask.Status {
- return true
- }
- if oldTask.FailReason != newTask.FailReason {
- return true
- }
- if (oldTask.Status == model.TaskStatusFailure || oldTask.Status == model.TaskStatusSuccess) && oldTask.Progress != "100%" {
- return true
- }
- oldData, _ := common.Marshal(oldTask.Data)
- newData, _ := common.Marshal(newTask.Data)
- sort.Slice(oldData, func(i, j int) bool {
- return oldData[i] < oldData[j]
- })
- sort.Slice(newData, func(i, j int) bool {
- return newData[i] < newData[j]
- })
- if string(oldData) != string(newData) {
- return true
- }
- return false
- }
- // UpdateVideoTasks 按渠道更新所有视频任务
- func UpdateVideoTasks(ctx context.Context, platform constant.TaskPlatform, taskChannelM map[int][]string, taskM map[string]*model.Task) error {
- for channelId, taskIds := range taskChannelM {
- if err := updateVideoTasks(ctx, platform, channelId, taskIds, taskM); err != nil {
- logger.LogError(ctx, fmt.Sprintf("Channel #%d failed to update video async tasks: %s", channelId, err.Error()))
- }
- }
- return nil
- }
- func updateVideoTasks(ctx context.Context, platform constant.TaskPlatform, channelId int, taskIds []string, taskM map[string]*model.Task) error {
- logger.LogInfo(ctx, fmt.Sprintf("Channel #%d pending video tasks: %d", channelId, len(taskIds)))
- if len(taskIds) == 0 {
- return nil
- }
- cacheGetChannel, err := model.CacheGetChannel(channelId)
- if err != nil {
- // Collect DB primary key IDs for bulk update (taskIds are upstream IDs, not task_id column values)
- var failedIDs []int64
- for _, upstreamID := range taskIds {
- if t, ok := taskM[upstreamID]; ok {
- failedIDs = append(failedIDs, t.ID)
- }
- }
- errUpdate := model.TaskBulkUpdateByID(failedIDs, map[string]any{
- "fail_reason": fmt.Sprintf("Failed to get channel info, channel ID: %d", channelId),
- "status": "FAILURE",
- "progress": "100%",
- })
- if errUpdate != nil {
- common.SysLog(fmt.Sprintf("UpdateVideoTask error: %v", errUpdate))
- }
- return fmt.Errorf("CacheGetChannel failed: %w", err)
- }
- adaptor := GetTaskAdaptorFunc(platform)
- if adaptor == nil {
- return fmt.Errorf("video adaptor not found")
- }
- info := &relaycommon.RelayInfo{}
- info.ChannelMeta = &relaycommon.ChannelMeta{
- ChannelBaseUrl: cacheGetChannel.GetBaseURL(),
- }
- info.ApiKey = cacheGetChannel.Key
- adaptor.Init(info)
- for _, taskId := range taskIds {
- if err := updateVideoSingleTask(ctx, adaptor, cacheGetChannel, taskId, taskM); err != nil {
- logger.LogError(ctx, fmt.Sprintf("Failed to update video task %s: %s", taskId, err.Error()))
- }
- }
- return nil
- }
- func updateVideoSingleTask(ctx context.Context, adaptor TaskPollingAdaptor, ch *model.Channel, taskId string, taskM map[string]*model.Task) error {
- baseURL := constant.ChannelBaseURLs[ch.Type]
- if ch.GetBaseURL() != "" {
- baseURL = ch.GetBaseURL()
- }
- proxy := ch.GetSetting().Proxy
- task := taskM[taskId]
- if task == nil {
- logger.LogError(ctx, fmt.Sprintf("Task %s not found in taskM", taskId))
- return fmt.Errorf("task %s not found", taskId)
- }
- key := ch.Key
- privateData := task.PrivateData
- if privateData.Key != "" {
- key = privateData.Key
- }
- resp, err := adaptor.FetchTask(baseURL, key, map[string]any{
- "task_id": task.GetUpstreamTaskID(),
- "action": task.Action,
- }, proxy)
- if err != nil {
- return fmt.Errorf("fetchTask failed for task %s: %w", taskId, err)
- }
- defer resp.Body.Close()
- responseBody, err := io.ReadAll(resp.Body)
- if err != nil {
- return fmt.Errorf("readAll failed for task %s: %w", taskId, err)
- }
- logger.LogDebug(ctx, fmt.Sprintf("updateVideoSingleTask response: %s", string(responseBody)))
- snap := task.Snapshot()
- taskResult := &relaycommon.TaskInfo{}
- // try parse as New API response format
- var responseItems dto.TaskResponse[model.Task]
- if err = common.Unmarshal(responseBody, &responseItems); err == nil && responseItems.IsSuccess() {
- logger.LogDebug(ctx, fmt.Sprintf("updateVideoSingleTask parsed as new api response format: %+v", responseItems))
- t := responseItems.Data
- taskResult.TaskID = t.TaskID
- taskResult.Status = string(t.Status)
- taskResult.Url = t.GetResultURL()
- taskResult.Progress = t.Progress
- taskResult.Reason = t.FailReason
- task.Data = t.Data
- } else if taskResult, err = adaptor.ParseTaskResult(responseBody); err != nil {
- return fmt.Errorf("parseTaskResult failed for task %s: %w", taskId, err)
- } else {
- task.Data = redactVideoResponseBody(responseBody)
- }
- logger.LogDebug(ctx, fmt.Sprintf("updateVideoSingleTask taskResult: %+v", taskResult))
- now := time.Now().Unix()
- if taskResult.Status == "" {
- taskResult = relaycommon.FailTaskInfo("upstream returned empty status")
- }
- shouldRefund := false
- shouldSettle := false
- quota := task.Quota
- task.Status = model.TaskStatus(taskResult.Status)
- switch taskResult.Status {
- case model.TaskStatusSubmitted:
- task.Progress = taskcommon.ProgressSubmitted
- case model.TaskStatusQueued:
- task.Progress = taskcommon.ProgressQueued
- case model.TaskStatusInProgress:
- task.Progress = taskcommon.ProgressInProgress
- if task.StartTime == 0 {
- task.StartTime = now
- }
- case model.TaskStatusSuccess:
- task.Progress = taskcommon.ProgressComplete
- if task.FinishTime == 0 {
- task.FinishTime = now
- }
- if strings.HasPrefix(taskResult.Url, "data:") {
- // data: URI (e.g. Vertex base64 encoded video) — keep in Data, not in ResultURL
- } else if taskResult.Url != "" {
- // Direct upstream URL (e.g. Kling, Ali, Doubao, etc.)
- task.PrivateData.ResultURL = taskResult.Url
- } else {
- // No URL from adaptor — construct proxy URL using public task ID
- task.PrivateData.ResultURL = taskcommon.BuildProxyURL(task.TaskID)
- }
- shouldSettle = true
- case model.TaskStatusFailure:
- logger.LogJson(ctx, fmt.Sprintf("Task %s failed", taskId), task)
- task.Status = model.TaskStatusFailure
- task.Progress = taskcommon.ProgressComplete
- if task.FinishTime == 0 {
- task.FinishTime = now
- }
- task.FailReason = taskResult.Reason
- logger.LogInfo(ctx, fmt.Sprintf("Task %s failed: %s", task.TaskID, task.FailReason))
- taskResult.Progress = taskcommon.ProgressComplete
- if quota != 0 {
- shouldRefund = true
- }
- default:
- return fmt.Errorf("unknown task status %s for task %s", taskResult.Status, task.TaskID)
- }
- if taskResult.Progress != "" {
- task.Progress = taskResult.Progress
- }
- isDone := task.Status == model.TaskStatusSuccess || task.Status == model.TaskStatusFailure
- if isDone && snap.Status != task.Status {
- won, err := task.UpdateWithStatus(snap.Status)
- if err != nil {
- logger.LogError(ctx, fmt.Sprintf("UpdateWithStatus failed for task %s: %s", task.TaskID, err.Error()))
- shouldRefund = false
- shouldSettle = false
- } else if !won {
- logger.LogWarn(ctx, fmt.Sprintf("Task %s already transitioned by another process, skip billing", task.TaskID))
- shouldRefund = false
- shouldSettle = false
- }
- } else if !snap.Equal(task.Snapshot()) {
- if _, err := task.UpdateWithStatus(snap.Status); err != nil {
- logger.LogError(ctx, fmt.Sprintf("Failed to update task %s: %s", task.TaskID, err.Error()))
- }
- } else {
- // No changes, skip update
- logger.LogDebug(ctx, fmt.Sprintf("No update needed for task %s", task.TaskID))
- }
- if shouldSettle {
- settleTaskBillingOnComplete(ctx, adaptor, task, taskResult)
- }
- if shouldRefund {
- RefundTaskQuota(ctx, task, task.FailReason)
- }
- return nil
- }
- func redactVideoResponseBody(body []byte) []byte {
- var m map[string]any
- if err := common.Unmarshal(body, &m); err != nil {
- return body
- }
- resp, _ := m["response"].(map[string]any)
- if resp != nil {
- delete(resp, "bytesBase64Encoded")
- if v, ok := resp["video"].(string); ok {
- resp["video"] = truncateBase64(v)
- }
- if vs, ok := resp["videos"].([]any); ok {
- for i := range vs {
- if vm, ok := vs[i].(map[string]any); ok {
- delete(vm, "bytesBase64Encoded")
- }
- }
- }
- }
- b, err := common.Marshal(m)
- if err != nil {
- return body
- }
- return b
- }
- func truncateBase64(s string) string {
- const maxKeep = 256
- if len(s) <= maxKeep {
- return s
- }
- return s[:maxKeep] + "..."
- }
- // settleTaskBillingOnComplete 任务完成时的统一计费调整。
- // 优先级:1. adaptor.AdjustBillingOnComplete 返回正数 → 使用 adaptor 计算的额度
- //
- // 2. taskResult.TotalTokens > 0 → 按 token 重算
- // 3. 都不满足 → 保持预扣额度不变
- func settleTaskBillingOnComplete(ctx context.Context, adaptor TaskPollingAdaptor, task *model.Task, taskResult *relaycommon.TaskInfo) {
- // 0. 按次计费的任务不做差额结算
- if bc := task.PrivateData.BillingContext; bc != nil && bc.PerCallBilling {
- logger.LogInfo(ctx, fmt.Sprintf("任务 %s 按次计费,跳过差额结算", task.TaskID))
- return
- }
- // 1. 优先让 adaptor 决定最终额度
- if actualQuota := adaptor.AdjustBillingOnComplete(task, taskResult); actualQuota > 0 {
- RecalculateTaskQuota(ctx, task, actualQuota, "adaptor计费调整")
- return
- }
- // 2. 回退到 token 重算
- if taskResult.TotalTokens > 0 {
- RecalculateTaskQuotaByTokens(ctx, task, taskResult.TotalTokens)
- return
- }
- // 3. 无调整,保持预扣额度
- }
|