Selaa lähdekoodia

✨ feat(channel): implement multi-key mode handling and improve channel update logic

CaIon 7 kuukautta sitten
vanhempi
commit
85efea3fb8

+ 19 - 1
controller/channel.go

@@ -718,8 +718,13 @@ func DeleteChannelBatch(c *gin.Context) {
 	return
 }
 
+type PatchChannel struct {
+	model.Channel
+	MultiKeyMode *string `json:"multi_key_mode"`
+}
+
 func UpdateChannel(c *gin.Context) {
-	channel := model.Channel{}
+	channel := PatchChannel{}
 	err := c.ShouldBindJSON(&channel)
 	if err != nil {
 		c.JSON(http.StatusOK, gin.H{
@@ -761,6 +766,19 @@ func UpdateChannel(c *gin.Context) {
 			}
 		}
 	}
+	if channel.MultiKeyMode != nil && *channel.MultiKeyMode != "" {
+		originChannel, err := model.GetChannelById(channel.Id, false)
+		if err != nil {
+			c.JSON(http.StatusOK, gin.H{
+				"success": false,
+				"message": err.Error(),
+			})
+		}
+		if originChannel.ChannelInfo.IsMultiKey {
+			channel.ChannelInfo = originChannel.ChannelInfo
+			channel.ChannelInfo.MultiKeyMode = constant.MultiKeyMode(*channel.MultiKeyMode)
+		}
+	}
 	err = channel.Update()
 	if err != nil {
 		c.JSON(http.StatusOK, gin.H{

+ 19 - 4
model/channel.go

@@ -117,7 +117,15 @@ func (channel *Channel) GetNextEnabledKey() (string, *types.NewAPIError) {
 		// Randomly pick one enabled key
 		return keys[enabledIdx[rand.Intn(len(enabledIdx))]], nil
 	case constant.MultiKeyModePolling:
+		defer func() {
+			if !common.MemoryCacheEnabled {
+				_ = channel.Save()
+			} else {
+				CacheUpdateChannel(channel)
+			}
+		}()
 		// Start from the saved polling index and look for the next enabled key
+		println(channel.ChannelInfo.MultiKeyPollingIndex)
 		start := channel.ChannelInfo.MultiKeyPollingIndex
 		if start < 0 || start >= len(keys) {
 			start = 0
@@ -127,6 +135,7 @@ func (channel *Channel) GetNextEnabledKey() (string, *types.NewAPIError) {
 			if getStatus(idx) == common.ChannelStatusEnabled {
 				// update polling index for next call (point to the next position)
 				channel.ChannelInfo.MultiKeyPollingIndex = (idx + 1) % len(keys)
+				println(channel.ChannelInfo.MultiKeyPollingIndex)
 				return keys[idx], nil
 			}
 		}
@@ -273,14 +282,20 @@ func SearchChannels(keyword string, group string, model string, idSort bool) ([]
 }
 
 func GetChannelById(id int, selectAll bool) (*Channel, error) {
-	channel := Channel{Id: id}
+	channel := &Channel{Id: id}
 	var err error = nil
 	if selectAll {
-		err = DB.First(&channel, "id = ?", id).Error
+		err = DB.First(channel, "id = ?", id).Error
 	} else {
-		err = DB.Omit("key").First(&channel, "id = ?", id).Error
+		err = DB.Omit("key").First(channel, "id = ?", id).Error
+	}
+	if err != nil {
+		return nil, err
+	}
+	if channel == nil {
+		return nil, errors.New("channel not found")
 	}
-	return &channel, err
+	return channel, nil
 }
 
 func BatchInsertChannels(channels []Channel) error {

+ 22 - 12
model/cache.go → model/channel_cache.go

@@ -14,7 +14,7 @@ import (
 	"github.com/gin-gonic/gin"
 )
 
-var group2model2channels map[string]map[string][]*Channel
+var group2model2channels map[string]map[string][]int
 var channelsIDM map[int]*Channel
 var channelSyncLock sync.RWMutex
 
@@ -34,10 +34,10 @@ func InitChannelCache() {
 	for _, ability := range abilities {
 		groups[ability.Group] = true
 	}
-	newGroup2model2channels := make(map[string]map[string][]*Channel)
+	newGroup2model2channels := make(map[string]map[string][]int)
 	newChannelsIDM := make(map[int]*Channel)
 	for group := range groups {
-		newGroup2model2channels[group] = make(map[string][]*Channel)
+		newGroup2model2channels[group] = make(map[string][]int)
 	}
 	for _, channel := range channels {
 		newChannelsIDM[channel.Id] = channel
@@ -46,9 +46,9 @@ func InitChannelCache() {
 			models := strings.Split(channel.Models, ",")
 			for _, model := range models {
 				if _, ok := newGroup2model2channels[group][model]; !ok {
-					newGroup2model2channels[group][model] = make([]*Channel, 0)
+					newGroup2model2channels[group][model] = make([]int, 0)
 				}
-				newGroup2model2channels[group][model] = append(newGroup2model2channels[group][model], channel)
+				newGroup2model2channels[group][model] = append(newGroup2model2channels[group][model], channel.Id)
 			}
 		}
 	}
@@ -57,7 +57,7 @@ func InitChannelCache() {
 	for group, model2channels := range newGroup2model2channels {
 		for model, channels := range model2channels {
 			sort.Slice(channels, func(i, j int) bool {
-				return channels[i].GetPriority() > channels[j].GetPriority()
+				return newChannelsIDM[channels[i]].GetPriority() > newChannelsIDM[channels[j]].GetPriority()
 			})
 			newGroup2model2channels[group][model] = channels
 		}
@@ -136,8 +136,12 @@ func getRandomSatisfiedChannel(group string, model string, retry int) (*Channel,
 	}
 
 	uniquePriorities := make(map[int]bool)
-	for _, channel := range channels {
-		uniquePriorities[int(channel.GetPriority())] = true
+	for _, channelId := range channels {
+		if channel, ok := channelsIDM[channelId]; ok {
+			uniquePriorities[int(channel.GetPriority())] = true
+		} else {
+			return nil, fmt.Errorf("数据库一致性错误,渠道# %d 不存在,请联系管理员修复", channelId)
+		}
 	}
 	var sortedUniquePriorities []int
 	for priority := range uniquePriorities {
@@ -152,9 +156,13 @@ func getRandomSatisfiedChannel(group string, model string, retry int) (*Channel,
 
 	// get the priority for the given retry number
 	var targetChannels []*Channel
-	for _, channel := range channels {
-		if channel.GetPriority() == targetPriority {
-			targetChannels = append(targetChannels, channel)
+	for _, channelId := range channels {
+		if channel, ok := channelsIDM[channelId]; ok {
+			if channel.GetPriority() == targetPriority {
+				targetChannels = append(targetChannels, channel)
+			}
+		} else {
+			return nil, fmt.Errorf("数据库一致性错误,渠道# %d 不存在,请联系管理员修复", channelId)
 		}
 	}
 
@@ -210,9 +218,11 @@ func CacheUpdateChannel(channel *Channel) {
 	}
 	channelSyncLock.Lock()
 	defer channelSyncLock.Unlock()
-
 	if channel == nil {
 		return
 	}
+
+	println("CacheUpdateChannel:", channel.Id, channel.Name, channel.Status, channel.ChannelInfo.MultiKeyPollingIndex)
+
 	channelsIDM[channel.Id] = channel
 }

+ 83 - 24
web/src/components/table/ChannelsTable.js

@@ -42,19 +42,20 @@ import {
   IconTreeTriangleDown,
   IconSearch,
   IconMore,
-  IconList
+  IconList, IconDescend2
 } from '@douyinfe/semi-icons';
 import { loadChannelModels, isMobile, copy } from '../../helpers';
 import EditTagModal from '../../pages/Channel/EditTagModal.js';
 import { useTranslation } from 'react-i18next';
 import { useTableCompactMode } from '../../hooks/useTableCompactMode';
+import { FaRandom } from 'react-icons/fa';
 
 const ChannelsTable = () => {
   const { t } = useTranslation();
 
   let type2label = undefined;
 
-  const renderType = (type, multiKey = false) => {
+  const renderType = (type, channelInfo = undefined) => {
     if (!type2label) {
       type2label = new Map();
       for (let i = 0; i < CHANNEL_OPTIONS.length; i++) {
@@ -65,13 +66,20 @@ const ChannelsTable = () => {
     
     let icon = getChannelIcon(type);
     
-    if (multiKey) {
+    if (channelInfo?.is_multi_key) {
       icon = (
-        <div className="flex items-center gap-1">
-          <IconList className="text-blue-500" />
-          {icon}
-        </div>
-      );
+        channelInfo?.multi_key_mode === 'random' ? (
+          <div className="flex items-center gap-1">
+            <FaRandom className="text-blue-500" />
+            {icon}
+          </div>
+        ) : (
+          <div className="flex items-center gap-1">
+            <IconDescend2 className="text-blue-500" />
+            {icon}
+          </div>
+        )
+      )
     }
     
     return (
@@ -587,24 +595,70 @@ const ChannelsTable = () => {
                 />
               </SplitButtonGroup>
 
-              {record.status === 1 ? (
-                <Button
-                  theme='light'
-                  type='warning'
-                  size="small"
-                  onClick={() => manageChannel(record.id, 'disable', record)}
+              {record.channel_info?.is_multi_key ? (
+                <SplitButtonGroup
+                  aria-label={t('多密钥渠道操作项目组')}
                 >
-                  {t('禁用')}
-                </Button>
+                  {
+                    record.status === 1 ? (
+                      <Button
+                        theme='light'
+                        type='warning'
+                        size="small"
+                        onClick={() => manageChannel(record.id, 'disable', record)}
+                      >
+                        {t('禁用')}
+                      </Button>
+                    ) : (
+                      <Button
+                        theme='light'
+                        type='secondary'
+                        size="small"
+                        onClick={() => manageChannel(record.id, 'enable', record)}
+                      >
+                        {t('启用')}
+                      </Button>
+                    )
+                  }
+                  <Dropdown
+                    trigger='click'
+                    position='bottomRight'
+                    menu={[
+                      {
+                        node: 'item',
+                        name: t('启用全部密钥'),
+                        onClick: () => manageChannel(record.id, 'enable_all', record),
+                      }
+                    ]}
+                  >
+                    <Button
+                      theme='light'
+                      type='secondary'
+                      size="small"
+                      icon={<IconTreeTriangleDown />}
+                    />
+                  </Dropdown>
+                </SplitButtonGroup>
               ) : (
-                <Button
-                  theme='light'
-                  type='secondary'
-                  size="small"
-                  onClick={() => manageChannel(record.id, 'enable', record)}
-                >
-                  {t('启用')}
-                </Button>
+                record.status === 1 ? (
+                  <Button
+                    theme='light'
+                    type='warning'
+                    size="small"
+                    onClick={() => manageChannel(record.id, 'disable', record)}
+                  >
+                    {t('禁用')}
+                  </Button>
+                ) : (
+                  <Button
+                    theme='light'
+                    type='secondary'
+                    size="small"
+                    onClick={() => manageChannel(record.id, 'enable', record)}
+                  >
+                    {t('启用')}
+                  </Button>
+                )
               )}
 
               <Button
@@ -1014,6 +1068,11 @@ const ChannelsTable = () => {
         }
         res = await API.put('/api/channel/', data);
         break;
+      case 'enable_all':
+        data.channel_info = record.channel_info;
+        data.channel_info.multi_key_status_list = {};
+        res = await API.put('/api/channel/', data);
+        break;
     }
     const { success, message } = res.data;
     if (success) {

+ 49 - 45
web/src/pages/Channel/EditChannel.js

@@ -435,7 +435,7 @@ const EditChannel = (props) => {
     const formValues = formApiRef.current ? formApiRef.current.getValues() : {};
     let localInputs = { ...formValues };
 
-    if (localInputs.type === 41 && batch) {
+    if (localInputs.type === 41) {
       let keys = vertexKeys;
       if (keys.length === 0) {
         // 确保提交时也能解析,避免因异步延迟导致 keys 为空
@@ -460,7 +460,11 @@ const EditChannel = (props) => {
         return;
       }
 
-      localInputs.key = JSON.stringify(keys);
+      if (batch) {
+        localInputs.key = JSON.stringify(keys);
+      } else {
+        localInputs.key = JSON.stringify(keys[0]);
+      }
     }
     delete localInputs.vertex_files;
 
@@ -561,7 +565,7 @@ const EditChannel = (props) => {
   const batchAllowed = !isEdit || isMultiKeyChannel;
   const batchExtra = batchAllowed ? (
     <Space>
-      <Checkbox checked={batch} onChange={() => {
+      <Checkbox disabled={isEdit} checked={batch} onChange={() => {
         setBatch(!batch);
         if (batch) {
           setMultiToSingle(false);
@@ -569,7 +573,7 @@ const EditChannel = (props) => {
         }
       }}>{t('批量创建')}</Checkbox>
       {batch && (
-        <Checkbox checked={multiToSingle} onChange={() => {
+        <Checkbox disabled={isEdit} checked={multiToSingle} onChange={() => {
           setMultiToSingle(prev => !prev);
           setInputs(prev => {
             const newInputs = { ...prev };
@@ -702,35 +706,26 @@ const EditChannel = (props) => {
                   ) : (
                     <>
                       {inputs.type === 41 ? (
-                        <Form.TextArea
-                          field='key'
-                          label={t('密钥')}
-                          placeholder={
-                            '{\n' +
-                            '  "type": "service_account",\n' +
-                            '  "project_id": "abc-bcd-123-456",\n' +
-                            '  "private_key_id": "123xxxxx456",\n' +
-                            '  "private_key": "-----BEGIN PRIVATE KEY-----xxxx\n' +
-                            '  "client_email": "xxx@developer.gserviceaccount.com",\n' +
-                            '  "client_id": "111222333",\n' +
-                            '  "auth_uri": "https://accounts.google.com/o/oauth2/auth",\n' +
-                            '  "token_uri": "https://oauth2.googleapis.com/token",\n' +
-                            '  "auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs",\n' +
-                            '  "client_x509_cert_url": "https://xxxxx.gserviceaccount.com",\n' +
-                            '  "universe_domain": "googleapis.com"\n' +
-                            '}'
-                          }
-                          rules={isEdit ? [] : [{ required: true, message: t('请输入密钥') }]}
-                          autosize
-                          autoComplete='new-password'
-                          onChange={(value) => handleInputChange('key', value)}
-                          extraText={batchExtra}
-                          showClear
-                        />
+                        <Form.Upload
+                        field='vertex_files'
+                        label={t('密钥文件 (.json)')}
+                        accept='.json'
+                        draggable
+                        dragIcon={<IconBolt />}
+                        dragMainText={t('点击上传文件或拖拽文件到这里')}
+                        dragSubText={t('仅支持 JSON 文件')}
+                        style={{ marginTop: 10 }}
+                        uploadTrigger='custom'
+                        beforeUpload={() => false}
+                        onChange={handleVertexUploadChange}
+                        fileList={vertexFileList}
+                        rules={isEdit ? [] : [{ required: true, message: t('请上传密钥文件') }]}
+                        extraText={batchExtra}
+                      />
                       ) : (
                         <Form.Input
                           field='key'
-                          label={t('密钥')}
+                          label={isEdit ? t('密钥(编辑模式下,保存的密钥不会显示)') : t('密钥')}
                           placeholder={t(type2secretPrompt(inputs.type))}
                           rules={isEdit ? [] : [{ required: true, message: t('请输入密钥') }]}
                           autoComplete='new-password'
@@ -743,21 +738,30 @@ const EditChannel = (props) => {
                   )}
 
                   {batch && multiToSingle && (
-                    <Form.Select
-                      field='multi_key_mode'
-                      label={t('密钥聚合模式')}
-                      placeholder={t('请选择多密钥使用策略')}
-                      optionList={[
-                        { label: t('随机'), value: 'random' },
-                        { label: t('轮询'), value: 'polling' },
-                      ]}
-                      style={{ width: '100%' }}
-                      value={inputs.multi_key_mode || 'random'}
-                      onChange={(value) => {
-                        setMultiKeyMode(value);
-                        handleInputChange('multi_key_mode', value);
-                      }}
-                    />
+                    <>
+                      <Form.Select
+                        field='multi_key_mode'
+                        label={t('密钥聚合模式')}
+                        placeholder={t('请选择多密钥使用策略')}
+                        optionList={[
+                          { label: t('随机'), value: 'random' },
+                          { label: t('轮询'), value: 'polling' },
+                        ]}
+                        style={{ width: '100%' }}
+                        value={inputs.multi_key_mode || 'random'}
+                        onChange={(value) => {
+                          setMultiKeyMode(value);
+                          handleInputChange('multi_key_mode', value);
+                        }}
+                      />
+                      {inputs.multi_key_mode === 'polling' && (
+                        <Banner
+                          type='warning'
+                          description={t('轮询模式必须搭配Redis和内存缓存功能使用,否则性能将大幅降低,并且无法实现轮询功能')}
+                          className='!rounded-lg mt-2'
+                        />
+                      )}
+                    </>
                   )}
 
                   {inputs.type === 18 && (