|
@@ -21,23 +21,6 @@ import (
|
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/gin-gonic/gin"
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
-var DefaultModelPrice = map[string]float64{
|
|
|
|
|
- "mj_imagine": 0.1,
|
|
|
|
|
- "mj_variation": 0.1,
|
|
|
|
|
- "mj_reroll": 0.1,
|
|
|
|
|
- "mj_blend": 0.1,
|
|
|
|
|
- "mj_inpaint": 0.1,
|
|
|
|
|
- "mj_zoom": 0.1,
|
|
|
|
|
- "mj_shorten": 0.1,
|
|
|
|
|
- "mj_high_variation": 0.1,
|
|
|
|
|
- "mj_low_variation": 0.1,
|
|
|
|
|
- "mj_pan": 0.1,
|
|
|
|
|
- "mj_inpaint_pre": 0,
|
|
|
|
|
- "mj_describe": 0.05,
|
|
|
|
|
- "mj_upscale": 0.05,
|
|
|
|
|
- "swap_face": 0.05,
|
|
|
|
|
-}
|
|
|
|
|
-
|
|
|
|
|
func RelayMidjourneyImage(c *gin.Context) {
|
|
func RelayMidjourneyImage(c *gin.Context) {
|
|
|
taskId := c.Param("id")
|
|
taskId := c.Param("id")
|
|
|
midjourneyTask := model.GetByOnlyMJId(taskId)
|
|
midjourneyTask := model.GetByOnlyMJId(taskId)
|
|
@@ -221,10 +204,9 @@ func RelayMidjourneyTask(c *gin.Context, relayMode int) *dto.MidjourneyResponse
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyResponse {
|
|
func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyResponse {
|
|
|
- imageModel := "midjourney"
|
|
|
|
|
|
|
|
|
|
tokenId := c.GetInt("token_id")
|
|
tokenId := c.GetInt("token_id")
|
|
|
- channelType := c.GetInt("channel")
|
|
|
|
|
|
|
+ //channelType := c.GetInt("channel")
|
|
|
userId := c.GetInt("id")
|
|
userId := c.GetInt("id")
|
|
|
group := c.GetString("group")
|
|
group := c.GetString("group")
|
|
|
channelId := c.GetInt("channel_id")
|
|
channelId := c.GetInt("channel_id")
|
|
@@ -236,7 +218,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
if relayMode == relayconstant.RelayModeMidjourneyAction { // midjourney plus,需要从customId中获取任务信息
|
|
if relayMode == relayconstant.RelayModeMidjourneyAction { // midjourney plus,需要从customId中获取任务信息
|
|
|
- mjErr := coverPlusActionToNormalAction(&midjRequest)
|
|
|
|
|
|
|
+ mjErr := service.CoverPlusActionToNormalAction(&midjRequest)
|
|
|
if mjErr != nil {
|
|
if mjErr != nil {
|
|
|
return mjErr
|
|
return mjErr
|
|
|
}
|
|
}
|
|
@@ -270,11 +252,11 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
|
|
|
if midjRequest.Content == "" {
|
|
if midjRequest.Content == "" {
|
|
|
return service.MidjourneyErrorWrapper(constant.MjRequestError, "content_is_required")
|
|
return service.MidjourneyErrorWrapper(constant.MjRequestError, "content_is_required")
|
|
|
}
|
|
}
|
|
|
- params := convertSimpleChangeParams(midjRequest.Content)
|
|
|
|
|
|
|
+ params := service.ConvertSimpleChangeParams(midjRequest.Content)
|
|
|
if params == nil {
|
|
if params == nil {
|
|
|
return service.MidjourneyErrorWrapper(constant.MjRequestError, "content_parse_failed")
|
|
return service.MidjourneyErrorWrapper(constant.MjRequestError, "content_parse_failed")
|
|
|
}
|
|
}
|
|
|
- mjId = params.ID
|
|
|
|
|
|
|
+ mjId = params.TaskId
|
|
|
midjRequest.Action = params.Action
|
|
midjRequest.Action = params.Action
|
|
|
} else if relayMode == relayconstant.RelayModeMidjourneyModal {
|
|
} else if relayMode == relayconstant.RelayModeMidjourneyModal {
|
|
|
if midjRequest.MaskBase64 == "" {
|
|
if midjRequest.MaskBase64 == "" {
|
|
@@ -294,18 +276,21 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
|
|
|
if err != nil {
|
|
if err != nil {
|
|
|
return service.MidjourneyErrorWrapper(constant.MjRequestError, "get_channel_info_failed")
|
|
return service.MidjourneyErrorWrapper(constant.MjRequestError, "get_channel_info_failed")
|
|
|
}
|
|
}
|
|
|
|
|
+ if channel.Status != common.ChannelStatusEnabled {
|
|
|
|
|
+ return service.MidjourneyErrorWrapper(constant.MjRequestError, "该任务所属渠道已被禁用")
|
|
|
|
|
+ }
|
|
|
c.Set("base_url", channel.GetBaseURL())
|
|
c.Set("base_url", channel.GetBaseURL())
|
|
|
c.Set("channel_id", originTask.ChannelId)
|
|
c.Set("channel_id", originTask.ChannelId)
|
|
|
log.Printf("检测到此操作为放大、变换、重绘,获取原channel信息: %s,%s", strconv.Itoa(originTask.ChannelId), channel.GetBaseURL())
|
|
log.Printf("检测到此操作为放大、变换、重绘,获取原channel信息: %s,%s", strconv.Itoa(originTask.ChannelId), channel.GetBaseURL())
|
|
|
}
|
|
}
|
|
|
midjRequest.Prompt = originTask.Prompt
|
|
midjRequest.Prompt = originTask.Prompt
|
|
|
|
|
|
|
|
- if channelType == common.ChannelTypeMidjourneyPlus {
|
|
|
|
|
- // plus
|
|
|
|
|
- } else {
|
|
|
|
|
- // 普通版渠道
|
|
|
|
|
-
|
|
|
|
|
- }
|
|
|
|
|
|
|
+ //if channelType == common.ChannelTypeMidjourneyPlus {
|
|
|
|
|
+ // // plus
|
|
|
|
|
+ //} else {
|
|
|
|
|
+ // // 普通版渠道
|
|
|
|
|
+ //
|
|
|
|
|
+ //}
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
if midjRequest.Action == constant.MjActionInPaintPre {
|
|
if midjRequest.Action == constant.MjActionInPaintPre {
|
|
@@ -313,54 +298,52 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
// map model name
|
|
// map model name
|
|
|
- modelMapping := c.GetString("model_mapping")
|
|
|
|
|
- isModelMapped := false
|
|
|
|
|
- if modelMapping != "" {
|
|
|
|
|
- modelMap := make(map[string]string)
|
|
|
|
|
- err := json.Unmarshal([]byte(modelMapping), &modelMap)
|
|
|
|
|
- if err != nil {
|
|
|
|
|
- //return errorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
|
|
|
|
|
- return &dto.MidjourneyResponse{
|
|
|
|
|
- Code: 4,
|
|
|
|
|
- Description: "unmarshal_model_mapping_failed",
|
|
|
|
|
- }
|
|
|
|
|
- }
|
|
|
|
|
- if modelMap[imageModel] != "" {
|
|
|
|
|
- imageModel = modelMap[imageModel]
|
|
|
|
|
- isModelMapped = true
|
|
|
|
|
- }
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- baseURL := common.ChannelBaseURLs[channelType]
|
|
|
|
|
|
|
+ //modelMapping := c.GetString("model_mapping")
|
|
|
|
|
+ //isModelMapped := false
|
|
|
|
|
+ //if modelMapping != "" {
|
|
|
|
|
+ // modelMap := make(map[string]string)
|
|
|
|
|
+ // err := json.Unmarshal([]byte(modelMapping), &modelMap)
|
|
|
|
|
+ // if err != nil {
|
|
|
|
|
+ // //return errorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
|
|
|
|
|
+ // return &dto.MidjourneyResponse{
|
|
|
|
|
+ // Code: 4,
|
|
|
|
|
+ // Description: "unmarshal_model_mapping_failed",
|
|
|
|
|
+ // }
|
|
|
|
|
+ // }
|
|
|
|
|
+ // if modelMap[imageModel] != "" {
|
|
|
|
|
+ // imageModel = modelMap[imageModel]
|
|
|
|
|
+ // isModelMapped = true
|
|
|
|
|
+ // }
|
|
|
|
|
+ //}
|
|
|
|
|
+
|
|
|
|
|
+ //baseURL := common.ChannelBaseURLs[channelType]
|
|
|
requestURL := c.Request.URL.String()
|
|
requestURL := c.Request.URL.String()
|
|
|
|
|
|
|
|
- if c.GetString("base_url") != "" {
|
|
|
|
|
- baseURL = c.GetString("base_url")
|
|
|
|
|
- }
|
|
|
|
|
|
|
+ baseURL := c.GetString("base_url")
|
|
|
|
|
|
|
|
//midjRequest.NotifyHook = "http://127.0.0.1:3000/mj/notify"
|
|
//midjRequest.NotifyHook = "http://127.0.0.1:3000/mj/notify"
|
|
|
|
|
|
|
|
fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
|
|
fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
|
|
|
|
|
|
|
|
var requestBody io.Reader
|
|
var requestBody io.Reader
|
|
|
- if isModelMapped {
|
|
|
|
|
- jsonStr, err := json.Marshal(midjRequest)
|
|
|
|
|
- if err != nil {
|
|
|
|
|
- return &dto.MidjourneyResponse{
|
|
|
|
|
- Code: 4,
|
|
|
|
|
- Description: "marshal_text_request_failed",
|
|
|
|
|
- }
|
|
|
|
|
- }
|
|
|
|
|
- requestBody = bytes.NewBuffer(jsonStr)
|
|
|
|
|
- } else {
|
|
|
|
|
- requestBody = c.Request.Body
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- mjAction := "mj_" + strings.ToLower(midjRequest.Action)
|
|
|
|
|
- modelPrice := common.GetModelPrice(mjAction, true)
|
|
|
|
|
|
|
+ //if isModelMapped {
|
|
|
|
|
+ // jsonStr, err := json.Marshal(midjRequest)
|
|
|
|
|
+ // if err != nil {
|
|
|
|
|
+ // return &dto.MidjourneyResponse{
|
|
|
|
|
+ // Code: 4,
|
|
|
|
|
+ // Description: "marshal_text_request_failed",
|
|
|
|
|
+ // }
|
|
|
|
|
+ // }
|
|
|
|
|
+ // requestBody = bytes.NewBuffer(jsonStr)
|
|
|
|
|
+ //} else {
|
|
|
|
|
+ //}
|
|
|
|
|
+ requestBody = c.Request.Body
|
|
|
|
|
+
|
|
|
|
|
+ modelName := service.CoverActionToModelName(midjRequest.Action)
|
|
|
|
|
+ modelPrice := common.GetModelPrice(modelName, true)
|
|
|
// 如果没有配置价格,则使用默认价格
|
|
// 如果没有配置价格,则使用默认价格
|
|
|
if modelPrice == -1 {
|
|
if modelPrice == -1 {
|
|
|
- defaultPrice, ok := DefaultModelPrice[mjAction]
|
|
|
|
|
|
|
+ defaultPrice, ok := common.DefaultModelPrice[modelName]
|
|
|
if !ok {
|
|
if !ok {
|
|
|
modelPrice = 0.1
|
|
modelPrice = 0.1
|
|
|
} else {
|
|
} else {
|
|
@@ -433,7 +416,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
|
|
|
if quota != 0 {
|
|
if quota != 0 {
|
|
|
tokenName := c.GetString("token_name")
|
|
tokenName := c.GetString("token_name")
|
|
|
logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s", modelPrice, groupRatio, midjRequest.Action)
|
|
logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s", modelPrice, groupRatio, midjRequest.Action)
|
|
|
- model.RecordConsumeLog(ctx, userId, channelId, 0, 0, imageModel, tokenName, quota, logContent, tokenId, userQuota, 0, false)
|
|
|
|
|
|
|
+ model.RecordConsumeLog(ctx, userId, channelId, 0, 0, modelName, tokenName, quota, logContent, tokenId, userQuota, 0, false)
|
|
|
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
|
|
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
|
|
|
channelId := c.GetInt("channel_id")
|
|
channelId := c.GetInt("channel_id")
|
|
|
model.UpdateChannelUsedQuota(channelId, quota)
|
|
model.UpdateChannelUsedQuota(channelId, quota)
|
|
@@ -558,85 +541,3 @@ type taskChangeParams struct {
|
|
|
Action string
|
|
Action string
|
|
|
Index int
|
|
Index int
|
|
|
}
|
|
}
|
|
|
-
|
|
|
|
|
-func convertSimpleChangeParams(content string) *taskChangeParams {
|
|
|
|
|
- split := strings.Split(content, " ")
|
|
|
|
|
- if len(split) != 2 {
|
|
|
|
|
- return nil
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- action := strings.ToLower(split[1])
|
|
|
|
|
- changeParams := &taskChangeParams{}
|
|
|
|
|
- changeParams.ID = split[0]
|
|
|
|
|
-
|
|
|
|
|
- if action[0] == 'u' {
|
|
|
|
|
- changeParams.Action = "UPSCALE"
|
|
|
|
|
- } else if action[0] == 'v' {
|
|
|
|
|
- changeParams.Action = "VARIATION"
|
|
|
|
|
- } else if action == "r" {
|
|
|
|
|
- changeParams.Action = "REROLL"
|
|
|
|
|
- return changeParams
|
|
|
|
|
- } else {
|
|
|
|
|
- return nil
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- index, err := strconv.Atoi(action[1:2])
|
|
|
|
|
- if err != nil || index < 1 || index > 4 {
|
|
|
|
|
- return nil
|
|
|
|
|
- }
|
|
|
|
|
- changeParams.Index = index
|
|
|
|
|
- return changeParams
|
|
|
|
|
-}
|
|
|
|
|
-
|
|
|
|
|
-func coverPlusActionToNormalAction(midjRequest *dto.MidjourneyRequest) *dto.MidjourneyResponse {
|
|
|
|
|
- // "customId": "MJ::JOB::upsample::2::3dbbd469-36af-4a0f-8f02-df6c579e7011"
|
|
|
|
|
- customId := midjRequest.CustomId
|
|
|
|
|
- if customId == "" {
|
|
|
|
|
- return service.MidjourneyErrorWrapper(constant.MjRequestError, "custom_id_is_required")
|
|
|
|
|
- }
|
|
|
|
|
- splits := strings.Split(customId, "::")
|
|
|
|
|
- var action string
|
|
|
|
|
- if splits[1] == "JOB" {
|
|
|
|
|
- action = splits[2]
|
|
|
|
|
- } else {
|
|
|
|
|
- action = splits[1]
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- if action == "" {
|
|
|
|
|
- return service.MidjourneyErrorWrapper(constant.MjRequestError, "unknown_action")
|
|
|
|
|
- }
|
|
|
|
|
- if strings.Contains(action, "upsample") {
|
|
|
|
|
- index, err := strconv.Atoi(splits[3])
|
|
|
|
|
- if err != nil {
|
|
|
|
|
- return service.MidjourneyErrorWrapper(constant.MjRequestError, "index_parse_failed")
|
|
|
|
|
- }
|
|
|
|
|
- midjRequest.Index = index
|
|
|
|
|
- midjRequest.Action = constant.MjActionUpscale
|
|
|
|
|
- } else if strings.Contains(action, "variation") {
|
|
|
|
|
- midjRequest.Index = 1
|
|
|
|
|
- if action == "variation" {
|
|
|
|
|
- index, err := strconv.Atoi(splits[3])
|
|
|
|
|
- if err != nil {
|
|
|
|
|
- return service.MidjourneyErrorWrapper(constant.MjRequestError, "index_parse_failed")
|
|
|
|
|
- }
|
|
|
|
|
- midjRequest.Index = index
|
|
|
|
|
- midjRequest.Action = constant.MjActionVariation
|
|
|
|
|
- } else if action == "low_variation" {
|
|
|
|
|
- midjRequest.Action = constant.MjActionLowVariation
|
|
|
|
|
- } else if action == "high_variation" {
|
|
|
|
|
- midjRequest.Action = constant.MjActionHighVariation
|
|
|
|
|
- }
|
|
|
|
|
- } else if strings.Contains(action, "pan") {
|
|
|
|
|
- midjRequest.Action = constant.MjActionPan
|
|
|
|
|
- midjRequest.Index = 1
|
|
|
|
|
- } else if action == "Outpaint" || action == "CustomZoom" {
|
|
|
|
|
- midjRequest.Action = constant.MjActionZoom
|
|
|
|
|
- midjRequest.Index = 1
|
|
|
|
|
- } else if action == "Inpaint" {
|
|
|
|
|
- midjRequest.Action = constant.MjActionInPaintPre
|
|
|
|
|
- midjRequest.Index = 1
|
|
|
|
|
- } else {
|
|
|
|
|
- return service.MidjourneyErrorWrapper(constant.MjRequestError, "unknown_action")
|
|
|
|
|
- }
|
|
|
|
|
- return nil
|
|
|
|
|
-}
|
|
|