task_polling.go 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466
  1. package service
  2. import (
  3. "context"
  4. "errors"
  5. "fmt"
  6. "io"
  7. "net/http"
  8. "sort"
  9. "strings"
  10. "time"
  11. "github.com/QuantumNous/new-api/common"
  12. "github.com/QuantumNous/new-api/constant"
  13. "github.com/QuantumNous/new-api/dto"
  14. "github.com/QuantumNous/new-api/logger"
  15. "github.com/QuantumNous/new-api/model"
  16. "github.com/QuantumNous/new-api/relay/channel/task/taskcommon"
  17. relaycommon "github.com/QuantumNous/new-api/relay/common"
  18. "github.com/samber/lo"
  19. )
  20. // TaskPollingAdaptor 定义轮询所需的最小适配器接口,避免 service -> relay 的循环依赖
  21. type TaskPollingAdaptor interface {
  22. Init(info *relaycommon.RelayInfo)
  23. FetchTask(baseURL string, key string, body map[string]any, proxy string) (*http.Response, error)
  24. ParseTaskResult(body []byte) (*relaycommon.TaskInfo, error)
  25. // AdjustBillingOnComplete 在任务到达终态(成功/失败)时由轮询循环调用。
  26. // 返回正数触发差额结算(补扣/退还),返回 0 保持预扣费金额不变。
  27. AdjustBillingOnComplete(task *model.Task, taskResult *relaycommon.TaskInfo) int
  28. }
  29. // GetTaskAdaptorFunc 由 main 包注入,用于获取指定平台的任务适配器。
  30. // 打破 service -> relay -> relay/channel -> service 的循环依赖。
  31. var GetTaskAdaptorFunc func(platform constant.TaskPlatform) TaskPollingAdaptor
  32. // TaskPollingLoop 主轮询循环,每 15 秒检查一次未完成的任务
  33. func TaskPollingLoop() {
  34. for {
  35. time.Sleep(time.Duration(15) * time.Second)
  36. common.SysLog("任务进度轮询开始")
  37. ctx := context.TODO()
  38. allTasks := model.GetAllUnFinishSyncTasks(constant.TaskQueryLimit)
  39. platformTask := make(map[constant.TaskPlatform][]*model.Task)
  40. for _, t := range allTasks {
  41. platformTask[t.Platform] = append(platformTask[t.Platform], t)
  42. }
  43. for platform, tasks := range platformTask {
  44. if len(tasks) == 0 {
  45. continue
  46. }
  47. taskChannelM := make(map[int][]string)
  48. taskM := make(map[string]*model.Task)
  49. nullTaskIds := make([]int64, 0)
  50. for _, task := range tasks {
  51. upstreamID := task.GetUpstreamTaskID()
  52. if upstreamID == "" {
  53. // 统计失败的未完成任务
  54. nullTaskIds = append(nullTaskIds, task.ID)
  55. continue
  56. }
  57. taskM[upstreamID] = task
  58. taskChannelM[task.ChannelId] = append(taskChannelM[task.ChannelId], upstreamID)
  59. }
  60. if len(nullTaskIds) > 0 {
  61. err := model.TaskBulkUpdateByID(nullTaskIds, map[string]any{
  62. "status": "FAILURE",
  63. "progress": "100%",
  64. })
  65. if err != nil {
  66. logger.LogError(ctx, fmt.Sprintf("Fix null task_id task error: %v", err))
  67. } else {
  68. logger.LogInfo(ctx, fmt.Sprintf("Fix null task_id task success: %v", nullTaskIds))
  69. }
  70. }
  71. if len(taskChannelM) == 0 {
  72. continue
  73. }
  74. DispatchPlatformUpdate(platform, taskChannelM, taskM)
  75. }
  76. common.SysLog("任务进度轮询完成")
  77. }
  78. }
  79. // DispatchPlatformUpdate 按平台分发轮询更新
  80. func DispatchPlatformUpdate(platform constant.TaskPlatform, taskChannelM map[int][]string, taskM map[string]*model.Task) {
  81. switch platform {
  82. case constant.TaskPlatformMidjourney:
  83. // MJ 轮询由其自身处理,这里预留入口
  84. case constant.TaskPlatformSuno:
  85. _ = UpdateSunoTasks(context.Background(), taskChannelM, taskM)
  86. default:
  87. if err := UpdateVideoTasks(context.Background(), platform, taskChannelM, taskM); err != nil {
  88. common.SysLog(fmt.Sprintf("UpdateVideoTasks fail: %s", err))
  89. }
  90. }
  91. }
  92. // UpdateSunoTasks 按渠道更新所有 Suno 任务
  93. func UpdateSunoTasks(ctx context.Context, taskChannelM map[int][]string, taskM map[string]*model.Task) error {
  94. for channelId, taskIds := range taskChannelM {
  95. err := updateSunoTasks(ctx, channelId, taskIds, taskM)
  96. if err != nil {
  97. logger.LogError(ctx, fmt.Sprintf("渠道 #%d 更新异步任务失败: %s", channelId, err.Error()))
  98. }
  99. }
  100. return nil
  101. }
  102. func updateSunoTasks(ctx context.Context, channelId int, taskIds []string, taskM map[string]*model.Task) error {
  103. logger.LogInfo(ctx, fmt.Sprintf("渠道 #%d 未完成的任务有: %d", channelId, len(taskIds)))
  104. if len(taskIds) == 0 {
  105. return nil
  106. }
  107. ch, err := model.CacheGetChannel(channelId)
  108. if err != nil {
  109. common.SysLog(fmt.Sprintf("CacheGetChannel: %v", err))
  110. // Collect DB primary key IDs for bulk update (taskIds are upstream IDs, not task_id column values)
  111. var failedIDs []int64
  112. for _, upstreamID := range taskIds {
  113. if t, ok := taskM[upstreamID]; ok {
  114. failedIDs = append(failedIDs, t.ID)
  115. }
  116. }
  117. err = model.TaskBulkUpdateByID(failedIDs, map[string]any{
  118. "fail_reason": fmt.Sprintf("获取渠道信息失败,请联系管理员,渠道ID:%d", channelId),
  119. "status": "FAILURE",
  120. "progress": "100%",
  121. })
  122. if err != nil {
  123. common.SysLog(fmt.Sprintf("UpdateSunoTask error: %v", err))
  124. }
  125. return err
  126. }
  127. adaptor := GetTaskAdaptorFunc(constant.TaskPlatformSuno)
  128. if adaptor == nil {
  129. return errors.New("adaptor not found")
  130. }
  131. proxy := ch.GetSetting().Proxy
  132. resp, err := adaptor.FetchTask(*ch.BaseURL, ch.Key, map[string]any{
  133. "ids": taskIds,
  134. }, proxy)
  135. if err != nil {
  136. common.SysLog(fmt.Sprintf("Get Task Do req error: %v", err))
  137. return err
  138. }
  139. if resp.StatusCode != http.StatusOK {
  140. logger.LogError(ctx, fmt.Sprintf("Get Task status code: %d", resp.StatusCode))
  141. return fmt.Errorf("Get Task status code: %d", resp.StatusCode)
  142. }
  143. defer resp.Body.Close()
  144. responseBody, err := io.ReadAll(resp.Body)
  145. if err != nil {
  146. common.SysLog(fmt.Sprintf("Get Task parse body error: %v", err))
  147. return err
  148. }
  149. var responseItems dto.TaskResponse[[]dto.SunoDataResponse]
  150. err = common.Unmarshal(responseBody, &responseItems)
  151. if err != nil {
  152. logger.LogError(ctx, fmt.Sprintf("Get Task parse body error2: %v, body: %s", err, string(responseBody)))
  153. return err
  154. }
  155. if !responseItems.IsSuccess() {
  156. common.SysLog(fmt.Sprintf("渠道 #%d 未完成的任务有: %d, 成功获取到任务数: %s", channelId, len(taskIds), string(responseBody)))
  157. return err
  158. }
  159. for _, responseItem := range responseItems.Data {
  160. task := taskM[responseItem.TaskID]
  161. if !taskNeedsUpdate(task, responseItem) {
  162. continue
  163. }
  164. task.Status = lo.If(model.TaskStatus(responseItem.Status) != "", model.TaskStatus(responseItem.Status)).Else(task.Status)
  165. task.FailReason = lo.If(responseItem.FailReason != "", responseItem.FailReason).Else(task.FailReason)
  166. task.SubmitTime = lo.If(responseItem.SubmitTime != 0, responseItem.SubmitTime).Else(task.SubmitTime)
  167. task.StartTime = lo.If(responseItem.StartTime != 0, responseItem.StartTime).Else(task.StartTime)
  168. task.FinishTime = lo.If(responseItem.FinishTime != 0, responseItem.FinishTime).Else(task.FinishTime)
  169. if responseItem.FailReason != "" || task.Status == model.TaskStatusFailure {
  170. logger.LogInfo(ctx, task.TaskID+" 构建失败,"+task.FailReason)
  171. task.Progress = "100%"
  172. RefundTaskQuota(ctx, task, task.FailReason)
  173. }
  174. if responseItem.Status == model.TaskStatusSuccess {
  175. task.Progress = "100%"
  176. }
  177. task.Data = responseItem.Data
  178. err = task.Update()
  179. if err != nil {
  180. common.SysLog("UpdateSunoTask task error: " + err.Error())
  181. }
  182. }
  183. return nil
  184. }
  185. // taskNeedsUpdate 检查 Suno 任务是否需要更新
  186. func taskNeedsUpdate(oldTask *model.Task, newTask dto.SunoDataResponse) bool {
  187. if oldTask.SubmitTime != newTask.SubmitTime {
  188. return true
  189. }
  190. if oldTask.StartTime != newTask.StartTime {
  191. return true
  192. }
  193. if oldTask.FinishTime != newTask.FinishTime {
  194. return true
  195. }
  196. if string(oldTask.Status) != newTask.Status {
  197. return true
  198. }
  199. if oldTask.FailReason != newTask.FailReason {
  200. return true
  201. }
  202. if (oldTask.Status == model.TaskStatusFailure || oldTask.Status == model.TaskStatusSuccess) && oldTask.Progress != "100%" {
  203. return true
  204. }
  205. oldData, _ := common.Marshal(oldTask.Data)
  206. newData, _ := common.Marshal(newTask.Data)
  207. sort.Slice(oldData, func(i, j int) bool {
  208. return oldData[i] < oldData[j]
  209. })
  210. sort.Slice(newData, func(i, j int) bool {
  211. return newData[i] < newData[j]
  212. })
  213. if string(oldData) != string(newData) {
  214. return true
  215. }
  216. return false
  217. }
  218. // UpdateVideoTasks 按渠道更新所有视频任务
  219. func UpdateVideoTasks(ctx context.Context, platform constant.TaskPlatform, taskChannelM map[int][]string, taskM map[string]*model.Task) error {
  220. for channelId, taskIds := range taskChannelM {
  221. if err := updateVideoTasks(ctx, platform, channelId, taskIds, taskM); err != nil {
  222. logger.LogError(ctx, fmt.Sprintf("Channel #%d failed to update video async tasks: %s", channelId, err.Error()))
  223. }
  224. }
  225. return nil
  226. }
  227. func updateVideoTasks(ctx context.Context, platform constant.TaskPlatform, channelId int, taskIds []string, taskM map[string]*model.Task) error {
  228. logger.LogInfo(ctx, fmt.Sprintf("Channel #%d pending video tasks: %d", channelId, len(taskIds)))
  229. if len(taskIds) == 0 {
  230. return nil
  231. }
  232. cacheGetChannel, err := model.CacheGetChannel(channelId)
  233. if err != nil {
  234. // Collect DB primary key IDs for bulk update (taskIds are upstream IDs, not task_id column values)
  235. var failedIDs []int64
  236. for _, upstreamID := range taskIds {
  237. if t, ok := taskM[upstreamID]; ok {
  238. failedIDs = append(failedIDs, t.ID)
  239. }
  240. }
  241. errUpdate := model.TaskBulkUpdateByID(failedIDs, map[string]any{
  242. "fail_reason": fmt.Sprintf("Failed to get channel info, channel ID: %d", channelId),
  243. "status": "FAILURE",
  244. "progress": "100%",
  245. })
  246. if errUpdate != nil {
  247. common.SysLog(fmt.Sprintf("UpdateVideoTask error: %v", errUpdate))
  248. }
  249. return fmt.Errorf("CacheGetChannel failed: %w", err)
  250. }
  251. adaptor := GetTaskAdaptorFunc(platform)
  252. if adaptor == nil {
  253. return fmt.Errorf("video adaptor not found")
  254. }
  255. info := &relaycommon.RelayInfo{}
  256. info.ChannelMeta = &relaycommon.ChannelMeta{
  257. ChannelBaseUrl: cacheGetChannel.GetBaseURL(),
  258. }
  259. info.ApiKey = cacheGetChannel.Key
  260. adaptor.Init(info)
  261. for _, taskId := range taskIds {
  262. if err := updateVideoSingleTask(ctx, adaptor, cacheGetChannel, taskId, taskM); err != nil {
  263. logger.LogError(ctx, fmt.Sprintf("Failed to update video task %s: %s", taskId, err.Error()))
  264. }
  265. }
  266. return nil
  267. }
  268. func updateVideoSingleTask(ctx context.Context, adaptor TaskPollingAdaptor, ch *model.Channel, taskId string, taskM map[string]*model.Task) error {
  269. baseURL := constant.ChannelBaseURLs[ch.Type]
  270. if ch.GetBaseURL() != "" {
  271. baseURL = ch.GetBaseURL()
  272. }
  273. proxy := ch.GetSetting().Proxy
  274. task := taskM[taskId]
  275. if task == nil {
  276. logger.LogError(ctx, fmt.Sprintf("Task %s not found in taskM", taskId))
  277. return fmt.Errorf("task %s not found", taskId)
  278. }
  279. key := ch.Key
  280. privateData := task.PrivateData
  281. if privateData.Key != "" {
  282. key = privateData.Key
  283. }
  284. resp, err := adaptor.FetchTask(baseURL, key, map[string]any{
  285. "task_id": task.GetUpstreamTaskID(),
  286. "action": task.Action,
  287. }, proxy)
  288. if err != nil {
  289. return fmt.Errorf("fetchTask failed for task %s: %w", taskId, err)
  290. }
  291. defer resp.Body.Close()
  292. responseBody, err := io.ReadAll(resp.Body)
  293. if err != nil {
  294. return fmt.Errorf("readAll failed for task %s: %w", taskId, err)
  295. }
  296. logger.LogDebug(ctx, fmt.Sprintf("updateVideoSingleTask response: %s", string(responseBody)))
  297. taskResult := &relaycommon.TaskInfo{}
  298. // try parse as New API response format
  299. var responseItems dto.TaskResponse[model.Task]
  300. if err = common.Unmarshal(responseBody, &responseItems); err == nil && responseItems.IsSuccess() {
  301. logger.LogDebug(ctx, fmt.Sprintf("updateVideoSingleTask parsed as new api response format: %+v", responseItems))
  302. t := responseItems.Data
  303. taskResult.TaskID = t.TaskID
  304. taskResult.Status = string(t.Status)
  305. taskResult.Url = t.GetResultURL()
  306. taskResult.Progress = t.Progress
  307. taskResult.Reason = t.FailReason
  308. task.Data = t.Data
  309. } else if taskResult, err = adaptor.ParseTaskResult(responseBody); err != nil {
  310. return fmt.Errorf("parseTaskResult failed for task %s: %w", taskId, err)
  311. } else {
  312. task.Data = redactVideoResponseBody(responseBody)
  313. }
  314. logger.LogDebug(ctx, fmt.Sprintf("updateVideoSingleTask taskResult: %+v", taskResult))
  315. now := time.Now().Unix()
  316. if taskResult.Status == "" {
  317. taskResult = relaycommon.FailTaskInfo("upstream returned empty status")
  318. }
  319. // 记录原本的状态,防止重复退款
  320. shouldRefund := false
  321. quota := task.Quota
  322. preStatus := task.Status
  323. task.Status = model.TaskStatus(taskResult.Status)
  324. switch taskResult.Status {
  325. case model.TaskStatusSubmitted:
  326. task.Progress = taskcommon.ProgressSubmitted
  327. case model.TaskStatusQueued:
  328. task.Progress = taskcommon.ProgressQueued
  329. case model.TaskStatusInProgress:
  330. task.Progress = taskcommon.ProgressInProgress
  331. if task.StartTime == 0 {
  332. task.StartTime = now
  333. }
  334. case model.TaskStatusSuccess:
  335. task.Progress = taskcommon.ProgressComplete
  336. if task.FinishTime == 0 {
  337. task.FinishTime = now
  338. }
  339. if strings.HasPrefix(taskResult.Url, "data:") {
  340. // data: URI (e.g. Vertex base64 encoded video) — keep in Data, not in ResultURL
  341. } else if taskResult.Url != "" {
  342. // Direct upstream URL (e.g. Kling, Ali, Doubao, etc.)
  343. task.PrivateData.ResultURL = taskResult.Url
  344. } else {
  345. // No URL from adaptor — construct proxy URL using public task ID
  346. task.PrivateData.ResultURL = taskcommon.BuildProxyURL(task.TaskID)
  347. }
  348. // 完成时计费调整:优先由 adaptor 计算,回退到 token 重算
  349. settleTaskBillingOnComplete(ctx, adaptor, task, taskResult)
  350. case model.TaskStatusFailure:
  351. logger.LogJson(ctx, fmt.Sprintf("Task %s failed", taskId), task)
  352. task.Status = model.TaskStatusFailure
  353. task.Progress = taskcommon.ProgressComplete
  354. if task.FinishTime == 0 {
  355. task.FinishTime = now
  356. }
  357. task.FailReason = taskResult.Reason
  358. logger.LogInfo(ctx, fmt.Sprintf("Task %s failed: %s", task.TaskID, task.FailReason))
  359. taskResult.Progress = taskcommon.ProgressComplete
  360. if quota != 0 {
  361. if preStatus != model.TaskStatusFailure {
  362. shouldRefund = true
  363. } else {
  364. logger.LogWarn(ctx, fmt.Sprintf("Task %s already in failure status, skip refund", task.TaskID))
  365. }
  366. }
  367. default:
  368. return fmt.Errorf("unknown task status %s for task %s", taskResult.Status, taskId)
  369. }
  370. if taskResult.Progress != "" {
  371. task.Progress = taskResult.Progress
  372. }
  373. if err := task.Update(); err != nil {
  374. common.SysLog("UpdateVideoTask task error: " + err.Error())
  375. shouldRefund = false
  376. }
  377. if shouldRefund {
  378. RefundTaskQuota(ctx, task, task.FailReason)
  379. }
  380. return nil
  381. }
  382. func redactVideoResponseBody(body []byte) []byte {
  383. var m map[string]any
  384. if err := common.Unmarshal(body, &m); err != nil {
  385. return body
  386. }
  387. resp, _ := m["response"].(map[string]any)
  388. if resp != nil {
  389. delete(resp, "bytesBase64Encoded")
  390. if v, ok := resp["video"].(string); ok {
  391. resp["video"] = truncateBase64(v)
  392. }
  393. if vs, ok := resp["videos"].([]any); ok {
  394. for i := range vs {
  395. if vm, ok := vs[i].(map[string]any); ok {
  396. delete(vm, "bytesBase64Encoded")
  397. }
  398. }
  399. }
  400. }
  401. b, err := common.Marshal(m)
  402. if err != nil {
  403. return body
  404. }
  405. return b
  406. }
  407. func truncateBase64(s string) string {
  408. const maxKeep = 256
  409. if len(s) <= maxKeep {
  410. return s
  411. }
  412. return s[:maxKeep] + "..."
  413. }
  414. // settleTaskBillingOnComplete 任务完成时的统一计费调整。
  415. // 优先级:1. adaptor.AdjustBillingOnComplete 返回正数 → 使用 adaptor 计算的额度
  416. //
  417. // 2. taskResult.TotalTokens > 0 → 按 token 重算
  418. // 3. 都不满足 → 保持预扣额度不变
  419. func settleTaskBillingOnComplete(ctx context.Context, adaptor TaskPollingAdaptor, task *model.Task, taskResult *relaycommon.TaskInfo) {
  420. // 1. 优先让 adaptor 决定最终额度
  421. if actualQuota := adaptor.AdjustBillingOnComplete(task, taskResult); actualQuota > 0 {
  422. RecalculateTaskQuota(ctx, task, actualQuota, "adaptor计费调整")
  423. return
  424. }
  425. // 2. 回退到 token 重算
  426. if taskResult.TotalTokens > 0 {
  427. RecalculateTaskQuotaByTokens(ctx, task, taskResult.TotalTokens)
  428. return
  429. }
  430. // 3. 无调整,保持预扣额度
  431. }