|
|
@@ -138,6 +138,111 @@ func coverMidjourneyTaskDto(c *gin.Context, originTask *model.Midjourney) (midjo
|
|
|
return
|
|
|
}
|
|
|
|
|
|
+func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse {
|
|
|
+ startTime := time.Now().UnixNano() / int64(time.Millisecond)
|
|
|
+ tokenId := c.GetInt("token_id")
|
|
|
+ userId := c.GetInt("id")
|
|
|
+ group := c.GetString("group")
|
|
|
+ channelId := c.GetInt("channel_id")
|
|
|
+ var swapFaceRequest dto.SwapFaceRequest
|
|
|
+ err := common.UnmarshalBodyReusable(c, &swapFaceRequest)
|
|
|
+ if err != nil {
|
|
|
+ return service.MidjourneyErrorWrapper(constant.MjRequestError, "bind_request_body_failed")
|
|
|
+ }
|
|
|
+ if swapFaceRequest.SourceBase64 == "" || swapFaceRequest.TargetBase64 == "" {
|
|
|
+ return service.MidjourneyErrorWrapper(constant.MjRequestError, "sour_base64_and_target_base64_is_required")
|
|
|
+ }
|
|
|
+ modelName := service.CoverActionToModelName(constant.MjActionSwapFace)
|
|
|
+ modelPrice := common.GetModelPrice(modelName, true)
|
|
|
+ // 如果没有配置价格,则使用默认价格
|
|
|
+ if modelPrice == -1 {
|
|
|
+ defaultPrice, ok := common.DefaultModelPrice[modelName]
|
|
|
+ if !ok {
|
|
|
+ modelPrice = 0.1
|
|
|
+ } else {
|
|
|
+ modelPrice = defaultPrice
|
|
|
+ }
|
|
|
+ }
|
|
|
+ groupRatio := common.GetGroupRatio(group)
|
|
|
+ ratio := modelPrice * groupRatio
|
|
|
+ userQuota, err := model.CacheGetUserQuota(userId)
|
|
|
+ if err != nil {
|
|
|
+ return &dto.MidjourneyResponse{
|
|
|
+ Code: 4,
|
|
|
+ Description: err.Error(),
|
|
|
+ }
|
|
|
+ }
|
|
|
+ quota := int(ratio * common.QuotaPerUnit)
|
|
|
+
|
|
|
+ if userQuota-quota < 0 {
|
|
|
+ return &dto.MidjourneyResponse{
|
|
|
+ Code: 4,
|
|
|
+ Description: "quota_not_enough",
|
|
|
+ }
|
|
|
+ }
|
|
|
+ requestURL := c.Request.URL.String()
|
|
|
+ baseURL := c.GetString("base_url")
|
|
|
+ fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
|
|
|
+ mjResp, _, err := service.DoMidjourneyHttpRequest(c, time.Second*120, fullRequestURL)
|
|
|
+ if err != nil {
|
|
|
+ return &mjResp.Response
|
|
|
+ }
|
|
|
+ defer func(ctx context.Context) {
|
|
|
+ if mjResp.StatusCode == 200 && mjResp.Response.Code == 1 {
|
|
|
+ err := model.PostConsumeTokenQuota(tokenId, userQuota, quota, 0, true)
|
|
|
+ if err != nil {
|
|
|
+ common.SysError("error consuming token remain quota: " + err.Error())
|
|
|
+ }
|
|
|
+ err = model.CacheUpdateUserQuota(userId)
|
|
|
+ if err != nil {
|
|
|
+ common.SysError("error update user quota cache: " + err.Error())
|
|
|
+ }
|
|
|
+ if quota != 0 {
|
|
|
+ tokenName := c.GetString("token_name")
|
|
|
+ logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s", modelPrice, groupRatio, constant.MjActionSwapFace)
|
|
|
+ model.RecordConsumeLog(ctx, userId, channelId, 0, 0, modelName, tokenName, quota, logContent, tokenId, userQuota, 0, false)
|
|
|
+ model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
|
|
|
+ channelId := c.GetInt("channel_id")
|
|
|
+ model.UpdateChannelUsedQuota(channelId, quota)
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }(c.Request.Context())
|
|
|
+ midjResponse := &mjResp.Response
|
|
|
+ midjourneyTask := &model.Midjourney{
|
|
|
+ UserId: userId,
|
|
|
+ Code: midjResponse.Code,
|
|
|
+ Action: constant.MjActionSwapFace,
|
|
|
+ MjId: midjResponse.Result,
|
|
|
+ Prompt: "swap_face",
|
|
|
+ PromptEn: "",
|
|
|
+ Description: midjResponse.Description,
|
|
|
+ State: "",
|
|
|
+ SubmitTime: startTime,
|
|
|
+ StartTime: time.Now().UnixNano() / int64(time.Millisecond),
|
|
|
+ FinishTime: 0,
|
|
|
+ ImageUrl: "",
|
|
|
+ Status: "",
|
|
|
+ Progress: "0%",
|
|
|
+ FailReason: "",
|
|
|
+ ChannelId: c.GetInt("channel_id"),
|
|
|
+ Quota: quota,
|
|
|
+ }
|
|
|
+ err = midjourneyTask.Insert()
|
|
|
+ if err != nil {
|
|
|
+ return service.MidjourneyErrorWrapper(constant.MjRequestError, "insert_midjourney_task_failed")
|
|
|
+ }
|
|
|
+ c.Writer.WriteHeader(mjResp.StatusCode)
|
|
|
+ respBody, err := json.Marshal(midjResponse)
|
|
|
+ if err != nil {
|
|
|
+ return service.MidjourneyErrorWrapper(constant.MjRequestError, "unmarshal_response_body_failed")
|
|
|
+ }
|
|
|
+ _, err = io.Copy(c.Writer, bytes.NewBuffer(respBody))
|
|
|
+ if err != nil {
|
|
|
+ return service.MidjourneyErrorWrapper(constant.MjRequestError, "copy_response_body_failed")
|
|
|
+ }
|
|
|
+ return nil
|
|
|
+}
|
|
|
+
|
|
|
func RelayMidjourneyTaskImageSeed(c *gin.Context) *dto.MidjourneyResponse {
|
|
|
taskId := c.Param("id")
|
|
|
userId := c.GetInt("id")
|
|
|
@@ -157,10 +262,28 @@ func RelayMidjourneyTaskImageSeed(c *gin.Context) *dto.MidjourneyResponse {
|
|
|
|
|
|
requestURL := c.Request.URL.String()
|
|
|
fullRequestURL := fmt.Sprintf("%s%s", channel.GetBaseURL(), requestURL)
|
|
|
- midjResponseWithStatus, _, err := service.DoMidjourneyHttpRequest(c, time.Second*30, fullRequestURL, nil)
|
|
|
+ midjResponseWithStatus, _, err := service.DoMidjourneyHttpRequest(c, time.Second*30, fullRequestURL)
|
|
|
if err != nil {
|
|
|
return &midjResponseWithStatus.Response
|
|
|
}
|
|
|
+ //defer func(ctx context.Context) {
|
|
|
+ // err := model.PostConsumeTokenQuota(tokenId, userQuota, quota, 0, true)
|
|
|
+ // if err != nil {
|
|
|
+ // common.SysError("error consuming token remain quota: " + err.Error())
|
|
|
+ // }
|
|
|
+ // err = model.CacheUpdateUserQuota(userId)
|
|
|
+ // if err != nil {
|
|
|
+ // common.SysError("error update user quota cache: " + err.Error())
|
|
|
+ // }
|
|
|
+ // if quota != 0 {
|
|
|
+ // tokenName := c.GetString("token_name")
|
|
|
+ // logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s", modelPrice, groupRatio, midjRequest.Action)
|
|
|
+ // model.RecordConsumeLog(ctx, userId, channelId, 0, 0, modelName, tokenName, quota, logContent, tokenId, userQuota, 0, false)
|
|
|
+ // model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
|
|
|
+ // channelId := c.GetInt("channel_id")
|
|
|
+ // model.UpdateChannelUsedQuota(channelId, quota)
|
|
|
+ // }
|
|
|
+ //}(c.Request.Context())
|
|
|
midjResponse := &midjResponseWithStatus.Response
|
|
|
c.Writer.WriteHeader(midjResponseWithStatus.StatusCode)
|
|
|
respBody, err := json.Marshal(midjResponse)
|
|
|
@@ -372,14 +495,14 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- midjResponseWithStatus, responseBody, err := service.DoMidjourneyHttpRequest(c, time.Second*30, fullRequestURL, &midjRequest)
|
|
|
+ midjResponseWithStatus, responseBody, err := service.DoMidjourneyHttpRequest(c, time.Second*30, fullRequestURL)
|
|
|
if err != nil {
|
|
|
return &midjResponseWithStatus.Response
|
|
|
}
|
|
|
midjResponse := &midjResponseWithStatus.Response
|
|
|
|
|
|
defer func(ctx context.Context) {
|
|
|
- if consumeQuota {
|
|
|
+ if consumeQuota && midjResponseWithStatus.StatusCode == 200 {
|
|
|
err := model.PostConsumeTokenQuota(tokenId, userQuota, quota, 0, true)
|
|
|
if err != nil {
|
|
|
common.SysError("error consuming token remain quota: " + err.Error())
|