فهرست منبع

feat: support xunfei's v2 api (#442, close #440)

* 兼容讯飞v2接口

* Revert "兼容讯飞v2接口"

This reverts commit 21f05d1294b8693d0a21664a23ec04f028b9b117.

* fix: fix implementation

---------

Co-authored-by: JustSong <songquanpeng@foxmail.com>
Co-authored-by: JustSong <39998050+songquanpeng@users.noreply.github.com>
滔哥 2 سال پیش
والد
کامیت
7e058bfb9b
4فایلهای تغییر یافته به همراه37 افزوده شده و 5 حذف شده
  1. 17 4
      controller/relay-xunfei.go
  2. 2 0
      i18n/en.json
  3. 1 1
      middleware/distributor.go
  4. 17 0
      web/src/pages/Channel/EditChannel.js

+ 17 - 4
controller/relay-xunfei.go

@@ -75,7 +75,7 @@ type XunfeiChatResponse struct {
 	} `json:"payload"`
 }
 
-func requestOpenAI2Xunfei(request GeneralOpenAIRequest, xunfeiAppId string) *XunfeiChatRequest {
+func requestOpenAI2Xunfei(request GeneralOpenAIRequest, xunfeiAppId string, domain string) *XunfeiChatRequest {
 	messages := make([]XunfeiMessage, 0, len(request.Messages))
 	for _, message := range request.Messages {
 		if message.Role == "system" {
@@ -96,7 +96,7 @@ func requestOpenAI2Xunfei(request GeneralOpenAIRequest, xunfeiAppId string) *Xun
 	}
 	xunfeiRequest := XunfeiChatRequest{}
 	xunfeiRequest.Header.AppId = xunfeiAppId
-	xunfeiRequest.Parameter.Chat.Domain = "general"
+	xunfeiRequest.Parameter.Chat.Domain = domain
 	xunfeiRequest.Parameter.Chat.Temperature = request.Temperature
 	xunfeiRequest.Parameter.Chat.TopK = request.N
 	xunfeiRequest.Parameter.Chat.MaxTokens = request.MaxTokens
@@ -178,15 +178,28 @@ func buildXunfeiAuthUrl(hostUrl string, apiKey, apiSecret string) string {
 
 func xunfeiStreamHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*OpenAIErrorWithStatusCode, *Usage) {
 	var usage Usage
+	query := c.Request.URL.Query()
+	apiVersion := query.Get("api-version")
+	if apiVersion == "" {
+		apiVersion = c.GetString("api_version")
+	}
+	if apiVersion == "" {
+		apiVersion = "v1.1"
+		common.SysLog("api_version not found, use default: " + apiVersion)
+	}
+	domain := "general"
+	if apiVersion == "v2.1" {
+		domain = "generalv2"
+	}
+	hostUrl := fmt.Sprintf("wss://spark-api.xf-yun.com/%s/chat", apiVersion)
 	d := websocket.Dialer{
 		HandshakeTimeout: 5 * time.Second,
 	}
-	hostUrl := "wss://aichat.xf-yun.com/v1/chat"
 	conn, resp, err := d.Dial(buildXunfeiAuthUrl(hostUrl, apiKey, apiSecret), nil)
 	if err != nil || resp.StatusCode != 101 {
 		return errorWrapper(err, "dial_failed", http.StatusInternalServerError), nil
 	}
-	data := requestOpenAI2Xunfei(textRequest, appId)
+	data := requestOpenAI2Xunfei(textRequest, appId, domain)
 	err = conn.WriteJSON(data)
 	if err != nil {
 		return errorWrapper(err, "write_json_failed", http.StatusInternalServerError), nil

+ 2 - 0
i18n/en.json

@@ -521,5 +521,7 @@
   "此项可选,用于通过代理站来进行 API 调用,请输入代理站地址,格式为:https://domain.com": "This is optional, used to make API calls through the proxy site, please enter the proxy site address, the format is: https://domain.com",
   "取消密码登录将导致所有未绑定其他登录方式的用户(包括管理员)无法通过密码登录,确认取消?": "Canceling password login will cause all users (including administrators) who have not bound other login methods to be unable to log in via password, confirm cancel?",
   "按照如下格式输入:": "Enter in the following format:",
+  "模型版本": "Model version",
+  "请输入星火大模型版本,注意是接口地址中的版本号,例如:v2.1": "Please enter the version of the Starfire model, note that it is the version number in the interface address, for example: v2.1",
   "点击查看": "click to view"
 }

+ 1 - 1
middleware/distributor.go

@@ -107,7 +107,7 @@ func Distribute() func(c *gin.Context) {
 		c.Set("model_mapping", channel.ModelMapping)
 		c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
 		c.Set("base_url", channel.BaseURL)
-		if channel.Type == common.ChannelTypeAzure {
+		if channel.Type == common.ChannelTypeAzure || channel.Type == common.ChannelTypeXunfei {
 			c.Set("api_version", channel.Other)
 		}
 		c.Next()

+ 17 - 0
web/src/pages/Channel/EditChannel.js

@@ -163,6 +163,9 @@ const EditChannel = () => {
     if (localInputs.type === 3 && localInputs.other === '') {
       localInputs.other = '2023-06-01-preview';
     }
+    if (localInputs.type === 18 && localInputs.other === '') {
+      localInputs.other = 'v2.1';
+    }
     if (localInputs.model_mapping === '') {
       localInputs.model_mapping = '{}';
     }
@@ -275,6 +278,20 @@ const EditChannel = () => {
               options={groupOptions}
             />
           </Form.Field>
+          {
+            inputs.type === 18 && (
+              <Form.Field>
+                <Form.Input
+                  label='模型版本'
+                  name='other'
+                  placeholder={'请输入星火大模型版本,注意是接口地址中的版本号,例如:v2.1'}
+                  onChange={handleInputChange}
+                  value={inputs.other}
+                  autoComplete='new-password'
+                />
+              </Form.Field>
+            )
+          }
           <Form.Field>
             <Form.Dropdown
               label='模型'