Jelajahi Sumber

feat(channel): add support for Vertex AI key type configuration in settings

CaIon 5 bulan lalu
induk
melakukan
e68eed3d40

+ 4 - 3
controller/channel.go

@@ -6,6 +6,7 @@ import (
 	"net/http"
 	"net/http"
 	"one-api/common"
 	"one-api/common"
 	"one-api/constant"
 	"one-api/constant"
+	"one-api/dto"
 	"one-api/model"
 	"one-api/model"
 	"strconv"
 	"strconv"
 	"strings"
 	"strings"
@@ -560,7 +561,7 @@ func AddChannel(c *gin.Context) {
 	case "multi_to_single":
 	case "multi_to_single":
 		addChannelRequest.Channel.ChannelInfo.IsMultiKey = true
 		addChannelRequest.Channel.ChannelInfo.IsMultiKey = true
 		addChannelRequest.Channel.ChannelInfo.MultiKeyMode = addChannelRequest.MultiKeyMode
 		addChannelRequest.Channel.ChannelInfo.MultiKeyMode = addChannelRequest.MultiKeyMode
-		if addChannelRequest.Channel.Type == constant.ChannelTypeVertexAi {
+		if addChannelRequest.Channel.Type == constant.ChannelTypeVertexAi && addChannelRequest.Channel.GetOtherSettings().VertexKeyType != dto.VertexKeyTypeAPIKey {
 			array, err := getVertexArrayKeys(addChannelRequest.Channel.Key)
 			array, err := getVertexArrayKeys(addChannelRequest.Channel.Key)
 			if err != nil {
 			if err != nil {
 				c.JSON(http.StatusOK, gin.H{
 				c.JSON(http.StatusOK, gin.H{
@@ -585,7 +586,7 @@ func AddChannel(c *gin.Context) {
 		}
 		}
 		keys = []string{addChannelRequest.Channel.Key}
 		keys = []string{addChannelRequest.Channel.Key}
 	case "batch":
 	case "batch":
-		if addChannelRequest.Channel.Type == constant.ChannelTypeVertexAi {
+		if addChannelRequest.Channel.Type == constant.ChannelTypeVertexAi && addChannelRequest.Channel.GetOtherSettings().VertexKeyType != dto.VertexKeyTypeAPIKey {
 			// multi json
 			// multi json
 			keys, err = getVertexArrayKeys(addChannelRequest.Channel.Key)
 			keys, err = getVertexArrayKeys(addChannelRequest.Channel.Key)
 			if err != nil {
 			if err != nil {
@@ -840,7 +841,7 @@ func UpdateChannel(c *gin.Context) {
 				}
 				}
 
 
 				// 处理 Vertex AI 的特殊情况
 				// 处理 Vertex AI 的特殊情况
-				if channel.Type == constant.ChannelTypeVertexAi {
+				if channel.Type == constant.ChannelTypeVertexAi && channel.GetOtherSettings().VertexKeyType != dto.VertexKeyTypeAPIKey {
 					// 尝试解析新密钥为JSON数组
 					// 尝试解析新密钥为JSON数组
 					if strings.HasPrefix(strings.TrimSpace(channel.Key), "[") {
 					if strings.HasPrefix(strings.TrimSpace(channel.Key), "[") {
 						array, err := getVertexArrayKeys(channel.Key)
 						array, err := getVertexArrayKeys(channel.Key)

+ 9 - 1
dto/channel_settings.go

@@ -9,6 +9,14 @@ type ChannelSettings struct {
 	SystemPromptOverride   bool   `json:"system_prompt_override,omitempty"`
 	SystemPromptOverride   bool   `json:"system_prompt_override,omitempty"`
 }
 }
 
 
+type VertexKeyType string
+
+const (
+	VertexKeyTypeJSON   VertexKeyType = "json"
+	VertexKeyTypeAPIKey VertexKeyType = "api_key"
+)
+
 type ChannelOtherSettings struct {
 type ChannelOtherSettings struct {
-	AzureResponsesVersion string `json:"azure_responses_version,omitempty"`
+	AzureResponsesVersion string        `json:"azure_responses_version,omitempty"`
+	VertexKeyType         VertexKeyType `json:"vertex_key_type,omitempty"` // "json" or "api_key"
 }
 }

+ 2 - 1
model/channel.go

@@ -42,7 +42,6 @@ type Channel struct {
 	Priority          *int64  `json:"priority" gorm:"bigint;default:0"`
 	Priority          *int64  `json:"priority" gorm:"bigint;default:0"`
 	AutoBan           *int    `json:"auto_ban" gorm:"default:1"`
 	AutoBan           *int    `json:"auto_ban" gorm:"default:1"`
 	OtherInfo         string  `json:"other_info"`
 	OtherInfo         string  `json:"other_info"`
-	OtherSettings     string  `json:"settings" gorm:"column:settings"` // 其他设置
 	Tag               *string `json:"tag" gorm:"index"`
 	Tag               *string `json:"tag" gorm:"index"`
 	Setting           *string `json:"setting" gorm:"type:text"` // 渠道额外设置
 	Setting           *string `json:"setting" gorm:"type:text"` // 渠道额外设置
 	ParamOverride     *string `json:"param_override" gorm:"type:text"`
 	ParamOverride     *string `json:"param_override" gorm:"type:text"`
@@ -51,6 +50,8 @@ type Channel struct {
 	// add after v0.8.5
 	// add after v0.8.5
 	ChannelInfo ChannelInfo `json:"channel_info" gorm:"type:json"`
 	ChannelInfo ChannelInfo `json:"channel_info" gorm:"type:json"`
 
 
+	OtherSettings string `json:"settings" gorm:"column:settings"` // 其他设置,存储azure版本等不需要检索的信息,详见dto.ChannelOtherSettings
+
 	// cache info
 	// cache info
 	Keys []string `json:"-" gorm:"-"`
 	Keys []string `json:"-" gorm:"-"`
 }
 }

+ 65 - 51
relay/channel/vertex/adaptor.go

@@ -6,6 +6,7 @@ import (
 	"fmt"
 	"fmt"
 	"io"
 	"io"
 	"net/http"
 	"net/http"
+	"one-api/common"
 	"one-api/dto"
 	"one-api/dto"
 	"one-api/relay/channel"
 	"one-api/relay/channel"
 	"one-api/relay/channel/claude"
 	"one-api/relay/channel/claude"
@@ -80,16 +81,64 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
 	}
 	}
 }
 }
 
 
-func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
-	adc := &Credentials{}
-	if err := json.Unmarshal([]byte(info.ApiKey), adc); err != nil {
-		return "", fmt.Errorf("failed to decode credentials file: %w", err)
-	}
+func (a *Adaptor) getRequestUrl(info *relaycommon.RelayInfo, modelName, suffix string) (string, error) {
 	region := GetModelRegion(info.ApiVersion, info.OriginModelName)
 	region := GetModelRegion(info.ApiVersion, info.OriginModelName)
-	a.AccountCredentials = *adc
+	if info.ChannelOtherSettings.VertexKeyType != dto.VertexKeyTypeAPIKey {
+		adc := &Credentials{}
+		if err := common.Unmarshal([]byte(info.ApiKey), adc); err != nil {
+			return "", fmt.Errorf("failed to decode credentials file: %w", err)
+		}
+		a.AccountCredentials = *adc
+
+		if a.RequestMode == RequestModeLlama {
+			return fmt.Sprintf(
+				"https://%s-aiplatform.googleapis.com/v1beta1/projects/%s/locations/%s/endpoints/openapi/chat/completions",
+				region,
+				adc.ProjectID,
+				region,
+			), nil
+		}
+
+		if region == "global" {
+			return fmt.Sprintf(
+				"https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/google/models/%s:%s",
+				adc.ProjectID,
+				modelName,
+				suffix,
+			), nil
+		} else {
+			return fmt.Sprintf(
+				"https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:%s",
+				region,
+				adc.ProjectID,
+				region,
+				modelName,
+				suffix,
+			), nil
+		}
+	} else {
+		if region == "global" {
+			return fmt.Sprintf(
+				"https://aiplatform.googleapis.com/v1/publishers/google/models/%s:%s?key=%s",
+				modelName,
+				suffix,
+				info.ApiKey,
+			), nil
+		} else {
+			return fmt.Sprintf(
+				"https://%s-aiplatform.googleapis.com/v1/publishers/google/models/%s:%s?key=%s",
+				region,
+				modelName,
+				suffix,
+				info.ApiKey,
+			), nil
+		}
+	}
+}
+
+func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
 	suffix := ""
 	suffix := ""
 	if a.RequestMode == RequestModeGemini {
 	if a.RequestMode == RequestModeGemini {
-
 		if model_setting.GetGeminiSettings().ThinkingAdapterEnabled {
 		if model_setting.GetGeminiSettings().ThinkingAdapterEnabled {
 			// 新增逻辑:处理 -thinking-<budget> 格式
 			// 新增逻辑:处理 -thinking-<budget> 格式
 			if strings.Contains(info.UpstreamModelName, "-thinking-") {
 			if strings.Contains(info.UpstreamModelName, "-thinking-") {
@@ -112,23 +161,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
 			suffix = "predict"
 			suffix = "predict"
 		}
 		}
 
 
-		if region == "global" {
-			return fmt.Sprintf(
-				"https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/google/models/%s:%s",
-				adc.ProjectID,
-				info.UpstreamModelName,
-				suffix,
-			), nil
-		} else {
-			return fmt.Sprintf(
-				"https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:%s",
-				region,
-				adc.ProjectID,
-				region,
-				info.UpstreamModelName,
-				suffix,
-			), nil
-		}
+		return a.getRequestUrl(info, info.UpstreamModelName, suffix)
 	} else if a.RequestMode == RequestModeClaude {
 	} else if a.RequestMode == RequestModeClaude {
 		if info.IsStream {
 		if info.IsStream {
 			suffix = "streamRawPredict?alt=sse"
 			suffix = "streamRawPredict?alt=sse"
@@ -139,41 +172,22 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
 		if v, ok := claudeModelMap[info.UpstreamModelName]; ok {
 		if v, ok := claudeModelMap[info.UpstreamModelName]; ok {
 			model = v
 			model = v
 		}
 		}
-		if region == "global" {
-			return fmt.Sprintf(
-				"https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/anthropic/models/%s:%s",
-				adc.ProjectID,
-				model,
-				suffix,
-			), nil
-		} else {
-			return fmt.Sprintf(
-				"https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:%s",
-				region,
-				adc.ProjectID,
-				region,
-				model,
-				suffix,
-			), nil
-		}
+		return a.getRequestUrl(info, model, suffix)
 	} else if a.RequestMode == RequestModeLlama {
 	} else if a.RequestMode == RequestModeLlama {
-		return fmt.Sprintf(
-			"https://%s-aiplatform.googleapis.com/v1beta1/projects/%s/locations/%s/endpoints/openapi/chat/completions",
-			region,
-			adc.ProjectID,
-			region,
-		), nil
+		return a.getRequestUrl(info, "", "")
 	}
 	}
 	return "", errors.New("unsupported request mode")
 	return "", errors.New("unsupported request mode")
 }
 }
 
 
 func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
 func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
 	channel.SetupApiRequestHeader(info, c, req)
 	channel.SetupApiRequestHeader(info, c, req)
-	accessToken, err := getAccessToken(a, info)
-	if err != nil {
-		return err
+	if info.ChannelOtherSettings.VertexKeyType == "json" {
+		accessToken, err := getAccessToken(a, info)
+		if err != nil {
+			return err
+		}
+		req.Set("Authorization", "Bearer "+accessToken)
 	}
 	}
-	req.Set("Authorization", "Bearer "+accessToken)
 	return nil
 	return nil
 }
 }
 
 

+ 86 - 47
web/src/components/table/channels/modals/EditChannelModal.jsx

@@ -142,6 +142,8 @@ const EditChannelModal = (props) => {
     system_prompt: '',
     system_prompt: '',
     system_prompt_override: false,
     system_prompt_override: false,
     settings: '',
     settings: '',
+    // 仅 Vertex: 密钥格式(存入 settings.vertex_key_type)
+    vertex_key_type: 'json',
   };
   };
   const [batch, setBatch] = useState(false);
   const [batch, setBatch] = useState(false);
   const [multiToSingle, setMultiToSingle] = useState(false);
   const [multiToSingle, setMultiToSingle] = useState(false);
@@ -409,11 +411,17 @@ const EditChannelModal = (props) => {
           const parsedSettings = JSON.parse(data.settings);
           const parsedSettings = JSON.parse(data.settings);
           data.azure_responses_version =
           data.azure_responses_version =
             parsedSettings.azure_responses_version || '';
             parsedSettings.azure_responses_version || '';
+          // 读取 Vertex 密钥格式
+          data.vertex_key_type = parsedSettings.vertex_key_type || 'json';
         } catch (error) {
         } catch (error) {
           console.error('解析其他设置失败:', error);
           console.error('解析其他设置失败:', error);
           data.azure_responses_version = '';
           data.azure_responses_version = '';
           data.region = '';
           data.region = '';
+          data.vertex_key_type = 'json';
         }
         }
+      } else {
+        // 兼容历史数据:老渠道没有 settings 时,默认按 json 展示
+        data.vertex_key_type = 'json';
       }
       }
 
 
       setInputs(data);
       setInputs(data);
@@ -745,59 +753,56 @@ const EditChannelModal = (props) => {
     let localInputs = { ...formValues };
     let localInputs = { ...formValues };
 
 
     if (localInputs.type === 41) {
     if (localInputs.type === 41) {
-      if (useManualInput) {
-        // 手动输入模式
-        if (localInputs.key && localInputs.key.trim() !== '') {
-          try {
-            // 验证 JSON 格式
-            const parsedKey = JSON.parse(localInputs.key);
-            // 确保是有效的密钥格式
-            localInputs.key = JSON.stringify(parsedKey);
-          } catch (err) {
-            showError(t('密钥格式无效,请输入有效的 JSON 格式密钥'));
-            return;
-          }
-        } else if (!isEdit) {
+      const keyType = localInputs.vertex_key_type || 'json';
+      if (keyType === 'api_key') {
+        // 直接作为普通字符串密钥处理
+        if (!isEdit && (!localInputs.key || localInputs.key.trim() === '')) {
           showInfo(t('请输入密钥!'));
           showInfo(t('请输入密钥!'));
           return;
           return;
         }
         }
       } else {
       } else {
-        // 文件上传模式
-        let keys = vertexKeys;
-
-        // 若当前未选择文件,尝试从已上传文件列表解析(异步读取)
-        if (keys.length === 0 && vertexFileList.length > 0) {
-          try {
-            const parsed = await Promise.all(
-              vertexFileList.map(async (item) => {
-                const fileObj = item.fileInstance;
-                if (!fileObj) return null;
-                const txt = await fileObj.text();
-                return JSON.parse(txt);
-              }),
-            );
-            keys = parsed.filter(Boolean);
-          } catch (err) {
-            showError(t('解析密钥文件失败: {{msg}}', { msg: err.message }));
-            return;
-          }
-        }
-
-        // 创建模式必须上传密钥;编辑模式可选
-        if (keys.length === 0) {
-          if (!isEdit) {
-            showInfo(t('请上传密钥文件!'));
+        // JSON 服务账号密钥
+        if (useManualInput) {
+          if (localInputs.key && localInputs.key.trim() !== '') {
+            try {
+              const parsedKey = JSON.parse(localInputs.key);
+              localInputs.key = JSON.stringify(parsedKey);
+            } catch (err) {
+              showError(t('密钥格式无效,请输入有效的 JSON 格式密钥'));
+              return;
+            }
+          } else if (!isEdit) {
+            showInfo(t('请输入密钥!'));
             return;
             return;
-          } else {
-            // 编辑模式且未上传新密钥,不修改 key
-            delete localInputs.key;
           }
           }
         } else {
         } else {
-          // 有新密钥,则覆盖
-          if (batch) {
-            localInputs.key = JSON.stringify(keys);
+          // 文件上传模式
+          let keys = vertexKeys;
+          if (keys.length === 0 && vertexFileList.length > 0) {
+            try {
+              const parsed = await Promise.all(
+                vertexFileList.map(async (item) => {
+                  const fileObj = item.fileInstance;
+                  if (!fileObj) return null;
+                  const txt = await fileObj.text();
+                  return JSON.parse(txt);
+                }),
+              );
+              keys = parsed.filter(Boolean);
+            } catch (err) {
+              showError(t('解析密钥文件失败: {{msg}}', { msg: err.message }));
+              return;
+            }
+          }
+          if (keys.length === 0) {
+            if (!isEdit) {
+              showInfo(t('请上传密钥文件!'));
+              return;
+            } else {
+              delete localInputs.key;
+            }
           } else {
           } else {
-            localInputs.key = JSON.stringify(keys[0]);
+            localInputs.key = batch ? JSON.stringify(keys) : JSON.stringify(keys[0]);
           }
           }
         }
         }
       }
       }
@@ -853,6 +858,8 @@ const EditChannelModal = (props) => {
     delete localInputs.pass_through_body_enabled;
     delete localInputs.pass_through_body_enabled;
     delete localInputs.system_prompt;
     delete localInputs.system_prompt;
     delete localInputs.system_prompt_override;
     delete localInputs.system_prompt_override;
+    // 顶层的 vertex_key_type 不应发送给后端
+    delete localInputs.vertex_key_type;
 
 
     let res;
     let res;
     localInputs.auto_ban = localInputs.auto_ban ? 1 : 0;
     localInputs.auto_ban = localInputs.auto_ban ? 1 : 0;
@@ -1178,8 +1185,40 @@ const EditChannelModal = (props) => {
                     autoComplete='new-password'
                     autoComplete='new-password'
                   />
                   />
 
 
+                  {inputs.type === 41 && (
+                    <Form.Select
+                      field='vertex_key_type'
+                      label={t('密钥格式')}
+                      placeholder={t('请选择密钥格式')}
+                      optionList={[
+                        { label: 'JSON', value: 'json' },
+                        { label: 'API Key', value: 'api_key' },
+                      ]}
+                      style={{ width: '100%' }}
+                      value={inputs.vertex_key_type || 'json'}
+                      onChange={(value) => {
+                        // 更新设置中的 vertex_key_type
+                        handleChannelOtherSettingsChange('vertex_key_type', value);
+                        // 切换为 api_key 时,关闭批量与手动/文件切换,并清理已选文件
+                        if (value === 'api_key') {
+                          setBatch(false);
+                          setUseManualInput(false);
+                          setVertexKeys([]);
+                          setVertexFileList([]);
+                          if (formApiRef.current) {
+                            formApiRef.current.setValue('vertex_files', []);
+                          }
+                        }
+                      }}
+                      extraText={
+                        inputs.vertex_key_type === 'api_key'
+                          ? t('API Key 模式下不支持批量创建')
+                          : t('JSON 模式支持手动输入或上传服务账号 JSON')
+                      }
+                    />
+                  )}
                   {batch ? (
                   {batch ? (
-                    inputs.type === 41 ? (
+                    inputs.type === 41 && (inputs.vertex_key_type || 'json') === 'json' ? (
                       <Form.Upload
                       <Form.Upload
                         field='vertex_files'
                         field='vertex_files'
                         label={t('密钥文件 (.json)')}
                         label={t('密钥文件 (.json)')}
@@ -1243,7 +1282,7 @@ const EditChannelModal = (props) => {
                     )
                     )
                   ) : (
                   ) : (
                     <>
                     <>
-                      {inputs.type === 41 ? (
+                      {inputs.type === 41 && (inputs.vertex_key_type || 'json') === 'json' ? (
                         <>
                         <>
                           {!batch && (
                           {!batch && (
                             <div className='flex items-center justify-between mb-3'>
                             <div className='flex items-center justify-between mb-3'>