|
@@ -118,6 +118,7 @@ func responseXunfei2OpenAI(response *XunfeiChatResponse) *OpenAITextResponse {
|
|
|
Role: "assistant",
|
|
Role: "assistant",
|
|
|
Content: response.Payload.Choices.Text[0].Content,
|
|
Content: response.Payload.Choices.Text[0].Content,
|
|
|
},
|
|
},
|
|
|
|
|
+ FinishReason: stopFinishReason,
|
|
|
}
|
|
}
|
|
|
fullTextResponse := OpenAITextResponse{
|
|
fullTextResponse := OpenAITextResponse{
|
|
|
Object: "chat.completion",
|
|
Object: "chat.completion",
|
|
@@ -177,33 +178,82 @@ func buildXunfeiAuthUrl(hostUrl string, apiKey, apiSecret string) string {
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
func xunfeiStreamHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*OpenAIErrorWithStatusCode, *Usage) {
|
|
func xunfeiStreamHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*OpenAIErrorWithStatusCode, *Usage) {
|
|
|
|
|
+ domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret)
|
|
|
|
|
+ dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId)
|
|
|
|
|
+ if err != nil {
|
|
|
|
|
+ return errorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil
|
|
|
|
|
+ }
|
|
|
|
|
+ setEventStreamHeaders(c)
|
|
|
var usage Usage
|
|
var usage Usage
|
|
|
- query := c.Request.URL.Query()
|
|
|
|
|
- apiVersion := query.Get("api-version")
|
|
|
|
|
- if apiVersion == "" {
|
|
|
|
|
- apiVersion = c.GetString("api_version")
|
|
|
|
|
|
|
+ c.Stream(func(w io.Writer) bool {
|
|
|
|
|
+ select {
|
|
|
|
|
+ case xunfeiResponse := <-dataChan:
|
|
|
|
|
+ usage.PromptTokens += xunfeiResponse.Payload.Usage.Text.PromptTokens
|
|
|
|
|
+ usage.CompletionTokens += xunfeiResponse.Payload.Usage.Text.CompletionTokens
|
|
|
|
|
+ usage.TotalTokens += xunfeiResponse.Payload.Usage.Text.TotalTokens
|
|
|
|
|
+ response := streamResponseXunfei2OpenAI(&xunfeiResponse)
|
|
|
|
|
+ jsonResponse, err := json.Marshal(response)
|
|
|
|
|
+ if err != nil {
|
|
|
|
|
+ common.SysError("error marshalling stream response: " + err.Error())
|
|
|
|
|
+ return true
|
|
|
|
|
+ }
|
|
|
|
|
+ c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
|
|
|
|
|
+ return true
|
|
|
|
|
+ case <-stopChan:
|
|
|
|
|
+ c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
|
|
|
|
+ return false
|
|
|
|
|
+ }
|
|
|
|
|
+ })
|
|
|
|
|
+ return nil, &usage
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+func xunfeiHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*OpenAIErrorWithStatusCode, *Usage) {
|
|
|
|
|
+ domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret)
|
|
|
|
|
+ dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId)
|
|
|
|
|
+ if err != nil {
|
|
|
|
|
+ return errorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil
|
|
|
}
|
|
}
|
|
|
- if apiVersion == "" {
|
|
|
|
|
- apiVersion = "v1.1"
|
|
|
|
|
- common.SysLog("api_version not found, use default: " + apiVersion)
|
|
|
|
|
|
|
+ var usage Usage
|
|
|
|
|
+ var content string
|
|
|
|
|
+ var xunfeiResponse XunfeiChatResponse
|
|
|
|
|
+ stop := false
|
|
|
|
|
+ for !stop {
|
|
|
|
|
+ select {
|
|
|
|
|
+ case xunfeiResponse = <-dataChan:
|
|
|
|
|
+ content += xunfeiResponse.Payload.Choices.Text[0].Content
|
|
|
|
|
+ usage.PromptTokens += xunfeiResponse.Payload.Usage.Text.PromptTokens
|
|
|
|
|
+ usage.CompletionTokens += xunfeiResponse.Payload.Usage.Text.CompletionTokens
|
|
|
|
|
+ usage.TotalTokens += xunfeiResponse.Payload.Usage.Text.TotalTokens
|
|
|
|
|
+ case stop = <-stopChan:
|
|
|
|
|
+ }
|
|
|
}
|
|
}
|
|
|
- domain := "general"
|
|
|
|
|
- if apiVersion == "v2.1" {
|
|
|
|
|
- domain = "generalv2"
|
|
|
|
|
|
|
+
|
|
|
|
|
+ xunfeiResponse.Payload.Choices.Text[0].Content = content
|
|
|
|
|
+
|
|
|
|
|
+ response := responseXunfei2OpenAI(&xunfeiResponse)
|
|
|
|
|
+ jsonResponse, err := json.Marshal(response)
|
|
|
|
|
+ if err != nil {
|
|
|
|
|
+ return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
|
|
}
|
|
}
|
|
|
- hostUrl := fmt.Sprintf("wss://spark-api.xf-yun.com/%s/chat", apiVersion)
|
|
|
|
|
|
|
+ c.Writer.Header().Set("Content-Type", "application/json")
|
|
|
|
|
+ _, _ = c.Writer.Write(jsonResponse)
|
|
|
|
|
+ return nil, &usage
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+func xunfeiMakeRequest(textRequest GeneralOpenAIRequest, domain, authUrl, appId string) (chan XunfeiChatResponse, chan bool, error) {
|
|
|
d := websocket.Dialer{
|
|
d := websocket.Dialer{
|
|
|
HandshakeTimeout: 5 * time.Second,
|
|
HandshakeTimeout: 5 * time.Second,
|
|
|
}
|
|
}
|
|
|
- conn, resp, err := d.Dial(buildXunfeiAuthUrl(hostUrl, apiKey, apiSecret), nil)
|
|
|
|
|
|
|
+ conn, resp, err := d.Dial(authUrl, nil)
|
|
|
if err != nil || resp.StatusCode != 101 {
|
|
if err != nil || resp.StatusCode != 101 {
|
|
|
- return errorWrapper(err, "dial_failed", http.StatusInternalServerError), nil
|
|
|
|
|
|
|
+ return nil, nil, err
|
|
|
}
|
|
}
|
|
|
data := requestOpenAI2Xunfei(textRequest, appId, domain)
|
|
data := requestOpenAI2Xunfei(textRequest, appId, domain)
|
|
|
err = conn.WriteJSON(data)
|
|
err = conn.WriteJSON(data)
|
|
|
if err != nil {
|
|
if err != nil {
|
|
|
- return errorWrapper(err, "write_json_failed", http.StatusInternalServerError), nil
|
|
|
|
|
|
|
+ return nil, nil, err
|
|
|
}
|
|
}
|
|
|
|
|
+
|
|
|
dataChan := make(chan XunfeiChatResponse)
|
|
dataChan := make(chan XunfeiChatResponse)
|
|
|
stopChan := make(chan bool)
|
|
stopChan := make(chan bool)
|
|
|
go func() {
|
|
go func() {
|
|
@@ -230,61 +280,24 @@ func xunfeiStreamHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId
|
|
|
}
|
|
}
|
|
|
stopChan <- true
|
|
stopChan <- true
|
|
|
}()
|
|
}()
|
|
|
- setEventStreamHeaders(c)
|
|
|
|
|
- c.Stream(func(w io.Writer) bool {
|
|
|
|
|
- select {
|
|
|
|
|
- case xunfeiResponse := <-dataChan:
|
|
|
|
|
- usage.PromptTokens += xunfeiResponse.Payload.Usage.Text.PromptTokens
|
|
|
|
|
- usage.CompletionTokens += xunfeiResponse.Payload.Usage.Text.CompletionTokens
|
|
|
|
|
- usage.TotalTokens += xunfeiResponse.Payload.Usage.Text.TotalTokens
|
|
|
|
|
- response := streamResponseXunfei2OpenAI(&xunfeiResponse)
|
|
|
|
|
- jsonResponse, err := json.Marshal(response)
|
|
|
|
|
- if err != nil {
|
|
|
|
|
- common.SysError("error marshalling stream response: " + err.Error())
|
|
|
|
|
- return true
|
|
|
|
|
- }
|
|
|
|
|
- c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
|
|
|
|
|
- return true
|
|
|
|
|
- case <-stopChan:
|
|
|
|
|
- c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
|
|
|
|
- return false
|
|
|
|
|
- }
|
|
|
|
|
- })
|
|
|
|
|
- return nil, &usage
|
|
|
|
|
|
|
+
|
|
|
|
|
+ return dataChan, stopChan, nil
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
-func xunfeiHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
|
|
|
|
|
- var xunfeiResponse XunfeiChatResponse
|
|
|
|
|
- responseBody, err := io.ReadAll(resp.Body)
|
|
|
|
|
- if err != nil {
|
|
|
|
|
- return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
|
|
|
|
- }
|
|
|
|
|
- err = resp.Body.Close()
|
|
|
|
|
- if err != nil {
|
|
|
|
|
- return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
|
|
|
|
- }
|
|
|
|
|
- err = json.Unmarshal(responseBody, &xunfeiResponse)
|
|
|
|
|
- if err != nil {
|
|
|
|
|
- return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
|
|
|
|
|
|
+func getXunfeiAuthUrl(c *gin.Context, apiKey string, apiSecret string) (string, string) {
|
|
|
|
|
+ query := c.Request.URL.Query()
|
|
|
|
|
+ apiVersion := query.Get("api-version")
|
|
|
|
|
+ if apiVersion == "" {
|
|
|
|
|
+ apiVersion = c.GetString("api_version")
|
|
|
}
|
|
}
|
|
|
- if xunfeiResponse.Header.Code != 0 {
|
|
|
|
|
- return &OpenAIErrorWithStatusCode{
|
|
|
|
|
- OpenAIError: OpenAIError{
|
|
|
|
|
- Message: xunfeiResponse.Header.Message,
|
|
|
|
|
- Type: "xunfei_error",
|
|
|
|
|
- Param: "",
|
|
|
|
|
- Code: xunfeiResponse.Header.Code,
|
|
|
|
|
- },
|
|
|
|
|
- StatusCode: resp.StatusCode,
|
|
|
|
|
- }, nil
|
|
|
|
|
|
|
+ if apiVersion == "" {
|
|
|
|
|
+ apiVersion = "v1.1"
|
|
|
|
|
+ common.SysLog("api_version not found, use default: " + apiVersion)
|
|
|
}
|
|
}
|
|
|
- fullTextResponse := responseXunfei2OpenAI(&xunfeiResponse)
|
|
|
|
|
- jsonResponse, err := json.Marshal(fullTextResponse)
|
|
|
|
|
- if err != nil {
|
|
|
|
|
- return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
|
|
|
|
|
|
+ domain := "general"
|
|
|
|
|
+ if apiVersion == "v2.1" {
|
|
|
|
|
+ domain = "generalv2"
|
|
|
}
|
|
}
|
|
|
- c.Writer.Header().Set("Content-Type", "application/json")
|
|
|
|
|
- c.Writer.WriteHeader(resp.StatusCode)
|
|
|
|
|
- _, err = c.Writer.Write(jsonResponse)
|
|
|
|
|
- return nil, &fullTextResponse.Usage
|
|
|
|
|
|
|
+ authUrl := buildXunfeiAuthUrl(fmt.Sprintf("wss://spark-api.xf-yun.com/%s/chat", apiVersion), apiKey, apiSecret)
|
|
|
|
|
+ return domain, authUrl
|
|
|
}
|
|
}
|