Просмотр исходного кода

feat: 支持部分渠道的system角色 (close #89)

1808837298@qq.com 2 лет назад
Родитель
Сommit
3ab4f145db

+ 7 - 6
relay/channel/ali/dto.go

@@ -1,8 +1,8 @@
 package ali
 
 type AliMessage struct {
-	User string `json:"user"`
-	Bot  string `json:"bot"`
+	Content string `json:"content"`
+	Role    string `json:"role"`
 }
 
 type AliInput struct {
@@ -11,10 +11,11 @@ type AliInput struct {
 }
 
 type AliParameters struct {
-	TopP         float64 `json:"top_p,omitempty"`
-	TopK         int     `json:"top_k,omitempty"`
-	Seed         uint64  `json:"seed,omitempty"`
-	EnableSearch bool    `json:"enable_search,omitempty"`
+	TopP              float64 `json:"top_p,omitempty"`
+	TopK              int     `json:"top_k,omitempty"`
+	Seed              uint64  `json:"seed,omitempty"`
+	EnableSearch      bool    `json:"enable_search,omitempty"`
+	IncrementalOutput bool    `json:"incremental_output,omitempty"`
 }
 
 type AliChatRequest struct {

+ 17 - 23
relay/channel/ali/relay-ali.go

@@ -14,28 +14,23 @@ import (
 
 // https://help.aliyun.com/document_detail/613695.html?spm=a2c4g.2399480.0.0.1adb778fAdzP9w#341800c0f8w0r
 
+const EnableSearchModelSuffix = "-internet"
+
 func requestOpenAI2Ali(request dto.GeneralOpenAIRequest) *AliChatRequest {
 	messages := make([]AliMessage, 0, len(request.Messages))
 	prompt := ""
 	for i := 0; i < len(request.Messages); i++ {
 		message := request.Messages[i]
-		if message.Role == "system" {
-			messages = append(messages, AliMessage{
-				User: message.StringContent(),
-				Bot:  "Okay",
-			})
-			continue
-		} else {
-			if i == len(request.Messages)-1 {
-				prompt = message.StringContent()
-				break
-			}
-			messages = append(messages, AliMessage{
-				User: message.StringContent(),
-				Bot:  string(request.Messages[i+1].Content),
-			})
-			i++
-		}
+		messages = append(messages, AliMessage{
+			Content: message.StringContent(),
+			Role:    strings.ToLower(message.Role),
+		})
+	}
+	enableSearch := false
+	aliModel := request.Model
+	if strings.HasSuffix(aliModel, EnableSearchModelSuffix) {
+		enableSearch = true
+		aliModel = strings.TrimSuffix(aliModel, EnableSearchModelSuffix)
 	}
 	return &AliChatRequest{
 		Model: request.Model,
@@ -43,12 +38,11 @@ func requestOpenAI2Ali(request dto.GeneralOpenAIRequest) *AliChatRequest {
 			Prompt:  prompt,
 			History: messages,
 		},
-		//Parameters: AliParameters{  // ChatGPT's parameters are not compatible with Ali's
-		//	TopP: request.TopP,
-		//	TopK: 50,
-		//	//Seed:         0,
-		//	//EnableSearch: false,
-		//},
+		Parameters: AliParameters{
+			IncrementalOutput: request.Stream,
+			Seed:              uint64(request.Seed),
+			EnableSearch:      enableSearch,
+		},
 	}
 }
 

+ 4 - 15
relay/channel/baidu/relay-baidu.go

@@ -24,21 +24,10 @@ var baiduTokenStore sync.Map
 func requestOpenAI2Baidu(request dto.GeneralOpenAIRequest) *BaiduChatRequest {
 	messages := make([]BaiduMessage, 0, len(request.Messages))
 	for _, message := range request.Messages {
-		if message.Role == "system" {
-			messages = append(messages, BaiduMessage{
-				Role:    "user",
-				Content: message.StringContent(),
-			})
-			messages = append(messages, BaiduMessage{
-				Role:    "assistant",
-				Content: "Okay",
-			})
-		} else {
-			messages = append(messages, BaiduMessage{
-				Role:    message.Role,
-				Content: message.StringContent(),
-			})
-		}
+		messages = append(messages, BaiduMessage{
+			Role:    message.Role,
+			Content: message.StringContent(),
+		})
 	}
 	return &BaiduChatRequest{
 		Messages: messages,

+ 4 - 4
relay/channel/openai/adaptor.go

@@ -50,10 +50,10 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re
 		return nil
 	}
 	req.Header.Set("Authorization", "Bearer "+info.ApiKey)
-	if info.ChannelType == common.ChannelTypeOpenRouter {
-		req.Header.Set("HTTP-Referer", "https://github.com/songquanpeng/one-api")
-		req.Header.Set("X-Title", "One API")
-	}
+	//if info.ChannelType == common.ChannelTypeOpenRouter {
+	//	req.Header.Set("HTTP-Referer", "https://github.com/songquanpeng/one-api")
+	//	req.Header.Set("X-Title", "One API")
+	//}
 	return nil
 }
 

+ 39 - 14
relay/channel/xunfei/relay-xunfei.go

@@ -24,8 +24,9 @@ import (
 
 func requestOpenAI2Xunfei(request dto.GeneralOpenAIRequest, xunfeiAppId string, domain string) *XunfeiChatRequest {
 	messages := make([]XunfeiMessage, 0, len(request.Messages))
+	shouldCovertSystemMessage := !strings.HasSuffix(request.Model, "3.5")
 	for _, message := range request.Messages {
-		if message.Role == "system" {
+		if message.Role == "system" && shouldCovertSystemMessage {
 			messages = append(messages, XunfeiMessage{
 				Role:    "user",
 				Content: message.StringContent(),
@@ -126,7 +127,7 @@ func buildXunfeiAuthUrl(hostUrl string, apiKey, apiSecret string) string {
 }
 
 func xunfeiStreamHandler(c *gin.Context, textRequest dto.GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
-	domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret)
+	domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret, textRequest.Model)
 	dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId)
 	if err != nil {
 		return service.OpenAIErrorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil
@@ -156,7 +157,7 @@ func xunfeiStreamHandler(c *gin.Context, textRequest dto.GeneralOpenAIRequest, a
 }
 
 func xunfeiHandler(c *gin.Context, textRequest dto.GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
-	domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret)
+	domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret, textRequest.Model)
 	dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId)
 	if err != nil {
 		return service.OpenAIErrorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil
@@ -235,20 +236,44 @@ func xunfeiMakeRequest(textRequest dto.GeneralOpenAIRequest, domain, authUrl, ap
 	return dataChan, stopChan, nil
 }
 
-func getXunfeiAuthUrl(c *gin.Context, apiKey string, apiSecret string) (string, string) {
+func apiVersion2domain(apiVersion string) string {
+	switch apiVersion {
+	case "v1.1":
+		return "general"
+	case "v2.1":
+		return "generalv2"
+	case "v3.1":
+		return "generalv3"
+	case "v3.5":
+		return "generalv3.5"
+	}
+	return "general" + apiVersion
+}
+
+func getXunfeiAuthUrl(c *gin.Context, apiKey string, apiSecret string, modelName string) (string, string) {
+	apiVersion := getAPIVersion(c, modelName)
+	domain := apiVersion2domain(apiVersion)
+	authUrl := buildXunfeiAuthUrl(fmt.Sprintf("wss://spark-api.xf-yun.com/%s/chat", apiVersion), apiKey, apiSecret)
+	return domain, authUrl
+}
+
+func getAPIVersion(c *gin.Context, modelName string) string {
 	query := c.Request.URL.Query()
 	apiVersion := query.Get("api-version")
-	if apiVersion == "" {
-		apiVersion = c.GetString("api_version")
+	if apiVersion != "" {
+		return apiVersion
 	}
-	if apiVersion == "" {
-		apiVersion = "v1.1"
-		common.SysLog("api_version not found, use default: " + apiVersion)
+	parts := strings.Split(modelName, "-")
+	if len(parts) == 2 {
+		apiVersion = parts[1]
+		return apiVersion
+
 	}
-	domain := "general"
-	if apiVersion != "v1.1" {
-		domain += strings.Split(apiVersion, ".")[0]
+	apiVersion = c.GetString("api_version")
+	if apiVersion != "" {
+		return apiVersion
 	}
-	authUrl := buildXunfeiAuthUrl(fmt.Sprintf("wss://spark-api.xf-yun.com/%s/chat", apiVersion), apiKey, apiSecret)
-	return domain, authUrl
+	apiVersion = "v1.1"
+	common.SysLog("api_version not found, using default: " + apiVersion)
+	return apiVersion
 }

+ 6 - 3
web/src/pages/Channel/EditChannel.js

@@ -72,13 +72,13 @@ const EditChannel = (props) => {
                     localModels = ['ERNIE-Bot', 'ERNIE-Bot-turbo', 'ERNIE-Bot-4', 'Embedding-V1'];
                     break;
                 case 17:
-                    localModels = ['qwen-turbo', 'qwen-plus', 'text-embedding-v1'];
+                    localModels = ["qwen-turbo", "qwen-plus", "qwen-max", "qwen-max-longcontext", 'text-embedding-v1'];
                     break;
                 case 16:
                     localModels = ['chatglm_pro', 'chatglm_std', 'chatglm_lite'];
                     break;
                 case 18:
-                    localModels = ['SparkDesk'];
+                    localModels = ['SparkDesk', 'SparkDesk-v1.1', 'SparkDesk-v2.1', 'SparkDesk-v3.1', 'SparkDesk-v3.5'];
                     break;
                 case 19:
                     localModels = ['360GPT_S2_V9', 'embedding-bert-512-v1', 'embedding_s1_v1', 'semantic_similarity_s1_v1'];
@@ -87,7 +87,10 @@ const EditChannel = (props) => {
                     localModels = ['hunyuan'];
                     break;
                 case 24:
-                    localModels = ['gemini-pro'];
+                    localModels = ['gemini-pro', 'gemini-pro-vision'];
+                    break;
+                case 25:
+                    localModels = ['moonshot-v1-8k', 'moonshot-v1-32k', 'moonshot-v1-128k'];
                     break;
                 case 26:
                     localModels = ['glm-4', 'glm-4v', 'glm-3-turbo'];