relay-baidu.go 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304
  1. package baidu
  2. import (
  3. "bufio"
  4. "encoding/json"
  5. "errors"
  6. "fmt"
  7. "github.com/gin-gonic/gin"
  8. "io"
  9. "net/http"
  10. "one-api/common"
  11. "one-api/constant"
  12. "one-api/dto"
  13. "one-api/relay/helper"
  14. "one-api/service"
  15. "strings"
  16. "sync"
  17. "time"
  18. )
  19. // https://cloud.baidu.com/doc/WENXINWORKSHOP/s/flfmc9do2
  20. var baiduTokenStore sync.Map
  21. func requestOpenAI2Baidu(request dto.GeneralOpenAIRequest) *BaiduChatRequest {
  22. baiduRequest := BaiduChatRequest{
  23. Temperature: request.Temperature,
  24. TopP: request.TopP,
  25. PenaltyScore: request.FrequencyPenalty,
  26. Stream: request.Stream,
  27. DisableSearch: false,
  28. EnableCitation: false,
  29. UserId: request.User,
  30. }
  31. if request.MaxTokens != 0 {
  32. maxTokens := int(request.MaxTokens)
  33. if request.MaxTokens == 1 {
  34. maxTokens = 2
  35. }
  36. baiduRequest.MaxOutputTokens = &maxTokens
  37. }
  38. for _, message := range request.Messages {
  39. if message.Role == "system" {
  40. baiduRequest.System = message.StringContent()
  41. } else {
  42. baiduRequest.Messages = append(baiduRequest.Messages, BaiduMessage{
  43. Role: message.Role,
  44. Content: message.StringContent(),
  45. })
  46. }
  47. }
  48. return &baiduRequest
  49. }
  50. func responseBaidu2OpenAI(response *BaiduChatResponse) *dto.OpenAITextResponse {
  51. content, _ := json.Marshal(response.Result)
  52. choice := dto.OpenAITextResponseChoice{
  53. Index: 0,
  54. Message: dto.Message{
  55. Role: "assistant",
  56. Content: content,
  57. },
  58. FinishReason: "stop",
  59. }
  60. fullTextResponse := dto.OpenAITextResponse{
  61. Id: response.Id,
  62. Object: "chat.completion",
  63. Created: response.Created,
  64. Choices: []dto.OpenAITextResponseChoice{choice},
  65. Usage: response.Usage,
  66. }
  67. return &fullTextResponse
  68. }
  69. func streamResponseBaidu2OpenAI(baiduResponse *BaiduChatStreamResponse) *dto.ChatCompletionsStreamResponse {
  70. var choice dto.ChatCompletionsStreamResponseChoice
  71. choice.Delta.SetContentString(baiduResponse.Result)
  72. if baiduResponse.IsEnd {
  73. choice.FinishReason = &constant.FinishReasonStop
  74. }
  75. response := dto.ChatCompletionsStreamResponse{
  76. Id: baiduResponse.Id,
  77. Object: "chat.completion.chunk",
  78. Created: baiduResponse.Created,
  79. Model: "ernie-bot",
  80. Choices: []dto.ChatCompletionsStreamResponseChoice{choice},
  81. }
  82. return &response
  83. }
  84. func embeddingRequestOpenAI2Baidu(request dto.EmbeddingRequest) *BaiduEmbeddingRequest {
  85. return &BaiduEmbeddingRequest{
  86. Input: request.ParseInput(),
  87. }
  88. }
  89. func embeddingResponseBaidu2OpenAI(response *BaiduEmbeddingResponse) *dto.OpenAIEmbeddingResponse {
  90. openAIEmbeddingResponse := dto.OpenAIEmbeddingResponse{
  91. Object: "list",
  92. Data: make([]dto.OpenAIEmbeddingResponseItem, 0, len(response.Data)),
  93. Model: "baidu-embedding",
  94. Usage: response.Usage,
  95. }
  96. for _, item := range response.Data {
  97. openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, dto.OpenAIEmbeddingResponseItem{
  98. Object: item.Object,
  99. Index: item.Index,
  100. Embedding: item.Embedding,
  101. })
  102. }
  103. return &openAIEmbeddingResponse
  104. }
  105. func baiduStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
  106. var usage dto.Usage
  107. scanner := bufio.NewScanner(resp.Body)
  108. scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
  109. if atEOF && len(data) == 0 {
  110. return 0, nil, nil
  111. }
  112. if i := strings.Index(string(data), "\n"); i >= 0 {
  113. return i + 1, data[0:i], nil
  114. }
  115. if atEOF {
  116. return len(data), data, nil
  117. }
  118. return 0, nil, nil
  119. })
  120. dataChan := make(chan string)
  121. stopChan := make(chan bool)
  122. go func() {
  123. for scanner.Scan() {
  124. data := scanner.Text()
  125. if len(data) < 6 { // ignore blank line or wrong format
  126. continue
  127. }
  128. data = data[6:]
  129. dataChan <- data
  130. }
  131. stopChan <- true
  132. }()
  133. helper.SetEventStreamHeaders(c)
  134. c.Stream(func(w io.Writer) bool {
  135. select {
  136. case data := <-dataChan:
  137. var baiduResponse BaiduChatStreamResponse
  138. err := json.Unmarshal([]byte(data), &baiduResponse)
  139. if err != nil {
  140. common.SysError("error unmarshalling stream response: " + err.Error())
  141. return true
  142. }
  143. if baiduResponse.Usage.TotalTokens != 0 {
  144. usage.TotalTokens = baiduResponse.Usage.TotalTokens
  145. usage.PromptTokens = baiduResponse.Usage.PromptTokens
  146. usage.CompletionTokens = baiduResponse.Usage.TotalTokens - baiduResponse.Usage.PromptTokens
  147. }
  148. response := streamResponseBaidu2OpenAI(&baiduResponse)
  149. jsonResponse, err := json.Marshal(response)
  150. if err != nil {
  151. common.SysError("error marshalling stream response: " + err.Error())
  152. return true
  153. }
  154. c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
  155. return true
  156. case <-stopChan:
  157. c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
  158. return false
  159. }
  160. })
  161. err := resp.Body.Close()
  162. if err != nil {
  163. return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
  164. }
  165. return nil, &usage
  166. }
  167. func baiduHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
  168. var baiduResponse BaiduChatResponse
  169. responseBody, err := io.ReadAll(resp.Body)
  170. if err != nil {
  171. return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
  172. }
  173. err = resp.Body.Close()
  174. if err != nil {
  175. return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
  176. }
  177. err = json.Unmarshal(responseBody, &baiduResponse)
  178. if err != nil {
  179. return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
  180. }
  181. if baiduResponse.ErrorMsg != "" {
  182. return &dto.OpenAIErrorWithStatusCode{
  183. Error: dto.OpenAIError{
  184. Message: baiduResponse.ErrorMsg,
  185. Type: "baidu_error",
  186. Param: "",
  187. Code: baiduResponse.ErrorCode,
  188. },
  189. StatusCode: resp.StatusCode,
  190. }, nil
  191. }
  192. fullTextResponse := responseBaidu2OpenAI(&baiduResponse)
  193. jsonResponse, err := json.Marshal(fullTextResponse)
  194. if err != nil {
  195. return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
  196. }
  197. c.Writer.Header().Set("Content-Type", "application/json")
  198. c.Writer.WriteHeader(resp.StatusCode)
  199. _, err = c.Writer.Write(jsonResponse)
  200. return nil, &fullTextResponse.Usage
  201. }
  202. func baiduEmbeddingHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
  203. var baiduResponse BaiduEmbeddingResponse
  204. responseBody, err := io.ReadAll(resp.Body)
  205. if err != nil {
  206. return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
  207. }
  208. err = resp.Body.Close()
  209. if err != nil {
  210. return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
  211. }
  212. err = json.Unmarshal(responseBody, &baiduResponse)
  213. if err != nil {
  214. return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
  215. }
  216. if baiduResponse.ErrorMsg != "" {
  217. return &dto.OpenAIErrorWithStatusCode{
  218. Error: dto.OpenAIError{
  219. Message: baiduResponse.ErrorMsg,
  220. Type: "baidu_error",
  221. Param: "",
  222. Code: baiduResponse.ErrorCode,
  223. },
  224. StatusCode: resp.StatusCode,
  225. }, nil
  226. }
  227. fullTextResponse := embeddingResponseBaidu2OpenAI(&baiduResponse)
  228. jsonResponse, err := json.Marshal(fullTextResponse)
  229. if err != nil {
  230. return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
  231. }
  232. c.Writer.Header().Set("Content-Type", "application/json")
  233. c.Writer.WriteHeader(resp.StatusCode)
  234. _, err = c.Writer.Write(jsonResponse)
  235. return nil, &fullTextResponse.Usage
  236. }
  237. func getBaiduAccessToken(apiKey string) (string, error) {
  238. if val, ok := baiduTokenStore.Load(apiKey); ok {
  239. var accessToken BaiduAccessToken
  240. if accessToken, ok = val.(BaiduAccessToken); ok {
  241. // soon this will expire
  242. if time.Now().Add(time.Hour).After(accessToken.ExpiresAt) {
  243. go func() {
  244. _, _ = getBaiduAccessTokenHelper(apiKey)
  245. }()
  246. }
  247. return accessToken.AccessToken, nil
  248. }
  249. }
  250. accessToken, err := getBaiduAccessTokenHelper(apiKey)
  251. if err != nil {
  252. return "", err
  253. }
  254. if accessToken == nil {
  255. return "", errors.New("getBaiduAccessToken return a nil token")
  256. }
  257. return (*accessToken).AccessToken, nil
  258. }
  259. func getBaiduAccessTokenHelper(apiKey string) (*BaiduAccessToken, error) {
  260. parts := strings.Split(apiKey, "|")
  261. if len(parts) != 2 {
  262. return nil, errors.New("invalid baidu apikey")
  263. }
  264. req, err := http.NewRequest("POST", fmt.Sprintf("https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id=%s&client_secret=%s",
  265. parts[0], parts[1]), nil)
  266. if err != nil {
  267. return nil, err
  268. }
  269. req.Header.Add("Content-Type", "application/json")
  270. req.Header.Add("Accept", "application/json")
  271. res, err := service.GetImpatientHttpClient().Do(req)
  272. if err != nil {
  273. return nil, err
  274. }
  275. defer res.Body.Close()
  276. var accessToken BaiduAccessToken
  277. err = json.NewDecoder(res.Body).Decode(&accessToken)
  278. if err != nil {
  279. return nil, err
  280. }
  281. if accessToken.Error != "" {
  282. return nil, errors.New(accessToken.Error + ": " + accessToken.ErrorDescription)
  283. }
  284. if accessToken.AccessToken == "" {
  285. return nil, errors.New("getBaiduAccessTokenHelper get empty access token")
  286. }
  287. accessToken.ExpiresAt = time.Now().Add(time.Duration(accessToken.ExpiresIn) * time.Second)
  288. baiduTokenStore.Store(apiKey, accessToken)
  289. return &accessToken, nil
  290. }