task_polling.go 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446
  1. package service
  2. import (
  3. "context"
  4. "errors"
  5. "fmt"
  6. "io"
  7. "net/http"
  8. "sort"
  9. "strings"
  10. "time"
  11. "github.com/QuantumNous/new-api/common"
  12. "github.com/QuantumNous/new-api/constant"
  13. "github.com/QuantumNous/new-api/dto"
  14. "github.com/QuantumNous/new-api/logger"
  15. "github.com/QuantumNous/new-api/model"
  16. "github.com/QuantumNous/new-api/relay/channel/task/taskcommon"
  17. relaycommon "github.com/QuantumNous/new-api/relay/common"
  18. "github.com/samber/lo"
  19. )
  20. // TaskPollingAdaptor 定义轮询所需的最小适配器接口,避免 service -> relay 的循环依赖
  21. type TaskPollingAdaptor interface {
  22. Init(info *relaycommon.RelayInfo)
  23. FetchTask(baseURL string, key string, body map[string]any, proxy string) (*http.Response, error)
  24. ParseTaskResult(body []byte) (*relaycommon.TaskInfo, error)
  25. }
  26. // GetTaskAdaptorFunc 由 main 包注入,用于获取指定平台的任务适配器。
  27. // 打破 service -> relay -> relay/channel -> service 的循环依赖。
  28. var GetTaskAdaptorFunc func(platform constant.TaskPlatform) TaskPollingAdaptor
  29. // TaskPollingLoop 主轮询循环,每 15 秒检查一次未完成的任务
  30. func TaskPollingLoop() {
  31. for {
  32. time.Sleep(time.Duration(15) * time.Second)
  33. common.SysLog("任务进度轮询开始")
  34. ctx := context.TODO()
  35. allTasks := model.GetAllUnFinishSyncTasks(constant.TaskQueryLimit)
  36. platformTask := make(map[constant.TaskPlatform][]*model.Task)
  37. for _, t := range allTasks {
  38. platformTask[t.Platform] = append(platformTask[t.Platform], t)
  39. }
  40. for platform, tasks := range platformTask {
  41. if len(tasks) == 0 {
  42. continue
  43. }
  44. taskChannelM := make(map[int][]string)
  45. taskM := make(map[string]*model.Task)
  46. nullTaskIds := make([]int64, 0)
  47. for _, task := range tasks {
  48. upstreamID := task.GetUpstreamTaskID()
  49. if upstreamID == "" {
  50. // 统计失败的未完成任务
  51. nullTaskIds = append(nullTaskIds, task.ID)
  52. continue
  53. }
  54. taskM[upstreamID] = task
  55. taskChannelM[task.ChannelId] = append(taskChannelM[task.ChannelId], upstreamID)
  56. }
  57. if len(nullTaskIds) > 0 {
  58. err := model.TaskBulkUpdateByID(nullTaskIds, map[string]any{
  59. "status": "FAILURE",
  60. "progress": "100%",
  61. })
  62. if err != nil {
  63. logger.LogError(ctx, fmt.Sprintf("Fix null task_id task error: %v", err))
  64. } else {
  65. logger.LogInfo(ctx, fmt.Sprintf("Fix null task_id task success: %v", nullTaskIds))
  66. }
  67. }
  68. if len(taskChannelM) == 0 {
  69. continue
  70. }
  71. DispatchPlatformUpdate(platform, taskChannelM, taskM)
  72. }
  73. common.SysLog("任务进度轮询完成")
  74. }
  75. }
  76. // DispatchPlatformUpdate 按平台分发轮询更新
  77. func DispatchPlatformUpdate(platform constant.TaskPlatform, taskChannelM map[int][]string, taskM map[string]*model.Task) {
  78. switch platform {
  79. case constant.TaskPlatformMidjourney:
  80. // MJ 轮询由其自身处理,这里预留入口
  81. case constant.TaskPlatformSuno:
  82. _ = UpdateSunoTasks(context.Background(), taskChannelM, taskM)
  83. default:
  84. if err := UpdateVideoTasks(context.Background(), platform, taskChannelM, taskM); err != nil {
  85. common.SysLog(fmt.Sprintf("UpdateVideoTasks fail: %s", err))
  86. }
  87. }
  88. }
  89. // UpdateSunoTasks 按渠道更新所有 Suno 任务
  90. func UpdateSunoTasks(ctx context.Context, taskChannelM map[int][]string, taskM map[string]*model.Task) error {
  91. for channelId, taskIds := range taskChannelM {
  92. err := updateSunoTasks(ctx, channelId, taskIds, taskM)
  93. if err != nil {
  94. logger.LogError(ctx, fmt.Sprintf("渠道 #%d 更新异步任务失败: %s", channelId, err.Error()))
  95. }
  96. }
  97. return nil
  98. }
  99. func updateSunoTasks(ctx context.Context, channelId int, taskIds []string, taskM map[string]*model.Task) error {
  100. logger.LogInfo(ctx, fmt.Sprintf("渠道 #%d 未完成的任务有: %d", channelId, len(taskIds)))
  101. if len(taskIds) == 0 {
  102. return nil
  103. }
  104. ch, err := model.CacheGetChannel(channelId)
  105. if err != nil {
  106. common.SysLog(fmt.Sprintf("CacheGetChannel: %v", err))
  107. // Collect DB primary key IDs for bulk update (taskIds are upstream IDs, not task_id column values)
  108. var failedIDs []int64
  109. for _, upstreamID := range taskIds {
  110. if t, ok := taskM[upstreamID]; ok {
  111. failedIDs = append(failedIDs, t.ID)
  112. }
  113. }
  114. err = model.TaskBulkUpdateByID(failedIDs, map[string]any{
  115. "fail_reason": fmt.Sprintf("获取渠道信息失败,请联系管理员,渠道ID:%d", channelId),
  116. "status": "FAILURE",
  117. "progress": "100%",
  118. })
  119. if err != nil {
  120. common.SysLog(fmt.Sprintf("UpdateSunoTask error: %v", err))
  121. }
  122. return err
  123. }
  124. adaptor := GetTaskAdaptorFunc(constant.TaskPlatformSuno)
  125. if adaptor == nil {
  126. return errors.New("adaptor not found")
  127. }
  128. proxy := ch.GetSetting().Proxy
  129. resp, err := adaptor.FetchTask(*ch.BaseURL, ch.Key, map[string]any{
  130. "ids": taskIds,
  131. }, proxy)
  132. if err != nil {
  133. common.SysLog(fmt.Sprintf("Get Task Do req error: %v", err))
  134. return err
  135. }
  136. if resp.StatusCode != http.StatusOK {
  137. logger.LogError(ctx, fmt.Sprintf("Get Task status code: %d", resp.StatusCode))
  138. return fmt.Errorf("Get Task status code: %d", resp.StatusCode)
  139. }
  140. defer resp.Body.Close()
  141. responseBody, err := io.ReadAll(resp.Body)
  142. if err != nil {
  143. common.SysLog(fmt.Sprintf("Get Task parse body error: %v", err))
  144. return err
  145. }
  146. var responseItems dto.TaskResponse[[]dto.SunoDataResponse]
  147. err = common.Unmarshal(responseBody, &responseItems)
  148. if err != nil {
  149. logger.LogError(ctx, fmt.Sprintf("Get Task parse body error2: %v, body: %s", err, string(responseBody)))
  150. return err
  151. }
  152. if !responseItems.IsSuccess() {
  153. common.SysLog(fmt.Sprintf("渠道 #%d 未完成的任务有: %d, 成功获取到任务数: %s", channelId, len(taskIds), string(responseBody)))
  154. return err
  155. }
  156. for _, responseItem := range responseItems.Data {
  157. task := taskM[responseItem.TaskID]
  158. if !taskNeedsUpdate(task, responseItem) {
  159. continue
  160. }
  161. task.Status = lo.If(model.TaskStatus(responseItem.Status) != "", model.TaskStatus(responseItem.Status)).Else(task.Status)
  162. task.FailReason = lo.If(responseItem.FailReason != "", responseItem.FailReason).Else(task.FailReason)
  163. task.SubmitTime = lo.If(responseItem.SubmitTime != 0, responseItem.SubmitTime).Else(task.SubmitTime)
  164. task.StartTime = lo.If(responseItem.StartTime != 0, responseItem.StartTime).Else(task.StartTime)
  165. task.FinishTime = lo.If(responseItem.FinishTime != 0, responseItem.FinishTime).Else(task.FinishTime)
  166. if responseItem.FailReason != "" || task.Status == model.TaskStatusFailure {
  167. logger.LogInfo(ctx, task.TaskID+" 构建失败,"+task.FailReason)
  168. task.Progress = "100%"
  169. RefundTaskQuota(ctx, task, task.FailReason)
  170. }
  171. if responseItem.Status == model.TaskStatusSuccess {
  172. task.Progress = "100%"
  173. }
  174. task.Data = responseItem.Data
  175. err = task.Update()
  176. if err != nil {
  177. common.SysLog("UpdateSunoTask task error: " + err.Error())
  178. }
  179. }
  180. return nil
  181. }
  182. // taskNeedsUpdate 检查 Suno 任务是否需要更新
  183. func taskNeedsUpdate(oldTask *model.Task, newTask dto.SunoDataResponse) bool {
  184. if oldTask.SubmitTime != newTask.SubmitTime {
  185. return true
  186. }
  187. if oldTask.StartTime != newTask.StartTime {
  188. return true
  189. }
  190. if oldTask.FinishTime != newTask.FinishTime {
  191. return true
  192. }
  193. if string(oldTask.Status) != newTask.Status {
  194. return true
  195. }
  196. if oldTask.FailReason != newTask.FailReason {
  197. return true
  198. }
  199. if (oldTask.Status == model.TaskStatusFailure || oldTask.Status == model.TaskStatusSuccess) && oldTask.Progress != "100%" {
  200. return true
  201. }
  202. oldData, _ := common.Marshal(oldTask.Data)
  203. newData, _ := common.Marshal(newTask.Data)
  204. sort.Slice(oldData, func(i, j int) bool {
  205. return oldData[i] < oldData[j]
  206. })
  207. sort.Slice(newData, func(i, j int) bool {
  208. return newData[i] < newData[j]
  209. })
  210. if string(oldData) != string(newData) {
  211. return true
  212. }
  213. return false
  214. }
  215. // UpdateVideoTasks 按渠道更新所有视频任务
  216. func UpdateVideoTasks(ctx context.Context, platform constant.TaskPlatform, taskChannelM map[int][]string, taskM map[string]*model.Task) error {
  217. for channelId, taskIds := range taskChannelM {
  218. if err := updateVideoTasks(ctx, platform, channelId, taskIds, taskM); err != nil {
  219. logger.LogError(ctx, fmt.Sprintf("Channel #%d failed to update video async tasks: %s", channelId, err.Error()))
  220. }
  221. }
  222. return nil
  223. }
  224. func updateVideoTasks(ctx context.Context, platform constant.TaskPlatform, channelId int, taskIds []string, taskM map[string]*model.Task) error {
  225. logger.LogInfo(ctx, fmt.Sprintf("Channel #%d pending video tasks: %d", channelId, len(taskIds)))
  226. if len(taskIds) == 0 {
  227. return nil
  228. }
  229. cacheGetChannel, err := model.CacheGetChannel(channelId)
  230. if err != nil {
  231. // Collect DB primary key IDs for bulk update (taskIds are upstream IDs, not task_id column values)
  232. var failedIDs []int64
  233. for _, upstreamID := range taskIds {
  234. if t, ok := taskM[upstreamID]; ok {
  235. failedIDs = append(failedIDs, t.ID)
  236. }
  237. }
  238. errUpdate := model.TaskBulkUpdateByID(failedIDs, map[string]any{
  239. "fail_reason": fmt.Sprintf("Failed to get channel info, channel ID: %d", channelId),
  240. "status": "FAILURE",
  241. "progress": "100%",
  242. })
  243. if errUpdate != nil {
  244. common.SysLog(fmt.Sprintf("UpdateVideoTask error: %v", errUpdate))
  245. }
  246. return fmt.Errorf("CacheGetChannel failed: %w", err)
  247. }
  248. adaptor := GetTaskAdaptorFunc(platform)
  249. if adaptor == nil {
  250. return fmt.Errorf("video adaptor not found")
  251. }
  252. info := &relaycommon.RelayInfo{}
  253. info.ChannelMeta = &relaycommon.ChannelMeta{
  254. ChannelBaseUrl: cacheGetChannel.GetBaseURL(),
  255. }
  256. info.ApiKey = cacheGetChannel.Key
  257. adaptor.Init(info)
  258. for _, taskId := range taskIds {
  259. if err := updateVideoSingleTask(ctx, adaptor, cacheGetChannel, taskId, taskM); err != nil {
  260. logger.LogError(ctx, fmt.Sprintf("Failed to update video task %s: %s", taskId, err.Error()))
  261. }
  262. }
  263. return nil
  264. }
  265. func updateVideoSingleTask(ctx context.Context, adaptor TaskPollingAdaptor, ch *model.Channel, taskId string, taskM map[string]*model.Task) error {
  266. baseURL := constant.ChannelBaseURLs[ch.Type]
  267. if ch.GetBaseURL() != "" {
  268. baseURL = ch.GetBaseURL()
  269. }
  270. proxy := ch.GetSetting().Proxy
  271. task := taskM[taskId]
  272. if task == nil {
  273. logger.LogError(ctx, fmt.Sprintf("Task %s not found in taskM", taskId))
  274. return fmt.Errorf("task %s not found", taskId)
  275. }
  276. key := ch.Key
  277. privateData := task.PrivateData
  278. if privateData.Key != "" {
  279. key = privateData.Key
  280. }
  281. resp, err := adaptor.FetchTask(baseURL, key, map[string]any{
  282. "task_id": task.GetUpstreamTaskID(),
  283. "action": task.Action,
  284. }, proxy)
  285. if err != nil {
  286. return fmt.Errorf("fetchTask failed for task %s: %w", taskId, err)
  287. }
  288. defer resp.Body.Close()
  289. responseBody, err := io.ReadAll(resp.Body)
  290. if err != nil {
  291. return fmt.Errorf("readAll failed for task %s: %w", taskId, err)
  292. }
  293. logger.LogDebug(ctx, fmt.Sprintf("updateVideoSingleTask response: %s", string(responseBody)))
  294. taskResult := &relaycommon.TaskInfo{}
  295. // try parse as New API response format
  296. var responseItems dto.TaskResponse[model.Task]
  297. if err = common.Unmarshal(responseBody, &responseItems); err == nil && responseItems.IsSuccess() {
  298. logger.LogDebug(ctx, fmt.Sprintf("updateVideoSingleTask parsed as new api response format: %+v", responseItems))
  299. t := responseItems.Data
  300. taskResult.TaskID = t.TaskID
  301. taskResult.Status = string(t.Status)
  302. taskResult.Url = t.GetResultURL()
  303. taskResult.Progress = t.Progress
  304. taskResult.Reason = t.FailReason
  305. task.Data = t.Data
  306. } else if taskResult, err = adaptor.ParseTaskResult(responseBody); err != nil {
  307. return fmt.Errorf("parseTaskResult failed for task %s: %w", taskId, err)
  308. } else {
  309. task.Data = redactVideoResponseBody(responseBody)
  310. }
  311. logger.LogDebug(ctx, fmt.Sprintf("updateVideoSingleTask taskResult: %+v", taskResult))
  312. now := time.Now().Unix()
  313. if taskResult.Status == "" {
  314. taskResult = relaycommon.FailTaskInfo("upstream returned empty status")
  315. }
  316. // 记录原本的状态,防止重复退款
  317. shouldRefund := false
  318. quota := task.Quota
  319. preStatus := task.Status
  320. task.Status = model.TaskStatus(taskResult.Status)
  321. switch taskResult.Status {
  322. case model.TaskStatusSubmitted:
  323. task.Progress = taskcommon.ProgressSubmitted
  324. case model.TaskStatusQueued:
  325. task.Progress = taskcommon.ProgressQueued
  326. case model.TaskStatusInProgress:
  327. task.Progress = taskcommon.ProgressInProgress
  328. if task.StartTime == 0 {
  329. task.StartTime = now
  330. }
  331. case model.TaskStatusSuccess:
  332. task.Progress = taskcommon.ProgressComplete
  333. if task.FinishTime == 0 {
  334. task.FinishTime = now
  335. }
  336. if strings.HasPrefix(taskResult.Url, "data:") {
  337. // data: URI (e.g. Vertex base64 encoded video) — keep in Data, not in ResultURL
  338. } else if taskResult.Url != "" {
  339. // Direct upstream URL (e.g. Kling, Ali, Doubao, etc.)
  340. task.PrivateData.ResultURL = taskResult.Url
  341. } else {
  342. // No URL from adaptor — construct proxy URL using public task ID
  343. task.PrivateData.ResultURL = taskcommon.BuildProxyURL(task.TaskID)
  344. }
  345. // 如果返回了 total_tokens,根据模型倍率重新计费
  346. if taskResult.TotalTokens > 0 {
  347. RecalculateTaskQuotaByTokens(ctx, task, taskResult.TotalTokens)
  348. }
  349. case model.TaskStatusFailure:
  350. logger.LogJson(ctx, fmt.Sprintf("Task %s failed", taskId), task)
  351. task.Status = model.TaskStatusFailure
  352. task.Progress = taskcommon.ProgressComplete
  353. if task.FinishTime == 0 {
  354. task.FinishTime = now
  355. }
  356. task.FailReason = taskResult.Reason
  357. logger.LogInfo(ctx, fmt.Sprintf("Task %s failed: %s", task.TaskID, task.FailReason))
  358. taskResult.Progress = taskcommon.ProgressComplete
  359. if quota != 0 {
  360. if preStatus != model.TaskStatusFailure {
  361. shouldRefund = true
  362. } else {
  363. logger.LogWarn(ctx, fmt.Sprintf("Task %s already in failure status, skip refund", task.TaskID))
  364. }
  365. }
  366. default:
  367. return fmt.Errorf("unknown task status %s for task %s", taskResult.Status, taskId)
  368. }
  369. if taskResult.Progress != "" {
  370. task.Progress = taskResult.Progress
  371. }
  372. if err := task.Update(); err != nil {
  373. common.SysLog("UpdateVideoTask task error: " + err.Error())
  374. shouldRefund = false
  375. }
  376. if shouldRefund {
  377. RefundTaskQuota(ctx, task, task.FailReason)
  378. }
  379. return nil
  380. }
  381. func redactVideoResponseBody(body []byte) []byte {
  382. var m map[string]any
  383. if err := common.Unmarshal(body, &m); err != nil {
  384. return body
  385. }
  386. resp, _ := m["response"].(map[string]any)
  387. if resp != nil {
  388. delete(resp, "bytesBase64Encoded")
  389. if v, ok := resp["video"].(string); ok {
  390. resp["video"] = truncateBase64(v)
  391. }
  392. if vs, ok := resp["videos"].([]any); ok {
  393. for i := range vs {
  394. if vm, ok := vs[i].(map[string]any); ok {
  395. delete(vm, "bytesBase64Encoded")
  396. }
  397. }
  398. }
  399. }
  400. b, err := common.Marshal(m)
  401. if err != nil {
  402. return body
  403. }
  404. return b
  405. }
  406. func truncateBase64(s string) string {
  407. const maxKeep = 256
  408. if len(s) <= maxKeep {
  409. return s
  410. }
  411. return s[:maxKeep] + "..."
  412. }