|
@@ -139,27 +139,38 @@ func coverMidjourneyTaskDto(c *gin.Context, originTask *model.Midjourney) (midjo
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
func RelayMidjourneyTaskImageSeed(c *gin.Context) *dto.MidjourneyResponse {
|
|
func RelayMidjourneyTaskImageSeed(c *gin.Context) *dto.MidjourneyResponse {
|
|
|
- //taskId := c.Param("id")
|
|
|
|
|
- //userId := c.GetInt("id")
|
|
|
|
|
- //originTask := model.GetByMJId(userId, taskId)
|
|
|
|
|
- //if originTask == nil {
|
|
|
|
|
- // return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_no_found")
|
|
|
|
|
- //}
|
|
|
|
|
- //channel, err := model.GetChannelById(originTask.ChannelId, false)
|
|
|
|
|
- //if err != nil {
|
|
|
|
|
- // return service.MidjourneyErrorWrapper(constant.MjRequestError, "get_channel_info_failed")
|
|
|
|
|
- //}
|
|
|
|
|
- //if channel.Status != common.ChannelStatusEnabled {
|
|
|
|
|
- // return service.MidjourneyErrorWrapper(constant.MjRequestError, "该任务所属渠道已被禁用")
|
|
|
|
|
- //}
|
|
|
|
|
- //c.Set("channel_id", originTask.ChannelId)
|
|
|
|
|
- //requestURL := c.Request.URL.String()
|
|
|
|
|
- //fullRequestURL := fmt.Sprintf("%s%s", channel.GetBaseURL(), requestURL)
|
|
|
|
|
- //req, err := http.NewRequest(c.Request.Method, fullRequestURL, c.Request.Body)
|
|
|
|
|
- //if err != nil {
|
|
|
|
|
- // return service.MidjourneyErrorWrapper(constant.MjRequestError, "create_request_failed")
|
|
|
|
|
- //}
|
|
|
|
|
- log.Println("RelayMidjourneyTaskImageSeed")
|
|
|
|
|
|
|
+ taskId := c.Param("id")
|
|
|
|
|
+ userId := c.GetInt("id")
|
|
|
|
|
+ originTask := model.GetByMJId(userId, taskId)
|
|
|
|
|
+ if originTask == nil {
|
|
|
|
|
+ return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_no_found")
|
|
|
|
|
+ }
|
|
|
|
|
+ channel, err := model.GetChannelById(originTask.ChannelId, true)
|
|
|
|
|
+ if err != nil {
|
|
|
|
|
+ return service.MidjourneyErrorWrapper(constant.MjRequestError, "get_channel_info_failed")
|
|
|
|
|
+ }
|
|
|
|
|
+ if channel.Status != common.ChannelStatusEnabled {
|
|
|
|
|
+ return service.MidjourneyErrorWrapper(constant.MjRequestError, "该任务所属渠道已被禁用")
|
|
|
|
|
+ }
|
|
|
|
|
+ c.Set("channel_id", originTask.ChannelId)
|
|
|
|
|
+ c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
|
|
|
|
|
+
|
|
|
|
|
+ requestURL := c.Request.URL.String()
|
|
|
|
|
+ fullRequestURL := fmt.Sprintf("%s%s", channel.GetBaseURL(), requestURL)
|
|
|
|
|
+ midjResponseWithStatus, _, err := service.DoMidjourneyHttpRequest(c, time.Second*30, fullRequestURL, nil)
|
|
|
|
|
+ if err != nil {
|
|
|
|
|
+ return &midjResponseWithStatus.Response
|
|
|
|
|
+ }
|
|
|
|
|
+ midjResponse := &midjResponseWithStatus.Response
|
|
|
|
|
+ c.Writer.WriteHeader(midjResponseWithStatus.StatusCode)
|
|
|
|
|
+ respBody, err := json.Marshal(midjResponse)
|
|
|
|
|
+ if err != nil {
|
|
|
|
|
+ return service.MidjourneyErrorWrapper(constant.MjRequestError, "unmarshal_response_body_failed")
|
|
|
|
|
+ }
|
|
|
|
|
+ _, err = io.Copy(c.Writer, bytes.NewBuffer(respBody))
|
|
|
|
|
+ if err != nil {
|
|
|
|
|
+ return service.MidjourneyErrorWrapper(constant.MjRequestError, "copy_response_body_failed")
|
|
|
|
|
+ }
|
|
|
return nil
|
|
return nil
|
|
|
}
|
|
}
|
|
|
|
|
|
|
@@ -297,7 +308,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
|
|
|
} else if originTask.Status != "SUCCESS" && relayMode != relayconstant.RelayModeMidjourneyModal {
|
|
} else if originTask.Status != "SUCCESS" && relayMode != relayconstant.RelayModeMidjourneyModal {
|
|
|
return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_status_not_success")
|
|
return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_status_not_success")
|
|
|
} else { //原任务的Status=SUCCESS,则可以做放大UPSCALE、变换VARIATION等动作,此时必须使用原来的请求地址才能正确处理
|
|
} else { //原任务的Status=SUCCESS,则可以做放大UPSCALE、变换VARIATION等动作,此时必须使用原来的请求地址才能正确处理
|
|
|
- channel, err := model.GetChannelById(originTask.ChannelId, false)
|
|
|
|
|
|
|
+ channel, err := model.GetChannelById(originTask.ChannelId, true)
|
|
|
if err != nil {
|
|
if err != nil {
|
|
|
return service.MidjourneyErrorWrapper(constant.MjRequestError, "get_channel_info_failed")
|
|
return service.MidjourneyErrorWrapper(constant.MjRequestError, "get_channel_info_failed")
|
|
|
}
|
|
}
|
|
@@ -306,6 +317,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
|
|
|
}
|
|
}
|
|
|
c.Set("base_url", channel.GetBaseURL())
|
|
c.Set("base_url", channel.GetBaseURL())
|
|
|
c.Set("channel_id", originTask.ChannelId)
|
|
c.Set("channel_id", originTask.ChannelId)
|
|
|
|
|
+ c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
|
|
|
log.Printf("检测到此操作为放大、变换、重绘,获取原channel信息: %s,%s", strconv.Itoa(originTask.ChannelId), channel.GetBaseURL())
|
|
log.Printf("检测到此操作为放大、变换、重绘,获取原channel信息: %s,%s", strconv.Itoa(originTask.ChannelId), channel.GetBaseURL())
|
|
|
}
|
|
}
|
|
|
midjRequest.Prompt = originTask.Prompt
|
|
midjRequest.Prompt = originTask.Prompt
|