Browse Source

feat: 兼容自定义变焦,完善modal操作

CaIon 2 years ago
parent
commit
d704902b70

+ 3 - 2
common/model-ratio.go

@@ -100,13 +100,14 @@ var DefaultModelPrice = map[string]float64{
 	"mj_variation":      0.1,
 	"mj_reroll":         0.1,
 	"mj_blend":          0.1,
-	"mj_inpaint":        0.1,
+	"mj_modal":          0.1,
 	"mj_zoom":           0.1,
 	"mj_shorten":        0.1,
 	"mj_high_variation": 0.1,
 	"mj_low_variation":  0.1,
 	"mj_pan":            0.1,
-	"mj_inpaint_pre":    0,
+	"mj_inpaint":        0,
+	"mj_custom_zoom":    0,
 	"mj_describe":       0.05,
 	"mj_upscale":        0.05,
 	"swap_face":         0.05,

+ 4 - 2
constant/midjourney.go

@@ -13,8 +13,9 @@ const (
 	MjActionVariation     = "VARIATION"
 	MjActionReRoll        = "REROLL"
 	MjActionInPaint       = "INPAINT"
-	MjActionInPaintPre    = "INPAINT_PRE"
+	MjActionModal         = "MODAL"
 	MjActionZoom          = "ZOOM"
+	MjActionCustomZoom    = "CUSTOM_ZOOM"
 	MjActionShorten       = "SHORTEN"
 	MjActionHighVariation = "HIGH_VARIATION"
 	MjActionLowVariation  = "LOW_VARIATION"
@@ -29,9 +30,10 @@ var MidjourneyModel2Action = map[string]string{
 	"mj_upscale":        MjActionUpscale,
 	"mj_variation":      MjActionVariation,
 	"mj_reroll":         MjActionReRoll,
+	"mj_modal":          MjActionModal,
 	"mj_inpaint":        MjActionInPaint,
-	"mj_inpaint_pre":    MjActionInPaintPre,
 	"mj_zoom":           MjActionZoom,
+	"mj_custom_zoom":    MjActionCustomZoom,
 	"mj_shorten":        MjActionShorten,
 	"mj_high_variation": MjActionHighVariation,
 	"mj_low_variation":  MjActionLowVariation,

+ 2 - 0
controller/relay.go

@@ -67,6 +67,8 @@ func RelayMidjourney(c *gin.Context) {
 		err = relay.RelayMidjourneyNotify(c)
 	case relayconstant.RelayModeMidjourneyTaskFetch, relayconstant.RelayModeMidjourneyTaskFetchByCondition:
 		err = relay.RelayMidjourneyTask(c, relayMode)
+	case relayconstant.RelayModeMidjourneyTaskImageSeed:
+		err = relay.RelayMidjourneyTaskImageSeed(c)
 	default:
 		err = relay.RelayMidjourneySubmit(c, relayMode)
 	}

+ 5 - 0
dto/midjourney.go

@@ -28,6 +28,11 @@ type MidjourneyResponse struct {
 	Result      string      `json:"result"`
 }
 
+type MidjourneyResponseWithStatusCode struct {
+	StatusCode int `json:"statusCode"`
+	Response   MidjourneyResponse
+}
+
 type MidjourneyDto struct {
 	MjId        string      `json:"id"`
 	Action      string      `json:"action"`

+ 2 - 1
middleware/distributor.go

@@ -48,7 +48,8 @@ func Distribute() func(c *gin.Context) {
 				relayMode := relayconstant.Path2RelayModeMidjourney(c.Request.URL.Path)
 				if relayMode == relayconstant.RelayModeMidjourneyTaskFetch ||
 					relayMode == relayconstant.RelayModeMidjourneyTaskFetchByCondition ||
-					relayMode == relayconstant.RelayModeMidjourneyNotify {
+					relayMode == relayconstant.RelayModeMidjourneyNotify ||
+					relayMode == relayconstant.RelayModeMidjourneyTaskImageSeed {
 					shouldSelectChannel = false
 				} else {
 					midjourneyRequest := dto.MidjourneyRequest{}

+ 3 - 0
relay/constant/relay_mode.go

@@ -17,6 +17,7 @@ const (
 	RelayModeMidjourneySimpleChange
 	RelayModeMidjourneyNotify
 	RelayModeMidjourneyTaskFetch
+	RelayModeMidjourneyTaskImageSeed
 	RelayModeMidjourneyTaskFetchByCondition
 	RelayModeAudioSpeech
 	RelayModeAudioTranscription
@@ -77,6 +78,8 @@ func Path2RelayModeMidjourney(path string) int {
 		relayMode = RelayModeMidjourneyChange
 	} else if strings.HasSuffix(path, "/fetch") {
 		relayMode = RelayModeMidjourneyTaskFetch
+	} else if strings.HasSuffix(path, "/image-seed") {
+		relayMode = RelayModeMidjourneyTaskImageSeed
 	} else if strings.HasSuffix(path, "/list-by-condition") {
 		relayMode = RelayModeMidjourneyTaskFetchByCondition
 	}

+ 42 - 97
relay/relay-mj.go

@@ -138,6 +138,31 @@ func coverMidjourneyTaskDto(c *gin.Context, originTask *model.Midjourney) (midjo
 	return
 }
 
+func RelayMidjourneyTaskImageSeed(c *gin.Context) *dto.MidjourneyResponse {
+	//taskId := c.Param("id")
+	//userId := c.GetInt("id")
+	//originTask := model.GetByMJId(userId, taskId)
+	//if originTask == nil {
+	//	return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_no_found")
+	//}
+	//channel, err := model.GetChannelById(originTask.ChannelId, false)
+	//if err != nil {
+	//	return service.MidjourneyErrorWrapper(constant.MjRequestError, "get_channel_info_failed")
+	//}
+	//if channel.Status != common.ChannelStatusEnabled {
+	//	return service.MidjourneyErrorWrapper(constant.MjRequestError, "该任务所属渠道已被禁用")
+	//}
+	//c.Set("channel_id", originTask.ChannelId)
+	//requestURL := c.Request.URL.String()
+	//fullRequestURL := fmt.Sprintf("%s%s", channel.GetBaseURL(), requestURL)
+	//req, err := http.NewRequest(c.Request.Method, fullRequestURL, c.Request.Body)
+	//if err != nil {
+	//	return service.MidjourneyErrorWrapper(constant.MjRequestError, "create_request_failed")
+	//}
+	log.Println("RelayMidjourneyTaskImageSeed")
+	return nil
+}
+
 func RelayMidjourneyTask(c *gin.Context, relayMode int) *dto.MidjourneyResponse {
 	userId := c.GetInt("id")
 	var err error
@@ -259,11 +284,11 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
 			mjId = params.TaskId
 			midjRequest.Action = params.Action
 		} else if relayMode == relayconstant.RelayModeMidjourneyModal {
-			if midjRequest.MaskBase64 == "" {
-				return service.MidjourneyErrorWrapper(constant.MjRequestError, "mask_base64_is_required")
-			}
+			//if midjRequest.MaskBase64 == "" {
+			//	return service.MidjourneyErrorWrapper(constant.MjRequestError, "mask_base64_is_required")
+			//}
 			mjId = midjRequest.TaskId
-			midjRequest.Action = constant.MjActionInPaint
+			midjRequest.Action = constant.MjActionModal
 		}
 
 		originTask := model.GetByMJId(userId, mjId)
@@ -293,29 +318,10 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
 		//}
 	}
 
-	if midjRequest.Action == constant.MjActionInPaintPre {
+	if midjRequest.Action == constant.MjActionInPaint || midjRequest.Action == constant.MjActionCustomZoom {
 		consumeQuota = false
 	}
 
-	// map model name
-	//modelMapping := c.GetString("model_mapping")
-	//isModelMapped := false
-	//if modelMapping != "" {
-	//	modelMap := make(map[string]string)
-	//	err := json.Unmarshal([]byte(modelMapping), &modelMap)
-	//	if err != nil {
-	//		//return errorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
-	//		return &dto.MidjourneyResponse{
-	//			Code:        4,
-	//			Description: "unmarshal_model_mapping_failed",
-	//		}
-	//	}
-	//	if modelMap[imageModel] != "" {
-	//		imageModel = modelMap[imageModel]
-	//		isModelMapped = true
-	//	}
-	//}
-
 	//baseURL := common.ChannelBaseURLs[channelType]
 	requestURL := c.Request.URL.String()
 
@@ -325,20 +331,6 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
 
 	fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
 
-	var requestBody io.Reader
-	//if isModelMapped {
-	//	jsonStr, err := json.Marshal(midjRequest)
-	//	if err != nil {
-	//		return &dto.MidjourneyResponse{
-	//			Code:        4,
-	//			Description: "marshal_text_request_failed",
-	//		}
-	//	}
-	//	requestBody = bytes.NewBuffer(jsonStr)
-	//} else {
-	//}
-	requestBody = c.Request.Body
-
 	modelName := service.CoverActionToModelName(midjRequest.Action)
 	modelPrice := common.GetModelPrice(modelName, true)
 	// 如果没有配置价格,则使用默认价格
@@ -368,40 +360,11 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
 		}
 	}
 
-	req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
-	if err != nil {
-		return &dto.MidjourneyResponse{
-			Code:        4,
-			Description: "create_request_failed",
-		}
-	}
-	//req.Header.Set("ApiKey", c.Request.Header.Get("ApiKey"))
-	timeout := time.Second * 30
-	ctx, cancel := context.WithTimeout(context.Background(), timeout)
-	// 使用带有超时的 context 创建新的请求
-	req = req.WithContext(ctx)
-	req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
-	req.Header.Set("Accept", c.Request.Header.Get("Accept"))
-	req.Header.Set("mj-api-secret", strings.Split(c.Request.Header.Get("Authorization"), " ")[1])
-	// print request header
-	//log.Printf("request header: %s", req.Header)
-	//log.Printf("request body: %s", midjRequest.Prompt)
-
-	defer cancel()
-	resp, err := service.GetHttpClient().Do(req)
-	if err != nil {
-		return service.MidjourneyErrorWrapper(constant.MjErrorUnknown, "do_request_failed")
-	}
-
-	err = req.Body.Close()
-	if err != nil {
-		return service.MidjourneyErrorWrapper(constant.MjErrorUnknown, "close_request_body_failed")
-	}
-	err = c.Request.Body.Close()
+	midjResponseWithStatus, responseBody, err := service.DoMidjourneyHttpRequest(c, time.Second*30, fullRequestURL, &midjRequest)
 	if err != nil {
-		return service.MidjourneyErrorWrapper(constant.MjErrorUnknown, "close_request_body_failed")
+		return &midjResponseWithStatus.Response
 	}
-	var midjResponse dto.MidjourneyResponse
+	midjResponse := &midjResponseWithStatus.Response
 
 	defer func(ctx context.Context) {
 		if consumeQuota {
@@ -424,25 +387,6 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
 		}
 	}(c.Request.Context())
 
-	responseBody, err := io.ReadAll(resp.Body)
-
-	if err != nil {
-		return service.MidjourneyErrorWrapper(constant.MjErrorUnknown, "read_response_body_failed")
-	}
-	err = resp.Body.Close()
-	if err != nil {
-		return service.MidjourneyErrorWrapper(constant.MjErrorUnknown, "close_response_body_failed")
-	}
-	if resp.StatusCode != 200 {
-		return service.MidjourneyErrorWrapper(constant.MjErrorUnknown, "unexpected_response_status")
-	}
-	err = json.Unmarshal(responseBody, &midjResponse)
-	log.Printf("responseBody: %s", string(responseBody))
-	log.Printf("midjResponse: %v", midjResponse)
-	if err != nil {
-		return service.MidjourneyErrorWrapper(constant.MjErrorUnknown, "unmarshal_response_body_failed")
-	}
-
 	// 文档:https://github.com/novicezk/midjourney-proxy/blob/main/docs/api.md
 	//1-提交成功
 	// 21-任务已存在(处理中或者有结果了) {"code":21,"description":"任务已存在","result":"0741798445574458","properties":{"status":"SUCCESS","imageUrl":"https://xxxx"}}
@@ -494,7 +438,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
 			}
 		}
 		//修改返回值
-		if midjRequest.Action != constant.MjActionInPaintPre {
+		if midjRequest.Action != constant.MjActionInPaint && midjRequest.Action != constant.MjActionCustomZoom {
 			newBody := strings.Replace(string(responseBody), `"code":21`, `"code":1`, -1)
 			responseBody = []byte(newBody)
 		}
@@ -514,21 +458,22 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
 		responseBody = []byte(newBody)
 	}
 
-	resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
+	//resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
+	bodyReader := io.NopCloser(bytes.NewBuffer(responseBody))
 
-	for k, v := range resp.Header {
-		c.Writer.Header().Set(k, v[0])
-	}
-	c.Writer.WriteHeader(resp.StatusCode)
+	//for k, v := range resp.Header {
+	//	c.Writer.Header().Set(k, v[0])
+	//}
+	c.Writer.WriteHeader(midjResponseWithStatus.StatusCode)
 
-	_, err = io.Copy(c.Writer, resp.Body)
+	_, err = io.Copy(c.Writer, bodyReader)
 	if err != nil {
 		return &dto.MidjourneyResponse{
 			Code:        4,
 			Description: "copy_response_body_failed",
 		}
 	}
-	err = resp.Body.Close()
+	err = bodyReader.Close()
 	if err != nil {
 		return &dto.MidjourneyResponse{
 			Code:        4,

+ 1 - 0
router/relay-router.go

@@ -57,6 +57,7 @@ func SetRelayRouter(router *gin.Engine) {
 		relayMjRouter.POST("/submit/blend", controller.RelayMidjourney)
 		relayMjRouter.POST("/notify", controller.RelayMidjourney)
 		relayMjRouter.GET("/task/:id/fetch", controller.RelayMidjourney)
+		relayMjRouter.GET("/task/:id/image-seed", controller.RelayMidjourney)
 		relayMjRouter.POST("/task/list-by-condition", controller.RelayMidjourney)
 	}
 	//relayMjRouter.Use()

+ 7 - 0
service/error.go

@@ -18,6 +18,13 @@ func MidjourneyErrorWrapper(code int, desc string) *dto.MidjourneyResponse {
 	}
 }
 
+func MidjourneyErrorWithStatusCodeWrapper(code int, desc string, statusCode int) *dto.MidjourneyResponseWithStatusCode {
+	return &dto.MidjourneyResponseWithStatusCode{
+		StatusCode: statusCode,
+		Response:   *MidjourneyErrorWrapper(code, desc),
+	}
+}
+
 // OpenAIErrorWrapper wraps an error into an OpenAIErrorWithStatusCode
 func OpenAIErrorWrapper(err error, code string, statusCode int) *dto.OpenAIErrorWithStatusCode {
 	text := err.Error()

+ 70 - 3
service/midjourney.go

@@ -1,11 +1,18 @@
 package service
 
 import (
+	"context"
+	"encoding/json"
+	"github.com/gin-gonic/gin"
+	"io"
+	"log"
+	"net/http"
 	"one-api/constant"
 	"one-api/dto"
 	relayconstant "one-api/relay/constant"
 	"strconv"
 	"strings"
+	"time"
 )
 
 func CoverActionToModelName(mjAction string) string {
@@ -35,7 +42,7 @@ func GetMjRequestModel(relayMode int, midjRequest *dto.MidjourneyRequest) (strin
 		case relayconstant.RelayModeMidjourneyChange:
 			action = midjRequest.Action
 		case relayconstant.RelayModeMidjourneyModal:
-			action = constant.MjActionInPaint
+			action = constant.MjActionModal
 		case relayconstant.RelayModeMidjourneySimpleChange:
 			params := ConvertSimpleChangeParams(midjRequest.Content)
 			if params == nil {
@@ -96,11 +103,14 @@ func CoverPlusActionToNormalAction(midjRequest *dto.MidjourneyRequest) *dto.Midj
 	} else if strings.Contains(action, "reroll") {
 		midjRequest.Action = constant.MjActionReRoll
 		midjRequest.Index = 1
-	} else if action == "Outpaint" || action == "CustomZoom" {
+	} else if action == "Outpaint" {
 		midjRequest.Action = constant.MjActionZoom
 		midjRequest.Index = 1
+	} else if action == "CustomZoom" {
+		midjRequest.Action = constant.MjActionCustomZoom
+		midjRequest.Index = 1
 	} else if action == "Inpaint" {
-		midjRequest.Action = constant.MjActionInPaintPre
+		midjRequest.Action = constant.MjActionInPaint
 		midjRequest.Index = 1
 	} else {
 		return MidjourneyErrorWrapper(constant.MjRequestError, "unknown_action:"+customId)
@@ -136,3 +146,60 @@ func ConvertSimpleChangeParams(content string) *dto.MidjourneyRequest {
 	changeParams.Index = index
 	return changeParams
 }
+
+func DoMidjourneyHttpRequest(c *gin.Context, timeout time.Duration, fullRequestURL string, midjRequest *dto.MidjourneyRequest) (*dto.MidjourneyResponseWithStatusCode, []byte, error) {
+	var nullBytes []byte
+	var requestBody io.Reader
+	requestBody = c.Request.Body
+	req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
+	if err != nil {
+		return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "create_request_failed", http.StatusInternalServerError), nullBytes, err
+	}
+	ctx, cancel := context.WithTimeout(context.Background(), timeout)
+	// 使用带有超时的 context 创建新的请求
+	req = req.WithContext(ctx)
+	req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
+	req.Header.Set("Accept", c.Request.Header.Get("Accept"))
+	req.Header.Set("mj-api-secret", strings.Split(c.Request.Header.Get("Authorization"), " ")[1])
+	defer cancel()
+	resp, err := GetHttpClient().Do(req)
+	if err != nil {
+		return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "do_request_failed", http.StatusInternalServerError), nullBytes, err
+	}
+	statusCode := resp.StatusCode
+	//if statusCode != 200 {
+	//	return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "bad_response_status_code", statusCode), nullBytes, nil
+	//}
+	err = req.Body.Close()
+	if err != nil {
+		return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "close_request_body_failed", statusCode), nullBytes, err
+	}
+	err = c.Request.Body.Close()
+	if err != nil {
+		return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "close_request_body_failed", statusCode), nullBytes, err
+	}
+	var midjResponse dto.MidjourneyResponse
+
+	responseBody, err := io.ReadAll(resp.Body)
+	if err != nil {
+		return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "read_response_body_failed", statusCode), nullBytes, err
+	}
+	err = resp.Body.Close()
+	if err != nil {
+		return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "close_response_body_failed", statusCode), responseBody, err
+	}
+
+	err = json.Unmarshal(responseBody, &midjResponse)
+	log.Printf("responseBody: %s", string(responseBody))
+	if err != nil {
+		return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "unmarshal_response_body_failed", statusCode), responseBody, err
+	}
+	//log.Printf("midjResponse: %v", midjResponse)
+	//for k, v := range resp.Header {
+	//	c.Writer.Header().Set(k, v[0])
+	//}
+	return &dto.MidjourneyResponseWithStatusCode{
+		StatusCode: statusCode,
+		Response:   midjResponse,
+	}, responseBody, nil
+}

+ 6 - 4
web/src/components/MjLogsTable.js

@@ -46,11 +46,13 @@ function renderType(type) {
         case 'REROLL':
             return <Tag color="indigo" size='large'>重绘</Tag>;
         case 'INPAINT':
-            return <Tag color="violet" size='large'>局部重绘</Tag>;
+            return <Tag color="violet" size='large'>局部重绘-提交</Tag>;
         case 'ZOOM':
             return <Tag color="teal" size='large'>变焦</Tag>;
-        case 'INPAINT_PRE':
-            return <Tag color="violet" size='large'>局部重绘-预处理</Tag>;
+        case 'CUSTOM_ZOOM':
+            return <Tag color="teal" size='large'>自定义变焦-提交</Tag>;
+        case 'MODAL':
+            return <Tag color="green" size='large'>窗口处理</Tag>;
         default:
             return <Tag color="white" size='large'>未知</Tag>;
     }
@@ -62,7 +64,7 @@ function renderCode(code) {
         case 1:
             return <Tag color="green" size='large'>已提交</Tag>;
         case 21:
-            return <Tag color="lime" size='large'>排队中</Tag>;
+            return <Tag color="lime" size='large'>等待中</Tag>;
         case 22:
             return <Tag color="orange" size='large'>重复提交</Tag>;
         default:

+ 2 - 1
web/src/pages/Channel/EditChannel.js

@@ -109,8 +109,9 @@ const EditChannel = (props) => {
                         'mj_describe',
                         'mj_zoom',
                         'mj_shorten',
-                        'mj_inpaint_pre',
+                        'mj_modal',
                         'mj_inpaint',
+                        'mj_custom_zoom',
                         'mj_high_variation',
                         'mj_low_variation',
                         'mj_pan',