فهرست منبع

feat: 支持vertex ai渠道多个部署地区

CalciumIon 1 سال پیش
والد
کامیت
e60f200192

+ 5 - 8
common/str.go

@@ -31,14 +31,6 @@ func MapToJsonStr(m map[string]interface{}) string {
 	return string(bytes)
 }
 
-func MapToJsonStrFloat(m map[string]float64) string {
-	bytes, err := json.Marshal(m)
-	if err != nil {
-		return ""
-	}
-	return string(bytes)
-}
-
 func StrToMap(str string) map[string]interface{} {
 	m := make(map[string]interface{})
 	err := json.Unmarshal([]byte(str), &m)
@@ -48,6 +40,11 @@ func StrToMap(str string) map[string]interface{} {
 	return m
 }
 
+func IsJsonStr(str string) bool {
+	var js map[string]interface{}
+	return json.Unmarshal([]byte(str), &js) == nil
+}
+
 func String2Int(str string) int {
 	num, err := strconv.Atoi(str)
 	if err != nil {

+ 19 - 0
controller/channel.go

@@ -199,6 +199,25 @@ func AddChannel(c *gin.Context) {
 	channel.CreatedTime = common.GetTimestamp()
 	keys := strings.Split(channel.Key, "\n")
 	if channel.Type == common.ChannelTypeVertexAi {
+		if channel.Other == "" {
+			c.JSON(http.StatusOK, gin.H{
+				"success": false,
+				"message": "部署地区不能为空",
+			})
+			return
+		} else {
+			if common.IsJsonStr(channel.Other) {
+				// must have default
+				regionMap := common.StrToMap(channel.Other)
+				if regionMap["default"] == nil {
+					c.JSON(http.StatusOK, gin.H{
+						"success": false,
+						"message": "必须包含default字段",
+					})
+					return
+				}
+			}
+		}
 		keys = []string{channel.Key}
 	}
 	channels := make([]model.Channel, 0, len(keys))

+ 7 - 6
relay/channel/vertex/adaptor.go

@@ -62,6 +62,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
 	if err := json.Unmarshal([]byte(info.ApiKey), adc); err != nil {
 		return "", fmt.Errorf("failed to decode credentials file: %w", err)
 	}
+	region := GetModelRegion(info.ApiVersion, info.OriginModelName)
 	a.AccountCredentials = *adc
 	suffix := ""
 	if a.RequestMode == RequestModeGemini {
@@ -72,9 +73,9 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
 		}
 		return fmt.Sprintf(
 			"https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:%s",
-			info.ApiVersion,
+			region,
 			adc.ProjectID,
-			info.ApiVersion,
+			region,
 			info.UpstreamModelName,
 			suffix,
 		), nil
@@ -89,18 +90,18 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
 		}
 		return fmt.Sprintf(
 			"https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:%s",
-			info.ApiVersion,
+			region,
 			adc.ProjectID,
-			info.ApiVersion,
+			region,
 			info.UpstreamModelName,
 			suffix,
 		), nil
 	} else if a.RequestMode == RequestModeLlama {
 		return fmt.Sprintf(
 			"https://%s-aiplatform.googleapis.com/v1beta1/projects/%s/locations/%s/endpoints/openapi/chat/completions",
-			info.ApiVersion,
+			region,
 			adc.ProjectID,
-			info.ApiVersion,
+			region,
 		), nil
 	}
 	return "", errors.New("unsupported request mode")

+ 16 - 0
relay/channel/vertex/relay-vertex.go

@@ -0,0 +1,16 @@
+package vertex
+
+import "one-api/common"
+
+func GetModelRegion(other string, localModelName string) string {
+	// if other is json string
+	if common.IsJsonStr(other) {
+		m := common.StrToMap(other)
+		if m[localModelName] != nil {
+			return m[localModelName].(string)
+		} else {
+			return m["default"].(string)
+		}
+	}
+	return other
+}

+ 1 - 1
relay/channel/vertex/service_account.go

@@ -83,7 +83,7 @@ func createSignedJWT(email, privateKeyPEM string) (string, error) {
 		"iss":   email,
 		"scope": "https://www.googleapis.com/auth/cloud-platform",
 		"aud":   "https://www.googleapis.com/oauth2/v4/token",
-		"exp":   now.Add(time.Minute * 30).Unix(),
+		"exp":   now.Add(time.Minute * 35).Unix(),
 		"iat":   now.Unix(),
 	}
 

+ 27 - 2
web/src/pages/Channel/EditChannel.js

@@ -37,6 +37,11 @@ const STATUS_CODE_MAPPING_EXAMPLE = {
   400: '500',
 };
 
+const REGION_EXAMPLE = {
+  "default": "us-central1",
+  "claude-3-5-sonnet-20240620": "europe-west1"
+}
+
 const fetchButtonTips = "1. 新建渠道时,请求通过当前浏览器发出;2. 编辑已有渠道,请求通过后端服务器发出"
 
 function type2secretPrompt(type) {
@@ -593,17 +598,37 @@ const EditChannel = (props) => {
               <div style={{ marginTop: 10 }}>
                 <Typography.Text strong>部署地区:</Typography.Text>
               </div>
-              <Input
+              <TextArea
                 name='other'
                 placeholder={
-                  '请输入部署地区,例如:us-central1'
+                  '请输入部署地区,例如:us-central1\n支持使用模型映射格式\n' +
+                  '{\n' +
+                  '    "default": "us-central1",\n' +
+                  '    "claude-3-5-sonnet-20240620": "europe-west1"\n' +
+                  '}'
                 }
+                autosize={{ minRows: 2 }}
                 onChange={(value) => {
                   handleInputChange('other', value);
                 }}
                 value={inputs.other}
                 autoComplete='new-password'
               />
+              <Typography.Text
+                style={{
+                  color: 'rgba(var(--semi-blue-5), 1)',
+                  userSelect: 'none',
+                  cursor: 'pointer',
+                }}
+                onClick={() => {
+                  handleInputChange(
+                    'other',
+                    JSON.stringify(REGION_EXAMPLE, null, 2),
+                  );
+                }}
+              >
+                填入模板
+              </Typography.Text>
             </>
           )}
           {inputs.type === 21 && (