Browse Source

feat: support image-seed (close #86)

CaIon 2 years ago
parent
commit
bc5a54df59
1 changed files with 34 additions and 22 deletions
  1. 34 22
      relay/relay-mj.go

+ 34 - 22
relay/relay-mj.go

@@ -139,27 +139,38 @@ func coverMidjourneyTaskDto(c *gin.Context, originTask *model.Midjourney) (midjo
 }
 
 func RelayMidjourneyTaskImageSeed(c *gin.Context) *dto.MidjourneyResponse {
-	//taskId := c.Param("id")
-	//userId := c.GetInt("id")
-	//originTask := model.GetByMJId(userId, taskId)
-	//if originTask == nil {
-	//	return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_no_found")
-	//}
-	//channel, err := model.GetChannelById(originTask.ChannelId, false)
-	//if err != nil {
-	//	return service.MidjourneyErrorWrapper(constant.MjRequestError, "get_channel_info_failed")
-	//}
-	//if channel.Status != common.ChannelStatusEnabled {
-	//	return service.MidjourneyErrorWrapper(constant.MjRequestError, "该任务所属渠道已被禁用")
-	//}
-	//c.Set("channel_id", originTask.ChannelId)
-	//requestURL := c.Request.URL.String()
-	//fullRequestURL := fmt.Sprintf("%s%s", channel.GetBaseURL(), requestURL)
-	//req, err := http.NewRequest(c.Request.Method, fullRequestURL, c.Request.Body)
-	//if err != nil {
-	//	return service.MidjourneyErrorWrapper(constant.MjRequestError, "create_request_failed")
-	//}
-	log.Println("RelayMidjourneyTaskImageSeed")
+	taskId := c.Param("id")
+	userId := c.GetInt("id")
+	originTask := model.GetByMJId(userId, taskId)
+	if originTask == nil {
+		return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_no_found")
+	}
+	channel, err := model.GetChannelById(originTask.ChannelId, true)
+	if err != nil {
+		return service.MidjourneyErrorWrapper(constant.MjRequestError, "get_channel_info_failed")
+	}
+	if channel.Status != common.ChannelStatusEnabled {
+		return service.MidjourneyErrorWrapper(constant.MjRequestError, "该任务所属渠道已被禁用")
+	}
+	c.Set("channel_id", originTask.ChannelId)
+	c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
+
+	requestURL := c.Request.URL.String()
+	fullRequestURL := fmt.Sprintf("%s%s", channel.GetBaseURL(), requestURL)
+	midjResponseWithStatus, _, err := service.DoMidjourneyHttpRequest(c, time.Second*30, fullRequestURL, nil)
+	if err != nil {
+		return &midjResponseWithStatus.Response
+	}
+	midjResponse := &midjResponseWithStatus.Response
+	c.Writer.WriteHeader(midjResponseWithStatus.StatusCode)
+	respBody, err := json.Marshal(midjResponse)
+	if err != nil {
+		return service.MidjourneyErrorWrapper(constant.MjRequestError, "unmarshal_response_body_failed")
+	}
+	_, err = io.Copy(c.Writer, bytes.NewBuffer(respBody))
+	if err != nil {
+		return service.MidjourneyErrorWrapper(constant.MjRequestError, "copy_response_body_failed")
+	}
 	return nil
 }
 
@@ -297,7 +308,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
 		} else if originTask.Status != "SUCCESS" && relayMode != relayconstant.RelayModeMidjourneyModal {
 			return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_status_not_success")
 		} else { //原任务的Status=SUCCESS,则可以做放大UPSCALE、变换VARIATION等动作,此时必须使用原来的请求地址才能正确处理
-			channel, err := model.GetChannelById(originTask.ChannelId, false)
+			channel, err := model.GetChannelById(originTask.ChannelId, true)
 			if err != nil {
 				return service.MidjourneyErrorWrapper(constant.MjRequestError, "get_channel_info_failed")
 			}
@@ -306,6 +317,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
 			}
 			c.Set("base_url", channel.GetBaseURL())
 			c.Set("channel_id", originTask.ChannelId)
+			c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
 			log.Printf("检测到此操作为放大、变换、重绘,获取原channel信息: %s,%s", strconv.Itoa(originTask.ChannelId), channel.GetBaseURL())
 		}
 		midjRequest.Prompt = originTask.Prompt