|
|
@@ -138,6 +138,31 @@ func coverMidjourneyTaskDto(c *gin.Context, originTask *model.Midjourney) (midjo
|
|
|
return
|
|
|
}
|
|
|
|
|
|
+func RelayMidjourneyTaskImageSeed(c *gin.Context) *dto.MidjourneyResponse {
|
|
|
+ //taskId := c.Param("id")
|
|
|
+ //userId := c.GetInt("id")
|
|
|
+ //originTask := model.GetByMJId(userId, taskId)
|
|
|
+ //if originTask == nil {
|
|
|
+ // return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_no_found")
|
|
|
+ //}
|
|
|
+ //channel, err := model.GetChannelById(originTask.ChannelId, false)
|
|
|
+ //if err != nil {
|
|
|
+ // return service.MidjourneyErrorWrapper(constant.MjRequestError, "get_channel_info_failed")
|
|
|
+ //}
|
|
|
+ //if channel.Status != common.ChannelStatusEnabled {
|
|
|
+ // return service.MidjourneyErrorWrapper(constant.MjRequestError, "该任务所属渠道已被禁用")
|
|
|
+ //}
|
|
|
+ //c.Set("channel_id", originTask.ChannelId)
|
|
|
+ //requestURL := c.Request.URL.String()
|
|
|
+ //fullRequestURL := fmt.Sprintf("%s%s", channel.GetBaseURL(), requestURL)
|
|
|
+ //req, err := http.NewRequest(c.Request.Method, fullRequestURL, c.Request.Body)
|
|
|
+ //if err != nil {
|
|
|
+ // return service.MidjourneyErrorWrapper(constant.MjRequestError, "create_request_failed")
|
|
|
+ //}
|
|
|
+ log.Println("RelayMidjourneyTaskImageSeed")
|
|
|
+ return nil
|
|
|
+}
|
|
|
+
|
|
|
func RelayMidjourneyTask(c *gin.Context, relayMode int) *dto.MidjourneyResponse {
|
|
|
userId := c.GetInt("id")
|
|
|
var err error
|
|
|
@@ -259,11 +284,11 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
|
|
|
mjId = params.TaskId
|
|
|
midjRequest.Action = params.Action
|
|
|
} else if relayMode == relayconstant.RelayModeMidjourneyModal {
|
|
|
- if midjRequest.MaskBase64 == "" {
|
|
|
- return service.MidjourneyErrorWrapper(constant.MjRequestError, "mask_base64_is_required")
|
|
|
- }
|
|
|
+ //if midjRequest.MaskBase64 == "" {
|
|
|
+ // return service.MidjourneyErrorWrapper(constant.MjRequestError, "mask_base64_is_required")
|
|
|
+ //}
|
|
|
mjId = midjRequest.TaskId
|
|
|
- midjRequest.Action = constant.MjActionInPaint
|
|
|
+ midjRequest.Action = constant.MjActionModal
|
|
|
}
|
|
|
|
|
|
originTask := model.GetByMJId(userId, mjId)
|
|
|
@@ -293,29 +318,10 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
|
|
|
//}
|
|
|
}
|
|
|
|
|
|
- if midjRequest.Action == constant.MjActionInPaintPre {
|
|
|
+ if midjRequest.Action == constant.MjActionInPaint || midjRequest.Action == constant.MjActionCustomZoom {
|
|
|
consumeQuota = false
|
|
|
}
|
|
|
|
|
|
- // map model name
|
|
|
- //modelMapping := c.GetString("model_mapping")
|
|
|
- //isModelMapped := false
|
|
|
- //if modelMapping != "" {
|
|
|
- // modelMap := make(map[string]string)
|
|
|
- // err := json.Unmarshal([]byte(modelMapping), &modelMap)
|
|
|
- // if err != nil {
|
|
|
- // //return errorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
|
|
|
- // return &dto.MidjourneyResponse{
|
|
|
- // Code: 4,
|
|
|
- // Description: "unmarshal_model_mapping_failed",
|
|
|
- // }
|
|
|
- // }
|
|
|
- // if modelMap[imageModel] != "" {
|
|
|
- // imageModel = modelMap[imageModel]
|
|
|
- // isModelMapped = true
|
|
|
- // }
|
|
|
- //}
|
|
|
-
|
|
|
//baseURL := common.ChannelBaseURLs[channelType]
|
|
|
requestURL := c.Request.URL.String()
|
|
|
|
|
|
@@ -325,20 +331,6 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
|
|
|
|
|
|
fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
|
|
|
|
|
|
- var requestBody io.Reader
|
|
|
- //if isModelMapped {
|
|
|
- // jsonStr, err := json.Marshal(midjRequest)
|
|
|
- // if err != nil {
|
|
|
- // return &dto.MidjourneyResponse{
|
|
|
- // Code: 4,
|
|
|
- // Description: "marshal_text_request_failed",
|
|
|
- // }
|
|
|
- // }
|
|
|
- // requestBody = bytes.NewBuffer(jsonStr)
|
|
|
- //} else {
|
|
|
- //}
|
|
|
- requestBody = c.Request.Body
|
|
|
-
|
|
|
modelName := service.CoverActionToModelName(midjRequest.Action)
|
|
|
modelPrice := common.GetModelPrice(modelName, true)
|
|
|
// 如果没有配置价格,则使用默认价格
|
|
|
@@ -368,40 +360,11 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
|
|
|
- if err != nil {
|
|
|
- return &dto.MidjourneyResponse{
|
|
|
- Code: 4,
|
|
|
- Description: "create_request_failed",
|
|
|
- }
|
|
|
- }
|
|
|
- //req.Header.Set("ApiKey", c.Request.Header.Get("ApiKey"))
|
|
|
- timeout := time.Second * 30
|
|
|
- ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
|
|
- // 使用带有超时的 context 创建新的请求
|
|
|
- req = req.WithContext(ctx)
|
|
|
- req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
|
|
|
- req.Header.Set("Accept", c.Request.Header.Get("Accept"))
|
|
|
- req.Header.Set("mj-api-secret", strings.Split(c.Request.Header.Get("Authorization"), " ")[1])
|
|
|
- // print request header
|
|
|
- //log.Printf("request header: %s", req.Header)
|
|
|
- //log.Printf("request body: %s", midjRequest.Prompt)
|
|
|
-
|
|
|
- defer cancel()
|
|
|
- resp, err := service.GetHttpClient().Do(req)
|
|
|
- if err != nil {
|
|
|
- return service.MidjourneyErrorWrapper(constant.MjErrorUnknown, "do_request_failed")
|
|
|
- }
|
|
|
-
|
|
|
- err = req.Body.Close()
|
|
|
- if err != nil {
|
|
|
- return service.MidjourneyErrorWrapper(constant.MjErrorUnknown, "close_request_body_failed")
|
|
|
- }
|
|
|
- err = c.Request.Body.Close()
|
|
|
+ midjResponseWithStatus, responseBody, err := service.DoMidjourneyHttpRequest(c, time.Second*30, fullRequestURL, &midjRequest)
|
|
|
if err != nil {
|
|
|
- return service.MidjourneyErrorWrapper(constant.MjErrorUnknown, "close_request_body_failed")
|
|
|
+ return &midjResponseWithStatus.Response
|
|
|
}
|
|
|
- var midjResponse dto.MidjourneyResponse
|
|
|
+ midjResponse := &midjResponseWithStatus.Response
|
|
|
|
|
|
defer func(ctx context.Context) {
|
|
|
if consumeQuota {
|
|
|
@@ -424,25 +387,6 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
|
|
|
}
|
|
|
}(c.Request.Context())
|
|
|
|
|
|
- responseBody, err := io.ReadAll(resp.Body)
|
|
|
-
|
|
|
- if err != nil {
|
|
|
- return service.MidjourneyErrorWrapper(constant.MjErrorUnknown, "read_response_body_failed")
|
|
|
- }
|
|
|
- err = resp.Body.Close()
|
|
|
- if err != nil {
|
|
|
- return service.MidjourneyErrorWrapper(constant.MjErrorUnknown, "close_response_body_failed")
|
|
|
- }
|
|
|
- if resp.StatusCode != 200 {
|
|
|
- return service.MidjourneyErrorWrapper(constant.MjErrorUnknown, "unexpected_response_status")
|
|
|
- }
|
|
|
- err = json.Unmarshal(responseBody, &midjResponse)
|
|
|
- log.Printf("responseBody: %s", string(responseBody))
|
|
|
- log.Printf("midjResponse: %v", midjResponse)
|
|
|
- if err != nil {
|
|
|
- return service.MidjourneyErrorWrapper(constant.MjErrorUnknown, "unmarshal_response_body_failed")
|
|
|
- }
|
|
|
-
|
|
|
// 文档:https://github.com/novicezk/midjourney-proxy/blob/main/docs/api.md
|
|
|
//1-提交成功
|
|
|
// 21-任务已存在(处理中或者有结果了) {"code":21,"description":"任务已存在","result":"0741798445574458","properties":{"status":"SUCCESS","imageUrl":"https://xxxx"}}
|
|
|
@@ -494,7 +438,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
|
|
|
}
|
|
|
}
|
|
|
//修改返回值
|
|
|
- if midjRequest.Action != constant.MjActionInPaintPre {
|
|
|
+ if midjRequest.Action != constant.MjActionInPaint && midjRequest.Action != constant.MjActionCustomZoom {
|
|
|
newBody := strings.Replace(string(responseBody), `"code":21`, `"code":1`, -1)
|
|
|
responseBody = []byte(newBody)
|
|
|
}
|
|
|
@@ -514,21 +458,22 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
|
|
|
responseBody = []byte(newBody)
|
|
|
}
|
|
|
|
|
|
- resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
|
|
|
+ //resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
|
|
|
+ bodyReader := io.NopCloser(bytes.NewBuffer(responseBody))
|
|
|
|
|
|
- for k, v := range resp.Header {
|
|
|
- c.Writer.Header().Set(k, v[0])
|
|
|
- }
|
|
|
- c.Writer.WriteHeader(resp.StatusCode)
|
|
|
+ //for k, v := range resp.Header {
|
|
|
+ // c.Writer.Header().Set(k, v[0])
|
|
|
+ //}
|
|
|
+ c.Writer.WriteHeader(midjResponseWithStatus.StatusCode)
|
|
|
|
|
|
- _, err = io.Copy(c.Writer, resp.Body)
|
|
|
+ _, err = io.Copy(c.Writer, bodyReader)
|
|
|
if err != nil {
|
|
|
return &dto.MidjourneyResponse{
|
|
|
Code: 4,
|
|
|
Description: "copy_response_body_failed",
|
|
|
}
|
|
|
}
|
|
|
- err = resp.Body.Close()
|
|
|
+ err = bodyReader.Close()
|
|
|
if err != nil {
|
|
|
return &dto.MidjourneyResponse{
|
|
|
Code: 4,
|