|
@@ -24,8 +24,9 @@ import (
|
|
|
|
|
|
|
|
func requestOpenAI2Xunfei(request dto.GeneralOpenAIRequest, xunfeiAppId string, domain string) *XunfeiChatRequest {
|
|
func requestOpenAI2Xunfei(request dto.GeneralOpenAIRequest, xunfeiAppId string, domain string) *XunfeiChatRequest {
|
|
|
messages := make([]XunfeiMessage, 0, len(request.Messages))
|
|
messages := make([]XunfeiMessage, 0, len(request.Messages))
|
|
|
|
|
+ shouldCovertSystemMessage := !strings.HasSuffix(request.Model, "3.5")
|
|
|
for _, message := range request.Messages {
|
|
for _, message := range request.Messages {
|
|
|
- if message.Role == "system" {
|
|
|
|
|
|
|
+ if message.Role == "system" && shouldCovertSystemMessage {
|
|
|
messages = append(messages, XunfeiMessage{
|
|
messages = append(messages, XunfeiMessage{
|
|
|
Role: "user",
|
|
Role: "user",
|
|
|
Content: message.StringContent(),
|
|
Content: message.StringContent(),
|
|
@@ -126,7 +127,7 @@ func buildXunfeiAuthUrl(hostUrl string, apiKey, apiSecret string) string {
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
func xunfeiStreamHandler(c *gin.Context, textRequest dto.GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
|
func xunfeiStreamHandler(c *gin.Context, textRequest dto.GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
|
|
- domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret)
|
|
|
|
|
|
|
+ domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret, textRequest.Model)
|
|
|
dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId)
|
|
dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId)
|
|
|
if err != nil {
|
|
if err != nil {
|
|
|
return service.OpenAIErrorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil
|
|
return service.OpenAIErrorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil
|
|
@@ -156,7 +157,7 @@ func xunfeiStreamHandler(c *gin.Context, textRequest dto.GeneralOpenAIRequest, a
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
func xunfeiHandler(c *gin.Context, textRequest dto.GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
|
func xunfeiHandler(c *gin.Context, textRequest dto.GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
|
|
- domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret)
|
|
|
|
|
|
|
+ domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret, textRequest.Model)
|
|
|
dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId)
|
|
dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId)
|
|
|
if err != nil {
|
|
if err != nil {
|
|
|
return service.OpenAIErrorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil
|
|
return service.OpenAIErrorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil
|
|
@@ -235,20 +236,44 @@ func xunfeiMakeRequest(textRequest dto.GeneralOpenAIRequest, domain, authUrl, ap
|
|
|
return dataChan, stopChan, nil
|
|
return dataChan, stopChan, nil
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
-func getXunfeiAuthUrl(c *gin.Context, apiKey string, apiSecret string) (string, string) {
|
|
|
|
|
|
|
+func apiVersion2domain(apiVersion string) string {
|
|
|
|
|
+ switch apiVersion {
|
|
|
|
|
+ case "v1.1":
|
|
|
|
|
+ return "general"
|
|
|
|
|
+ case "v2.1":
|
|
|
|
|
+ return "generalv2"
|
|
|
|
|
+ case "v3.1":
|
|
|
|
|
+ return "generalv3"
|
|
|
|
|
+ case "v3.5":
|
|
|
|
|
+ return "generalv3.5"
|
|
|
|
|
+ }
|
|
|
|
|
+ return "general" + apiVersion
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+func getXunfeiAuthUrl(c *gin.Context, apiKey string, apiSecret string, modelName string) (string, string) {
|
|
|
|
|
+ apiVersion := getAPIVersion(c, modelName)
|
|
|
|
|
+ domain := apiVersion2domain(apiVersion)
|
|
|
|
|
+ authUrl := buildXunfeiAuthUrl(fmt.Sprintf("wss://spark-api.xf-yun.com/%s/chat", apiVersion), apiKey, apiSecret)
|
|
|
|
|
+ return domain, authUrl
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+func getAPIVersion(c *gin.Context, modelName string) string {
|
|
|
query := c.Request.URL.Query()
|
|
query := c.Request.URL.Query()
|
|
|
apiVersion := query.Get("api-version")
|
|
apiVersion := query.Get("api-version")
|
|
|
- if apiVersion == "" {
|
|
|
|
|
- apiVersion = c.GetString("api_version")
|
|
|
|
|
|
|
+ if apiVersion != "" {
|
|
|
|
|
+ return apiVersion
|
|
|
}
|
|
}
|
|
|
- if apiVersion == "" {
|
|
|
|
|
- apiVersion = "v1.1"
|
|
|
|
|
- common.SysLog("api_version not found, use default: " + apiVersion)
|
|
|
|
|
|
|
+ parts := strings.Split(modelName, "-")
|
|
|
|
|
+ if len(parts) == 2 {
|
|
|
|
|
+ apiVersion = parts[1]
|
|
|
|
|
+ return apiVersion
|
|
|
|
|
+
|
|
|
}
|
|
}
|
|
|
- domain := "general"
|
|
|
|
|
- if apiVersion != "v1.1" {
|
|
|
|
|
- domain += strings.Split(apiVersion, ".")[0]
|
|
|
|
|
|
|
+ apiVersion = c.GetString("api_version")
|
|
|
|
|
+ if apiVersion != "" {
|
|
|
|
|
+ return apiVersion
|
|
|
}
|
|
}
|
|
|
- authUrl := buildXunfeiAuthUrl(fmt.Sprintf("wss://spark-api.xf-yun.com/%s/chat", apiVersion), apiKey, apiSecret)
|
|
|
|
|
- return domain, authUrl
|
|
|
|
|
|
|
+ apiVersion = "v1.1"
|
|
|
|
|
+ common.SysLog("api_version not found, using default: " + apiVersion)
|
|
|
|
|
+ return apiVersion
|
|
|
}
|
|
}
|