Browse Source

Merge pull request #114 from Calcium-Ion/midjourney-proxy-plus

feat: support midjourney-proxy-plus
Calcium-Ion 2 years ago
parent
commit
f62dcbf669

+ 44 - 268
Midjourney.md

@@ -4,288 +4,64 @@
 
 ## 模型价格设置(在设置-运营设置-模型固定价格设置中设置)
 
+### 模型列表
+
+### midjourney-proxy支持
+
+- mj_imagine (绘图)
+- mj_variation (变换)
+- mj_reroll (重绘)
+- mj_blend (混合)
+- mj_upscale (放大)
+- mj_describe (图生文)
+
+### 仅midjourney-proxy-plus支持
+
+- mj_zoom (比例变焦)
+- mj_shorten (提示词缩短)
+- mj_modal (窗口提交,局部重绘和自定义比例变焦必须和mj_modal一同添加)
+- mj_inpaint (局部重绘提交,必须和mj_modal一同添加)
+- mj_custom_zoom (自定义比例变焦,必须和mj_modal一同添加)
+- mj_high_variation (强变换)
+- mj_low_variation (弱变换)
+- mj_pan (平移)
+- swap_face (换脸)
+
 ```json
 {
-  "gpt-4-gizmo-*": 0.1,
   "mj_imagine": 0.1,
   "mj_variation": 0.1,
   "mj_reroll": 0.1,
   "mj_blend": 0.1,
+  "mj_modal": 0.1,
+  "mj_zoom": 0.1,
+  "mj_shorten": 0.1,
+  "mj_high_variation": 0.1,
+  "mj_low_variation": 0.1,
+  "mj_pan": 0.1,
+  "mj_inpaint": 0,
+  "mj_custom_zoom": 0,
   "mj_describe": 0.05,
-  "mj_upscale": 0.05
+  "mj_upscale": 0.05,
+  "swap_face": 0.05
 }
 ```
 
 ## 渠道设置
 
