relay-palm.go 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134
  1. package palm
  2. import (
  3. "encoding/json"
  4. "io"
  5. "net/http"
  6. "github.com/QuantumNous/new-api/common"
  7. "github.com/QuantumNous/new-api/constant"
  8. "github.com/QuantumNous/new-api/dto"
  9. relaycommon "github.com/QuantumNous/new-api/relay/common"
  10. "github.com/QuantumNous/new-api/relay/helper"
  11. "github.com/QuantumNous/new-api/service"
  12. "github.com/QuantumNous/new-api/types"
  13. "github.com/gin-gonic/gin"
  14. )
  15. // https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#request-body
  16. // https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#response-body
  17. func responsePaLM2OpenAI(response *PaLMChatResponse) *dto.OpenAITextResponse {
  18. fullTextResponse := dto.OpenAITextResponse{
  19. Choices: make([]dto.OpenAITextResponseChoice, 0, len(response.Candidates)),
  20. }
  21. for i, candidate := range response.Candidates {
  22. choice := dto.OpenAITextResponseChoice{
  23. Index: i,
  24. Message: dto.Message{
  25. Role: "assistant",
  26. Content: candidate.Content,
  27. },
  28. FinishReason: "stop",
  29. }
  30. fullTextResponse.Choices = append(fullTextResponse.Choices, choice)
  31. }
  32. return &fullTextResponse
  33. }
  34. func streamResponsePaLM2OpenAI(palmResponse *PaLMChatResponse) *dto.ChatCompletionsStreamResponse {
  35. var choice dto.ChatCompletionsStreamResponseChoice
  36. if len(palmResponse.Candidates) > 0 {
  37. choice.Delta.SetContentString(palmResponse.Candidates[0].Content)
  38. }
  39. choice.FinishReason = &constant.FinishReasonStop
  40. var response dto.ChatCompletionsStreamResponse
  41. response.Object = "chat.completion.chunk"
  42. response.Model = "palm2"
  43. response.Choices = []dto.ChatCompletionsStreamResponseChoice{choice}
  44. return &response
  45. }
  46. func palmStreamHandler(c *gin.Context, resp *http.Response) (*types.NewAPIError, string) {
  47. responseText := ""
  48. responseId := helper.GetResponseID(c)
  49. createdTime := common.GetTimestamp()
  50. dataChan := make(chan string)
  51. stopChan := make(chan bool)
  52. go func() {
  53. responseBody, err := io.ReadAll(resp.Body)
  54. if err != nil {
  55. common.SysLog("error reading stream response: " + err.Error())
  56. stopChan <- true
  57. return
  58. }
  59. service.CloseResponseBodyGracefully(resp)
  60. var palmResponse PaLMChatResponse
  61. err = json.Unmarshal(responseBody, &palmResponse)
  62. if err != nil {
  63. common.SysLog("error unmarshalling stream response: " + err.Error())
  64. stopChan <- true
  65. return
  66. }
  67. fullTextResponse := streamResponsePaLM2OpenAI(&palmResponse)
  68. fullTextResponse.Id = responseId
  69. fullTextResponse.Created = createdTime
  70. if len(palmResponse.Candidates) > 0 {
  71. responseText = palmResponse.Candidates[0].Content
  72. }
  73. jsonResponse, err := json.Marshal(fullTextResponse)
  74. if err != nil {
  75. common.SysLog("error marshalling stream response: " + err.Error())
  76. stopChan <- true
  77. return
  78. }
  79. dataChan <- string(jsonResponse)
  80. stopChan <- true
  81. }()
  82. helper.SetEventStreamHeaders(c)
  83. c.Stream(func(w io.Writer) bool {
  84. select {
  85. case data := <-dataChan:
  86. c.Render(-1, common.CustomEvent{Data: "data: " + data})
  87. return true
  88. case <-stopChan:
  89. c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
  90. return false
  91. }
  92. })
  93. service.CloseResponseBodyGracefully(resp)
  94. return nil, responseText
  95. }
  96. func palmHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
  97. responseBody, err := io.ReadAll(resp.Body)
  98. if err != nil {
  99. return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
  100. }
  101. service.CloseResponseBodyGracefully(resp)
  102. var palmResponse PaLMChatResponse
  103. err = json.Unmarshal(responseBody, &palmResponse)
  104. if err != nil {
  105. return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
  106. }
  107. if palmResponse.Error.Code != 0 || len(palmResponse.Candidates) == 0 {
  108. return nil, types.WithOpenAIError(types.OpenAIError{
  109. Message: palmResponse.Error.Message,
  110. Type: palmResponse.Error.Status,
  111. Param: "",
  112. Code: palmResponse.Error.Code,
  113. }, resp.StatusCode)
  114. }
  115. fullTextResponse := responsePaLM2OpenAI(&palmResponse)
  116. usage := service.ResponseText2Usage(c, palmResponse.Candidates[0].Content, info.UpstreamModelName, info.GetEstimatePromptTokens())
  117. fullTextResponse.Usage = *usage
  118. jsonResponse, err := common.Marshal(fullTextResponse)
  119. if err != nil {
  120. return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
  121. }
  122. c.Writer.Header().Set("Content-Type", "application/json")
  123. c.Writer.WriteHeader(resp.StatusCode)
  124. service.IOCopyBytesGracefully(c, resp, jsonResponse)
  125. return usage, nil
  126. }