midjourney.go 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201
  1. package controller
  2. import (
  3. "bytes"
  4. "context"
  5. "encoding/json"
  6. "fmt"
  7. "github.com/gin-gonic/gin"
  8. "io"
  9. "log"
  10. "net/http"
  11. "one-api/common"
  12. "one-api/model"
  13. "strconv"
  14. "strings"
  15. "time"
  16. )
  17. func UpdateMidjourneyTask() {
  18. //revocer
  19. imageModel := "midjourney"
  20. defer func() {
  21. if err := recover(); err != nil {
  22. log.Printf("UpdateMidjourneyTask panic: %v", err)
  23. }
  24. }()
  25. for {
  26. time.Sleep(time.Duration(15) * time.Second)
  27. tasks := model.GetAllUnFinishTasks()
  28. if len(tasks) != 0 {
  29. log.Printf("检测到未完成的任务数有: %v", len(tasks))
  30. for _, task := range tasks {
  31. log.Printf("未完成的任务信息: %v", task)
  32. midjourneyChannel, err := model.GetChannelById(task.ChannelId, true)
  33. if err != nil {
  34. log.Printf("UpdateMidjourneyTask: %v", err)
  35. task.FailReason = fmt.Sprintf("获取渠道信息失败,请联系管理员,渠道ID:%d", task.ChannelId)
  36. task.Status = "FAILURE"
  37. task.Progress = "100%"
  38. err := task.Update()
  39. if err != nil {
  40. log.Printf("UpdateMidjourneyTask error: %v", err)
  41. }
  42. continue
  43. }
  44. requestUrl := fmt.Sprintf("%s/mj/task/%s/fetch", *midjourneyChannel.BaseURL, task.MjId)
  45. log.Printf("requestUrl: %s", requestUrl)
  46. req, err := http.NewRequest("GET", requestUrl, bytes.NewBuffer([]byte("")))
  47. if err != nil {
  48. log.Printf("UpdateMidjourneyTask error: %v", err)
  49. continue
  50. }
  51. // 设置超时时间
  52. timeout := time.Second * 5
  53. ctx, cancel := context.WithTimeout(context.Background(), timeout)
  54. // 使用带有超时的 context 创建新的请求
  55. req = req.WithContext(ctx)
  56. req.Header.Set("Content-Type", "application/json")
  57. //req.Header.Set("Authorization", "Bearer midjourney-proxy")
  58. req.Header.Set("mj-api-secret", midjourneyChannel.Key)
  59. resp, err := httpClient.Do(req)
  60. if err != nil {
  61. log.Printf("UpdateMidjourneyTask error: %v", err)
  62. continue
  63. }
  64. responseBody, err := io.ReadAll(resp.Body)
  65. resp.Body.Close()
  66. log.Printf("responseBody: %s", string(responseBody))
  67. var responseItem Midjourney
  68. // err = json.NewDecoder(resp.Body).Decode(&responseItem)
  69. err = json.Unmarshal(responseBody, &responseItem)
  70. if err != nil {
  71. if strings.Contains(err.Error(), "cannot unmarshal number into Go struct field Midjourney.status of type string") {
  72. var responseWithoutStatus MidjourneyWithoutStatus
  73. var responseStatus MidjourneyStatus
  74. err1 := json.Unmarshal(responseBody, &responseWithoutStatus)
  75. err2 := json.Unmarshal(responseBody, &responseStatus)
  76. if err1 == nil && err2 == nil {
  77. jsonData, err3 := json.Marshal(responseWithoutStatus)
  78. if err3 != nil {
  79. log.Printf("UpdateMidjourneyTask error1: %v", err3)
  80. continue
  81. }
  82. err4 := json.Unmarshal(jsonData, &responseStatus)
  83. if err4 != nil {
  84. log.Printf("UpdateMidjourneyTask error2: %v", err4)
  85. continue
  86. }
  87. responseItem.Status = strconv.Itoa(responseStatus.Status)
  88. } else {
  89. log.Printf("UpdateMidjourneyTask error3: %v", err)
  90. continue
  91. }
  92. } else {
  93. log.Printf("UpdateMidjourneyTask error4: %v", err)
  94. continue
  95. }
  96. }
  97. task.Code = 1
  98. task.Progress = responseItem.Progress
  99. task.PromptEn = responseItem.PromptEn
  100. task.State = responseItem.State
  101. task.SubmitTime = responseItem.SubmitTime
  102. task.StartTime = responseItem.StartTime
  103. task.FinishTime = responseItem.FinishTime
  104. task.ImageUrl = responseItem.ImageUrl
  105. task.Status = responseItem.Status
  106. task.FailReason = responseItem.FailReason
  107. if task.Progress != "100%" && responseItem.FailReason != "" {
  108. log.Println(task.MjId + " 构建失败," + task.FailReason)
  109. task.Progress = "100%"
  110. err = model.CacheUpdateUserQuota(task.UserId)
  111. if err != nil {
  112. log.Println("error update user quota cache: " + err.Error())
  113. } else {
  114. modelRatio := common.GetModelRatio(imageModel)
  115. groupRatio := common.GetGroupRatio("default")
  116. ratio := modelRatio * groupRatio
  117. quota := int(ratio * 1 * 1000)
  118. if quota != 0 {
  119. err := model.IncreaseUserQuota(task.UserId, quota)
  120. if err != nil {
  121. log.Println("fail to increase user quota")
  122. }
  123. logContent := fmt.Sprintf("%s 构图失败,补偿 %s", task.MjId, common.LogQuota(quota))
  124. model.RecordLog(task.UserId, 1, logContent)
  125. }
  126. }
  127. }
  128. err = task.Update()
  129. if err != nil {
  130. log.Printf("UpdateMidjourneyTask error5: %v", err)
  131. }
  132. log.Printf("UpdateMidjourneyTask success: %v", task)
  133. cancel()
  134. }
  135. }
  136. }
  137. }
  138. func GetAllMidjourney(c *gin.Context) {
  139. p, _ := strconv.Atoi(c.Query("p"))
  140. if p < 0 {
  141. p = 0
  142. }
  143. // 解析其他查询参数
  144. queryParams := model.TaskQueryParams{
  145. ChannelID: c.Query("channel_id"),
  146. MjID: c.Query("mj_id"),
  147. StartTimestamp: c.Query("start_timestamp"),
  148. EndTimestamp: c.Query("end_timestamp"),
  149. }
  150. logs := model.GetAllTasks(p*common.ItemsPerPage, common.ItemsPerPage, queryParams)
  151. if logs == nil {
  152. logs = make([]*model.Midjourney, 0)
  153. }
  154. c.JSON(200, gin.H{
  155. "success": true,
  156. "message": "",
  157. "data": logs,
  158. })
  159. }
  160. func GetUserMidjourney(c *gin.Context) {
  161. p, _ := strconv.Atoi(c.Query("p"))
  162. if p < 0 {
  163. p = 0
  164. }
  165. userId := c.GetInt("id")
  166. log.Printf("userId = %d \n", userId)
  167. queryParams := model.TaskQueryParams{
  168. MjID: c.Query("mj_id"),
  169. StartTimestamp: c.Query("start_timestamp"),
  170. EndTimestamp: c.Query("end_timestamp"),
  171. }
  172. logs := model.GetAllUserTask(userId, p*common.ItemsPerPage, common.ItemsPerPage, queryParams)
  173. if logs == nil {
  174. logs = make([]*model.Midjourney, 0)
  175. }
  176. if !strings.Contains(common.ServerAddress, "localhost") {
  177. for i, midjourney := range logs {
  178. midjourney.ImageUrl = common.ServerAddress + "/mj/image/" + midjourney.MjId
  179. logs[i] = midjourney
  180. }
  181. }
  182. c.JSON(200, gin.H{
  183. "success": true,
  184. "message": "",
  185. "data": logs,
  186. })
  187. }