relay-xunfei.go 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292
  1. package xunfei
  2. import (
  3. "crypto/hmac"
  4. "crypto/sha256"
  5. "encoding/base64"
  6. "encoding/json"
  7. "fmt"
  8. "io"
  9. "net/url"
  10. "strings"
  11. "time"
  12. "github.com/QuantumNous/new-api/common"
  13. "github.com/QuantumNous/new-api/constant"
  14. "github.com/QuantumNous/new-api/dto"
  15. "github.com/QuantumNous/new-api/relay/helper"
  16. "github.com/QuantumNous/new-api/types"
  17. "github.com/samber/lo"
  18. "github.com/gin-gonic/gin"
  19. "github.com/gorilla/websocket"
  20. )
  21. // https://console.xfyun.cn/services/cbm
  22. // https://www.xfyun.cn/doc/spark/Web.html
  23. func requestOpenAI2Xunfei(request dto.GeneralOpenAIRequest, xunfeiAppId string, domain string) *XunfeiChatRequest {
  24. messages := make([]XunfeiMessage, 0, len(request.Messages))
  25. shouldCovertSystemMessage := !strings.HasSuffix(request.Model, "3.5")
  26. for _, message := range request.Messages {
  27. if message.Role == "system" && shouldCovertSystemMessage {
  28. messages = append(messages, XunfeiMessage{
  29. Role: "user",
  30. Content: message.StringContent(),
  31. })
  32. messages = append(messages, XunfeiMessage{
  33. Role: "assistant",
  34. Content: "Okay",
  35. })
  36. } else {
  37. messages = append(messages, XunfeiMessage{
  38. Role: message.Role,
  39. Content: message.StringContent(),
  40. })
  41. }
  42. }
  43. xunfeiRequest := XunfeiChatRequest{}
  44. xunfeiRequest.Header.AppId = xunfeiAppId
  45. xunfeiRequest.Parameter.Chat.Domain = domain
  46. xunfeiRequest.Parameter.Chat.Temperature = request.Temperature
  47. xunfeiRequest.Parameter.Chat.TopK = lo.FromPtrOr(request.N, 0)
  48. xunfeiRequest.Parameter.Chat.MaxTokens = request.GetMaxTokens()
  49. xunfeiRequest.Payload.Message.Text = messages
  50. return &xunfeiRequest
  51. }
  52. func responseXunfei2OpenAI(response *XunfeiChatResponse) *dto.OpenAITextResponse {
  53. if len(response.Payload.Choices.Text) == 0 {
  54. response.Payload.Choices.Text = []XunfeiChatResponseTextItem{
  55. {
  56. Content: "",
  57. },
  58. }
  59. }
  60. choice := dto.OpenAITextResponseChoice{
  61. Index: 0,
  62. Message: dto.Message{
  63. Role: "assistant",
  64. Content: response.Payload.Choices.Text[0].Content,
  65. },
  66. FinishReason: constant.FinishReasonStop,
  67. }
  68. fullTextResponse := dto.OpenAITextResponse{
  69. Object: "chat.completion",
  70. Created: common.GetTimestamp(),
  71. Choices: []dto.OpenAITextResponseChoice{choice},
  72. Usage: response.Payload.Usage.Text,
  73. }
  74. return &fullTextResponse
  75. }
  76. func streamResponseXunfei2OpenAI(xunfeiResponse *XunfeiChatResponse) *dto.ChatCompletionsStreamResponse {
  77. if len(xunfeiResponse.Payload.Choices.Text) == 0 {
  78. xunfeiResponse.Payload.Choices.Text = []XunfeiChatResponseTextItem{
  79. {
  80. Content: "",
  81. },
  82. }
  83. }
  84. var choice dto.ChatCompletionsStreamResponseChoice
  85. choice.Delta.SetContentString(xunfeiResponse.Payload.Choices.Text[0].Content)
  86. if xunfeiResponse.Payload.Choices.Status == 2 {
  87. choice.FinishReason = &constant.FinishReasonStop
  88. }
  89. response := dto.ChatCompletionsStreamResponse{
  90. Object: "chat.completion.chunk",
  91. Created: common.GetTimestamp(),
  92. Model: "SparkDesk",
  93. Choices: []dto.ChatCompletionsStreamResponseChoice{choice},
  94. }
  95. return &response
  96. }
  97. func buildXunfeiAuthUrl(hostUrl string, apiKey, apiSecret string) string {
  98. HmacWithShaToBase64 := func(algorithm, data, key string) string {
  99. mac := hmac.New(sha256.New, []byte(key))
  100. mac.Write([]byte(data))
  101. encodeData := mac.Sum(nil)
  102. return base64.StdEncoding.EncodeToString(encodeData)
  103. }
  104. ul, err := url.Parse(hostUrl)
  105. if err != nil {
  106. fmt.Println(err)
  107. }
  108. date := time.Now().UTC().Format(time.RFC1123)
  109. signString := []string{"host: " + ul.Host, "date: " + date, "GET " + ul.Path + " HTTP/1.1"}
  110. sign := strings.Join(signString, "\n")
  111. sha := HmacWithShaToBase64("hmac-sha256", sign, apiSecret)
  112. authUrl := fmt.Sprintf("hmac username=\"%s\", algorithm=\"%s\", headers=\"%s\", signature=\"%s\"", apiKey,
  113. "hmac-sha256", "host date request-line", sha)
  114. authorization := base64.StdEncoding.EncodeToString([]byte(authUrl))
  115. v := url.Values{}
  116. v.Add("host", ul.Host)
  117. v.Add("date", date)
  118. v.Add("authorization", authorization)
  119. callUrl := hostUrl + "?" + v.Encode()
  120. return callUrl
  121. }
  122. func xunfeiStreamHandler(c *gin.Context, textRequest dto.GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*dto.Usage, *types.NewAPIError) {
  123. domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret, textRequest.Model)
  124. dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId)
  125. if err != nil {
  126. return nil, types.NewError(err, types.ErrorCodeDoRequestFailed)
  127. }
  128. helper.SetEventStreamHeaders(c)
  129. var usage dto.Usage
  130. c.Stream(func(w io.Writer) bool {
  131. select {
  132. case xunfeiResponse := <-dataChan:
  133. usage.PromptTokens += xunfeiResponse.Payload.Usage.Text.PromptTokens
  134. usage.CompletionTokens += xunfeiResponse.Payload.Usage.Text.CompletionTokens
  135. usage.TotalTokens += xunfeiResponse.Payload.Usage.Text.TotalTokens
  136. response := streamResponseXunfei2OpenAI(&xunfeiResponse)
  137. jsonResponse, err := json.Marshal(response)
  138. if err != nil {
  139. common.SysLog("error marshalling stream response: " + err.Error())
  140. return true
  141. }
  142. c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
  143. return true
  144. case <-stopChan:
  145. c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
  146. return false
  147. }
  148. })
  149. return &usage, nil
  150. }
  151. func xunfeiHandler(c *gin.Context, textRequest dto.GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*dto.Usage, *types.NewAPIError) {
  152. domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret, textRequest.Model)
  153. dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId)
  154. if err != nil {
  155. return nil, types.NewError(err, types.ErrorCodeDoRequestFailed)
  156. }
  157. var usage dto.Usage
  158. var content string
  159. var xunfeiResponse XunfeiChatResponse
  160. stop := false
  161. for !stop {
  162. select {
  163. case xunfeiResponse = <-dataChan:
  164. if len(xunfeiResponse.Payload.Choices.Text) == 0 {
  165. continue
  166. }
  167. content += xunfeiResponse.Payload.Choices.Text[0].Content
  168. usage.PromptTokens += xunfeiResponse.Payload.Usage.Text.PromptTokens
  169. usage.CompletionTokens += xunfeiResponse.Payload.Usage.Text.CompletionTokens
  170. usage.TotalTokens += xunfeiResponse.Payload.Usage.Text.TotalTokens
  171. case stop = <-stopChan:
  172. }
  173. }
  174. if len(xunfeiResponse.Payload.Choices.Text) == 0 {
  175. xunfeiResponse.Payload.Choices.Text = []XunfeiChatResponseTextItem{
  176. {
  177. Content: "",
  178. },
  179. }
  180. }
  181. xunfeiResponse.Payload.Choices.Text[0].Content = content
  182. response := responseXunfei2OpenAI(&xunfeiResponse)
  183. jsonResponse, err := json.Marshal(response)
  184. if err != nil {
  185. return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
  186. }
  187. c.Writer.Header().Set("Content-Type", "application/json")
  188. _, _ = c.Writer.Write(jsonResponse)
  189. return &usage, nil
  190. }
  191. func xunfeiMakeRequest(textRequest dto.GeneralOpenAIRequest, domain, authUrl, appId string) (chan XunfeiChatResponse, chan bool, error) {
  192. d := websocket.Dialer{
  193. HandshakeTimeout: 5 * time.Second,
  194. }
  195. conn, resp, err := d.Dial(authUrl, nil)
  196. if err != nil || resp.StatusCode != 101 {
  197. return nil, nil, err
  198. }
  199. data := requestOpenAI2Xunfei(textRequest, appId, domain)
  200. err = conn.WriteJSON(data)
  201. if err != nil {
  202. return nil, nil, err
  203. }
  204. dataChan := make(chan XunfeiChatResponse)
  205. stopChan := make(chan bool)
  206. go func() {
  207. defer func() {
  208. conn.Close()
  209. }()
  210. for {
  211. _, msg, err := conn.ReadMessage()
  212. if err != nil {
  213. common.SysLog("error reading stream response: " + err.Error())
  214. break
  215. }
  216. var response XunfeiChatResponse
  217. err = json.Unmarshal(msg, &response)
  218. if err != nil {
  219. common.SysLog("error unmarshalling stream response: " + err.Error())
  220. break
  221. }
  222. dataChan <- response
  223. if response.Payload.Choices.Status == 2 {
  224. if err != nil {
  225. common.SysLog("error closing websocket connection: " + err.Error())
  226. }
  227. break
  228. }
  229. }
  230. stopChan <- true
  231. }()
  232. return dataChan, stopChan, nil
  233. }
  234. func apiVersion2domain(apiVersion string) string {
  235. switch apiVersion {
  236. case "v1.1":
  237. return "lite"
  238. case "v2.1":
  239. return "generalv2"
  240. case "v3.1":
  241. return "generalv3"
  242. case "v3.5":
  243. return "generalv3.5"
  244. case "v4.0":
  245. return "4.0Ultra"
  246. }
  247. return "general" + apiVersion
  248. }
  249. func getXunfeiAuthUrl(c *gin.Context, apiKey string, apiSecret string, modelName string) (string, string) {
  250. apiVersion := getAPIVersion(c, modelName)
  251. domain := apiVersion2domain(apiVersion)
  252. authUrl := buildXunfeiAuthUrl(fmt.Sprintf("wss://spark-api.xf-yun.com/%s/chat", apiVersion), apiKey, apiSecret)
  253. return domain, authUrl
  254. }
  255. func getAPIVersion(c *gin.Context, modelName string) string {
  256. query := c.Request.URL.Query()
  257. apiVersion := query.Get("api-version")
  258. if apiVersion != "" {
  259. return apiVersion
  260. }
  261. parts := strings.Split(modelName, "-")
  262. if len(parts) == 2 {
  263. apiVersion = parts[1]
  264. return apiVersion
  265. }
  266. apiVersion = c.GetString("api_version")
  267. if apiVersion != "" {
  268. return apiVersion
  269. }
  270. apiVersion = "v1.1"
  271. common.SysLog("api_version not found, using default: " + apiVersion)
  272. return apiVersion
  273. }