midjourney.go 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143
  1. package controller
  2. import (
  3. "bytes"
  4. "encoding/json"
  5. "fmt"
  6. "github.com/gin-gonic/gin"
  7. "log"
  8. "net/http"
  9. "one-api/common"
  10. "one-api/model"
  11. "strconv"
  12. "time"
  13. )
  14. func UpdateMidjourneyTask() {
  15. //revocer
  16. defer func() {
  17. if err := recover(); err != nil {
  18. log.Printf("UpdateMidjourneyTask: %v", err)
  19. }
  20. }()
  21. imageModel := "midjourney"
  22. for {
  23. time.Sleep(time.Duration(15) * time.Second)
  24. tasks := model.GetAllUnFinishTasks()
  25. if len(tasks) != 0 {
  26. //log.Printf("UpdateMidjourneyTask: %v", time.Now())
  27. ids := make([]string, 0)
  28. for _, task := range tasks {
  29. ids = append(ids, task.MjId)
  30. }
  31. requestUrl := "http://107.173.171.147:8080/mj/task/list-by-condition"
  32. requestBody := map[string]interface{}{
  33. "ids": ids,
  34. }
  35. jsonStr, err := json.Marshal(requestBody)
  36. if err != nil {
  37. log.Printf("UpdateMidjourneyTask: %v", err)
  38. continue
  39. }
  40. req, err := http.NewRequest("POST", requestUrl, bytes.NewBuffer(jsonStr))
  41. if err != nil {
  42. log.Printf("UpdateMidjourneyTask: %v", err)
  43. continue
  44. }
  45. req.Header.Set("Content-Type", "application/json")
  46. req.Header.Set("mj-api-secret", "uhiftyuwadbkjshbiklahcuitguasguzhxliawodawdu")
  47. resp, err := httpClient.Do(req)
  48. if err != nil {
  49. log.Printf("UpdateMidjourneyTask: %v", err)
  50. continue
  51. }
  52. defer resp.Body.Close()
  53. var response []Midjourney
  54. err = json.NewDecoder(resp.Body).Decode(&response)
  55. if err != nil {
  56. log.Printf("UpdateMidjourneyTask: %v", err)
  57. continue
  58. }
  59. for _, responseItem := range response {
  60. var midjourneyTask *model.Midjourney
  61. for _, mj := range tasks {
  62. mj.MjId = responseItem.MjId
  63. midjourneyTask = model.GetMjByuId(mj.Id)
  64. }
  65. if midjourneyTask != nil {
  66. midjourneyTask.Code = 1
  67. midjourneyTask.Progress = responseItem.Progress
  68. midjourneyTask.PromptEn = responseItem.PromptEn
  69. midjourneyTask.State = responseItem.State
  70. midjourneyTask.SubmitTime = responseItem.SubmitTime
  71. midjourneyTask.StartTime = responseItem.StartTime
  72. midjourneyTask.FinishTime = responseItem.FinishTime
  73. midjourneyTask.ImageUrl = responseItem.ImageUrl
  74. midjourneyTask.Status = responseItem.Status
  75. midjourneyTask.FailReason = responseItem.FailReason
  76. if midjourneyTask.Progress != "100%" && responseItem.FailReason != "" {
  77. log.Println(midjourneyTask.MjId + " 构建失败," + midjourneyTask.FailReason)
  78. midjourneyTask.Progress = "100%"
  79. err = model.CacheUpdateUserQuota(midjourneyTask.UserId)
  80. if err != nil {
  81. log.Println("error update user quota cache: " + err.Error())
  82. } else {
  83. modelRatio := common.GetModelRatio(imageModel)
  84. groupRatio := common.GetGroupRatio("default")
  85. ratio := modelRatio * groupRatio
  86. quota := int(ratio * 1 * 1000)
  87. if quota != 0 {
  88. err := model.IncreaseUserQuota(midjourneyTask.UserId, quota)
  89. if err != nil {
  90. log.Println("fail to increase user quota")
  91. }
  92. logContent := fmt.Sprintf("%s 构图失败,补偿 %s", midjourneyTask.MjId, common.LogQuota(quota))
  93. model.RecordLog(midjourneyTask.UserId, 1, logContent)
  94. }
  95. }
  96. }
  97. err = midjourneyTask.Update()
  98. if err != nil {
  99. log.Printf("UpdateMidjourneyTaskFail: %v", err)
  100. }
  101. log.Printf("UpdateMidjourneyTask: %v", midjourneyTask)
  102. }
  103. }
  104. }
  105. }
  106. }
  107. func GetAllMidjourney(c *gin.Context) {
  108. p, _ := strconv.Atoi(c.Query("p"))
  109. if p < 0 {
  110. p = 0
  111. }
  112. logs := model.GetAllTasks(p*common.ItemsPerPage, common.ItemsPerPage)
  113. if logs == nil {
  114. logs = make([]*model.Midjourney, 0)
  115. }
  116. c.JSON(200, gin.H{
  117. "success": true,
  118. "message": "",
  119. "data": logs,
  120. })
  121. }
  122. func GetUserMidjourney(c *gin.Context) {
  123. p, _ := strconv.Atoi(c.Query("p"))
  124. if p < 0 {
  125. p = 0
  126. }
  127. userId := c.GetInt("id")
  128. log.Printf("userId = %d \n", userId)
  129. logs := model.GetAllUserTask(userId, p*common.ItemsPerPage, common.ItemsPerPage)
  130. if logs == nil {
  131. logs = make([]*model.Midjourney, 0)
  132. }
  133. c.JSON(200, gin.H{
  134. "success": true,
  135. "message": "",
  136. "data": logs,
  137. })
  138. }