Browse Source

feat: support InsightFace (close #60)

CaIon 2 years ago
parent
commit
9b5353a81a

+ 6 - 4
Midjourney.md

@@ -19,8 +19,9 @@
 
 - mj_zoom (比例变焦)
 - mj_shorten (提示词缩短)
-- mj_inpaint_pre (发起局部重绘,必须和mj_inpaint一同添加)
-- mj_inpaint (局部重绘提交,必须和mj_inpaint_pre一同添加)
+- mj_modal (窗口提交,局部重绘和自定义比例变焦必须和mj_modal一同添加)
+- mj_inpaint (局部重绘提交,必须和mj_modal一同添加)
+- mj_custom_zoom (自定义比例变焦,必须和mj_modal一同添加)
 - mj_high_variation (强变换)
 - mj_low_variation (弱变换)
 - mj_pan (平移)
@@ -32,13 +33,14 @@
   "mj_variation": 0.1,
   "mj_reroll": 0.1,
   "mj_blend": 0.1,
-  "mj_inpaint": 0.1,
+  "mj_modal": 0.1,
   "mj_zoom": 0.1,
   "mj_shorten": 0.1,
   "mj_high_variation": 0.1,
   "mj_low_variation": 0.1,
   "mj_pan": 0.1,
-  "mj_inpaint_pre": 0,
+  "mj_inpaint": 0,
+  "mj_custom_zoom": 0,
   "mj_describe": 0.05,
   "mj_upscale": 0.05,
   "swap_face": 0.05

+ 2 - 2
constant/midjourney.go

@@ -20,7 +20,7 @@ const (
 	MjActionHighVariation = "HIGH_VARIATION"
 	MjActionLowVariation  = "LOW_VARIATION"
 	MjActionPan           = "PAN"
-	SwapFace              = "SWAP_FACE"
+	MjActionSwapFace      = "SWAP_FACE"
 )
 
 var MidjourneyModel2Action = map[string]string{
@@ -38,5 +38,5 @@ var MidjourneyModel2Action = map[string]string{
 	"mj_high_variation": MjActionHighVariation,
 	"mj_low_variation":  MjActionLowVariation,
 	"mj_pan":            MjActionPan,
-	"swap_face":         SwapFace,
+	"swap_face":         MjActionSwapFace,
 }

+ 2 - 0
controller/relay.go

@@ -69,6 +69,8 @@ func RelayMidjourney(c *gin.Context) {
 		err = relay.RelayMidjourneyTask(c, relayMode)
 	case relayconstant.RelayModeMidjourneyTaskImageSeed:
 		err = relay.RelayMidjourneyTaskImageSeed(c)
+	case relayconstant.RelayModeSwapFace:
+		err = relay.RelaySwapFace(c)
 	default:
 		err = relay.RelayMidjourneySubmit(c, relayMode)
 	}

+ 5 - 0
dto/midjourney.go

@@ -7,6 +7,11 @@ package dto
 //	Content  string `json:"content"`
 //}
 
+type SwapFaceRequest struct {
+	SourceBase64 string `json:"sourceBase64"`
+	TargetBase64 string `json:"targetBase64"`
+}
+
 type MidjourneyRequest struct {
 	Prompt      string   `json:"prompt"`
 	CustomId    string   `json:"customId"`

+ 4 - 0
relay/constant/relay_mode.go

@@ -25,6 +25,7 @@ const (
 	RelayModeMidjourneyAction
 	RelayModeMidjourneyModal
 	RelayModeMidjourneyShorten
+	RelayModeSwapFace
 )
 
 func Path2RelayMode(path string) int {
@@ -64,6 +65,9 @@ func Path2RelayModeMidjourney(path string) int {
 	} else if strings.HasPrefix(path, "/mj/submit/shorten") {
 		// midjourney plus
 		relayMode = RelayModeMidjourneyShorten
+	} else if strings.HasPrefix(path, "/mj/insight-face/swap") {
+		// midjourney plus
+		relayMode = RelayModeSwapFace
 	} else if strings.HasPrefix(path, "/mj/submit/imagine") {
 		relayMode = RelayModeMidjourneyImagine
 	} else if strings.HasPrefix(path, "/mj/submit/blend") {

+ 126 - 3
relay/relay-mj.go

@@ -138,6 +138,111 @@ func coverMidjourneyTaskDto(c *gin.Context, originTask *model.Midjourney) (midjo
 	return
 }
 
+func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse {
+	startTime := time.Now().UnixNano() / int64(time.Millisecond)
+	tokenId := c.GetInt("token_id")
+	userId := c.GetInt("id")
+	group := c.GetString("group")
+	channelId := c.GetInt("channel_id")
+	var swapFaceRequest dto.SwapFaceRequest
+	err := common.UnmarshalBodyReusable(c, &swapFaceRequest)
+	if err != nil {
+		return service.MidjourneyErrorWrapper(constant.MjRequestError, "bind_request_body_failed")
+	}
+	if swapFaceRequest.SourceBase64 == "" || swapFaceRequest.TargetBase64 == "" {
+		return service.MidjourneyErrorWrapper(constant.MjRequestError, "sour_base64_and_target_base64_is_required")
+	}
+	modelName := service.CoverActionToModelName(constant.MjActionSwapFace)
+	modelPrice := common.GetModelPrice(modelName, true)
+	// 如果没有配置价格,则使用默认价格
+	if modelPrice == -1 {
+		defaultPrice, ok := common.DefaultModelPrice[modelName]
+		if !ok {
+			modelPrice = 0.1
+		} else {
+			modelPrice = defaultPrice
+		}
+	}
+	groupRatio := common.GetGroupRatio(group)
+	ratio := modelPrice * groupRatio
+	userQuota, err := model.CacheGetUserQuota(userId)
+	if err != nil {
+		return &dto.MidjourneyResponse{
+			Code:        4,
+			Description: err.Error(),
+		}
+	}
+	quota := int(ratio * common.QuotaPerUnit)
+
+	if userQuota-quota < 0 {
+		return &dto.MidjourneyResponse{
+			Code:        4,
+			Description: "quota_not_enough",
+		}
+	}
+	requestURL := c.Request.URL.String()
+	baseURL := c.GetString("base_url")
+	fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
+	mjResp, _, err := service.DoMidjourneyHttpRequest(c, time.Second*120, fullRequestURL)
+	if err != nil {
+		return &mjResp.Response
+	}
+	defer func(ctx context.Context) {
+		if mjResp.StatusCode == 200 && mjResp.Response.Code == 1 {
+			err := model.PostConsumeTokenQuota(tokenId, userQuota, quota, 0, true)
+			if err != nil {
+				common.SysError("error consuming token remain quota: " + err.Error())
+			}
+			err = model.CacheUpdateUserQuota(userId)
+			if err != nil {
+				common.SysError("error update user quota cache: " + err.Error())
+			}
+			if quota != 0 {
+				tokenName := c.GetString("token_name")
+				logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s", modelPrice, groupRatio, constant.MjActionSwapFace)
+				model.RecordConsumeLog(ctx, userId, channelId, 0, 0, modelName, tokenName, quota, logContent, tokenId, userQuota, 0, false)
+				model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
+				channelId := c.GetInt("channel_id")
+				model.UpdateChannelUsedQuota(channelId, quota)
+			}
+		}
+	}(c.Request.Context())
+	midjResponse := &mjResp.Response
+	midjourneyTask := &model.Midjourney{
+		UserId:      userId,
+		Code:        midjResponse.Code,
+		Action:      constant.MjActionSwapFace,
+		MjId:        midjResponse.Result,
+		Prompt:      "swap_face",
+		PromptEn:    "",
+		Description: midjResponse.Description,
+		State:       "",
+		SubmitTime:  startTime,
+		StartTime:   time.Now().UnixNano() / int64(time.Millisecond),
+		FinishTime:  0,
+		ImageUrl:    "",
+		Status:      "",
+		Progress:    "0%",
+		FailReason:  "",
+		ChannelId:   c.GetInt("channel_id"),
+		Quota:       quota,
+	}
+	err = midjourneyTask.Insert()
+	if err != nil {
+		return service.MidjourneyErrorWrapper(constant.MjRequestError, "insert_midjourney_task_failed")
+	}
+	c.Writer.WriteHeader(mjResp.StatusCode)
+	respBody, err := json.Marshal(midjResponse)
+	if err != nil {
+		return service.MidjourneyErrorWrapper(constant.MjRequestError, "unmarshal_response_body_failed")
+	}
+	_, err = io.Copy(c.Writer, bytes.NewBuffer(respBody))
+	if err != nil {
+		return service.MidjourneyErrorWrapper(constant.MjRequestError, "copy_response_body_failed")
+	}
+	return nil
+}
+
 func RelayMidjourneyTaskImageSeed(c *gin.Context) *dto.MidjourneyResponse {
 	taskId := c.Param("id")
 	userId := c.GetInt("id")
@@ -157,10 +262,28 @@ func RelayMidjourneyTaskImageSeed(c *gin.Context) *dto.MidjourneyResponse {
 
 	requestURL := c.Request.URL.String()
 	fullRequestURL := fmt.Sprintf("%s%s", channel.GetBaseURL(), requestURL)
-	midjResponseWithStatus, _, err := service.DoMidjourneyHttpRequest(c, time.Second*30, fullRequestURL, nil)
+	midjResponseWithStatus, _, err := service.DoMidjourneyHttpRequest(c, time.Second*30, fullRequestURL)
 	if err != nil {
 		return &midjResponseWithStatus.Response
 	}
+	//defer func(ctx context.Context) {
+	//	err := model.PostConsumeTokenQuota(tokenId, userQuota, quota, 0, true)
+	//	if err != nil {
+	//		common.SysError("error consuming token remain quota: " + err.Error())
+	//	}
+	//	err = model.CacheUpdateUserQuota(userId)
+	//	if err != nil {
+	//		common.SysError("error update user quota cache: " + err.Error())
+	//	}
+	//	if quota != 0 {
+	//		tokenName := c.GetString("token_name")
+	//		logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s", modelPrice, groupRatio, midjRequest.Action)
+	//		model.RecordConsumeLog(ctx, userId, channelId, 0, 0, modelName, tokenName, quota, logContent, tokenId, userQuota, 0, false)
+	//		model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
+	//		channelId := c.GetInt("channel_id")
+	//		model.UpdateChannelUsedQuota(channelId, quota)
+	//	}
+	//}(c.Request.Context())
 	midjResponse := &midjResponseWithStatus.Response
 	c.Writer.WriteHeader(midjResponseWithStatus.StatusCode)
 	respBody, err := json.Marshal(midjResponse)
@@ -372,14 +495,14 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
 		}
 	}
 
-	midjResponseWithStatus, responseBody, err := service.DoMidjourneyHttpRequest(c, time.Second*30, fullRequestURL, &midjRequest)
+	midjResponseWithStatus, responseBody, err := service.DoMidjourneyHttpRequest(c, time.Second*30, fullRequestURL)
 	if err != nil {
 		return &midjResponseWithStatus.Response
 	}
 	midjResponse := &midjResponseWithStatus.Response
 
 	defer func(ctx context.Context) {
-		if consumeQuota {
+		if consumeQuota && midjResponseWithStatus.StatusCode == 200 {
 			err := model.PostConsumeTokenQuota(tokenId, userQuota, quota, 0, true)
 			if err != nil {
 				common.SysError("error consuming token remain quota: " + err.Error())

+ 1 - 0
router/relay-router.go

@@ -59,6 +59,7 @@ func SetRelayRouter(router *gin.Engine) {
 		relayMjRouter.GET("/task/:id/fetch", controller.RelayMidjourney)
 		relayMjRouter.GET("/task/:id/image-seed", controller.RelayMidjourney)
 		relayMjRouter.POST("/task/list-by-condition", controller.RelayMidjourney)
+		relayMjRouter.POST("/insight-face/swap", controller.RelayMidjourney)
 	}
 	//relayMjRouter.Use()
 }

+ 23 - 4
service/midjourney.go

@@ -17,6 +17,9 @@ import (
 
 func CoverActionToModelName(mjAction string) string {
 	modelName := "mj_" + strings.ToLower(mjAction)
+	if mjAction == constant.MjActionSwapFace {
+		modelName = "swap_face"
+	}
 	return modelName
 }
 
@@ -43,6 +46,8 @@ func GetMjRequestModel(relayMode int, midjRequest *dto.MidjourneyRequest) (strin
 			action = midjRequest.Action
 		case relayconstant.RelayModeMidjourneyModal:
 			action = constant.MjActionModal
+		case relayconstant.RelayModeSwapFace:
+			action = constant.MjActionSwapFace
 		case relayconstant.RelayModeMidjourneySimpleChange:
 			params := ConvertSimpleChangeParams(midjRequest.Content)
 			if params == nil {
@@ -147,11 +152,25 @@ func ConvertSimpleChangeParams(content string) *dto.MidjourneyRequest {
 	return changeParams
 }
 
-func DoMidjourneyHttpRequest(c *gin.Context, timeout time.Duration, fullRequestURL string, midjRequest *dto.MidjourneyRequest) (*dto.MidjourneyResponseWithStatusCode, []byte, error) {
+func DoMidjourneyHttpRequest(c *gin.Context, timeout time.Duration, fullRequestURL string) (*dto.MidjourneyResponseWithStatusCode, []byte, error) {
 	var nullBytes []byte
-	var requestBody io.Reader
-	requestBody = c.Request.Body
-	req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
+	//var requestBody io.Reader
+	//requestBody = c.Request.Body
+	// read request body to json, delete accountFilter and notifyHook
+	var mapResult map[string]interface{}
+	err := json.NewDecoder(c.Request.Body).Decode(&mapResult)
+	if err != nil {
+		return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "read_request_body_failed", http.StatusInternalServerError), nullBytes, err
+	}
+	delete(mapResult, "accountFilter")
+	delete(mapResult, "notifyHook")
+	//req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
+	// make new request with mapResult
+	reqBody, err := json.Marshal(mapResult)
+	if err != nil {
+		return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "marshal_request_body_failed", http.StatusInternalServerError), nullBytes, err
+	}
+	req, err := http.NewRequest(c.Request.Method, fullRequestURL, strings.NewReader(string(reqBody)))
 	if err != nil {
 		return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "create_request_failed", http.StatusInternalServerError), nullBytes, err
 	}

+ 4 - 0
web/src/components/MjLogsTable.js

@@ -53,6 +53,8 @@ function renderType(type) {
             return <Tag color="teal" size='large'>自定义变焦-提交</Tag>;
         case 'MODAL':
             return <Tag color="green" size='large'>窗口处理</Tag>;
+        case 'SWAP_FACE':
+            return <Tag color="light-green" size='large'>换脸</Tag>;
         default:
             return <Tag color="white" size='large'>未知</Tag>;
     }
@@ -67,6 +69,8 @@ function renderCode(code) {
             return <Tag color="lime" size='large'>等待中</Tag>;
         case 22:
             return <Tag color="orange" size='large'>重复提交</Tag>;
+        case 0:
+            return <Tag color="yellow" size='large'>未提交</Tag>;
         default:
             return <Tag color="white" size='large'>未知</Tag>;
     }