relay-baidu.go 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303
  1. package baidu
  2. import (
  3. "bufio"
  4. "encoding/json"
  5. "errors"
  6. "fmt"
  7. "github.com/gin-gonic/gin"
  8. "io"
  9. "net/http"
  10. "one-api/common"
  11. "one-api/dto"
  12. relaycommon "one-api/relay/common"
  13. "one-api/service"
  14. "strings"
  15. "sync"
  16. "time"
  17. )
  18. // https://cloud.baidu.com/doc/WENXINWORKSHOP/s/flfmc9do2
  19. var baiduTokenStore sync.Map
  20. func requestOpenAI2Baidu(request dto.GeneralOpenAIRequest) *BaiduChatRequest {
  21. baiduRequest := BaiduChatRequest{
  22. Temperature: request.Temperature,
  23. TopP: request.TopP,
  24. PenaltyScore: request.FrequencyPenalty,
  25. Stream: request.Stream,
  26. DisableSearch: false,
  27. EnableCitation: false,
  28. UserId: request.User,
  29. }
  30. if request.MaxTokens != 0 {
  31. maxTokens := int(request.MaxTokens)
  32. if request.MaxTokens == 1 {
  33. maxTokens = 2
  34. }
  35. baiduRequest.MaxOutputTokens = &maxTokens
  36. }
  37. for _, message := range request.Messages {
  38. if message.Role == "system" {
  39. baiduRequest.System = message.StringContent()
  40. } else {
  41. baiduRequest.Messages = append(baiduRequest.Messages, BaiduMessage{
  42. Role: message.Role,
  43. Content: message.StringContent(),
  44. })
  45. }
  46. }
  47. return &baiduRequest
  48. }
  49. func responseBaidu2OpenAI(response *BaiduChatResponse) *dto.OpenAITextResponse {
  50. content, _ := json.Marshal(response.Result)
  51. choice := dto.OpenAITextResponseChoice{
  52. Index: 0,
  53. Message: dto.Message{
  54. Role: "assistant",
  55. Content: content,
  56. },
  57. FinishReason: "stop",
  58. }
  59. fullTextResponse := dto.OpenAITextResponse{
  60. Id: response.Id,
  61. Object: "chat.completion",
  62. Created: response.Created,
  63. Choices: []dto.OpenAITextResponseChoice{choice},
  64. Usage: response.Usage,
  65. }
  66. return &fullTextResponse
  67. }
  68. func streamResponseBaidu2OpenAI(baiduResponse *BaiduChatStreamResponse) *dto.ChatCompletionsStreamResponse {
  69. var choice dto.ChatCompletionsStreamResponseChoice
  70. choice.Delta.SetContentString(baiduResponse.Result)
  71. if baiduResponse.IsEnd {
  72. choice.FinishReason = &relaycommon.StopFinishReason
  73. }
  74. response := dto.ChatCompletionsStreamResponse{
  75. Id: baiduResponse.Id,
  76. Object: "chat.completion.chunk",
  77. Created: baiduResponse.Created,
  78. Model: "ernie-bot",
  79. Choices: []dto.ChatCompletionsStreamResponseChoice{choice},
  80. }
  81. return &response
  82. }
  83. func embeddingRequestOpenAI2Baidu(request dto.GeneralOpenAIRequest) *BaiduEmbeddingRequest {
  84. return &BaiduEmbeddingRequest{
  85. Input: request.ParseInput(),
  86. }
  87. }
  88. func embeddingResponseBaidu2OpenAI(response *BaiduEmbeddingResponse) *dto.OpenAIEmbeddingResponse {
  89. openAIEmbeddingResponse := dto.OpenAIEmbeddingResponse{
  90. Object: "list",
  91. Data: make([]dto.OpenAIEmbeddingResponseItem, 0, len(response.Data)),
  92. Model: "baidu-embedding",
  93. Usage: response.Usage,
  94. }
  95. for _, item := range response.Data {
  96. openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, dto.OpenAIEmbeddingResponseItem{
  97. Object: item.Object,
  98. Index: item.Index,
  99. Embedding: item.Embedding,
  100. })
  101. }
  102. return &openAIEmbeddingResponse
  103. }
  104. func baiduStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
  105. var usage dto.Usage
  106. scanner := bufio.NewScanner(resp.Body)
  107. scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
  108. if atEOF && len(data) == 0 {
  109. return 0, nil, nil
  110. }
  111. if i := strings.Index(string(data), "\n"); i >= 0 {
  112. return i + 1, data[0:i], nil
  113. }
  114. if atEOF {
  115. return len(data), data, nil
  116. }
  117. return 0, nil, nil
  118. })
  119. dataChan := make(chan string)
  120. stopChan := make(chan bool)
  121. go func() {
  122. for scanner.Scan() {
  123. data := scanner.Text()
  124. if len(data) < 6 { // ignore blank line or wrong format
  125. continue
  126. }
  127. data = data[6:]
  128. dataChan <- data
  129. }
  130. stopChan <- true
  131. }()
  132. service.SetEventStreamHeaders(c)
  133. c.Stream(func(w io.Writer) bool {
  134. select {
  135. case data := <-dataChan:
  136. var baiduResponse BaiduChatStreamResponse
  137. err := json.Unmarshal([]byte(data), &baiduResponse)
  138. if err != nil {
  139. common.SysError("error unmarshalling stream response: " + err.Error())
  140. return true
  141. }
  142. if baiduResponse.Usage.TotalTokens != 0 {
  143. usage.TotalTokens = baiduResponse.Usage.TotalTokens
  144. usage.PromptTokens = baiduResponse.Usage.PromptTokens
  145. usage.CompletionTokens = baiduResponse.Usage.TotalTokens - baiduResponse.Usage.PromptTokens
  146. }
  147. response := streamResponseBaidu2OpenAI(&baiduResponse)
  148. jsonResponse, err := json.Marshal(response)
  149. if err != nil {
  150. common.SysError("error marshalling stream response: " + err.Error())
  151. return true
  152. }
  153. c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
  154. return true
  155. case <-stopChan:
  156. c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
  157. return false
  158. }
  159. })
  160. err := resp.Body.Close()
  161. if err != nil {
  162. return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
  163. }
  164. return nil, &usage
  165. }
  166. func baiduHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
  167. var baiduResponse BaiduChatResponse
  168. responseBody, err := io.ReadAll(resp.Body)
  169. if err != nil {
  170. return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
  171. }
  172. err = resp.Body.Close()
  173. if err != nil {
  174. return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
  175. }
  176. err = json.Unmarshal(responseBody, &baiduResponse)
  177. if err != nil {
  178. return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
  179. }
  180. if baiduResponse.ErrorMsg != "" {
  181. return &dto.OpenAIErrorWithStatusCode{
  182. Error: dto.OpenAIError{
  183. Message: baiduResponse.ErrorMsg,
  184. Type: "baidu_error",
  185. Param: "",
  186. Code: baiduResponse.ErrorCode,
  187. },
  188. StatusCode: resp.StatusCode,
  189. }, nil
  190. }
  191. fullTextResponse := responseBaidu2OpenAI(&baiduResponse)
  192. jsonResponse, err := json.Marshal(fullTextResponse)
  193. if err != nil {
  194. return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
  195. }
  196. c.Writer.Header().Set("Content-Type", "application/json")
  197. c.Writer.WriteHeader(resp.StatusCode)
  198. _, err = c.Writer.Write(jsonResponse)
  199. return nil, &fullTextResponse.Usage
  200. }
  201. func baiduEmbeddingHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
  202. var baiduResponse BaiduEmbeddingResponse
  203. responseBody, err := io.ReadAll(resp.Body)
  204. if err != nil {
  205. return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
  206. }
  207. err = resp.Body.Close()
  208. if err != nil {
  209. return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
  210. }
  211. err = json.Unmarshal(responseBody, &baiduResponse)
  212. if err != nil {
  213. return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
  214. }
  215. if baiduResponse.ErrorMsg != "" {
  216. return &dto.OpenAIErrorWithStatusCode{
  217. Error: dto.OpenAIError{
  218. Message: baiduResponse.ErrorMsg,
  219. Type: "baidu_error",
  220. Param: "",
  221. Code: baiduResponse.ErrorCode,
  222. },
  223. StatusCode: resp.StatusCode,
  224. }, nil
  225. }
  226. fullTextResponse := embeddingResponseBaidu2OpenAI(&baiduResponse)
  227. jsonResponse, err := json.Marshal(fullTextResponse)
  228. if err != nil {
  229. return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
  230. }
  231. c.Writer.Header().Set("Content-Type", "application/json")
  232. c.Writer.WriteHeader(resp.StatusCode)
  233. _, err = c.Writer.Write(jsonResponse)
  234. return nil, &fullTextResponse.Usage
  235. }
  236. func getBaiduAccessToken(apiKey string) (string, error) {
  237. if val, ok := baiduTokenStore.Load(apiKey); ok {
  238. var accessToken BaiduAccessToken
  239. if accessToken, ok = val.(BaiduAccessToken); ok {
  240. // soon this will expire
  241. if time.Now().Add(time.Hour).After(accessToken.ExpiresAt) {
  242. go func() {
  243. _, _ = getBaiduAccessTokenHelper(apiKey)
  244. }()
  245. }
  246. return accessToken.AccessToken, nil
  247. }
  248. }
  249. accessToken, err := getBaiduAccessTokenHelper(apiKey)
  250. if err != nil {
  251. return "", err
  252. }
  253. if accessToken == nil {
  254. return "", errors.New("getBaiduAccessToken return a nil token")
  255. }
  256. return (*accessToken).AccessToken, nil
  257. }
  258. func getBaiduAccessTokenHelper(apiKey string) (*BaiduAccessToken, error) {
  259. parts := strings.Split(apiKey, "|")
  260. if len(parts) != 2 {
  261. return nil, errors.New("invalid baidu apikey")
  262. }
  263. req, err := http.NewRequest("POST", fmt.Sprintf("https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id=%s&client_secret=%s",
  264. parts[0], parts[1]), nil)
  265. if err != nil {
  266. return nil, err
  267. }
  268. req.Header.Add("Content-Type", "application/json")
  269. req.Header.Add("Accept", "application/json")
  270. res, err := service.GetImpatientHttpClient().Do(req)
  271. if err != nil {
  272. return nil, err
  273. }
  274. defer res.Body.Close()
  275. var accessToken BaiduAccessToken
  276. err = json.NewDecoder(res.Body).Decode(&accessToken)
  277. if err != nil {
  278. return nil, err
  279. }
  280. if accessToken.Error != "" {
  281. return nil, errors.New(accessToken.Error + ": " + accessToken.ErrorDescription)
  282. }
  283. if accessToken.AccessToken == "" {
  284. return nil, errors.New("getBaiduAccessTokenHelper get empty access token")
  285. }
  286. accessToken.ExpiresAt = time.Now().Add(time.Duration(accessToken.ExpiresIn) * time.Second)
  287. baiduTokenStore.Store(apiKey, accessToken)
  288. return &accessToken, nil
  289. }