Explorar o código

Merge remote-tracking branch 'origin/main'

CaIon hai 1 ano
pai
achega
d80a7d3c97
Modificáronse 4 ficheiros con 35 adicións e 15 borrados
  1. 1 1
      middleware/distributor.go
  2. 10 10
      relay/constant/relay_mode.go
  3. 15 3
      relay/relay-mj.go
  4. 9 1
      router/relay-router.go

+ 1 - 1
middleware/distributor.go

@@ -44,7 +44,7 @@ func Distribute() func(c *gin.Context) {
 			// Select a channel for the user
 			// Select a channel for the user
 			var modelRequest ModelRequest
 			var modelRequest ModelRequest
 			var err error
 			var err error
-			if strings.HasPrefix(c.Request.URL.Path, "/mj") {
+			if strings.Contains(c.Request.URL.Path, "/mj/") {
 				relayMode := relayconstant.Path2RelayModeMidjourney(c.Request.URL.Path)
 				relayMode := relayconstant.Path2RelayModeMidjourney(c.Request.URL.Path)
 				if relayMode == relayconstant.RelayModeMidjourneyTaskFetch ||
 				if relayMode == relayconstant.RelayModeMidjourneyTaskFetch ||
 					relayMode == relayconstant.RelayModeMidjourneyTaskFetchByCondition ||
 					relayMode == relayconstant.RelayModeMidjourneyTaskFetchByCondition ||

+ 10 - 10
relay/constant/relay_mode.go

@@ -56,29 +56,29 @@ func Path2RelayMode(path string) int {
 
 
 func Path2RelayModeMidjourney(path string) int {
 func Path2RelayModeMidjourney(path string) int {
 	relayMode := RelayModeUnknown
 	relayMode := RelayModeUnknown
-	if strings.HasPrefix(path, "/mj/submit/action") {
+	if strings.HasSuffix(path, "/mj/submit/action") {
 		// midjourney plus
 		// midjourney plus
 		relayMode = RelayModeMidjourneyAction
 		relayMode = RelayModeMidjourneyAction
-	} else if strings.HasPrefix(path, "/mj/submit/modal") {
+	} else if strings.HasSuffix(path, "/mj/submit/modal") {
 		// midjourney plus
 		// midjourney plus
 		relayMode = RelayModeMidjourneyModal
 		relayMode = RelayModeMidjourneyModal
-	} else if strings.HasPrefix(path, "/mj/submit/shorten") {
+	} else if strings.HasSuffix(path, "/mj/submit/shorten") {
 		// midjourney plus
 		// midjourney plus
 		relayMode = RelayModeMidjourneyShorten
 		relayMode = RelayModeMidjourneyShorten
-	} else if strings.HasPrefix(path, "/mj/insight-face/swap") {
+	} else if strings.HasSuffix(path, "/mj/insight-face/swap") {
 		// midjourney plus
 		// midjourney plus
 		relayMode = RelayModeSwapFace
 		relayMode = RelayModeSwapFace
-	} else if strings.HasPrefix(path, "/mj/submit/imagine") {
+	} else if strings.HasSuffix(path, "/mj/submit/imagine") {
 		relayMode = RelayModeMidjourneyImagine
 		relayMode = RelayModeMidjourneyImagine
-	} else if strings.HasPrefix(path, "/mj/submit/blend") {
+	} else if strings.HasSuffix(path, "/mj/submit/blend") {
 		relayMode = RelayModeMidjourneyBlend
 		relayMode = RelayModeMidjourneyBlend
-	} else if strings.HasPrefix(path, "/mj/submit/describe") {
+	} else if strings.HasSuffix(path, "/mj/submit/describe") {
 		relayMode = RelayModeMidjourneyDescribe
 		relayMode = RelayModeMidjourneyDescribe
-	} else if strings.HasPrefix(path, "/mj/notify") {
+	} else if strings.HasSuffix(path, "/mj/notify") {
 		relayMode = RelayModeMidjourneyNotify
 		relayMode = RelayModeMidjourneyNotify
-	} else if strings.HasPrefix(path, "/mj/submit/change") {
+	} else if strings.HasSuffix(path, "/mj/submit/change") {
 		relayMode = RelayModeMidjourneyChange
 		relayMode = RelayModeMidjourneyChange
-	} else if strings.HasPrefix(path, "/mj/submit/simple-change") {
+	} else if strings.HasSuffix(path, "/mj/submit/simple-change") {
 		relayMode = RelayModeMidjourneyChange
 		relayMode = RelayModeMidjourneyChange
 	} else if strings.HasSuffix(path, "/fetch") {
 	} else if strings.HasSuffix(path, "/fetch") {
 		relayMode = RelayModeMidjourneyTaskFetch
 		relayMode = RelayModeMidjourneyTaskFetch

+ 15 - 3
relay/relay-mj.go

@@ -180,7 +180,7 @@ func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse {
 			Description: "quota_not_enough",
 			Description: "quota_not_enough",
 		}
 		}
 	}
 	}
-	requestURL := c.Request.URL.String()
+	requestURL := getMjRequestPath(c.Request.URL.String())
 	baseURL := c.GetString("base_url")
 	baseURL := c.GetString("base_url")
 	fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
 	fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
 	mjResp, _, err := service.DoMidjourneyHttpRequest(c, time.Second*60, fullRequestURL)
 	mjResp, _, err := service.DoMidjourneyHttpRequest(c, time.Second*60, fullRequestURL)
@@ -260,7 +260,7 @@ func RelayMidjourneyTaskImageSeed(c *gin.Context) *dto.MidjourneyResponse {
 	c.Set("channel_id", originTask.ChannelId)
 	c.Set("channel_id", originTask.ChannelId)
 	c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
 	c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
 
 
-	requestURL := c.Request.URL.String()
+	requestURL := getMjRequestPath(c.Request.URL.String())
 	fullRequestURL := fmt.Sprintf("%s%s", channel.GetBaseURL(), requestURL)
 	fullRequestURL := fmt.Sprintf("%s%s", channel.GetBaseURL(), requestURL)
 	midjResponseWithStatus, _, err := service.DoMidjourneyHttpRequest(c, time.Second*30, fullRequestURL)
 	midjResponseWithStatus, _, err := service.DoMidjourneyHttpRequest(c, time.Second*30, fullRequestURL)
 	if err != nil {
 	if err != nil {
@@ -440,7 +440,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
 	}
 	}
 
 
 	//baseURL := common.ChannelBaseURLs[channelType]
 	//baseURL := common.ChannelBaseURLs[channelType]
-	requestURL := c.Request.URL.String()
+	requestURL := getMjRequestPath(c.Request.URL.String())
 
 
 	baseURL := c.GetString("base_url")
 	baseURL := c.GetString("base_url")
 
 
@@ -605,3 +605,15 @@ type taskChangeParams struct {
 	Action string
 	Action string
 	Index  int
 	Index  int
 }
 }
+
+func getMjRequestPath(path string) string {
+	requestURL := path
+	if strings.Contains(requestURL, "/mj-") {
+		urls := strings.Split(requestURL, "/mj/")
+		if len(urls) < 2 {
+			return requestURL
+		}
+		requestURL = "/mj/" + urls[1]
+	}
+	return requestURL
+}

+ 9 - 1
router/relay-router.go

@@ -43,7 +43,16 @@ func SetRelayRouter(router *gin.Engine) {
 		relayV1Router.DELETE("/models/:model", controller.RelayNotImplemented)
 		relayV1Router.DELETE("/models/:model", controller.RelayNotImplemented)
 		relayV1Router.POST("/moderations", controller.Relay)
 		relayV1Router.POST("/moderations", controller.Relay)
 	}
 	}
+
 	relayMjRouter := router.Group("/mj")
 	relayMjRouter := router.Group("/mj")
+	registerMjRouterGroup(relayMjRouter)
+
+	relayMjModeRouter := router.Group("/:mode/mj")
+	registerMjRouterGroup(relayMjModeRouter)
+	//relayMjRouter.Use()
+}
+
+func registerMjRouterGroup(relayMjRouter *gin.RouterGroup) {
 	relayMjRouter.GET("/image/:id", relay.RelayMidjourneyImage)
 	relayMjRouter.GET("/image/:id", relay.RelayMidjourneyImage)
 	relayMjRouter.Use(middleware.TokenAuth(), middleware.Distribute())
 	relayMjRouter.Use(middleware.TokenAuth(), middleware.Distribute())
 	{
 	{
@@ -61,5 +70,4 @@ func SetRelayRouter(router *gin.Engine) {
 		relayMjRouter.POST("/task/list-by-condition", controller.RelayMidjourney)
 		relayMjRouter.POST("/task/list-by-condition", controller.RelayMidjourney)
 		relayMjRouter.POST("/insight-face/swap", controller.RelayMidjourney)
 		relayMjRouter.POST("/insight-face/swap", controller.RelayMidjourney)
 	}
 	}
-	//relayMjRouter.Use()
 }
 }