Przeglądaj źródła

feat: support non-stream mode for xunfei (#498)

* feat:xunfei suport none stream

* fix:join content ignore seq

---------

Co-authored-by: igophper <admin@jialilgu.cn>
igophper 2 lat temu
rodzic
commit
24df3e5f62
2 zmienionych plików z 96 dodań i 81 usunięć
  1. 17 15
      controller/relay-text.go
  2. 79 66
      controller/relay-xunfei.go

+ 17 - 15
controller/relay-text.go

@@ -541,24 +541,26 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 			return nil
 		}
 	case APITypeXunfei:
+		auth := c.Request.Header.Get("Authorization")
+		auth = strings.TrimPrefix(auth, "Bearer ")
+		splits := strings.Split(auth, "|")
+		if len(splits) != 3 {
+			return errorWrapper(errors.New("invalid auth"), "invalid_auth", http.StatusBadRequest)
+		}
+		var err *OpenAIErrorWithStatusCode
+		var usage *Usage
 		if isStream {
-			auth := c.Request.Header.Get("Authorization")
-			auth = strings.TrimPrefix(auth, "Bearer ")
-			splits := strings.Split(auth, "|")
-			if len(splits) != 3 {
-				return errorWrapper(errors.New("invalid auth"), "invalid_auth", http.StatusBadRequest)
-			}
-			err, usage := xunfeiStreamHandler(c, textRequest, splits[0], splits[1], splits[2])
-			if err != nil {
-				return err
-			}
-			if usage != nil {
-				textResponse.Usage = *usage
-			}
-			return nil
+			err, usage = xunfeiStreamHandler(c, textRequest, splits[0], splits[1], splits[2])
 		} else {
-			return errorWrapper(errors.New("xunfei api does not support non-stream mode"), "invalid_api_type", http.StatusBadRequest)
+			err, usage = xunfeiHandler(c, textRequest, splits[0], splits[1], splits[2])
+		}
+		if err != nil {
+			return err
+		}
+		if usage != nil {
+			textResponse.Usage = *usage
 		}
+		return nil
 	case APITypeAIProxyLibrary:
 		if isStream {
 			err, usage := aiProxyLibraryStreamHandler(c, resp)

+ 79 - 66
controller/relay-xunfei.go

@@ -118,6 +118,7 @@ func responseXunfei2OpenAI(response *XunfeiChatResponse) *OpenAITextResponse {
 			Role:    "assistant",
 			Content: response.Payload.Choices.Text[0].Content,
 		},
+		FinishReason: stopFinishReason,
 	}
 	fullTextResponse := OpenAITextResponse{
 		Object:  "chat.completion",
@@ -177,33 +178,82 @@ func buildXunfeiAuthUrl(hostUrl string, apiKey, apiSecret string) string {
 }
 
 func xunfeiStreamHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*OpenAIErrorWithStatusCode, *Usage) {
+	domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret)
+	dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId)
+	if err != nil {
+		return errorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil
+	}
+	setEventStreamHeaders(c)
 	var usage Usage
-	query := c.Request.URL.Query()
-	apiVersion := query.Get("api-version")
-	if apiVersion == "" {
-		apiVersion = c.GetString("api_version")
+	c.Stream(func(w io.Writer) bool {
+		select {
+		case xunfeiResponse := <-dataChan:
+			usage.PromptTokens += xunfeiResponse.Payload.Usage.Text.PromptTokens
+			usage.CompletionTokens += xunfeiResponse.Payload.Usage.Text.CompletionTokens
+			usage.TotalTokens += xunfeiResponse.Payload.Usage.Text.TotalTokens
+			response := streamResponseXunfei2OpenAI(&xunfeiResponse)
+			jsonResponse, err := json.Marshal(response)
+			if err != nil {
+				common.SysError("error marshalling stream response: " + err.Error())
+				return true
+			}
+			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
+			return true
+		case <-stopChan:
+			c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
+			return false
+		}
+	})
+	return nil, &usage
+}
+
+func xunfeiHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*OpenAIErrorWithStatusCode, *Usage) {
+	domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret)
+	dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId)
+	if err != nil {
+		return errorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil
 	}
