midjourney.go 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137
  1. package controller
  2. import (
  3. "bytes"
  4. "encoding/json"
  5. "fmt"
  6. "github.com/gin-gonic/gin"
  7. "log"
  8. "net/http"
  9. "one-api/common"
  10. "one-api/model"
  11. "strconv"
  12. "time"
  13. )
  14. func UpdateMidjourneyTask() {
  15. //revocer
  16. imageModel := "midjourney"
  17. for {
  18. defer func() {
  19. if err := recover(); err != nil {
  20. log.Printf("UpdateMidjourneyTask panic: %v", err)
  21. }
  22. }()
  23. time.Sleep(time.Duration(15) * time.Second)
  24. tasks := model.GetAllUnFinishTasks()
  25. if len(tasks) != 0 {
  26. for _, task := range tasks {
  27. midjourneyChannel, err := model.GetChannelById(task.ChannelId, true)
  28. if err != nil {
  29. log.Printf("UpdateMidjourneyTask: %v", err)
  30. task.FailReason = fmt.Sprintf("获取渠道信息失败,请联系管理员,渠道ID:%d", task.ChannelId)
  31. task.Status = "FAILURE"
  32. task.Progress = "100%"
  33. err := task.Update()
  34. if err != nil {
  35. log.Printf("UpdateMidjourneyTask error: %v", err)
  36. }
  37. continue
  38. }
  39. requestUrl := fmt.Sprintf("%s/mj/task/%s/fetch", *midjourneyChannel.BaseURL, task.MjId)
  40. req, err := http.NewRequest("GET", requestUrl, bytes.NewBuffer([]byte("")))
  41. if err != nil {
  42. log.Printf("UpdateMidjourneyTask error: %v", err)
  43. continue
  44. }
  45. req.Header.Set("Content-Type", "application/json")
  46. req.Header.Set("mj-api-secret", midjourneyChannel.Key)
  47. resp, err := httpClient.Do(req)
  48. if err != nil {
  49. log.Printf("UpdateMidjourneyTask error: %v", err)
  50. continue
  51. }
  52. defer resp.Body.Close()
  53. var responseItem Midjourney
  54. err = json.NewDecoder(resp.Body).Decode(&responseItem)
  55. if err != nil {
  56. log.Printf("UpdateMidjourneyTask error: %v", err)
  57. continue
  58. }
  59. task.Code = 1
  60. task.Progress = responseItem.Progress
  61. task.PromptEn = responseItem.PromptEn
  62. task.State = responseItem.State
  63. task.SubmitTime = responseItem.SubmitTime
  64. task.StartTime = responseItem.StartTime
  65. task.FinishTime = responseItem.FinishTime
  66. task.ImageUrl = responseItem.ImageUrl
  67. task.Status = responseItem.Status
  68. task.FailReason = responseItem.FailReason
  69. if task.Progress != "100%" && responseItem.FailReason != "" {
  70. log.Println(task.MjId + " 构建失败," + task.FailReason)
  71. task.Progress = "100%"
  72. err = model.CacheUpdateUserQuota(task.UserId)
  73. if err != nil {
  74. log.Println("error update user quota cache: " + err.Error())
  75. } else {
  76. modelRatio := common.GetModelRatio(imageModel)
  77. groupRatio := common.GetGroupRatio("default")
  78. ratio := modelRatio * groupRatio
  79. quota := int(ratio * 1 * 1000)
  80. if quota != 0 {
  81. err := model.IncreaseUserQuota(task.UserId, quota)
  82. if err != nil {
  83. log.Println("fail to increase user quota")
  84. }
  85. logContent := fmt.Sprintf("%s 构图失败,补偿 %s", task.MjId, common.LogQuota(quota))
  86. model.RecordLog(task.UserId, 1, logContent)
  87. }
  88. }
  89. }
  90. err = task.Update()
  91. if err != nil {
  92. log.Printf("UpdateMidjourneyTask error: %v", err)
  93. }
  94. log.Printf("UpdateMidjourneyTask success: %v", task)
  95. }
  96. }
  97. }
  98. }
  99. func GetAllMidjourney(c *gin.Context) {
  100. p, _ := strconv.Atoi(c.Query("p"))
  101. if p < 0 {
  102. p = 0
  103. }
  104. logs := model.GetAllTasks(p*common.ItemsPerPage, common.ItemsPerPage)
  105. if logs == nil {
  106. logs = make([]*model.Midjourney, 0)
  107. }
  108. c.JSON(200, gin.H{
  109. "success": true,
  110. "message": "",
  111. "data": logs,
  112. })
  113. }
  114. func GetUserMidjourney(c *gin.Context) {
  115. p, _ := strconv.Atoi(c.Query("p"))
  116. if p < 0 {
  117. p = 0
  118. }
  119. userId := c.GetInt("id")
  120. log.Printf("userId = %d \n", userId)
  121. logs := model.GetAllUserTask(userId, p*common.ItemsPerPage, common.ItemsPerPage)
  122. if logs == nil {
  123. logs = make([]*model.Midjourney, 0)
  124. }
  125. c.JSON(200, gin.H{
  126. "success": true,
  127. "message": "",
  128. "data": logs,
  129. })
  130. }