task_video.go 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141
  1. package controller
  2. import (
  3. "context"
  4. "encoding/json"
  5. "fmt"
  6. "io"
  7. "net/http"
  8. "one-api/common"
  9. "one-api/constant"
  10. "one-api/model"
  11. "one-api/relay"
  12. "one-api/relay/channel"
  13. )
  14. func UpdateVideoTaskAll(ctx context.Context, taskChannelM map[int][]string, taskM map[string]*model.Task) error {
  15. for channelId, taskIds := range taskChannelM {
  16. if err := updateVideoTaskAll(ctx, channelId, taskIds, taskM); err != nil {
  17. common.LogError(ctx, fmt.Sprintf("Channel #%d failed to update video async tasks: %s", channelId, err.Error()))
  18. }
  19. }
  20. return nil
  21. }
  22. func updateVideoTaskAll(ctx context.Context, channelId int, taskIds []string, taskM map[string]*model.Task) error {
  23. common.LogInfo(ctx, fmt.Sprintf("Channel #%d pending video tasks: %d", channelId, len(taskIds)))
  24. if len(taskIds) == 0 {
  25. return nil
  26. }
  27. cacheGetChannel, err := model.CacheGetChannel(channelId)
  28. if err != nil {
  29. errUpdate := model.TaskBulkUpdate(taskIds, map[string]any{
  30. "fail_reason": fmt.Sprintf("Failed to get channel info, channel ID: %d", channelId),
  31. "status": "FAILURE",
  32. "progress": "100%",
  33. })
  34. if errUpdate != nil {
  35. common.SysError(fmt.Sprintf("UpdateVideoTask error: %v", errUpdate))
  36. }
  37. return fmt.Errorf("CacheGetChannel failed: %w", err)
  38. }
  39. adaptor := relay.GetTaskAdaptor(constant.TaskPlatformKling)
  40. if adaptor == nil {
  41. return fmt.Errorf("video adaptor not found")
  42. }
  43. for _, taskId := range taskIds {
  44. if err := updateVideoSingleTask(ctx, adaptor, cacheGetChannel, taskId, taskM); err != nil {
  45. common.LogError(ctx, fmt.Sprintf("Failed to update video task %s: %s", taskId, err.Error()))
  46. }
  47. }
  48. return nil
  49. }
  50. func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, channel *model.Channel, taskId string, taskM map[string]*model.Task) error {
  51. baseURL := common.ChannelBaseURLs[channel.Type]
  52. if channel.GetBaseURL() != "" {
  53. baseURL = channel.GetBaseURL()
  54. }
  55. task := taskM[taskId]
  56. if task == nil {
  57. common.LogError(ctx, fmt.Sprintf("Task %s not found in taskM", taskId))
  58. return fmt.Errorf("task %s not found", taskId)
  59. }
  60. resp, err := adaptor.FetchTask(baseURL, channel.Key, map[string]any{
  61. "task_id": taskId,
  62. "action": task.Action,
  63. })
  64. if err != nil {
  65. return fmt.Errorf("FetchTask failed for task %s: %w", taskId, err)
  66. }
  67. if resp.StatusCode != http.StatusOK {
  68. return fmt.Errorf("Get Video Task status code: %d", resp.StatusCode)
  69. }
  70. defer resp.Body.Close()
  71. responseBody, err := io.ReadAll(resp.Body)
  72. if err != nil {
  73. return fmt.Errorf("ReadAll failed for task %s: %w", taskId, err)
  74. }
  75. var responseItem map[string]interface{}
  76. err = json.Unmarshal(responseBody, &responseItem)
  77. if err != nil {
  78. common.LogError(ctx, fmt.Sprintf("Failed to parse video task response body: %v, body: %s", err, string(responseBody)))
  79. return fmt.Errorf("Unmarshal failed for task %s: %w", taskId, err)
  80. }
  81. code, _ := responseItem["code"].(float64)
  82. if code != 0 {
  83. return fmt.Errorf("video task fetch failed for task %s", taskId)
  84. }
  85. data, ok := responseItem["data"].(map[string]interface{})
  86. if !ok {
  87. common.LogError(ctx, fmt.Sprintf("Video task data format error: %s", string(responseBody)))
  88. return fmt.Errorf("video task data format error for task %s", taskId)
  89. }
  90. if status, ok := data["task_status"].(string); ok {
  91. switch status {
  92. case "submitted", "queued":
  93. task.Status = model.TaskStatusSubmitted
  94. case "processing":
  95. task.Status = model.TaskStatusInProgress
  96. case "succeed":
  97. task.Status = model.TaskStatusSuccess
  98. task.Progress = "100%"
  99. if url, err := adaptor.ParseResultUrl(responseItem); err == nil {
  100. task.FailReason = url
  101. } else {
  102. common.LogWarn(ctx, fmt.Sprintf("Failed to get url from body for task %s: %s", task.TaskID, err.Error()))
  103. }
  104. case "failed":
  105. task.Status = model.TaskStatusFailure
  106. task.Progress = "100%"
  107. if reason, ok := data["fail_reason"].(string); ok {
  108. task.FailReason = reason
  109. }
  110. }
  111. }
  112. // If task failed, refund quota
  113. if task.Status == model.TaskStatusFailure {
  114. common.LogInfo(ctx, fmt.Sprintf("Task %s failed: %s", task.TaskID, task.FailReason))
  115. quota := task.Quota
  116. if quota != 0 {
  117. if err := model.IncreaseUserQuota(task.UserId, quota, false); err != nil {
  118. common.LogError(ctx, "Failed to increase user quota: "+err.Error())
  119. }
  120. logContent := fmt.Sprintf("Video async task failed %s, refund %s", task.TaskID, common.LogQuota(quota))
  121. model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
  122. }
  123. }
  124. task.Data = responseBody
  125. if err := task.Update(); err != nil {
  126. common.SysError("UpdateVideoTask task error: " + err.Error())
  127. }
  128. return nil
  129. }