-	if apiVersion == "" {
-		apiVersion = "v1.1"
-		common.SysLog("api_version not found, use default: " + apiVersion)
+	var usage Usage
+	var content string
+	var xunfeiResponse XunfeiChatResponse
+	stop := false
+	for !stop {
+		select {
+		case xunfeiResponse = <-dataChan:
+			content += xunfeiResponse.Payload.Choices.Text[0].Content
+			usage.PromptTokens += xunfeiResponse.Payload.Usage.Text.PromptTokens
+			usage.CompletionTokens += xunfeiResponse.Payload.Usage.Text.CompletionTokens
+			usage.TotalTokens += xunfeiResponse.Payload.Usage.Text.TotalTokens
+		case stop = <-stopChan:
+		}
 	}
-	domain := "general"
-	if apiVersion == "v2.1" {
-		domain = "generalv2"
+
+	xunfeiResponse.Payload.Choices.Text[0].Content = content
+
+	response := responseXunfei2OpenAI(&xunfeiResponse)
+	jsonResponse, err := json.Marshal(response)
+	if err != nil {
+		return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
 	}
-	hostUrl := fmt.Sprintf("wss://spark-api.xf-yun.com/%s/chat", apiVersion)
+	c.Writer.Header().Set("Content-Type", "application/json")
+	_, _ = c.Writer.Write(jsonResponse)
+	return nil, &usage
+}
+
+func xunfeiMakeRequest(textRequest GeneralOpenAIRequest, domain, authUrl, appId string) (chan XunfeiChatResponse, chan bool, error) {
 	d := websocket.Dialer{
 		HandshakeTimeout: 5 * time.Second,
 	}
-	conn, resp, err := d.Dial(buildXunfeiAuthUrl(hostUrl, apiKey, apiSecret), nil)
+	conn, resp, err := d.Dial(authUrl, nil)
 	if err != nil || resp.StatusCode != 101 {
-		return errorWrapper(err, "dial_failed", http.StatusInternalServerError), nil
+		return nil, nil, err
 	}
 	data := requestOpenAI2Xunfei(textRequest, appId, domain)
 	err = conn.WriteJSON(data)
 	if err != nil {
-		return errorWrapper(err, "write_json_failed", http.StatusInternalServerError), nil
+		return nil, nil, err
 	}
+
 	dataChan := make(chan XunfeiChatResponse)
 	stopChan := make(chan bool)
 	go func() {
@@ -230,61 +280,24 @@ func xunfeiStreamHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId
 		}
 		stopChan <- true
 	}()
-	setEventStreamHeaders(c)
-	c.Stream(func(w io.Writer) bool {
-		select {
-		case xunfeiResponse := <-dataChan:
-			usage.PromptTokens += xunfeiResponse.Payload.Usage.Text.PromptTokens
-			usage.CompletionTokens += xunfeiResponse.Payload.Usage.Text.CompletionTokens
-			usage.TotalTokens += xunfeiResponse.Payload.Usage.Text.TotalTokens
-			response := streamResponseXunfei2OpenAI(&xunfeiResponse)
-			jsonResponse, err := json.Marshal(response)
-			if err != nil {
-				common.SysError("error marshalling stream response: " + err.Error())
-				return true
-			}
-			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
-			return true
-		case <-stopChan:
-			c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
-			return false
-		}
-	})
-	return nil, &usage
+
+	return dataChan, stopChan, nil
 }
 
-func xunfeiHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
-	var xunfeiResponse XunfeiChatResponse
-	responseBody, err := io.ReadAll(resp.Body)
-	if err != nil {
-		return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
-	}
-	err = resp.Body.Close()
-	if err != nil {
-		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
-	}
-	err = json.Unmarshal(responseBody, &xunfeiResponse)
-	if err != nil {
-		return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
+func getXunfeiAuthUrl(c *gin.Context, apiKey string, apiSecret string) (string, string) {
+	query := c.Request.URL.Query()
+	apiVersion := query.Get("api-version")
+	if apiVersion == "" {
+		apiVersion = c.GetString("api_version")
 	}
-	if xunfeiResponse.Header.Code != 0 {
-		return &OpenAIErrorWithStatusCode{
-			OpenAIError: OpenAIError{
-				Message: xunfeiResponse.Header.Message,
-				Type:    "xunfei_error",
-				Param:   "",
-				Code:    xunfeiResponse.Header.Code,
-			},
-			StatusCode: resp.StatusCode,
-		}, nil
+	if apiVersion == "" {
+		apiVersion = "v1.1"
+		common.SysLog("api_version not found, use default: " + apiVersion)
 	}
-	fullTextResponse := responseXunfei2OpenAI(&xunfeiResponse)
-	jsonResponse, err := json.Marshal(fullTextResponse)
-	if err != nil {
-		return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
+	domain := "general"
+	if apiVersion == "v2.1" {
+		domain = "generalv2"
 	}
-	c.Writer.Header().Set("Content-Type", "application/json")
-	c.Writer.WriteHeader(resp.StatusCode)
-	_, err = c.Writer.Write(jsonResponse)
-	return nil, &fullTextResponse.Usage
+	authUrl := buildXunfeiAuthUrl(fmt.Sprintf("wss://spark-api.xf-yun.com/%s/chat", apiVersion), apiKey, apiSecret)
+	return domain, authUrl
 }