midjourney.go 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133
  1. package controller
  2. import (
  3. "bytes"
  4. "encoding/json"
  5. "fmt"
  6. "github.com/gin-gonic/gin"
  7. "log"
  8. "net/http"
  9. "one-api/common"
  10. "one-api/model"
  11. "strconv"
  12. "time"
  13. )
  14. func UpdateMidjourneyTask() {
  15. imageModel := "midjourney"
  16. for {
  17. time.Sleep(time.Duration(15) * time.Second)
  18. tasks := model.GetAllUnFinishTasks()
  19. if len(tasks) != 0 {
  20. //log.Printf("UpdateMidjourneyTask: %v", time.Now())
  21. ids := make([]string, 0)
  22. for _, task := range tasks {
  23. ids = append(ids, task.MjId)
  24. }
  25. requestUrl := "http://107.173.171.147:8080/mj/task/list-by-condition"
  26. requestBody := map[string]interface{}{
  27. "ids": ids,
  28. }
  29. jsonStr, err := json.Marshal(requestBody)
  30. if err != nil {
  31. log.Printf("UpdateMidjourneyTask: %v", err)
  32. }
  33. req, err := http.NewRequest("POST", requestUrl, bytes.NewBuffer(jsonStr))
  34. if err != nil {
  35. log.Printf("UpdateMidjourneyTask: %v", err)
  36. }
  37. req.Header.Set("Content-Type", "application/json")
  38. req.Header.Set("mj-api-secret", "uhiftyuwadbkjshbiklahcuitguasguzhxliawodawdu")
  39. resp, err := httpClient.Do(req)
  40. if err != nil {
  41. log.Printf("UpdateMidjourneyTask: %v", err)
  42. }
  43. defer resp.Body.Close()
  44. var response []Midjourney
  45. err = json.NewDecoder(resp.Body).Decode(&response)
  46. if err != nil {
  47. log.Printf("UpdateMidjourneyTask: %v", err)
  48. }
  49. for _, responseItem := range response {
  50. var midjourneyTask *model.Midjourney
  51. for _, mj := range tasks {
  52. mj.MjId = responseItem.MjId
  53. midjourneyTask = model.GetMjByuId(mj.Id)
  54. }
  55. if midjourneyTask != nil {
  56. midjourneyTask.Code = 1
  57. midjourneyTask.Progress = responseItem.Progress
  58. midjourneyTask.PromptEn = responseItem.PromptEn
  59. midjourneyTask.State = responseItem.State
  60. midjourneyTask.SubmitTime = responseItem.SubmitTime
  61. midjourneyTask.StartTime = responseItem.StartTime
  62. midjourneyTask.FinishTime = responseItem.FinishTime
  63. midjourneyTask.ImageUrl = responseItem.ImageUrl
  64. midjourneyTask.Status = responseItem.Status
  65. midjourneyTask.FailReason = responseItem.FailReason
  66. if midjourneyTask.Progress != "100%" && responseItem.FailReason != "" {
  67. log.Println(midjourneyTask.MjId + " 构建失败," + midjourneyTask.FailReason)
  68. midjourneyTask.Progress = "100%"
  69. err = model.CacheUpdateUserQuota(midjourneyTask.UserId)
  70. if err != nil {
  71. log.Println("error update user quota cache: " + err.Error())
  72. } else {
  73. modelRatio := common.GetModelRatio(imageModel)
  74. groupRatio := common.GetGroupRatio("default")
  75. ratio := modelRatio * groupRatio
  76. quota := int(ratio * 1 * 1000)
  77. if quota != 0 {
  78. err := model.IncreaseUserQuota(midjourneyTask.UserId, quota)
  79. if err != nil {
  80. log.Println("fail to increase user quota")
  81. }
  82. logContent := fmt.Sprintf("%s 构图失败,补偿 %s", midjourneyTask.MjId, common.LogQuota(quota))
  83. model.RecordLog(midjourneyTask.UserId, 1, logContent)
  84. }
  85. }
  86. }
  87. err = midjourneyTask.Update()
  88. if err != nil {
  89. log.Printf("UpdateMidjourneyTaskFail: %v", err)
  90. }
  91. log.Printf("UpdateMidjourneyTask: %v", midjourneyTask)
  92. }
  93. }
  94. }
  95. }
  96. }
  97. func GetAllMidjourney(c *gin.Context) {
  98. p, _ := strconv.Atoi(c.Query("p"))
  99. if p < 0 {
  100. p = 0
  101. }
  102. logs := model.GetAllTasks(p*common.ItemsPerPage, common.ItemsPerPage)
  103. if logs == nil {
  104. logs = make([]*model.Midjourney, 0)
  105. }
  106. c.JSON(200, gin.H{
  107. "success": true,
  108. "message": "",
  109. "data": logs,
  110. })
  111. }
  112. func GetUserMidjourney(c *gin.Context) {
  113. p, _ := strconv.Atoi(c.Query("p"))
  114. if p < 0 {
  115. p = 0
  116. }
  117. userId := c.GetInt("id")
  118. log.Printf("userId = %d \n", userId)
  119. logs := model.GetAllUserTask(userId, p*common.ItemsPerPage, common.ItemsPerPage)
  120. if logs == nil {
  121. logs = make([]*model.Midjourney, 0)
  122. }
  123. c.JSON(200, gin.H{
  124. "success": true,
  125. "message": "",
  126. "data": logs,
  127. })
  128. }