Просмотр исходного кода

feat: support specific default api version now (#57)

JustSong 2 лет назад
Родитель
Сommit
83e86b9f8a
4 измененных файлов с 18 добавлено и 2 удалено
  1. 2 1
      controller/relay.go
  2. 3 0
      middleware/distributor.go
  3. 1 0
      model/channel.go
  4. 12 1
      web/src/pages/Channel/EditChannel.js

+ 2 - 1
controller/relay.go

@@ -95,7 +95,8 @@ func relayHelper(c *gin.Context) error {
 		// https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=rest-api&tabs=command-line#rest-api
 		// https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=rest-api&tabs=command-line#rest-api
 		query := c.Request.URL.Query()
 		query := c.Request.URL.Query()
 		if query.Get("api-version") == "" {
 		if query.Get("api-version") == "" {
-			requestURL = fmt.Sprintf("%s?api-version=2023-03-15-preview", requestURL)
+			apiVersion := c.GetString("api_version")
+			requestURL = fmt.Sprintf("%s?api-version=%s", requestURL, apiVersion)
 		}
 		}
 		baseURL = c.GetString("base_url")
 		baseURL = c.GetString("base_url")
 		task := strings.TrimPrefix(requestURL, "/v1/")
 		task := strings.TrimPrefix(requestURL, "/v1/")

+ 3 - 0
middleware/distributor.go

@@ -65,6 +65,9 @@ func Distribute() func(c *gin.Context) {
 		c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
 		c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
 		if channel.Type == common.ChannelTypeCustom || channel.Type == common.ChannelTypeAzure {
 		if channel.Type == common.ChannelTypeCustom || channel.Type == common.ChannelTypeAzure {
 			c.Set("base_url", channel.BaseURL)
 			c.Set("base_url", channel.BaseURL)
+			if channel.Type == common.ChannelTypeAzure {
+				c.Set("api_version", channel.Other)
+			}
 		}
 		}
 		c.Next()
 		c.Next()
 	}
 	}

+ 1 - 0
model/channel.go

@@ -15,6 +15,7 @@ type Channel struct {
 	CreatedTime  int64  `json:"created_time" gorm:"bigint"`
 	CreatedTime  int64  `json:"created_time" gorm:"bigint"`
 	AccessedTime int64  `json:"accessed_time" gorm:"bigint"`
 	AccessedTime int64  `json:"accessed_time" gorm:"bigint"`
 	BaseURL      string `json:"base_url" gorm:"column:base_url"`
 	BaseURL      string `json:"base_url" gorm:"column:base_url"`
+	Other        string `json:"other"`
 }
 }
 
 
 func GetAllChannels(startIdx int, num int) ([]*Channel, error) {
 func GetAllChannels(startIdx int, num int) ([]*Channel, error) {

+ 12 - 1
web/src/pages/Channel/EditChannel.js

@@ -13,7 +13,8 @@ const EditChannel = () => {
     name: '',
     name: '',
     type: 1,
     type: 1,
     key: '',
     key: '',
-    base_url: ''
+    base_url: '',
+    other: ''
   };
   };
   const [inputs, setInputs] = useState(originInputs);
   const [inputs, setInputs] = useState(originInputs);
   const handleInputChange = (e, { name, value }) => {
   const handleInputChange = (e, { name, value }) => {
@@ -92,6 +93,16 @@ const EditChannel = () => {
                     autoComplete='new-password'
                     autoComplete='new-password'
                   />
                   />
                 </Form.Field>
                 </Form.Field>
+                <Form.Field>
+                  <Form.Input
+                    label='默认 API 版本'
+                    name='other'
+                    placeholder={'请输入默认 API 版本,例如:2023-03-15-preview,该配置可以被实际的请求查询参数所覆盖'}
+                    onChange={handleInputChange}
+                    value={inputs.other}
+                    autoComplete='new-password'
+                  />
+                </Form.Field>
               </>
               </>
             )
             )
           }
           }