midjourney.go 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139
  1. package controller
  2. import (
  3. "bytes"
  4. "encoding/json"
  5. "fmt"
  6. "github.com/gin-gonic/gin"
  7. "log"
  8. "net/http"
  9. "one-api/common"
  10. "one-api/model"
  11. "strconv"
  12. "time"
  13. )
  14. func UpdateMidjourneyTask() {
  15. //revocer
  16. defer func() {
  17. if err := recover(); err != nil {
  18. log.Printf("UpdateMidjourneyTask: %v", err)
  19. }
  20. }()
  21. imageModel := "midjourney"
  22. for {
  23. time.Sleep(time.Duration(15) * time.Second)
  24. tasks := model.GetAllUnFinishTasks()
  25. if len(tasks) != 0 {
  26. //log.Printf("UpdateMidjourneyTask: %v", time.Now())
  27. ids := make([]string, 0)
  28. for _, task := range tasks {
  29. ids = append(ids, task.MjId)
  30. }
  31. requestUrl := "http://107.173.171.147:8080/mj/task/list-by-condition"
  32. requestBody := map[string]interface{}{
  33. "ids": ids,
  34. }
  35. jsonStr, err := json.Marshal(requestBody)
  36. if err != nil {
  37. log.Printf("UpdateMidjourneyTask: %v", err)
  38. }
  39. req, err := http.NewRequest("POST", requestUrl, bytes.NewBuffer(jsonStr))
  40. if err != nil {
  41. log.Printf("UpdateMidjourneyTask: %v", err)
  42. }
  43. req.Header.Set("Content-Type", "application/json")
  44. req.Header.Set("mj-api-secret", "uhiftyuwadbkjshbiklahcuitguasguzhxliawodawdu")
  45. resp, err := httpClient.Do(req)
  46. if err != nil {
  47. log.Printf("UpdateMidjourneyTask: %v", err)
  48. }
  49. defer resp.Body.Close()
  50. var response []Midjourney
  51. err = json.NewDecoder(resp.Body).Decode(&response)
  52. if err != nil {
  53. log.Printf("UpdateMidjourneyTask: %v", err)
  54. }
  55. for _, responseItem := range response {
  56. var midjourneyTask *model.Midjourney
  57. for _, mj := range tasks {
  58. mj.MjId = responseItem.MjId
  59. midjourneyTask = model.GetMjByuId(mj.Id)
  60. }
  61. if midjourneyTask != nil {
  62. midjourneyTask.Code = 1
  63. midjourneyTask.Progress = responseItem.Progress
  64. midjourneyTask.PromptEn = responseItem.PromptEn
  65. midjourneyTask.State = responseItem.State
  66. midjourneyTask.SubmitTime = responseItem.SubmitTime
  67. midjourneyTask.StartTime = responseItem.StartTime
  68. midjourneyTask.FinishTime = responseItem.FinishTime
  69. midjourneyTask.ImageUrl = responseItem.ImageUrl
  70. midjourneyTask.Status = responseItem.Status
  71. midjourneyTask.FailReason = responseItem.FailReason
  72. if midjourneyTask.Progress != "100%" && responseItem.FailReason != "" {
  73. log.Println(midjourneyTask.MjId + " 构建失败," + midjourneyTask.FailReason)
  74. midjourneyTask.Progress = "100%"
  75. err = model.CacheUpdateUserQuota(midjourneyTask.UserId)
  76. if err != nil {
  77. log.Println("error update user quota cache: " + err.Error())
  78. } else {
  79. modelRatio := common.GetModelRatio(imageModel)
  80. groupRatio := common.GetGroupRatio("default")
  81. ratio := modelRatio * groupRatio
  82. quota := int(ratio * 1 * 1000)
  83. if quota != 0 {
  84. err := model.IncreaseUserQuota(midjourneyTask.UserId, quota)
  85. if err != nil {
  86. log.Println("fail to increase user quota")
  87. }
  88. logContent := fmt.Sprintf("%s 构图失败,补偿 %s", midjourneyTask.MjId, common.LogQuota(quota))
  89. model.RecordLog(midjourneyTask.UserId, 1, logContent)
  90. }
  91. }
  92. }
  93. err = midjourneyTask.Update()
  94. if err != nil {
  95. log.Printf("UpdateMidjourneyTaskFail: %v", err)
  96. }
  97. log.Printf("UpdateMidjourneyTask: %v", midjourneyTask)
  98. }
  99. }
  100. }
  101. }
  102. }
  103. func GetAllMidjourney(c *gin.Context) {
  104. p, _ := strconv.Atoi(c.Query("p"))
  105. if p < 0 {
  106. p = 0
  107. }
  108. logs := model.GetAllTasks(p*common.ItemsPerPage, common.ItemsPerPage)
  109. if logs == nil {
  110. logs = make([]*model.Midjourney, 0)
  111. }
  112. c.JSON(200, gin.H{
  113. "success": true,
  114. "message": "",
  115. "data": logs,
  116. })
  117. }
  118. func GetUserMidjourney(c *gin.Context) {
  119. p, _ := strconv.Atoi(c.Query("p"))
  120. if p < 0 {
  121. p = 0
  122. }
  123. userId := c.GetInt("id")
  124. log.Printf("userId = %d \n", userId)
  125. logs := model.GetAllUserTask(userId, p*common.ItemsPerPage, common.ItemsPerPage)
  126. if logs == nil {
  127. logs = make([]*model.Midjourney, 0)
  128. }
  129. c.JSON(200, gin.H{
  130. "success": true,
  131. "message": "",
  132. "data": logs,
  133. })
  134. }