relay-baidu.go 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287
  1. package baidu
  2. import (
  3. "bufio"
  4. "encoding/json"
  5. "errors"
  6. "fmt"
  7. "github.com/gin-gonic/gin"
  8. "io"
  9. "net/http"
  10. "one-api/common"
  11. "one-api/dto"
  12. relaycommon "one-api/relay/common"
  13. "one-api/service"
  14. "strings"
  15. "sync"
  16. "time"
  17. )
  18. // https://cloud.baidu.com/doc/WENXINWORKSHOP/s/flfmc9do2
  19. var baiduTokenStore sync.Map
  20. func requestOpenAI2Baidu(request dto.GeneralOpenAIRequest) *BaiduChatRequest {
  21. messages := make([]BaiduMessage, 0, len(request.Messages))
  22. for _, message := range request.Messages {
  23. messages = append(messages, BaiduMessage{
  24. Role: message.Role,
  25. Content: message.StringContent(),
  26. })
  27. }
  28. return &BaiduChatRequest{
  29. Messages: messages,
  30. Stream: request.Stream,
  31. }
  32. }
  33. func responseBaidu2OpenAI(response *BaiduChatResponse) *dto.OpenAITextResponse {
  34. content, _ := json.Marshal(response.Result)
  35. choice := dto.OpenAITextResponseChoice{
  36. Index: 0,
  37. Message: dto.Message{
  38. Role: "assistant",
  39. Content: content,
  40. },
  41. FinishReason: "stop",
  42. }
  43. fullTextResponse := dto.OpenAITextResponse{
  44. Id: response.Id,
  45. Object: "chat.completion",
  46. Created: response.Created,
  47. Choices: []dto.OpenAITextResponseChoice{choice},
  48. Usage: response.Usage,
  49. }
  50. return &fullTextResponse
  51. }
  52. func streamResponseBaidu2OpenAI(baiduResponse *BaiduChatStreamResponse) *dto.ChatCompletionsStreamResponse {
  53. var choice dto.ChatCompletionsStreamResponseChoice
  54. choice.Delta.SetContentString(baiduResponse.Result)
  55. if baiduResponse.IsEnd {
  56. choice.FinishReason = &relaycommon.StopFinishReason
  57. }
  58. response := dto.ChatCompletionsStreamResponse{
  59. Id: baiduResponse.Id,
  60. Object: "chat.completion.chunk",
  61. Created: baiduResponse.Created,
  62. Model: "ernie-bot",
  63. Choices: []dto.ChatCompletionsStreamResponseChoice{choice},
  64. }
  65. return &response
  66. }
  67. func embeddingRequestOpenAI2Baidu(request dto.GeneralOpenAIRequest) *BaiduEmbeddingRequest {
  68. return &BaiduEmbeddingRequest{
  69. Input: request.ParseInput(),
  70. }
  71. }
  72. func embeddingResponseBaidu2OpenAI(response *BaiduEmbeddingResponse) *dto.OpenAIEmbeddingResponse {
  73. openAIEmbeddingResponse := dto.OpenAIEmbeddingResponse{
  74. Object: "list",
  75. Data: make([]dto.OpenAIEmbeddingResponseItem, 0, len(response.Data)),
  76. Model: "baidu-embedding",
  77. Usage: response.Usage,
  78. }
  79. for _, item := range response.Data {
  80. openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, dto.OpenAIEmbeddingResponseItem{
  81. Object: item.Object,
  82. Index: item.Index,
  83. Embedding: item.Embedding,
  84. })
  85. }
  86. return &openAIEmbeddingResponse
  87. }
  88. func baiduStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
  89. var usage dto.Usage
  90. scanner := bufio.NewScanner(resp.Body)
  91. scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
  92. if atEOF && len(data) == 0 {
  93. return 0, nil, nil
  94. }
  95. if i := strings.Index(string(data), "\n"); i >= 0 {
  96. return i + 1, data[0:i], nil
  97. }
  98. if atEOF {
  99. return len(data), data, nil
  100. }
  101. return 0, nil, nil
  102. })
  103. dataChan := make(chan string)
  104. stopChan := make(chan bool)
  105. go func() {
  106. for scanner.Scan() {
  107. data := scanner.Text()
  108. if len(data) < 6 { // ignore blank line or wrong format
  109. continue
  110. }
  111. data = data[6:]
  112. dataChan <- data
  113. }
  114. stopChan <- true
  115. }()
  116. service.SetEventStreamHeaders(c)
  117. c.Stream(func(w io.Writer) bool {
  118. select {
  119. case data := <-dataChan:
  120. var baiduResponse BaiduChatStreamResponse
  121. err := json.Unmarshal([]byte(data), &baiduResponse)
  122. if err != nil {
  123. common.SysError("error unmarshalling stream response: " + err.Error())
  124. return true
  125. }
  126. if baiduResponse.Usage.TotalTokens != 0 {
  127. usage.TotalTokens = baiduResponse.Usage.TotalTokens
  128. usage.PromptTokens = baiduResponse.Usage.PromptTokens
  129. usage.CompletionTokens = baiduResponse.Usage.TotalTokens - baiduResponse.Usage.PromptTokens
  130. }
  131. response := streamResponseBaidu2OpenAI(&baiduResponse)
  132. jsonResponse, err := json.Marshal(response)
  133. if err != nil {
  134. common.SysError("error marshalling stream response: " + err.Error())
  135. return true
  136. }
  137. c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
  138. return true
  139. case <-stopChan:
  140. c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
  141. return false
  142. }
  143. })
  144. err := resp.Body.Close()
  145. if err != nil {
  146. return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
  147. }
  148. return nil, &usage
  149. }
  150. func baiduHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
  151. var baiduResponse BaiduChatResponse
  152. responseBody, err := io.ReadAll(resp.Body)
  153. if err != nil {
  154. return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
  155. }
  156. err = resp.Body.Close()
  157. if err != nil {
  158. return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
  159. }
  160. err = json.Unmarshal(responseBody, &baiduResponse)
  161. if err != nil {
  162. return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
  163. }
  164. if baiduResponse.ErrorMsg != "" {
  165. return &dto.OpenAIErrorWithStatusCode{
  166. Error: dto.OpenAIError{
  167. Message: baiduResponse.ErrorMsg,
  168. Type: "baidu_error",
  169. Param: "",
  170. Code: baiduResponse.ErrorCode,
  171. },
  172. StatusCode: resp.StatusCode,
  173. }, nil
  174. }
  175. fullTextResponse := responseBaidu2OpenAI(&baiduResponse)
  176. jsonResponse, err := json.Marshal(fullTextResponse)
  177. if err != nil {
  178. return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
  179. }
  180. c.Writer.Header().Set("Content-Type", "application/json")
  181. c.Writer.WriteHeader(resp.StatusCode)
  182. _, err = c.Writer.Write(jsonResponse)
  183. return nil, &fullTextResponse.Usage
  184. }
  185. func baiduEmbeddingHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
  186. var baiduResponse BaiduEmbeddingResponse
  187. responseBody, err := io.ReadAll(resp.Body)
  188. if err != nil {
  189. return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
  190. }
  191. err = resp.Body.Close()
  192. if err != nil {
  193. return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
  194. }
  195. err = json.Unmarshal(responseBody, &baiduResponse)
  196. if err != nil {
  197. return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
  198. }
  199. if baiduResponse.ErrorMsg != "" {
  200. return &dto.OpenAIErrorWithStatusCode{
  201. Error: dto.OpenAIError{
  202. Message: baiduResponse.ErrorMsg,
  203. Type: "baidu_error",
  204. Param: "",
  205. Code: baiduResponse.ErrorCode,
  206. },
  207. StatusCode: resp.StatusCode,
  208. }, nil
  209. }
  210. fullTextResponse := embeddingResponseBaidu2OpenAI(&baiduResponse)
  211. jsonResponse, err := json.Marshal(fullTextResponse)
  212. if err != nil {
  213. return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
  214. }
  215. c.Writer.Header().Set("Content-Type", "application/json")
  216. c.Writer.WriteHeader(resp.StatusCode)
  217. _, err = c.Writer.Write(jsonResponse)
  218. return nil, &fullTextResponse.Usage
  219. }
  220. func getBaiduAccessToken(apiKey string) (string, error) {
  221. if val, ok := baiduTokenStore.Load(apiKey); ok {
  222. var accessToken BaiduAccessToken
  223. if accessToken, ok = val.(BaiduAccessToken); ok {
  224. // soon this will expire
  225. if time.Now().Add(time.Hour).After(accessToken.ExpiresAt) {
  226. go func() {
  227. _, _ = getBaiduAccessTokenHelper(apiKey)
  228. }()
  229. }
  230. return accessToken.AccessToken, nil
  231. }
  232. }
  233. accessToken, err := getBaiduAccessTokenHelper(apiKey)
  234. if err != nil {
  235. return "", err
  236. }
  237. if accessToken == nil {
  238. return "", errors.New("getBaiduAccessToken return a nil token")
  239. }
  240. return (*accessToken).AccessToken, nil
  241. }
  242. func getBaiduAccessTokenHelper(apiKey string) (*BaiduAccessToken, error) {
  243. parts := strings.Split(apiKey, "|")
  244. if len(parts) != 2 {
  245. return nil, errors.New("invalid baidu apikey")
  246. }
  247. req, err := http.NewRequest("POST", fmt.Sprintf("https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id=%s&client_secret=%s",
  248. parts[0], parts[1]), nil)
  249. if err != nil {
  250. return nil, err
  251. }
  252. req.Header.Add("Content-Type", "application/json")
  253. req.Header.Add("Accept", "application/json")
  254. res, err := service.GetImpatientHttpClient().Do(req)
  255. if err != nil {
  256. return nil, err
  257. }
  258. defer res.Body.Close()
  259. var accessToken BaiduAccessToken
  260. err = json.NewDecoder(res.Body).Decode(&accessToken)
  261. if err != nil {
  262. return nil, err
  263. }
  264. if accessToken.Error != "" {
  265. return nil, errors.New(accessToken.Error + ": " + accessToken.ErrorDescription)
  266. }
  267. if accessToken.AccessToken == "" {
  268. return nil, errors.New("getBaiduAccessTokenHelper get empty access token")
  269. }
  270. accessToken.ExpiresAt = time.Now().Add(time.Duration(accessToken.ExpiresIn) * time.Second)
  271. baiduTokenStore.Store(apiKey, accessToken)
  272. return &accessToken, nil
  273. }