Просмотр исходного кода

✨ feat(channel): enhance AddChannel functionality with structured request handling

CaIon 8 месяцев назад
Родитель
Сommit
0089157b83
3 измененных файлов с 93 добавлено и 23 удалено
  1. 84 21
      controller/channel.go
  2. 8 1
      model/channel.go
  3. 1 1
      model/main.go

+ 84 - 21
controller/channel.go

@@ -250,9 +250,14 @@ func GetChannel(c *gin.Context) {
 	return
 }
 
+type AddChannelRequest struct {
+	Mode    string         `json:"mode"`
+	Channel *model.Channel `json:"channel"`
+}
+
 func AddChannel(c *gin.Context) {
-	channel := model.Channel{}
-	err := c.ShouldBindJSON(&channel)
+	addChannelRequest := AddChannelRequest{}
+	err := c.ShouldBindJSON(&addChannelRequest)
 	if err != nil {
 		c.JSON(http.StatusOK, gin.H{
 			"success": false,
@@ -260,19 +265,35 @@ func AddChannel(c *gin.Context) {
 		})
 		return
 	}
-	channel.CreatedTime = common.GetTimestamp()
-	keys := strings.Split(channel.Key, "\n")
-	if channel.Type == common.ChannelTypeVertexAi {
-		if channel.Other == "" {
+	if addChannelRequest.Channel == nil || addChannelRequest.Channel.Key == "" {
+		c.JSON(http.StatusOK, gin.H{
+			"success": false,
+			"message": "channel cannot be empty",
+		})
+		return
+	}
+
+	// Validate the length of the model name
+	for _, m := range addChannelRequest.Channel.GetModels() {
+		if len(m) > 255 {
+			c.JSON(http.StatusOK, gin.H{
+				"success": false,
+				"message": fmt.Sprintf("模型名称过长: %s", m),
+			})
+			return
+		}
+	}
+	if addChannelRequest.Channel.Type == common.ChannelTypeVertexAi {
+		if addChannelRequest.Channel.Other == "" {
 			c.JSON(http.StatusOK, gin.H{
 				"success": false,
 				"message": "部署地区不能为空",
 			})
 			return
 		} else {
-			if common.IsJsonStr(channel.Other) {
+			if common.IsJsonStr(addChannelRequest.Channel.Other) {
 				// must have default
-				regionMap := common.StrToMap(channel.Other)
+				regionMap := common.StrToMap(addChannelRequest.Channel.Other)
 				if regionMap["default"] == nil {
 					c.JSON(http.StatusOK, gin.H{
 						"success": false,
@@ -282,27 +303,69 @@ func AddChannel(c *gin.Context) {
 				}
 			}
 		}
-		keys = []string{channel.Key}
 	}
-	channels := make([]model.Channel, 0, len(keys))
-	for _, key := range keys {
-		if key == "" {
-			continue
+
+	addChannelRequest.Channel.CreatedTime = common.GetTimestamp()
+	keys := make([]string, 0)
+	switch addChannelRequest.Mode {
+	case "multi_to_single":
+		addChannelRequest.Channel.ChannelInfo.MultiKeyMode = true
+		if addChannelRequest.Channel.Type == common.ChannelTypeVertexAi {
+			if !common.IsJsonStr(addChannelRequest.Channel.Key) {
+				c.JSON(http.StatusOK, gin.H{
+					"success": false,
+					"message": "Vertex AI 批量添加模式必须使用标准的JsonArray格式,例如[{key1}, {key2}...],请检查输入",
+				})
+				return
+			}
 		}
-		localChannel := channel
-		localChannel.Key = key
-		// Validate the length of the model name
-		models := strings.Split(localChannel.Models, ",")
-		for _, model := range models {
-			if len(model) > 255 {
+		keys = []string{addChannelRequest.Channel.Key}
+	case "batch":
+		if addChannelRequest.Channel.Type == common.ChannelTypeVertexAi {
+			// multi json
+			if !common.IsJsonStr(addChannelRequest.Channel.Key) {
 				c.JSON(http.StatusOK, gin.H{
 					"success": false,
-					"message": fmt.Sprintf("模型名称过长: %s", model),
+					"message": "Vertex AI 批量添加模式必须使用标准的JsonArray格式,例如[{key1}, {key2}...],请检查输入",
 				})
 				return
 			}
+			toMap := common.StrToMap(addChannelRequest.Channel.Key)
+			if toMap == nil {
+				c.JSON(http.StatusOK, gin.H{
+					"success": false,
+					"message": "Vertex AI 批量添加模式必须使用标准的JsonArray格式,例如[{key1}, {key2}...],请检查输入",
+				})
+				return
+			}
+			keys = make([]string, 0, len(toMap))
+			for k := range toMap {
+				if k == "" {
+					continue
+				}
+				keys = append(keys, k)
+			}
+		} else {
+			keys = strings.Split(addChannelRequest.Channel.Key, "\n")
 		}
-		channels = append(channels, localChannel)
+	case "single":
+		keys = []string{addChannelRequest.Channel.Key}
+	default:
+		c.JSON(http.StatusOK, gin.H{
+			"success": false,
+			"message": "不支持的添加模式",
+		})
+		return
+	}
+
+	channels := make([]model.Channel, 0, len(keys))
+	for _, key := range keys {
+		if key == "" {
+			continue
+		}
+		localChannel := addChannelRequest.Channel
+		localChannel.Key = key
+		channels = append(channels, *localChannel)
 	}
 	err = model.BatchInsertChannels(channels)
 	if err != nil {

+ 8 - 1
model/channel.go

@@ -9,6 +9,11 @@ import (
 	"gorm.io/gorm"
 )
 
+type ChannelInfo struct {
+	MultiKeyMode       bool        `json:"multi_key_mode"`        // 是否多Key模式
+	MultiKeyStatusList map[int]int `json:"multi_key_status_list"` // key状态列表,key index -> status
+}
+
 type Channel struct {
 	Id                 int     `json:"id"`
 	Type               int     `json:"type" gorm:"default:0"`
@@ -35,8 +40,10 @@ type Channel struct {
 	AutoBan           *int    `json:"auto_ban" gorm:"default:1"`
 	OtherInfo         string  `json:"other_info"`
 	Tag               *string `json:"tag" gorm:"index"`
-	Setting           *string `json:"setting" gorm:"type:text"`
+	Setting           *string `json:"setting" gorm:"type:text"` // 渠道额外设置
 	ParamOverride     *string `json:"param_override" gorm:"type:text"`
+	// add after v0.8.5
+	ChannelInfo ChannelInfo `json:"channel_info" gorm:"type:json"`
 }
 
 func (channel *Channel) GetModels() []string {

+ 1 - 1
model/main.go

@@ -48,7 +48,7 @@ func initCol() {
 		}
 	}
 	// log sql type and database type
-	common.SysLog("Using Log SQL Type: " + common.LogSqlType)
+	//common.SysLog("Using Log SQL Type: " + common.LogSqlType)
 }
 
 var DB *gorm.DB