-### 对接 midjourney-proxy
-1. 部署Midjourney-Proxy,并配置好midjourney账号等(强烈建议设置密钥),[项目地址](https://github.com/novicezk/midjourney-proxy)
-2. 在渠道管理中添加渠道,渠道类型选择Midjourney Proxy,模型选择midjourney
-3. 地址填写midjourney-proxy部署的地址,例如:http://localhost:8080
-4. 密钥填写midjourney-proxy的密钥,如果没有设置密钥,可以随便填
-
-### 对接上游new api
-1. 在渠道管理中添加渠道,渠道类型选择Midjourney Proxy,模型选择midjourney
-2. 地址填写上游new api的地址,例如:http://localhost:8080
-3. 密钥填写上游new api的密钥
-
-## 任务提交
-
-### 绘图变化
-
-**接口地址**:`/mj/submit/change`
-
-**请求方式**:`POST`
-
-**请求数据类型**:`application/json`
-
-**响应数据类型**:`*/*`
-
-**接口描述**:
-
-**请求示例**:
-
-```javascript
-{
-    "action"
-:
-    "UPSCALE",
-        "index"
-:
-    1,
-        "notifyHook"
-:
-    "",
-        "state"
-:
-    "",
-        "taskId"
-:
-    "1320098173412546"
-}
-```
-
-**请求参数**:
-
-| 参数名称                   | 参数说明                                                                  | 请求类型 | 是否必须  | 数据类型           | schema   |
-|------------------------|-----------------------------------------------------------------------|------|-------|----------------|----------|
-| changeDTO              | changeDTO                                                             | body | true  | 变化任务提交参数       | 变化任务提交参数 |
-|   action     | UPSCALE(放大); VARIATION(变换); REROLL(重新生成),可用值:UPSCALE,VARIATION,REROLL |      | true  | string         |          |
-|   index      | 序号(1~4), action为UPSCALE,VARIATION时必传                                  |      | false | integer(int32) |          |
-|   notifyHook | 回调地址, 为空时使用全局notifyHook                                               |      | false | string         |          |
-|   state      | 自定义参数                                                                 |      | false | string         |          |
-|   taskId     | 任务ID                                                                  |      | true  | string         |          |
-
-**响应状态**:
-
-| 状态码 | 说明           | schema |
-|-----|--------------|--------| 
-| 200 | OK           | 提交结果   |
-| 201 | Created      |        |
-| 401 | Unauthorized |        |
-| 403 | Forbidden    |        |
-| 404 | Not Found    |        |
-
-**响应参数**:
-
-| 参数名称        | 参数说明                                      | 类型             | schema         |
-|-------------|-------------------------------------------|----------------|----------------| 
-| code        | 状态码: 1(提交成功), 21(已存在), 22(排队中), other(错误) | integer(int32) | integer(int32) |
-| description | 描述                                        | string         |                |
-| properties  | 扩展字段                                      | object         |                |
-| result      | 任务ID                                      | string         |                |
-
-**响应示例**:
-
-```javascript
-{
-    "code"
-:
-    1,
-        "description"
-:
-    "提交成功",
-        "properties"
-:
-    {
-    }
-,
-    "result"
-:
-    1320098173412546
-}
-```
-
-### 提交Imagine任务
-
-**接口地址**:`/mj/submit/imagine`
-
-**请求方式**:`POST`
+### 对接 midjourney-proxy(plus)
 
-**请求数据类型**:`application/json`
+1.
 
-**响应数据类型**:`*/*`
+部署Midjourney-Proxy,并配置好midjourney账号等(强烈建议设置密钥),[项目地址](https://github.com/novicezk/midjourney-proxy)
 
-**接口描述**:
-
-**请求示例**:
-
-```javascript
-{
-    "base64"
-:
-    "",
-        "notifyHook"
-:
-    "",
-        "prompt"
-:
-    "Cat",
-        "state"
-:
-    ""
-}
-```
-
-**请求参数**:
-
-| 参数名称                   | 参数说明                    | 请求类型 | 是否必须  | 数据类型        | schema      |
-|------------------------|-------------------------|------|-------|-------------|-------------|
-| imagineDTO             | imagineDTO              | body | true  | Imagine提交参数 | Imagine提交参数 |
-|   base64     | 垫图base64                |      | false | string      |             |
-|   notifyHook | 回调地址, 为空时使用全局notifyHook |      | false | string      |             |
-|   prompt     | 提示词                     |      | true  | string      |             |
-|   state      | 自定义参数                   |      | false | string      |             |
-
-**响应状态**:
-
-| 状态码 | 说明           | schema |
-|-----|--------------|--------| 
-| 200 | OK           | 提交结果   |
-| 201 | Created      |        |
-| 401 | Unauthorized |        |
-| 403 | Forbidden    |        |
-| 404 | Not Found    |        |
-
-**响应参数**:
-
-| 参数名称        | 参数说明                                      | 类型             | schema         |
-|-------------|-------------------------------------------|----------------|----------------| 
-| code        | 状态码: 1(提交成功), 21(已存在), 22(排队中), other(错误) | integer(int32) | integer(int32) |
-| description | 描述                                        | string         |                |
-| properties  | 扩展字段                                      | object         |                |
-| result      | 任务ID                                      | string         |                |
-
-**响应示例**:
-
-```javascript
-{
-    "code"
-:
-    1,
-        "description"
-:
-    "提交成功",
-        "properties"
-:
-    {
-    }
-,
-    "result"
-:
-    1320098173412546
-}
-```
-
-## 任务查询
-
-### 指定ID获取任务
-
-**接口地址**:`/mj/task/{id}/fetch`
-
-**请求方式**:`GET`
-
-**请求数据类型**:`application/x-www-form-urlencoded`
-
-**响应数据类型**:`*/*`
-
-**接口描述**:
-
-**请求参数**:
-
-| 参数名称 | 参数说明 | 请求类型 | 是否必须  | 数据类型   | schema |
-|------|------|------|-------|--------|--------|
-| id   | 任务ID | path | false | string |        |
-
-**响应状态**:
-
-| 状态码 | 说明           | schema |
-|-----|--------------|--------| 
-| 200 | OK           | 任务     |
-| 401 | Unauthorized |        |
-| 403 | Forbidden    |        |
-| 404 | Not Found    |        |
-
-**响应参数**:
-
-| 参数名称        | 参数说明                                                     | 类型             | schema         |
-|-------------|----------------------------------------------------------|----------------|----------------| 
-| action      | 可用值:IMAGINE,UPSCALE,VARIATION,REROLL,DESCRIBE,BLEND      | string         |                |
-| description | 任务描述                                                     | string         |                |
-| failReason  | 失败原因                                                     | string         |                |
-| finishTime  | 结束时间                                                     | integer(int64) | integer(int64) |
-| id          | 任务ID                                                     | string         |                |
-| imageUrl    | 图片url                                                    | string         |                |
-| progress    | 任务进度                                                     | string         |                |
-| prompt      | 提示词                                                      | string         |                |
-| promptEn    | 提示词-英文                                                   | string         |                |
-| startTime   | 开始执行时间                                                   | integer(int64) | integer(int64) |
-| state       | 自定义参数                                                    | string         |                |
-| status      | 任务状态,可用值:NOT_START,SUBMITTED,IN_PROGRESS,FAILURE,SUCCESS | string         |                |
-| submitTime  | 提交时间                                                     | integer(int64) | integer(int64) |
+2. 在渠道管理中添加渠道,渠道类型选择**Midjourney Proxy**,如果是plus版本选择**Midjourney Proxy Plus**
+   ,模型选择midjourney,如果有换脸模型,可以选择swap_face
+3. 地址填写midjourney-proxy部署的地址,例如:http://localhost:8080
+4. 密钥填写midjourney-proxy的密钥,如果没有设置密钥,可以随便填
 
-**响应示例**:
+### 对接上游new api
 
-```javascript
-{
-    "action"
-:
-    "",
-        "description"
-:
-    "",
-        "failReason"
-:
-    "",
-        "finishTime"
-:
-    0,
-        "id"
-:
-    "",
-        "imageUrl"
-:
-    "",
-        "progress"
-:
-    "",
-        "prompt"
-:
-    "",
-        "promptEn"
-:
-    "",
-        "startTime"
-:
-    0,
-        "state"
-:
-    "",
-        "status"
-:
-    "",
-        "submitTime"
-:
-    0
-}
-```
+1. 在渠道管理中添加渠道,渠道类型选择**Midjourney Proxy Plus**,模型选择midjourney,如果有换脸模型,可以选择swap_face
+2. 地址填写上游new api的地址,例如:http://localhost:3000
+3. 密钥填写上游new api的密钥

+ 7 - 1
README.md

@@ -18,7 +18,7 @@
 此分叉版本的主要变更如下:
 
 1. 全新的UI界面(部分界面还待更新)
-2. 添加[Midjourney-Proxy](https://github.com/novicezk/midjourney-proxy)接口的支持
+2. 添加[Midjourney-Proxy(Plus)](https://github.com/novicezk/midjourney-proxy)接口的支持
    + [x] /mj/submit/imagine
    + [x] /mj/submit/change
    + [x] /mj/submit/blend
@@ -26,6 +26,11 @@
    + [x] /mj/image/{id} (通过此接口获取图片,**请必须在系统设置中填写服务器地址!!**)
    + [x] /mj/task/{id}/fetch (此接口返回的图片地址为经过One API转发的地址)
    + [x] /task/list-by-condition
+   + [x] /mj/submit/action (仅midjourney-proxy-plus支持,下同)
+   + [x] /mj/submit/modal
+   + [x] /mj/submit/shorten
+   + [x] /mj/task/{id}/image-seed
+   + [x] /mj/insight-face/swap (InsightFace)
 3. 支持在线充值功能,可在系统设置中设置,当前支持的支付接口:
    + [x] 易支付
 4. 支持用key查询使用额度:
@@ -49,6 +54,7 @@
 2. 智谱glm-4v,glm-4v识图
 3. Anthropic Claude 3 (claude-3-opus-20240229, claude-3-sonnet-20240229)
 4. [Ollama](https://github.com/ollama/ollama?tab=readme-ov-file),添加渠道时,密钥可以随便填写,默认的请求地址是[http://localhost:11434](http://localhost:11434),如果需要修改请在渠道中修改
+5. [Midjourney-Proxy(Plus)](https://github.com/novicezk/midjourney-proxy)接口
 
 您可以在渠道中添加自定义模型gpt-4-gizmo-*,此模型并非OpenAI官方模型,而是第三方模型,使用官方key无法调用。
 

+ 1 - 1
common/constants.go

@@ -189,7 +189,7 @@ const (
 	ChannelTypeMidjourney     = 2
 	ChannelTypeAzure          = 3
 	ChannelTypeOllama         = 4
-	ChannelTypeOpenAISB       = 5
+	ChannelTypeMidjourneyPlus = 5
 	ChannelTypeOpenAIMax      = 6
 	ChannelTypeOhMyGPT        = 7
 	ChannelTypeCustom         = 8

+ 25 - 8
common/model-ratio.go

@@ -95,17 +95,31 @@ var ModelRatio = map[string]float64{
 	"hunyuan":                   7.143,  // ¥0.1 / 1k tokens  // https://cloud.tencent.com/document/product/1729/97731#e0e6be58-60c8-469f-bdeb-6c264ce3b4d0
 }
 
-var ModelPrice = map[string]float64{
-	"gpt-4-gizmo-*": 0.1,
-	"mj_imagine":    0.1,
-	"mj_variation":  0.1,
-	"mj_reroll":     0.1,
-	"mj_blend":      0.1,
-	"mj_describe":   0.05,
-	"mj_upscale":    0.05,
+var DefaultModelPrice = map[string]float64{
+	"gpt-4-gizmo-*":     0.1,
+	"mj_imagine":        0.1,
+	"mj_variation":      0.1,
+	"mj_reroll":         0.1,
+	"mj_blend":          0.1,
+	"mj_modal":          0.1,
+	"mj_zoom":           0.1,
+	"mj_shorten":        0.1,
+	"mj_high_variation": 0.1,
+	"mj_low_variation":  0.1,
+	"mj_pan":            0.1,
+	"mj_inpaint":        0,
+	"mj_custom_zoom":    0,
+	"mj_describe":       0.05,
+	"mj_upscale":        0.05,
+	"swap_face":         0.05,
 }
 
+var ModelPrice = map[string]float64{}
+
 func ModelPrice2JSONString() string {
+	if len(ModelPrice) == 0 {
+		ModelPrice = DefaultModelPrice
+	}
 	jsonBytes, err := json.Marshal(ModelPrice)
 	if err != nil {
 		SysError("error marshalling model price: " + err.Error())
@@ -119,6 +133,9 @@ func UpdateModelPriceByJSONString(jsonStr string) error {
 }
 
 func GetModelPrice(name string, printErr bool) float64 {
+	if len(ModelPrice) == 0 {
+		ModelPrice = DefaultModelPrice
+	}
 	if strings.HasPrefix(name, "gpt-4-gizmo") {
 		name = "gpt-4-gizmo-*"
 	}

+ 42 - 0
constant/midjourney.go

@@ -0,0 +1,42 @@
+package constant
+
+const (
+	MjErrorUnknown = 5
+	MjRequestError = 4
+)
+
+const (
+	MjActionImagine       = "IMAGINE"
+	MjActionDescribe      = "DESCRIBE"
+	MjActionBlend         = "BLEND"
+	MjActionUpscale       = "UPSCALE"
+	MjActionVariation     = "VARIATION"
+	MjActionReRoll        = "REROLL"
+	MjActionInPaint       = "INPAINT"
+	MjActionModal         = "MODAL"
+	MjActionZoom          = "ZOOM"
+	MjActionCustomZoom    = "CUSTOM_ZOOM"
+	MjActionShorten       = "SHORTEN"
+	MjActionHighVariation = "HIGH_VARIATION"
+	MjActionLowVariation  = "LOW_VARIATION"
+	MjActionPan           = "PAN"
+	MjActionSwapFace      = "SWAP_FACE"
+)
+
+var MidjourneyModel2Action = map[string]string{
+	"mj_imagine":        MjActionImagine,
+	"mj_describe":       MjActionDescribe,
+	"mj_blend":          MjActionBlend,
+	"mj_upscale":        MjActionUpscale,
+	"mj_variation":      MjActionVariation,
+	"mj_reroll":         MjActionReRoll,
+	"mj_modal":          MjActionModal,
+	"mj_inpaint":        MjActionInPaint,
+	"mj_zoom":           MjActionZoom,
+	"mj_custom_zoom":    MjActionCustomZoom,
+	"mj_shorten":        MjActionShorten,
+	"mj_high_variation": MjActionHighVariation,
+	"mj_low_variation":  MjActionLowVariation,
+	"mj_pan":            MjActionPan,
+	"swap_face":         MjActionSwapFace,
+}

+ 2 - 2
controller/channel-billing.go

@@ -214,8 +214,8 @@ func updateChannelBalance(channel *model.Channel) (float64, error) {
 		return 0, errors.New("尚未实现")
 	case common.ChannelTypeCustom:
 		baseURL = channel.GetBaseURL()
-	case common.ChannelTypeOpenAISB:
-		return updateChannelOpenAISBBalance(channel)
+	//case common.ChannelTypeOpenAISB:
+	//	return updateChannelOpenAISBBalance(channel)
 	case common.ChannelTypeAIProxy:
 		return updateChannelAIProxyBalance(channel)
 	case common.ChannelTypeAPI2GPT:

+ 23 - 135
controller/midjourney.go

@@ -10,145 +10,14 @@ import (
 	"log"
 	"net/http"
 	"one-api/common"
+	"one-api/dto"
 	"one-api/model"
-	relay2 "one-api/relay"
 	"one-api/service"
 	"strconv"
 	"strings"
 	"time"
 )
 
-/*func UpdateMidjourneyTask() {
-	//revocer
-	//imageModel := "midjourney"
-	ctx := context.TODO()
-	imageModel := "midjourney"
-	defer func() {
-		if err := recover(); err != nil {
-			log.Printf("UpdateMidjourneyTask panic: %v", err)
-		}
-	}()
-	for {
-		time.Sleep(time.Duration(15) * time.Second)
-		tasks := model.GetAllUnFinishTasks()
-		if len(tasks) != 0 {
-			common.LogInfo(ctx, fmt.Sprintf("检测到未完成的任务数有: %v", len(tasks)))
-			for _, task := range tasks {
-				common.LogInfo(ctx, fmt.Sprintf("未完成的任务信息: %v", task))
-				midjourneyChannel, err := model.GetChannelById(task.ChannelId, true)
-				if err != nil {
-					common.LogError(ctx, fmt.Sprintf("UpdateMidjourneyTask: %v", err))
-					task.FailReason = fmt.Sprintf("获取渠道信息失败,请联系管理员,渠道ID:%d", task.ChannelId)
-					task.Status = "FAILURE"
-					task.Progress = "100%"
-					err := task.Update()
-					if err != nil {
-						common.LogInfo(ctx, fmt.Sprintf("UpdateMidjourneyTask error: %v", err))
-						continue
-					}
-					continue
-				}
-				requestUrl := fmt.Sprintf("%s/mj/task/%s/fetch", *midjourneyChannel.BaseURL, task.MjId)
-				common.LogInfo(ctx, fmt.Sprintf("requestUrl: %s", requestUrl))
-
-				req, err := http.NewRequest("GET", requestUrl, bytes.NewBuffer([]byte("")))
-				if err != nil {
-					common.LogInfo(ctx, fmt.Sprintf("Get Task error: %v", err))
-					continue
-				}
-
-				// 设置超时时间
-				timeout := time.Second * 5
-				ctx, cancel := context.WithTimeout(context.Background(), timeout)
-
-				// 使用带有超时的 context 创建新的请求
-				req = req.WithContext(ctx)
-
-				req.Header.Set("Content-Type", "application/json")
-				//req.Header.Set("ApiKey", "Bearer midjourney-proxy")
-				req.Header.Set("mj-api-secret", midjourneyChannel.Key)
-				resp, err := httpClient.Do(req)
-				if err != nil {
-					log.Printf("UpdateMidjourneyTask error: %v", err)
-					continue
-				}
-				responseBody, err := io.ReadAll(resp.Body)
-				resp.Body.Close()
-				log.Printf("responseBody: %s", string(responseBody))
-				var responseItem Midjourney
-				// err = json.NewDecoder(resp.Body).Decode(&responseItem)
-				err = json.Unmarshal(responseBody, &responseItem)
-				if err != nil {
-					if strings.Contains(err.Error(), "cannot unmarshal number into Go struct field Midjourney.status of type string") {
-						var responseWithoutStatus MidjourneyWithoutStatus
-						var responseStatus MidjourneyStatus
-						err1 := json.Unmarshal(responseBody, &responseWithoutStatus)
-						err2 := json.Unmarshal(responseBody, &responseStatus)
-						if err1 == nil && err2 == nil {
-							jsonData, err3 := json.Marshal(responseWithoutStatus)
-							if err3 != nil {
-								log.Printf("UpdateMidjourneyTask error1: %v", err3)
-								continue
-							}
-							err4 := json.Unmarshal(jsonData, &responseStatus)
-							if err4 != nil {
-								log.Printf("UpdateMidjourneyTask error2: %v", err4)
-								continue
-							}
-							responseItem.Status = strconv.Itoa(responseStatus.Status)
-						} else {
-							log.Printf("UpdateMidjourneyTask error3: %v", err)
-							continue
-						}
-					} else {
-						log.Printf("UpdateMidjourneyTask error4: %v", err)
-						continue
-					}
-				}
-				task.Code = 1
-				task.Progress = responseItem.Progress
-				task.PromptEn = responseItem.PromptEn
-				task.State = responseItem.State
-				task.SubmitTime = responseItem.SubmitTime
-				task.StartTime = responseItem.StartTime
-				task.FinishTime = responseItem.FinishTime
-				task.ImageUrl = responseItem.ImageUrl
-				task.Status = responseItem.Status
-				task.FailReason = responseItem.FailReason
-				if task.Progress != "100%" && responseItem.FailReason != "" {
-					common.LogWarn(task.MjId + " 构建失败," + task.FailReason)
-					task.Progress = "100%"
-					err = model.CacheUpdateUserQuota(task.UserId)
-					if err != nil {
-						log.Println("error update user quota cache: " + err.Error())
-					} else {
-						modelRatio := common.GetModelRatio(imageModel)
-						groupRatio := common.GetGroupRatio("default")
-						ratio := modelRatio * groupRatio
-						quota := int(ratio * 1 * 1000)
-						if quota != 0 {
-							err := model.IncreaseUserQuota(task.UserId, quota)
-							if err != nil {
-								log.Println("fail to increase user quota")
-							}
-							logContent := fmt.Sprintf("构图失败 %s,补偿 %s", task.MjId, common.LogQuota(quota))
-							model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
-						}
-					}
-				}
-
-				err = task.Update()
-				if err != nil {
-					log.Printf("UpdateMidjourneyTask error5: %v", err)
-				}
-				log.Printf("UpdateMidjourneyTask success: %v", task)
-				cancel()
-			}
-		}
-	}
-}
-*/
-
 func UpdateMidjourneyTaskBulk() {
 	//imageModel := "midjourney"
 	ctx := context.TODO()
@@ -228,12 +97,16 @@ func UpdateMidjourneyTaskBulk() {
 				common.LogError(ctx, fmt.Sprintf("Get Task Do req error: %v", err))
 				continue
 			}
+			if resp.StatusCode != http.StatusOK {
+				common.LogError(ctx, fmt.Sprintf("Get Task status code: %d", resp.StatusCode))
+				continue
+			}
 			responseBody, err := io.ReadAll(resp.Body)
 			if err != nil {
 				common.LogError(ctx, fmt.Sprintf("Get Task parse body error: %v", err))
 				continue
 			}
-			var responseItems []relay2.Midjourney
+			var responseItems []dto.MidjourneyDto
 			err = json.Unmarshal(responseBody, &responseItems)
 			if err != nil {
 				common.LogError(ctx, fmt.Sprintf("Get Task parse body error2: %v, body: %s", err, string(responseBody)))
@@ -245,10 +118,16 @@ func UpdateMidjourneyTaskBulk() {
 
 			for _, responseItem := range responseItems {
 				task := taskM[responseItem.MjId]
+
+				useTime := (time.Now().UnixNano() / int64(time.Millisecond)) - task.SubmitTime
+				// 如果时间超过一小时,且进度不是100%,则认为任务失败
+				if useTime > 3600000 && task.Progress != "100%" {
+					responseItem.FailReason = "上游任务超时(超过1小时)"
+					responseItem.Status = "FAILURE"
+				}
 				if !checkMjTaskNeedUpdate(task, responseItem) {
 					continue
 				}
-
 				task.Code = 1
 				task.Progress = responseItem.Progress
 				task.PromptEn = responseItem.PromptEn
@@ -259,6 +138,15 @@ func UpdateMidjourneyTaskBulk() {
 				task.ImageUrl = responseItem.ImageUrl
 				task.Status = responseItem.Status
 				task.FailReason = responseItem.FailReason
+				if responseItem.Properties != nil {
+					propertiesStr, _ := json.Marshal(responseItem.Properties)
+					task.Properties = string(propertiesStr)
+				}
+				if responseItem.Buttons != nil {
+					buttonStr, _ := json.Marshal(responseItem.Buttons)
+					task.Buttons = string(buttonStr)
+				}
+
 				if task.Progress != "100%" && responseItem.FailReason != "" {
 					common.LogInfo(ctx, task.MjId+" 构建失败,"+task.FailReason)
 					task.Progress = "100%"
@@ -286,7 +174,7 @@ func UpdateMidjourneyTaskBulk() {
 	}
 }
 
-func checkMjTaskNeedUpdate(oldTask *model.Midjourney, newTask relay2.Midjourney) bool {
+func checkMjTaskNeedUpdate(oldTask *model.Midjourney, newTask dto.MidjourneyDto) bool {
 	if oldTask.Code != 1 {
 		return true
 	}

+ 15 - 3
controller/model.go

@@ -4,12 +4,13 @@ import (
 	"fmt"
 	"github.com/gin-gonic/gin"
 	"net/http"
+	"one-api/constant"
 	"one-api/dto"
 	"one-api/model"
 	"one-api/relay"
 	"one-api/relay/channel/ai360"
 	"one-api/relay/channel/moonshot"
-	"one-api/relay/constant"
+	relayconstant "one-api/relay/constant"
 )
 
 // https://platform.openai.com/docs/api-reference/models/list
@@ -59,8 +60,8 @@ func init() {
 		IsBlocking:         false,
 	})
 	// https://platform.openai.com/docs/models/model-endpoint-compatibility
-	for i := 0; i < constant.APITypeDummy; i++ {
-		if i == constant.APITypeAIProxyLibrary {
+	for i := 0; i < relayconstant.APITypeDummy; i++ {
+		if i == relayconstant.APITypeAIProxyLibrary {
 			continue
 		}
 		adaptor := relay.GetAdaptor(i)
@@ -100,6 +101,17 @@ func init() {
 			Parent:     nil,
 		})
 	}
+	for modelName, _ := range constant.MidjourneyModel2Action {
+		openAIModels = append(openAIModels, OpenAIModels{
+			Id:         modelName,
+			Object:     "model",
+			Created:    1626777600,
+			OwnedBy:    "midjourney",
+			Permission: permission,
+			Root:       modelName,
+			Parent:     nil,
+		})
+	}
 	openAIModelsMap = make(map[string]OpenAIModels)
 	for _, model := range openAIModels {
 		openAIModelsMap[model.Id] = model

+ 14 - 40
controller/relay.go

@@ -12,7 +12,6 @@ import (
 	relayconstant "one-api/relay/constant"
 	"one-api/service"
 	"strconv"
-	"strings"
 )
 
 func Relay(c *gin.Context) {
@@ -61,60 +60,35 @@ func Relay(c *gin.Context) {
 }
 
 func RelayMidjourney(c *gin.Context) {
-	relayMode := relayconstant.RelayModeUnknown
-	if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/imagine") {
-		relayMode = relayconstant.RelayModeMidjourneyImagine
-	} else if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/blend") {
-		relayMode = relayconstant.RelayModeMidjourneyBlend
-	} else if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/describe") {
-		relayMode = relayconstant.RelayModeMidjourneyDescribe
-	} else if strings.HasPrefix(c.Request.URL.Path, "/mj/notify") {
-		relayMode = relayconstant.RelayModeMidjourneyNotify
-	} else if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/change") {
-		relayMode = relayconstant.RelayModeMidjourneyChange
-	} else if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/simple-change") {
-		relayMode = relayconstant.RelayModeMidjourneyChange
-	} else if strings.HasSuffix(c.Request.URL.Path, "/fetch") {
-		relayMode = relayconstant.RelayModeMidjourneyTaskFetch
-	} else if strings.HasSuffix(c.Request.URL.Path, "/list-by-condition") {
-		relayMode = relayconstant.RelayModeMidjourneyTaskFetchByCondition
-	}
-
+	relayMode := c.GetInt("relay_mode")
 	var err *dto.MidjourneyResponse
 	switch relayMode {
 	case relayconstant.RelayModeMidjourneyNotify:
 		err = relay.RelayMidjourneyNotify(c)
 	case relayconstant.RelayModeMidjourneyTaskFetch, relayconstant.RelayModeMidjourneyTaskFetchByCondition:
 		err = relay.RelayMidjourneyTask(c, relayMode)
+	case relayconstant.RelayModeMidjourneyTaskImageSeed:
+		err = relay.RelayMidjourneyTaskImageSeed(c)
+	case relayconstant.RelayModeSwapFace:
+		err = relay.RelaySwapFace(c)
 	default:
 		err = relay.RelayMidjourneySubmit(c, relayMode)
 	}
 	//err = relayMidjourneySubmit(c, relayMode)
 	log.Println(err)
 	if err != nil {
-		retryTimesStr := c.Query("retry")
-		retryTimes, _ := strconv.Atoi(retryTimesStr)
-		if retryTimesStr == "" {
-			retryTimes = common.RetryTimes
-		}
-		if retryTimes > 0 {
-			c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s?retry=%d", c.Request.URL.Path, retryTimes-1))
-		} else {
-			if err.Code == 30 {
-				err.Result = "当前分组负载已饱和,请稍后再试,或升级账户以提升服务质量。"
-			}
-			c.JSON(429, gin.H{
-				"error": fmt.Sprintf("%s %s", err.Description, err.Result),
-				"type":  "upstream_error",
-			})
+		statusCode := http.StatusBadRequest
+		if err.Code == 30 {
+			err.Result = "当前分组负载已饱和,请稍后再试,或升级账户以提升服务质量。"
+			statusCode = http.StatusTooManyRequests
 		}
+		c.JSON(statusCode, gin.H{
+			"description": fmt.Sprintf("%s %s", err.Description, err.Result),
+			"type":        "upstream_error",
+			"code":        err.Code,
+		})
 		channelId := c.GetInt("channel_id")
 		common.SysError(fmt.Sprintf("relay error (channel #%d): %s", channelId, fmt.Sprintf("%s %s", err.Description, err.Result)))
-		//if shouldDisableChannel(&err.Error) {
-		//	channelId := c.GetInt("channel_id")
-		//	channelName := c.GetString("channel_name")
-		//	disableChannel(channelId, channelName, err.Result)
-		//};''''''''''''''''''''''''''''''''
 	}
 }
 

+ 76 - 0
dto/midjourney.go

@@ -1,7 +1,21 @@
 package dto
 
+//type SimpleMjRequest struct {
+//	Prompt   string `json:"prompt"`
+//	CustomId string `json:"customId"`
+//	Action   string `json:"action"`
+//	Content  string `json:"content"`
+//}
+
+type SwapFaceRequest struct {
+	SourceBase64 string `json:"sourceBase64"`
+	TargetBase64 string `json:"targetBase64"`
+}
+
 type MidjourneyRequest struct {
 	Prompt      string   `json:"prompt"`
+	CustomId    string   `json:"customId"`
+	BotType     string   `json:"botType"`
 	NotifyHook  string   `json:"notifyHook"`
 	Action      string   `json:"action"`
 	Index       int      `json:"index"`
@@ -9,6 +23,7 @@ type MidjourneyRequest struct {
 	TaskId      string   `json:"taskId"`
 	Base64Array []string `json:"base64Array"`
 	Content     string   `json:"content"`
+	MaskBase64  string   `json:"maskBase64"`
 }
 
 type MidjourneyResponse struct {
@@ -17,3 +32,64 @@ type MidjourneyResponse struct {
 	Properties  interface{} `json:"properties"`
 	Result      string      `json:"result"`
 }
+
+type MidjourneyResponseWithStatusCode struct {
+	StatusCode int `json:"statusCode"`
+	Response   MidjourneyResponse
+}
+
+type MidjourneyDto struct {
+	MjId        string      `json:"id"`
+	Action      string      `json:"action"`
+	CustomId    string      `json:"customId"`
+	BotType     string      `json:"botType"`
+	Prompt      string      `json:"prompt"`
+	PromptEn    string      `json:"promptEn"`
+	Description string      `json:"description"`
+	State       string      `json:"state"`
+	SubmitTime  int64       `json:"submitTime"`
+	StartTime   int64       `json:"startTime"`
+	FinishTime  int64       `json:"finishTime"`
+	ImageUrl    string      `json:"imageUrl"`
+	Status      string      `json:"status"`
+	Progress    string      `json:"progress"`
+	FailReason  string      `json:"failReason"`
+	Buttons     any         `json:"buttons"`
+	MaskBase64  string      `json:"maskBase64"`
+	Properties  *Properties `json:"properties"`
+}
+
+type MidjourneyStatus struct {
+	Status int `json:"status"`
+}
+type MidjourneyWithoutStatus struct {
+	Id          int    `json:"id"`
+	Code        int    `json:"code"`
+	UserId      int    `json:"user_id" gorm:"index"`
+	Action      string `json:"action"`
+	MjId        string `json:"mj_id" gorm:"index"`
+	Prompt      string `json:"prompt"`
+	PromptEn    string `json:"prompt_en"`
+	Description string `json:"description"`
+	State       string `json:"state"`
+	SubmitTime  int64  `json:"submit_time"`
+	StartTime   int64  `json:"start_time"`
+	FinishTime  int64  `json:"finish_time"`
+	ImageUrl    string `json:"image_url"`
+	Progress    string `json:"progress"`
+	FailReason  string `json:"fail_reason"`
+	ChannelId   int    `json:"channel_id"`
+}
+
+type ActionButton struct {
+	CustomId any `json:"customId"`
+	Emoji    any `json:"emoji"`
+	Label    any `json:"label"`
+	Type     any `json:"type"`
+	Style    any `json:"style"`
+}
+
+type Properties struct {
+	FinalPrompt   string `json:"finalPrompt"`
+	FinalZhPrompt string `json:"finalZhPrompt"`
+}

+ 4 - 10
middleware/auth.go

@@ -100,16 +100,16 @@ func TokenAuth() func(c *gin.Context) {
 		}
 		token, err := model.ValidateUserToken(key)
 		if err != nil {
-			abortWithMessage(c, http.StatusUnauthorized, err.Error())
+			abortWithOpenAiMessage(c, http.StatusUnauthorized, err.Error())
 			return
 		}
 		userEnabled, err := model.CacheIsUserEnabled(token.UserId)
 		if err != nil {
-			abortWithMessage(c, http.StatusInternalServerError, err.Error())
+			abortWithOpenAiMessage(c, http.StatusInternalServerError, err.Error())
 			return
 		}
 		if !userEnabled {
-			abortWithMessage(c, http.StatusForbidden, "用户已被封禁")
+			abortWithOpenAiMessage(c, http.StatusForbidden, "用户已被封禁")
 			return
 		}
 		c.Set("id", token.UserId)
@@ -125,17 +125,11 @@ func TokenAuth() func(c *gin.Context) {
 		} else {
 			c.Set("token_model_limit_enabled", false)
 		}
-		requestURL := c.Request.URL.String()
-		consumeQuota := true
-		if strings.HasPrefix(requestURL, "/v1/models") {
-			consumeQuota = false
-		}
-		c.Set("consume_quota", consumeQuota)
 		if len(parts) > 1 {
 			if model.IsAdmin(token.UserId) {
 				c.Set("channelId", parts[1])
 			} else {
-				abortWithMessage(c, http.StatusForbidden, "普通用户不支持指定渠道")
+				abortWithOpenAiMessage(c, http.StatusForbidden, "普通用户不支持指定渠道")
 				return
 			}
 		}

+ 81 - 49
middleware/distributor.go

@@ -4,7 +4,11 @@ import (
 	"fmt"
 	"net/http"
 	"one-api/common"
+	"one-api/constant"
+	"one-api/dto"
 	"one-api/model"
+	relayconstant "one-api/relay/constant"
+	"one-api/service"
 	"strconv"
 	"strings"
 
@@ -23,32 +27,59 @@ func Distribute() func(c *gin.Context) {
 		if ok {
 			id, err := strconv.Atoi(channelId.(string))
 			if err != nil {
-				abortWithMessage(c, http.StatusBadRequest, "无效的渠道 Id")
+				abortWithOpenAiMessage(c, http.StatusBadRequest, "无效的渠道 Id")
 				return
 			}
 			channel, err = model.GetChannelById(id, true)
 			if err != nil {
-				abortWithMessage(c, http.StatusBadRequest, "无效的渠道 Id")
+				abortWithOpenAiMessage(c, http.StatusBadRequest, "无效的渠道 Id")
 				return
 			}
 			if channel.Status != common.ChannelStatusEnabled {
-				abortWithMessage(c, http.StatusForbidden, "该渠道已被禁用")
+				abortWithOpenAiMessage(c, http.StatusForbidden, "该渠道已被禁用")
 				return
 			}
 		} else {
+			shouldSelectChannel := true
 			// Select a channel for the user
 			var modelRequest ModelRequest
 			var err error
 			if strings.HasPrefix(c.Request.URL.Path, "/mj") {
-				// Midjourney
-				if modelRequest.Model == "" {
-					modelRequest.Model = "midjourney"
+				relayMode := relayconstant.Path2RelayModeMidjourney(c.Request.URL.Path)
+				if relayMode == relayconstant.RelayModeMidjourneyTaskFetch ||
+					relayMode == relayconstant.RelayModeMidjourneyTaskFetchByCondition ||
+					relayMode == relayconstant.RelayModeMidjourneyNotify ||
+					relayMode == relayconstant.RelayModeMidjourneyTaskImageSeed {
+					shouldSelectChannel = false
+				} else {
+					midjourneyRequest := dto.MidjourneyRequest{}
+					err = common.UnmarshalBodyReusable(c, &midjourneyRequest)
+					if err != nil {
+						abortWithMidjourneyMessage(c, http.StatusBadRequest, constant.MjErrorUnknown, "无效的请求, "+err.Error())
+						return
+					}
+					midjourneyModel, mjErr, success := service.GetMjRequestModel(relayMode, &midjourneyRequest)
+					if mjErr != nil {
+						abortWithMidjourneyMessage(c, http.StatusBadRequest, mjErr.Code, mjErr.Description)
+						return
+					}
+					if midjourneyModel == "" {
+						if !success {
+							abortWithMidjourneyMessage(c, http.StatusBadRequest, constant.MjErrorUnknown, "无效的请求, 无法解析模型")
+							return
+						} else {
+							// task fetch, task fetch by condition, notify
+							shouldSelectChannel = false
+						}
+					}
+					modelRequest.Model = midjourneyModel
 				}
+				c.Set("relay_mode", relayMode)
 			} else if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") {
 				err = common.UnmarshalBodyReusable(c, &modelRequest)
 			}
 			if err != nil {
-				abortWithMessage(c, http.StatusBadRequest, "无效的请求, "+err.Error())
+				abortWithOpenAiMessage(c, http.StatusBadRequest, "无效的请求, "+err.Error())
 				return
 			}
 			if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
@@ -87,61 +118,62 @@ func Distribute() func(c *gin.Context) {
 				}
 				if tokenModelLimit != nil {
 					if _, ok := tokenModelLimit[modelRequest.Model]; !ok {
-						abortWithMessage(c, http.StatusForbidden, "该令牌无权访问模型 "+modelRequest.Model)
+						abortWithOpenAiMessage(c, http.StatusForbidden, "该令牌无权访问模型 "+modelRequest.Model)
 						return
 					}
 				} else {
 					// token model limit is empty, all models are not allowed
-					abortWithMessage(c, http.StatusForbidden, "该令牌无权访问任何模型")
+					abortWithOpenAiMessage(c, http.StatusForbidden, "该令牌无权访问任何模型")
 					return
 				}
 			}
 
 			userGroup, _ := model.CacheGetUserGroup(userId)
 			c.Set("group", userGroup)
-
-			channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model)
-			if err != nil {
-				message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model)
-				// 如果错误,但是渠道不为空,说明是数据库一致性问题
-				if channel != nil {
-					common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id))
-					message = "数据库一致性已被破坏,请联系管理员"
+			if shouldSelectChannel {
+				channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model)
+				if err != nil {
+					message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model)
+					// 如果错误,但是渠道不为空,说明是数据库一致性问题
+					if channel != nil {
+						common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id))
+						message = "数据库一致性已被破坏,请联系管理员"
+					}
+					// 如果错误,而且渠道为空,说明是没有可用渠道
+					abortWithOpenAiMessage(c, http.StatusServiceUnavailable, message)
+					return
+				}
+				if channel == nil {
+					abortWithOpenAiMessage(c, http.StatusServiceUnavailable, fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道(数据库一致性已被破坏)", userGroup, modelRequest.Model))
+					return
+				}
+				c.Set("channel", channel.Type)
+				c.Set("channel_id", channel.Id)
+				c.Set("channel_name", channel.Name)
+				ban := true
+				// parse *int to bool
+				if channel.AutoBan != nil && *channel.AutoBan == 0 {
+					ban = false
+				}
+				c.Set("auto_ban", ban)
+				c.Set("model_mapping", channel.GetModelMapping())
+				c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
+				c.Set("base_url", channel.GetBaseURL())
+				// TODO: api_version统一
+				switch channel.Type {
+				case common.ChannelTypeAzure:
+					c.Set("api_version", channel.Other)
+				case common.ChannelTypeXunfei:
+					c.Set("api_version", channel.Other)
+				//case common.ChannelTypeAIProxyLibrary:
+				//	c.Set("library_id", channel.Other)
+				case common.ChannelTypeGemini:
+					c.Set("api_version", channel.Other)
+				case common.ChannelTypeAli:
+					c.Set("plugin", channel.Other)
 				}
-				// 如果错误,而且渠道为空,说明是没有可用渠道
-				abortWithMessage(c, http.StatusServiceUnavailable, message)
-				return
-			}
-			if channel == nil {
-				abortWithMessage(c, http.StatusServiceUnavailable, fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道(数据库一致性已被破坏)", userGroup, modelRequest.Model))
-				return
 			}
 		}
-		c.Set("channel", channel.Type)
-		c.Set("channel_id", channel.Id)
-		c.Set("channel_name", channel.Name)
-		ban := true
-		// parse *int to bool
-		if channel.AutoBan != nil && *channel.AutoBan == 0 {
-			ban = false
-		}
-		c.Set("auto_ban", ban)
-		c.Set("model_mapping", channel.GetModelMapping())
-		c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
-		c.Set("base_url", channel.GetBaseURL())
-		// TODO: api_version统一
-		switch channel.Type {
-		case common.ChannelTypeAzure:
-			c.Set("api_version", channel.Other)
-		case common.ChannelTypeXunfei:
-			c.Set("api_version", channel.Other)
-		//case common.ChannelTypeAIProxyLibrary:
-		//	c.Set("library_id", channel.Other)
-		case common.ChannelTypeGemini:
-			c.Set("api_version", channel.Other)
-		case common.ChannelTypeAli:
-			c.Set("plugin", channel.Other)
-		}
 		c.Next()
 	}
 }

+ 11 - 1
middleware/utils.go

@@ -5,7 +5,7 @@ import (
 	"one-api/common"
 )
 
-func abortWithMessage(c *gin.Context, statusCode int, message string) {
+func abortWithOpenAiMessage(c *gin.Context, statusCode int, message string) {
 	c.JSON(statusCode, gin.H{
 		"error": gin.H{
 			"message": common.MessageWithRequestId(message, c.GetString(common.RequestIdKey)),
@@ -15,3 +15,13 @@ func abortWithMessage(c *gin.Context, statusCode int, message string) {
 	c.Abort()
 	common.LogError(c.Request.Context(), message)
 }
+
+func abortWithMidjourneyMessage(c *gin.Context, statusCode int, code int, description string) {
+	c.JSON(statusCode, gin.H{
+		"description": description,
+		"type":        "new_api_error",
+		"code":        code,
+	})
+	c.Abort()
+	common.LogError(c.Request.Context(), description)
+}

+ 6 - 1
model/ability.go

@@ -147,7 +147,12 @@ func FixAbility() (int, error) {
 		return 0, err
 	}
 	var channels []Channel
-	err = DB.Where("id NOT IN (?)", abilityChannelIds).Find(&channels).Error
+
+	if len(abilityChannelIds) == 0 {
+		err = DB.Find(&channels).Error
+	} else {
+		err = DB.Where("id NOT IN (?)", abilityChannelIds).Find(&channels).Error
+	}
 	if err != nil {
 		return 0, err
 	}

+ 2 - 0
model/midjourney.go

@@ -19,6 +19,8 @@ type Midjourney struct {
 	FailReason  string `json:"fail_reason"`
 	ChannelId   int    `json:"channel_id"`
 	Quota       int    `json:"quota"`
+	Buttons     string `json:"buttons"`
+	Properties  string `json:"properties"`
 }
 
 // TaskQueryParams 用于包含所有搜索条件的结构体,可以根据需求添加更多字段

+ 41 - 0
relay/constant/relay_mode.go

@@ -17,10 +17,15 @@ const (
 	RelayModeMidjourneySimpleChange
 	RelayModeMidjourneyNotify
 	RelayModeMidjourneyTaskFetch
+	RelayModeMidjourneyTaskImageSeed
 	RelayModeMidjourneyTaskFetchByCondition
 	RelayModeAudioSpeech
 	RelayModeAudioTranscription
 	RelayModeAudioTranslation
+	RelayModeMidjourneyAction
+	RelayModeMidjourneyModal
+	RelayModeMidjourneyShorten
+	RelayModeSwapFace
 )
 
 func Path2RelayMode(path string) int {
@@ -48,3 +53,39 @@ func Path2RelayMode(path string) int {
 	}
 	return relayMode
 }
+
+func Path2RelayModeMidjourney(path string) int {
+	relayMode := RelayModeUnknown
+	if strings.HasPrefix(path, "/mj/submit/action") {
+		// midjourney plus
+		relayMode = RelayModeMidjourneyAction
+	} else if strings.HasPrefix(path, "/mj/submit/modal") {
+		// midjourney plus
+		relayMode = RelayModeMidjourneyModal
+	} else if strings.HasPrefix(path, "/mj/submit/shorten") {
+		// midjourney plus
+		relayMode = RelayModeMidjourneyShorten
+	} else if strings.HasPrefix(path, "/mj/insight-face/swap") {
+		// midjourney plus
+		relayMode = RelayModeSwapFace
+	} else if strings.HasPrefix(path, "/mj/submit/imagine") {
+		relayMode = RelayModeMidjourneyImagine
+	} else if strings.HasPrefix(path, "/mj/submit/blend") {
+		relayMode = RelayModeMidjourneyBlend
+	} else if strings.HasPrefix(path, "/mj/submit/describe") {
+		relayMode = RelayModeMidjourneyDescribe
+	} else if strings.HasPrefix(path, "/mj/notify") {
+		relayMode = RelayModeMidjourneyNotify
+	} else if strings.HasPrefix(path, "/mj/submit/change") {
+		relayMode = RelayModeMidjourneyChange
+	} else if strings.HasPrefix(path, "/mj/submit/simple-change") {
+		relayMode = RelayModeMidjourneyChange
+	} else if strings.HasSuffix(path, "/fetch") {
+		relayMode = RelayModeMidjourneyTaskFetch
+	} else if strings.HasSuffix(path, "/image-seed") {
+		relayMode = RelayModeMidjourneyTaskImageSeed
+	} else if strings.HasSuffix(path, "/list-by-condition") {
+		relayMode = RelayModeMidjourneyTaskFetchByCondition
+	}
+	return relayMode
+}

+ 32 - 39
relay/relay-image.go

@@ -24,16 +24,13 @@ func RelayImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusC
 	channelType := c.GetInt("channel")
 	channelId := c.GetInt("channel_id")
 	userId := c.GetInt("id")
-	consumeQuota := c.GetBool("consume_quota")
 	group := c.GetString("group")
 	startTime := time.Now()
 
 	var imageRequest dto.ImageRequest
-	if consumeQuota {
-		err := common.UnmarshalBodyReusable(c, &imageRequest)
-		if err != nil {
-			return service.OpenAIErrorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
-		}
+	err := common.UnmarshalBodyReusable(c, &imageRequest)
+	if err != nil {
+		return service.OpenAIErrorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
 	}
 
 	if imageRequest.Model == "" {
@@ -136,7 +133,7 @@ func RelayImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusC
 
 	quota := int(ratio*sizeRatio*qualityRatio*1000) * imageRequest.N
 
-	if consumeQuota && userQuota-quota < 0 {
+	if userQuota-quota < 0 {
 		return service.OpenAIErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
 	}
 
@@ -176,47 +173,43 @@ func RelayImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusC
 	var textResponse dto.ImageResponse
 	defer func(ctx context.Context) {
 		useTimeSeconds := time.Now().Unix() - startTime.Unix()
-		if consumeQuota {
-			if resp.StatusCode != http.StatusOK {
-				return
-			}
-			err := model.PostConsumeTokenQuota(tokenId, userQuota, quota, 0, true)
-			if err != nil {
-				common.SysError("error consuming token remain quota: " + err.Error())
-			}
-			err = model.CacheUpdateUserQuota(userId)
-			if err != nil {
-				common.SysError("error update user quota cache: " + err.Error())
-			}
-			if quota != 0 {
-				tokenName := c.GetString("token_name")
-				logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
-				model.RecordConsumeLog(ctx, userId, channelId, 0, 0, imageRequest.Model, tokenName, quota, logContent, tokenId, userQuota, int(useTimeSeconds), false)
-				model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
-				channelId := c.GetInt("channel_id")
-				model.UpdateChannelUsedQuota(channelId, quota)
-			}
+		if resp.StatusCode != http.StatusOK {
+			return
 		}
-	}(c.Request.Context())
-
-	if consumeQuota {
-		responseBody, err := io.ReadAll(resp.Body)
-
+		err := model.PostConsumeTokenQuota(tokenId, userQuota, quota, 0, true)
 		if err != nil {
-			return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
+			common.SysError("error consuming token remain quota: " + err.Error())
 		}
-		err = resp.Body.Close()
+		err = model.CacheUpdateUserQuota(userId)
 		if err != nil {
-			return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
+			common.SysError("error update user quota cache: " + err.Error())
 		}
-		err = json.Unmarshal(responseBody, &textResponse)
-		if err != nil {
-			return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
+		if quota != 0 {
+			tokenName := c.GetString("token_name")
+			logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
+			model.RecordConsumeLog(ctx, userId, channelId, 0, 0, imageRequest.Model, tokenName, quota, logContent, tokenId, userQuota, int(useTimeSeconds), false)
+			model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
+			channelId := c.GetInt("channel_id")
+			model.UpdateChannelUsedQuota(channelId, quota)
 		}
+	}(c.Request.Context())
+
+	responseBody, err := io.ReadAll(resp.Body)
 
-		resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
+	if err != nil {
+		return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
+	}
+	err = resp.Body.Close()
+	if err != nil {
+		return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
+	}
+	err = json.Unmarshal(responseBody, &textResponse)
+	if err != nil {
+		return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
 	}
 
+	resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
+
 	for k, v := range resp.Header {
 		c.Writer.Header().Set(k, v[0])
 	}

+ 252 - 280
relay/relay-mj.go

@@ -9,6 +9,7 @@ import (
 	"log"
 	"net/http"
 	"one-api/common"
+	"one-api/constant"
 	"one-api/dto"
 	"one-api/model"
 	relayconstant "one-api/relay/constant"
@@ -20,53 +21,6 @@ import (
 	"github.com/gin-gonic/gin"
 )
 
-type Midjourney struct {
-	MjId        string `json:"id"`
-	Action      string `json:"action"`
-	Prompt      string `json:"prompt"`
-	PromptEn    string `json:"promptEn"`
-	Description string `json:"description"`
-	State       string `json:"state"`
-	SubmitTime  int64  `json:"submitTime"`
-	StartTime   int64  `json:"startTime"`
-	FinishTime  int64  `json:"finishTime"`
-	ImageUrl    string `json:"imageUrl"`
-	Status      string `json:"status"`
-	Progress    string `json:"progress"`
-	FailReason  string `json:"failReason"`
-}
-
-type MidjourneyStatus struct {
-	Status int `json:"status"`
-}
-type MidjourneyWithoutStatus struct {
-	Id          int    `json:"id"`
-	Code        int    `json:"code"`
-	UserId      int    `json:"user_id" gorm:"index"`
-	Action      string `json:"action"`
-	MjId        string `json:"mj_id" gorm:"index"`
-	Prompt      string `json:"prompt"`
-	PromptEn    string `json:"prompt_en"`
-	Description string `json:"description"`
-	State       string `json:"state"`
-	SubmitTime  int64  `json:"submit_time"`
-	StartTime   int64  `json:"start_time"`
-	FinishTime  int64  `json:"finish_time"`
-	ImageUrl    string `json:"image_url"`
-	Progress    string `json:"progress"`
-	FailReason  string `json:"fail_reason"`
-	ChannelId   int    `json:"channel_id"`
-}
-
-var DefaultModelPrice = map[string]float64{
-	"mj_imagine":   0.1,
-	"mj_variation": 0.1,
-	"mj_reroll":    0.1,
-	"mj_blend":     0.1,
-	"mj_describe":  0.05,
-	"mj_upscale":   0.05,
-}
-
 func RelayMidjourneyImage(c *gin.Context) {
 	taskId := c.Param("id")
 	midjourneyTask := model.GetByOnlyMJId(taskId)
@@ -108,7 +62,7 @@ func RelayMidjourneyImage(c *gin.Context) {
 }
 
 func RelayMidjourneyNotify(c *gin.Context) *dto.MidjourneyResponse {
-	var midjRequest Midjourney
+	var midjRequest dto.MidjourneyDto
 	err := common.UnmarshalBodyReusable(c, &midjRequest)
 	if err != nil {
 		return &dto.MidjourneyResponse{
@@ -147,7 +101,7 @@ func RelayMidjourneyNotify(c *gin.Context) *dto.MidjourneyResponse {
 	return nil
 }
 
-func getMidjourneyTaskModel(c *gin.Context, originTask *model.Midjourney) (midjourneyTask Midjourney) {
+func coverMidjourneyTaskDto(c *gin.Context, originTask *model.Midjourney) (midjourneyTask dto.MidjourneyDto) {
 	midjourneyTask.MjId = originTask.MjId
 	midjourneyTask.Progress = originTask.Progress
 	midjourneyTask.PromptEn = originTask.PromptEn
@@ -167,9 +121,182 @@ func getMidjourneyTaskModel(c *gin.Context, originTask *model.Midjourney) (midjo
 	midjourneyTask.Action = originTask.Action
 	midjourneyTask.Description = originTask.Description
 	midjourneyTask.Prompt = originTask.Prompt
+	if originTask.Buttons != "" {
+		var buttons []dto.ActionButton
+		err := json.Unmarshal([]byte(originTask.Buttons), &buttons)
+		if err == nil {
+			midjourneyTask.Buttons = buttons
+		}
+	}
+	if originTask.Properties != "" {
+		var properties dto.Properties
+		err := json.Unmarshal([]byte(originTask.Properties), &properties)
+		if err == nil {
+			midjourneyTask.Properties = &properties
+		}
+	}
 	return
 }
 
+func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse {
+	startTime := time.Now().UnixNano() / int64(time.Millisecond)
+	tokenId := c.GetInt("token_id")
+	userId := c.GetInt("id")
+	group := c.GetString("group")
+	channelId := c.GetInt("channel_id")
+	var swapFaceRequest dto.SwapFaceRequest
+	err := common.UnmarshalBodyReusable(c, &swapFaceRequest)
+	if err != nil {
+		return service.MidjourneyErrorWrapper(constant.MjRequestError, "bind_request_body_failed")
+	}
+	if swapFaceRequest.SourceBase64 == "" || swapFaceRequest.TargetBase64 == "" {
+		return service.MidjourneyErrorWrapper(constant.MjRequestError, "sour_base64_and_target_base64_is_required")
+	}
+	modelName := service.CoverActionToModelName(constant.MjActionSwapFace)
+	modelPrice := common.GetModelPrice(modelName, true)
+	// 如果没有配置价格,则使用默认价格
+	if modelPrice == -1 {
+		defaultPrice, ok := common.DefaultModelPrice[modelName]
+		if !ok {
+			modelPrice = 0.1
+		} else {
+			modelPrice = defaultPrice
+		}
+	}
+	groupRatio := common.GetGroupRatio(group)
+	ratio := modelPrice * groupRatio
+	userQuota, err := model.CacheGetUserQuota(userId)
+	if err != nil {
+		return &dto.MidjourneyResponse{
+			Code:        4,
+			Description: err.Error(),
+		}
+	}
+	quota := int(ratio * common.QuotaPerUnit)
+
+	if userQuota-quota < 0 {
+		return &dto.MidjourneyResponse{
+			Code:        4,
+			Description: "quota_not_enough",
+		}
+	}
+	requestURL := c.Request.URL.String()
+	baseURL := c.GetString("base_url")
+	fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
+	mjResp, _, err := service.DoMidjourneyHttpRequest(c, time.Second*60, fullRequestURL)
+	if err != nil {
+		return &mjResp.Response
+	}
+	defer func(ctx context.Context) {
+		if mjResp.StatusCode == 200 && mjResp.Response.Code == 1 {
+			err := model.PostConsumeTokenQuota(tokenId, userQuota, quota, 0, true)
+			if err != nil {
+				common.SysError("error consuming token remain quota: " + err.Error())
+			}
+			err = model.CacheUpdateUserQuota(userId)
+			if err != nil {
+				common.SysError("error update user quota cache: " + err.Error())
+			}
+			if quota != 0 {
+				tokenName := c.GetString("token_name")
+				logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s", modelPrice, groupRatio, constant.MjActionSwapFace)
+				model.RecordConsumeLog(ctx, userId, channelId, 0, 0, modelName, tokenName, quota, logContent, tokenId, userQuota, 0, false)
+				model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
+				channelId := c.GetInt("channel_id")
+				model.UpdateChannelUsedQuota(channelId, quota)
+			}
+		}
+	}(c.Request.Context())
+	midjResponse := &mjResp.Response
+	midjourneyTask := &model.Midjourney{
+		UserId:      userId,
+		Code:        midjResponse.Code,
+		Action:      constant.MjActionSwapFace,
+		MjId:        midjResponse.Result,
+		Prompt:      "InsightFace",
+		PromptEn:    "",
+		Description: midjResponse.Description,
+		State:       "",
+		SubmitTime:  startTime,
+		StartTime:   time.Now().UnixNano() / int64(time.Millisecond),
+		FinishTime:  0,
+		ImageUrl:    "",
+		Status:      "",
+		Progress:    "0%",
+		FailReason:  "",
+		ChannelId:   c.GetInt("channel_id"),
+		Quota:       quota,
+	}
+	err = midjourneyTask.Insert()
+	if err != nil {
+		return service.MidjourneyErrorWrapper(constant.MjRequestError, "insert_midjourney_task_failed")
+	}
+	c.Writer.WriteHeader(mjResp.StatusCode)
+	respBody, err := json.Marshal(midjResponse)
+	if err != nil {
+		return service.MidjourneyErrorWrapper(constant.MjRequestError, "unmarshal_response_body_failed")
+	}
+	_, err = io.Copy(c.Writer, bytes.NewBuffer(respBody))
+	if err != nil {
+		return service.MidjourneyErrorWrapper(constant.MjRequestError, "copy_response_body_failed")
+	}
+	return nil
+}
+
+func RelayMidjourneyTaskImageSeed(c *gin.Context) *dto.MidjourneyResponse {
+	taskId := c.Param("id")
+	userId := c.GetInt("id")
+	originTask := model.GetByMJId(userId, taskId)
+	if originTask == nil {
+		return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_no_found")
+	}
+	channel, err := model.GetChannelById(originTask.ChannelId, true)
+	if err != nil {
+		return service.MidjourneyErrorWrapper(constant.MjRequestError, "get_channel_info_failed")
+	}
+	if channel.Status != common.ChannelStatusEnabled {
+		return service.MidjourneyErrorWrapper(constant.MjRequestError, "该任务所属渠道已被禁用")
+	}
+	c.Set("channel_id", originTask.ChannelId)
+	c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
+
+	requestURL := c.Request.URL.String()
+	fullRequestURL := fmt.Sprintf("%s%s", channel.GetBaseURL(), requestURL)
+	midjResponseWithStatus, _, err := service.DoMidjourneyHttpRequest(c, time.Second*30, fullRequestURL)
+	if err != nil {
+		return &midjResponseWithStatus.Response
+	}
+	//defer func(ctx context.Context) {
+	//	err := model.PostConsumeTokenQuota(tokenId, userQuota, quota, 0, true)
+	//	if err != nil {
+	//		common.SysError("error consuming token remain quota: " + err.Error())
+	//	}
+	//	err = model.CacheUpdateUserQuota(userId)
+	//	if err != nil {
+	//		common.SysError("error update user quota cache: " + err.Error())
+	//	}
+	//	if quota != 0 {
+	//		tokenName := c.GetString("token_name")
+	//		logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s", modelPrice, groupRatio, midjRequest.Action)
+	//		model.RecordConsumeLog(ctx, userId, channelId, 0, 0, modelName, tokenName, quota, logContent, tokenId, userQuota, 0, false)
+	//		model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
+	//		channelId := c.GetInt("channel_id")
+	//		model.UpdateChannelUsedQuota(channelId, quota)
+	//	}
+	//}(c.Request.Context())
+	midjResponse := &midjResponseWithStatus.Response
+	c.Writer.WriteHeader(midjResponseWithStatus.StatusCode)
+	respBody, err := json.Marshal(midjResponse)
+	if err != nil {
+		return service.MidjourneyErrorWrapper(constant.MjRequestError, "unmarshal_response_body_failed")
+	}
+	_, err = io.Copy(c.Writer, bytes.NewBuffer(respBody))
+	if err != nil {
+		return service.MidjourneyErrorWrapper(constant.MjRequestError, "copy_response_body_failed")
+	}
+	return nil
+}
+
 func RelayMidjourneyTask(c *gin.Context, relayMode int) *dto.MidjourneyResponse {
 	userId := c.GetInt("id")
 	var err error
@@ -184,7 +311,7 @@ func RelayMidjourneyTask(c *gin.Context, relayMode int) *dto.MidjourneyResponse
 				Description: "task_no_found",
 			}
 		}
-		midjourneyTask := getMidjourneyTaskModel(c, originTask)
+		midjourneyTask := coverMidjourneyTaskDto(c, originTask)
 		respBody, err = json.Marshal(midjourneyTask)
 		if err != nil {
 			return &dto.MidjourneyResponse{
@@ -203,16 +330,16 @@ func RelayMidjourneyTask(c *gin.Context, relayMode int) *dto.MidjourneyResponse
 				Description: "do_request_failed",
 			}
 		}
-		var tasks []Midjourney
+		var tasks []dto.MidjourneyDto
 		if len(condition.IDs) != 0 {
 			originTasks := model.GetByMJIds(userId, condition.IDs)
 			for _, originTask := range originTasks {
-				midjourneyTask := getMidjourneyTaskModel(c, originTask)
+				midjourneyTask := coverMidjourneyTaskDto(c, originTask)
 				tasks = append(tasks, midjourneyTask)
 			}
 		}
 		if tasks == nil {
-			tasks = make([]Midjourney, 0)
+			tasks = make([]dto.MidjourneyDto, 0)
 		}
 		respBody, err = json.Marshal(tasks)
 		if err != nil {
@@ -235,170 +362,115 @@ func RelayMidjourneyTask(c *gin.Context, relayMode int) *dto.MidjourneyResponse
 	return nil
 }
 
-const (
-	// type 1 根据 mode 价格不同
-	MJSubmitActionImagine   = "IMAGINE"
-	MJSubmitActionVariation = "VARIATION" //变换
-	MJSubmitActionBlend     = "BLEND"     //混图
-
-	MJSubmitActionReroll = "REROLL" //重新生成
-	// type 2 固定价格
-	MJSubmitActionDescribe = "DESCRIBE"
-	MJSubmitActionUpscale  = "UPSCALE" // 放大
-)
-
 func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyResponse {
-	imageModel := "midjourney"
 
 	tokenId := c.GetInt("token_id")
-	channelType := c.GetInt("channel")
+	//channelType := c.GetInt("channel")
 	userId := c.GetInt("id")
-	consumeQuota := c.GetBool("consume_quota")
 	group := c.GetString("group")
 	channelId := c.GetInt("channel_id")
+	consumeQuota := true
 	var midjRequest dto.MidjourneyRequest
-	if consumeQuota {
-		err := common.UnmarshalBodyReusable(c, &midjRequest)
-		if err != nil {
-			return &dto.MidjourneyResponse{
-				Code:        4,
-				Description: "bind_request_body_failed",
-			}
+	err := common.UnmarshalBodyReusable(c, &midjRequest)
+	if err != nil {
+		return service.MidjourneyErrorWrapper(constant.MjRequestError, "bind_request_body_failed")
+	}
+
+	if relayMode == relayconstant.RelayModeMidjourneyAction { // midjourney plus,需要从customId中获取任务信息
+		mjErr := service.CoverPlusActionToNormalAction(&midjRequest)
+		if mjErr != nil {
+			return mjErr
 		}
+		relayMode = relayconstant.RelayModeMidjourneyChange
 	}
 
 	if relayMode == relayconstant.RelayModeMidjourneyImagine { //绘画任务,此类任务可重复
 		if midjRequest.Prompt == "" {
-			return &dto.MidjourneyResponse{
-				Code:        4,
-				Description: "prompt_is_required",
-			}
+			return service.MidjourneyErrorWrapper(constant.MjRequestError, "prompt_is_required")
 		}
-		midjRequest.Action = "IMAGINE"
+		midjRequest.Action = constant.MjActionImagine
 	} else if relayMode == relayconstant.RelayModeMidjourneyDescribe { //按图生文任务,此类任务可重复
-		midjRequest.Action = "DESCRIBE"
+		midjRequest.Action = constant.MjActionDescribe
+	} else if relayMode == relayconstant.RelayModeMidjourneyShorten { //缩短任务,此类任务可重复,plus only
+		midjRequest.Action = constant.MjActionShorten
 	} else if relayMode == relayconstant.RelayModeMidjourneyBlend { //绘画任务,此类任务可重复
-		midjRequest.Action = "BLEND"
+		midjRequest.Action = constant.MjActionBlend
 	} else if midjRequest.TaskId != "" { //放大、变换任务,此类任务,如果重复且已有结果,远端api会直接返回最终结果
 		mjId := ""
 		if relayMode == relayconstant.RelayModeMidjourneyChange {
 			if midjRequest.TaskId == "" {
-				return &dto.MidjourneyResponse{
-					Code:        4,
-					Description: "taskId_is_required",
-				}
+				return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_id_is_required")
 			} else if midjRequest.Action == "" {
-				return &dto.MidjourneyResponse{
-					Code:        4,
-					Description: "action_is_required",
-				}
+				return service.MidjourneyErrorWrapper(constant.MjRequestError, "action_is_required")
 			} else if midjRequest.Index == 0 {
-				return &dto.MidjourneyResponse{
-					Code:        4,
-					Description: "index_can_only_be_1_2_3_4",
-				}
+				return service.MidjourneyErrorWrapper(constant.MjRequestError, "index_is_required")
 			}
 			//action = midjRequest.Action
 			mjId = midjRequest.TaskId
 		} else if relayMode == relayconstant.RelayModeMidjourneySimpleChange {
 			if midjRequest.Content == "" {
-				return &dto.MidjourneyResponse{
-					Code:        4,
-					Description: "content_is_required",
-				}
+				return service.MidjourneyErrorWrapper(constant.MjRequestError, "content_is_required")
 			}
-			params := convertSimpleChangeParams(midjRequest.Content)
+			params := service.ConvertSimpleChangeParams(midjRequest.Content)
 			if params == nil {
-				return &dto.MidjourneyResponse{
-					Code:        4,
-					Description: "content_parse_failed",
-				}
+				return service.MidjourneyErrorWrapper(constant.MjRequestError, "content_parse_failed")
 			}
-			mjId = params.ID
+			mjId = params.TaskId
 			midjRequest.Action = params.Action
+		} else if relayMode == relayconstant.RelayModeMidjourneyModal {
+			//if midjRequest.MaskBase64 == "" {
+			//	return service.MidjourneyErrorWrapper(constant.MjRequestError, "mask_base64_is_required")
+			//}
+			mjId = midjRequest.TaskId
+			midjRequest.Action = constant.MjActionModal
 		}
 
 		originTask := model.GetByMJId(userId, mjId)
 		if originTask == nil {
-			return &dto.MidjourneyResponse{
-				Code:        4,
-				Description: "task_no_found",
-			}
-		} else if originTask.Action == "UPSCALE" {
-			//return errorWrapper(errors.New("upscale task can not be change"), "request_params_error", http.StatusBadRequest).
-			return &dto.MidjourneyResponse{
-				Code:        4,
-				Description: "upscale_task_can_not_be_change",
-			}
-		} else if originTask.Status != "SUCCESS" {
-			return &dto.MidjourneyResponse{
-				Code:        4,
-				Description: "task_status_is_not_success",
-			}
+			return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_not_found")
+		} else if originTask.Status != "SUCCESS" && relayMode != relayconstant.RelayModeMidjourneyModal {
+			return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_status_not_success")
 		} else { //原任务的Status=SUCCESS,则可以做放大UPSCALE、变换VARIATION等动作,此时必须使用原来的请求地址才能正确处理
-			channel, err := model.GetChannelById(originTask.ChannelId, false)
+			channel, err := model.GetChannelById(originTask.ChannelId, true)
 			if err != nil {
-				return &dto.MidjourneyResponse{
-					Code:        4,
-					Description: "channel_not_found",
-				}
+				return service.MidjourneyErrorWrapper(constant.MjRequestError, "get_channel_info_failed")
+			}
+			if channel.Status != common.ChannelStatusEnabled {
+				return service.MidjourneyErrorWrapper(constant.MjRequestError, "该任务所属渠道已被禁用")
 			}
 			c.Set("base_url", channel.GetBaseURL())
 			c.Set("channel_id", originTask.ChannelId)
-			log.Printf("检测到此操作为放大、变换,获取原channel信息: %s,%s", strconv.Itoa(originTask.ChannelId), channel.GetBaseURL())
+			c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
+			log.Printf("检测到此操作为放大、变换、重绘,获取原channel信息: %s,%s", strconv.Itoa(originTask.ChannelId), channel.GetBaseURL())
 		}
 		midjRequest.Prompt = originTask.Prompt
+
+		//if channelType == common.ChannelTypeMidjourneyPlus {
+		//	// plus
+		//} else {
+		//	// 普通版渠道
+		//
+		//}
 	}
 
-	// map model name
-	modelMapping := c.GetString("model_mapping")
-	isModelMapped := false
-	if modelMapping != "" {
-		modelMap := make(map[string]string)
-		err := json.Unmarshal([]byte(modelMapping), &modelMap)
-		if err != nil {
-			//return errorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
-			return &dto.MidjourneyResponse{
-				Code:        4,
-				Description: "unmarshal_model_mapping_failed",
-			}
-		}
-		if modelMap[imageModel] != "" {
-			imageModel = modelMap[imageModel]
-			isModelMapped = true
-		}
+	if midjRequest.Action == constant.MjActionInPaint || midjRequest.Action == constant.MjActionCustomZoom {
+		consumeQuota = false
 	}
 
-	baseURL := common.ChannelBaseURLs[channelType]
+	//baseURL := common.ChannelBaseURLs[channelType]
 	requestURL := c.Request.URL.String()
 
-	if c.GetString("base_url") != "" {
-		baseURL = c.GetString("base_url")
-	}
+	baseURL := c.GetString("base_url")
 
 	//midjRequest.NotifyHook = "http://127.0.0.1:3000/mj/notify"
 
 	fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
-	log.Printf("fullRequestURL: %s", fullRequestURL)
 
-	var requestBody io.Reader
-	if isModelMapped {
-		jsonStr, err := json.Marshal(midjRequest)
-		if err != nil {
-			return &dto.MidjourneyResponse{
-				Code:        4,
-				Description: "marshal_text_request_failed",
-			}
-		}
-		requestBody = bytes.NewBuffer(jsonStr)
-	} else {
-		requestBody = c.Request.Body
-	}
-	mjAction := "mj_" + strings.ToLower(midjRequest.Action)
-	modelPrice := common.GetModelPrice(mjAction, true)
+	modelName := service.CoverActionToModelName(midjRequest.Action)
+	modelPrice := common.GetModelPrice(modelName, true)
 	// 如果没有配置价格,则使用默认价格
 	if modelPrice == -1 {
-		defaultPrice, ok := DefaultModelPrice[mjAction]
+		defaultPrice, ok := common.DefaultModelPrice[modelName]
 		if !ok {
 			modelPrice = 0.1
 		} else {
@@ -423,53 +495,14 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
 		}
 	}
 
-	req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
-	if err != nil {
-		return &dto.MidjourneyResponse{
-			Code:        4,
-			Description: "create_request_failed",
-		}
-	}
-	//req.Header.Set("ApiKey", c.Request.Header.Get("ApiKey"))
-
-	req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
-	req.Header.Set("Accept", c.Request.Header.Get("Accept"))
-	//mjToken := ""
-	//if c.Request.Header.Get("ApiKey") != "" {
-	//	mjToken = strings.Split(c.Request.Header.Get("ApiKey"), " ")[1]
-	//}
-	//req.Header.Set("ApiKey", "Bearer midjourney-proxy")
-	req.Header.Set("mj-api-secret", strings.Split(c.Request.Header.Get("Authorization"), " ")[1])
-	// print request header
-	log.Printf("request header: %s", req.Header)
-	log.Printf("request body: %s", midjRequest.Prompt)
-
-	resp, err := service.GetHttpClient().Do(req)
+	midjResponseWithStatus, responseBody, err := service.DoMidjourneyHttpRequest(c, time.Second*60, fullRequestURL)
 	if err != nil {
-		return &dto.MidjourneyResponse{
-			Code:        4,
-			Description: "do_request_failed",
-		}
+		return &midjResponseWithStatus.Response
 	}
-
-	err = req.Body.Close()
-	if err != nil {
-		return &dto.MidjourneyResponse{
-			Code:        4,
-			Description: "close_request_body_failed",
-		}
-	}
-	err = c.Request.Body.Close()
-	if err != nil {
-		return &dto.MidjourneyResponse{
-			Code:        4,
-			Description: "close_request_body_failed",
-		}
-	}
-	var midjResponse dto.MidjourneyResponse
+	midjResponse := &midjResponseWithStatus.Response
 
 	defer func(ctx context.Context) {
-		if consumeQuota {
+		if consumeQuota && midjResponseWithStatus.StatusCode == 200 {
 			err := model.PostConsumeTokenQuota(tokenId, userQuota, quota, 0, true)
 			if err != nil {
 				common.SysError("error consuming token remain quota: " + err.Error())
@@ -481,7 +514,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
 			if quota != 0 {
 				tokenName := c.GetString("token_name")
 				logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s", modelPrice, groupRatio, midjRequest.Action)
-				model.RecordConsumeLog(ctx, userId, channelId, 0, 0, imageModel, tokenName, quota, logContent, tokenId, userQuota, 0, false)
+				model.RecordConsumeLog(ctx, userId, channelId, 0, 0, modelName, tokenName, quota, logContent, tokenId, userQuota, 0, false)
 				model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
 				channelId := c.GetInt("channel_id")
 				model.UpdateChannelUsedQuota(channelId, quota)
@@ -489,41 +522,6 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
 		}
 	}(c.Request.Context())
 
-	//if consumeQuota {
-	//
-	//}
-	responseBody, err := io.ReadAll(resp.Body)
-
-	if err != nil {
-		return &dto.MidjourneyResponse{
-			Code:        4,
-			Description: "read_response_body_failed",
-		}
-	}
-	err = resp.Body.Close()
-	if err != nil {
-		return &dto.MidjourneyResponse{
-			Code:        4,
-			Description: "close_response_body_failed",
-		}
-	}
-
-	err = json.Unmarshal(responseBody, &midjResponse)
-	log.Printf("responseBody: %s", string(responseBody))
-	log.Printf("midjResponse: %v", midjResponse)
-	if resp.StatusCode != 200 {
-		return &dto.MidjourneyResponse{
-			Code:        4,
-			Description: "fail_to_fetch_midjourney status_code: " + strconv.Itoa(resp.StatusCode),
-		}
-	}
-	if err != nil {
-		return &dto.MidjourneyResponse{
-			Code:        4,
-			Description: "unmarshal_response_body_failed",
-		}
-	}
-
 	// 文档:https://github.com/novicezk/midjourney-proxy/blob/main/docs/api.md
 	//1-提交成功
 	// 21-任务已存在(处理中或者有结果了) {"code":21,"description":"任务已存在","result":"0741798445574458","properties":{"status":"SUCCESS","imageUrl":"https://xxxx"}}
@@ -575,8 +573,10 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
 			}
 		}
 		//修改返回值
-		newBody := strings.Replace(string(responseBody), `"code":21`, `"code":1`, -1)
-		responseBody = []byte(newBody)
+		if midjRequest.Action != constant.MjActionInPaint && midjRequest.Action != constant.MjActionCustomZoom {
+			newBody := strings.Replace(string(responseBody), `"code":21`, `"code":1`, -1)
+			responseBody = []byte(newBody)
+		}
 	}
 
 	err = midjourneyTask.Insert()
@@ -593,21 +593,22 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
 		responseBody = []byte(newBody)
 	}
 
-	resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
+	//resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
+	bodyReader := io.NopCloser(bytes.NewBuffer(responseBody))
 
-	for k, v := range resp.Header {
-		c.Writer.Header().Set(k, v[0])
-	}
-	c.Writer.WriteHeader(resp.StatusCode)
+	//for k, v := range resp.Header {
+	//	c.Writer.Header().Set(k, v[0])
+	//}
+	c.Writer.WriteHeader(midjResponseWithStatus.StatusCode)
 
-	_, err = io.Copy(c.Writer, resp.Body)
+	_, err = io.Copy(c.Writer, bodyReader)
 	if err != nil {
 		return &dto.MidjourneyResponse{
 			Code:        4,
 			Description: "copy_response_body_failed",
 		}
 	}
-	err = resp.Body.Close()
+	err = bodyReader.Close()
 	if err != nil {
 		return &dto.MidjourneyResponse{
 			Code:        4,
@@ -622,32 +623,3 @@ type taskChangeParams struct {
 	Action string
 	Index  int
 }
-
-func convertSimpleChangeParams(content string) *taskChangeParams {
-	split := strings.Split(content, " ")
-	if len(split) != 2 {
-		return nil
-	}
-
-	action := strings.ToLower(split[1])
-	changeParams := &taskChangeParams{}
-	changeParams.ID = split[0]
-
-	if action[0] == 'u' {
-		changeParams.Action = "UPSCALE"
-	} else if action[0] == 'v' {
-		changeParams.Action = "VARIATION"
-	} else if action == "r" {
-		changeParams.Action = "REROLL"
-		return changeParams
-	} else {
-		return nil
-	}
-
-	index, err := strconv.Atoi(action[1:2])
-	if err != nil || index < 1 || index > 4 {
-		return nil
-	}
-	changeParams.Index = index
-	return changeParams
-}

+ 5 - 0
router/relay-router.go

@@ -47,6 +47,9 @@ func SetRelayRouter(router *gin.Engine) {
 	relayMjRouter.GET("/image/:id", relay.RelayMidjourneyImage)
 	relayMjRouter.Use(middleware.TokenAuth(), middleware.Distribute())
 	{
+		relayMjRouter.POST("/submit/action", controller.RelayMidjourney)
+		relayMjRouter.POST("/submit/shorten", controller.RelayMidjourney)
+		relayMjRouter.POST("/submit/modal", controller.RelayMidjourney)
 		relayMjRouter.POST("/submit/imagine", controller.RelayMidjourney)
 		relayMjRouter.POST("/submit/change", controller.RelayMidjourney)
 		relayMjRouter.POST("/submit/simple-change", controller.RelayMidjourney)
@@ -54,7 +57,9 @@ func SetRelayRouter(router *gin.Engine) {
 		relayMjRouter.POST("/submit/blend", controller.RelayMidjourney)
 		relayMjRouter.POST("/notify", controller.RelayMidjourney)
 		relayMjRouter.GET("/task/:id/fetch", controller.RelayMidjourney)
+		relayMjRouter.GET("/task/:id/image-seed", controller.RelayMidjourney)
 		relayMjRouter.POST("/task/list-by-condition", controller.RelayMidjourney)
+		relayMjRouter.POST("/insight-face/swap", controller.RelayMidjourney)
 	}
 	//relayMjRouter.Use()
 }

+ 14 - 0
service/error.go

@@ -11,6 +11,20 @@ import (
 	"strings"
 )
 
+func MidjourneyErrorWrapper(code int, desc string) *dto.MidjourneyResponse {
+	return &dto.MidjourneyResponse{
+		Code:        code,
+		Description: desc,
+	}
+}
+
+func MidjourneyErrorWithStatusCodeWrapper(code int, desc string, statusCode int) *dto.MidjourneyResponseWithStatusCode {
+	return &dto.MidjourneyResponseWithStatusCode{
+		StatusCode: statusCode,
+		Response:   *MidjourneyErrorWrapper(code, desc),
+	}
+}
+
 // OpenAIErrorWrapper wraps an error into an OpenAIErrorWithStatusCode
 func OpenAIErrorWrapper(err error, code string, statusCode int) *dto.OpenAIErrorWithStatusCode {
 	text := err.Error()

+ 224 - 0
service/midjourney.go

@@ -0,0 +1,224 @@
+package service
+
+import (
+	"context"
+	"encoding/json"
+	"github.com/gin-gonic/gin"
+	"io"
+	"log"
+	"net/http"
+	"one-api/constant"
+	"one-api/dto"
+	relayconstant "one-api/relay/constant"
+	"strconv"
+	"strings"
+	"time"
+)
+
+func CoverActionToModelName(mjAction string) string {
+	modelName := "mj_" + strings.ToLower(mjAction)
+	if mjAction == constant.MjActionSwapFace {
+		modelName = "swap_face"
+	}
+	return modelName
+}
+
+func GetMjRequestModel(relayMode int, midjRequest *dto.MidjourneyRequest) (string, *dto.MidjourneyResponse, bool) {
+	action := ""
+	if relayMode == relayconstant.RelayModeMidjourneyAction {
+		// plus request
+		err := CoverPlusActionToNormalAction(midjRequest)
+		if err != nil {
+			return "", err, false
+		}
+		action = midjRequest.Action
+	} else {
+		switch relayMode {
+		case relayconstant.RelayModeMidjourneyImagine:
+			action = constant.MjActionImagine
+		case relayconstant.RelayModeMidjourneyDescribe:
+			action = constant.MjActionDescribe
+		case relayconstant.RelayModeMidjourneyBlend:
+			action = constant.MjActionBlend
+		case relayconstant.RelayModeMidjourneyShorten:
+			action = constant.MjActionShorten
+		case relayconstant.RelayModeMidjourneyChange:
+			action = midjRequest.Action
+		case relayconstant.RelayModeMidjourneyModal:
+			action = constant.MjActionModal
+		case relayconstant.RelayModeSwapFace:
+			action = constant.MjActionSwapFace
+		case relayconstant.RelayModeMidjourneySimpleChange:
+			params := ConvertSimpleChangeParams(midjRequest.Content)
+			if params == nil {
+				return "", MidjourneyErrorWrapper(constant.MjRequestError, "invalid_request"), false
+			}
+			action = params.Action
+		case relayconstant.RelayModeMidjourneyTaskFetch, relayconstant.RelayModeMidjourneyTaskFetchByCondition, relayconstant.RelayModeMidjourneyNotify:
+			return "", nil, true
+		default:
+			return "", MidjourneyErrorWrapper(constant.MjRequestError, "unknown_relay_action"), false
+		}
+	}
+	modelName := CoverActionToModelName(action)
+	return modelName, nil, true
+}
+
+func CoverPlusActionToNormalAction(midjRequest *dto.MidjourneyRequest) *dto.MidjourneyResponse {
+	// "customId": "MJ::JOB::upsample::2::3dbbd469-36af-4a0f-8f02-df6c579e7011"
+	customId := midjRequest.CustomId
+	if customId == "" {
+		return MidjourneyErrorWrapper(constant.MjRequestError, "custom_id_is_required")
+	}
+	splits := strings.Split(customId, "::")
+	var action string
+	if splits[1] == "JOB" {
+		action = splits[2]
+	} else {
+		action = splits[1]
+	}
+
+	if action == "" {
+		return MidjourneyErrorWrapper(constant.MjRequestError, "unknown_action")
+	}
+	if strings.Contains(action, "upsample") {
+		index, err := strconv.Atoi(splits[3])
+		if err != nil {
+			return MidjourneyErrorWrapper(constant.MjRequestError, "index_parse_failed")
+		}
+		midjRequest.Index = index
+		midjRequest.Action = constant.MjActionUpscale
+	} else if strings.Contains(action, "variation") {
+		midjRequest.Index = 1
+		if action == "variation" {
+			index, err := strconv.Atoi(splits[3])
+			if err != nil {
+				return MidjourneyErrorWrapper(constant.MjRequestError, "index_parse_failed")
+			}
+			midjRequest.Index = index
+			midjRequest.Action = constant.MjActionVariation
+		} else if action == "low_variation" {
+			midjRequest.Action = constant.MjActionLowVariation
+		} else if action == "high_variation" {
+			midjRequest.Action = constant.MjActionHighVariation
+		}
+	} else if strings.Contains(action, "pan") {
+		midjRequest.Action = constant.MjActionPan
+		midjRequest.Index = 1
+	} else if strings.Contains(action, "reroll") {
+		midjRequest.Action = constant.MjActionReRoll
+		midjRequest.Index = 1
+	} else if action == "Outpaint" {
+		midjRequest.Action = constant.MjActionZoom
+		midjRequest.Index = 1
+	} else if action == "CustomZoom" {
+		midjRequest.Action = constant.MjActionCustomZoom
+		midjRequest.Index = 1
+	} else if action == "Inpaint" {
+		midjRequest.Action = constant.MjActionInPaint
+		midjRequest.Index = 1
+	} else {
+		return MidjourneyErrorWrapper(constant.MjRequestError, "unknown_action:"+customId)
+	}
+	return nil
+}
+
+func ConvertSimpleChangeParams(content string) *dto.MidjourneyRequest {
+	split := strings.Split(content, " ")
+	if len(split) != 2 {
+		return nil
+	}
+
+	action := strings.ToLower(split[1])
+	changeParams := &dto.MidjourneyRequest{}
+	changeParams.TaskId = split[0]
+
+	if action[0] == 'u' {
+		changeParams.Action = "UPSCALE"
+	} else if action[0] == 'v' {
+		changeParams.Action = "VARIATION"
+	} else if action == "r" {
+		changeParams.Action = "REROLL"
+		return changeParams
+	} else {
+		return nil
+	}
+
+	index, err := strconv.Atoi(action[1:2])
+	if err != nil || index < 1 || index > 4 {
+		return nil
+	}
+	changeParams.Index = index
+	return changeParams
+}
+
+func DoMidjourneyHttpRequest(c *gin.Context, timeout time.Duration, fullRequestURL string) (*dto.MidjourneyResponseWithStatusCode, []byte, error) {
+	var nullBytes []byte
+	//var requestBody io.Reader
+	//requestBody = c.Request.Body
+	// read request body to json, delete accountFilter and notifyHook
+	var mapResult map[string]interface{}
+	err := json.NewDecoder(c.Request.Body).Decode(&mapResult)
+	if err != nil {
+		return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "read_request_body_failed", http.StatusInternalServerError), nullBytes, err
+	}
+	delete(mapResult, "accountFilter")
+	delete(mapResult, "notifyHook")
+	//req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
+	// make new request with mapResult
+	reqBody, err := json.Marshal(mapResult)
+	if err != nil {
+		return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "marshal_request_body_failed", http.StatusInternalServerError), nullBytes, err
+	}
+	req, err := http.NewRequest(c.Request.Method, fullRequestURL, strings.NewReader(string(reqBody)))
+	if err != nil {
+		return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "create_request_failed", http.StatusInternalServerError), nullBytes, err
+	}
+	ctx, cancel := context.WithTimeout(context.Background(), timeout)
+	// 使用带有超时的 context 创建新的请求
+	req = req.WithContext(ctx)
+	req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
+	req.Header.Set("Accept", c.Request.Header.Get("Accept"))
+	req.Header.Set("mj-api-secret", strings.Split(c.Request.Header.Get("Authorization"), " ")[1])
+	defer cancel()
+	resp, err := GetHttpClient().Do(req)
+	if err != nil {
+		return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "do_request_failed", http.StatusInternalServerError), nullBytes, err
+	}
+	statusCode := resp.StatusCode
+	//if statusCode != 200 {
+	//	return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "bad_response_status_code", statusCode), nullBytes, nil
+	//}
+	err = req.Body.Close()
+	if err != nil {
+		return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "close_request_body_failed", statusCode), nullBytes, err
+	}
+	err = c.Request.Body.Close()
+	if err != nil {
+		return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "close_request_body_failed", statusCode), nullBytes, err
+	}
+	var midjResponse dto.MidjourneyResponse
+
+	responseBody, err := io.ReadAll(resp.Body)
+	if err != nil {
+		return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "read_response_body_failed", statusCode), nullBytes, err
+	}
+	err = resp.Body.Close()
+	if err != nil {
+		return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "close_response_body_failed", statusCode), responseBody, err
+	}
+
+	err = json.Unmarshal(responseBody, &midjResponse)
+	log.Printf("responseBody: %s", string(responseBody))
+	if err != nil {
+		return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "unmarshal_response_body_failed", statusCode), responseBody, err
+	}
+	//log.Printf("midjResponse: %v", midjResponse)
+	//for k, v := range resp.Header {
+	//	c.Writer.Header().Set(k, v[0])
+	//}
+	return &dto.MidjourneyResponseWithStatusCode{
+		StatusCode: statusCode,
+		Response:   midjResponse,
+	}, responseBody, nil
+}

+ 26 - 2
web/src/components/MjLogsTable.js

@@ -31,10 +31,30 @@ function renderType(type) {
             return <Tag color="orange" size='large'>放大</Tag>;
         case 'VARIATION':
             return <Tag color="purple" size='large'>变换</Tag>;
+        case 'HIGH_VARIATION':
+            return <Tag color="purple" size='large'>强变换</Tag>;
+        case 'LOW_VARIATION':
+            return <Tag color="purple" size='large'>弱变换</Tag>;
+        case 'PAN':
+            return <Tag color="cyan" size='large'>平移</Tag>;
         case 'DESCRIBE':
             return <Tag color="yellow" size='large'>图生文</Tag>;
-        case 'BLEAND':
+        case 'BLEND':
             return <Tag color="lime" size='large'>图混合</Tag>;
+        case 'SHORTEN':
+            return <Tag color="pink" size='large'>缩词</Tag>;
+        case 'REROLL':
+            return <Tag color="indigo" size='large'>重绘</Tag>;
+        case 'INPAINT':
+            return <Tag color="violet" size='large'>局部重绘-提交</Tag>;
+        case 'ZOOM':
+            return <Tag color="teal" size='large'>变焦</Tag>;
+        case 'CUSTOM_ZOOM':
+            return <Tag color="teal" size='large'>自定义变焦-提交</Tag>;
+        case 'MODAL':
+            return <Tag color="green" size='large'>窗口处理</Tag>;
+        case 'SWAP_FACE':
+            return <Tag color="light-green" size='large'>换脸</Tag>;
         default:
             return <Tag color="white" size='large'>未知</Tag>;
     }
@@ -46,9 +66,11 @@ function renderCode(code) {
         case 1:
             return <Tag color="green" size='large'>已提交</Tag>;
         case 21:
-            return <Tag color="lime" size='large'>排队中</Tag>;
+            return <Tag color="lime" size='large'>等待中</Tag>;
         case 22:
             return <Tag color="orange" size='large'>重复提交</Tag>;
+        case 0:
+            return <Tag color="yellow" size='large'>未提交</Tag>;
         default:
             return <Tag color="white" size='large'>未知</Tag>;
     }
@@ -68,6 +90,8 @@ function renderStatus(type) {
             return <Tag color="blue" size='large'>执行中</Tag>;
         case 'FAILURE':
             return <Tag color="red" size='large'>失败</Tag>;
+        case 'MODAL':
+            return <Tag color="yellow" size='large'>窗口等待</Tag>;
         default:
             return <Tag color="white" size='large'>未知</Tag>;
     }

+ 1 - 0
web/src/constants/channel.constants.js

@@ -1,6 +1,7 @@
 export const CHANNEL_OPTIONS = [
     {key: 1, text: 'OpenAI', value: 1, color: 'green', label: 'OpenAI'},
     {key: 2, text: 'Midjourney Proxy', value: 2, color: 'light-blue', label: 'Midjourney Proxy'},
+    {key: 5, text: 'Midjourney Proxy Plus', value: 5, color: 'blue', label: 'Midjourney Proxy Plus'},
     {key: 4, text: 'Ollama', value: 4, color: 'grey', label: 'Ollama'},
     {key: 14, text: 'Anthropic Claude', value: 14, color: 'indigo', label: 'Anthropic Claude'},
     {key: 3, text: 'Azure OpenAI', value: 3, color: 'teal', label: 'Azure OpenAI'},

+ 22 - 0
web/src/pages/Channel/EditChannel.js

@@ -95,6 +95,28 @@ const EditChannel = (props) => {
                 case 26:
                     localModels = ['glm-4', 'glm-4v', 'glm-3-turbo'];
                     break;
+                case 2:
+                    localModels = ['mj_imagine', 'mj_variation', 'mj_reroll', 'mj_blend', 'mj_upscale', 'mj_describe'];
+                    break;
+                case 5:
+                    localModels = [
+                        'swap_face',
+                        'mj_imagine',
+                        'mj_variation',
+                        'mj_reroll',
+                        'mj_blend',
+                        'mj_upscale',
+                        'mj_describe',
+                        'mj_zoom',
+                        'mj_shorten',
+                        'mj_modal',
+                        'mj_inpaint',
+                        'mj_custom_zoom',
+                        'mj_high_variation',
+                        'mj_low_variation',
+                        'mj_pan',
+                    ];
+                    break;
             }
             setInputs((inputs) => ({...inputs, models: localModels}));
         }