relay-cohere.go 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251
  1. package cohere
  2. import (
  3. "bufio"
  4. "encoding/json"
  5. "io"
  6. "net/http"
  7. "strings"
  8. "time"
  9. "github.com/QuantumNous/new-api/common"
  10. "github.com/QuantumNous/new-api/dto"
  11. relaycommon "github.com/QuantumNous/new-api/relay/common"
  12. "github.com/QuantumNous/new-api/relay/helper"
  13. "github.com/QuantumNous/new-api/service"
  14. "github.com/QuantumNous/new-api/types"
  15. "github.com/gin-gonic/gin"
  16. "github.com/samber/lo"
  17. )
  18. func requestOpenAI2Cohere(textRequest dto.GeneralOpenAIRequest) *CohereRequest {
  19. cohereReq := CohereRequest{
  20. Model: textRequest.Model,
  21. ChatHistory: []ChatHistory{},
  22. Message: "",
  23. Stream: lo.FromPtrOr(textRequest.Stream, false),
  24. MaxTokens: textRequest.GetMaxTokens(),
  25. }
  26. if common.CohereSafetySetting != "NONE" {
  27. cohereReq.SafetyMode = common.CohereSafetySetting
  28. }
  29. if cohereReq.MaxTokens == 0 {
  30. cohereReq.MaxTokens = 4000
  31. }
  32. for _, msg := range textRequest.Messages {
  33. if msg.Role == "user" {
  34. cohereReq.Message = msg.StringContent()
  35. } else {
  36. var role string
  37. if msg.Role == "assistant" {
  38. role = "CHATBOT"
  39. } else if msg.Role == "system" {
  40. role = "SYSTEM"
  41. } else {
  42. role = "USER"
  43. }
  44. cohereReq.ChatHistory = append(cohereReq.ChatHistory, ChatHistory{
  45. Role: role,
  46. Message: msg.StringContent(),
  47. })
  48. }
  49. }
  50. return &cohereReq
  51. }
  52. func requestConvertRerank2Cohere(rerankRequest dto.RerankRequest) *CohereRerankRequest {
  53. topN := lo.FromPtrOr(rerankRequest.TopN, 1)
  54. if topN <= 0 {
  55. topN = 1
  56. }
  57. cohereReq := CohereRerankRequest{
  58. Query: rerankRequest.Query,
  59. Documents: rerankRequest.Documents,
  60. Model: rerankRequest.Model,
  61. TopN: topN,
  62. ReturnDocuments: true,
  63. }
  64. return &cohereReq
  65. }
  66. func stopReasonCohere2OpenAI(reason string) string {
  67. switch reason {
  68. case "COMPLETE":
  69. return "stop"
  70. case "MAX_TOKENS":
  71. return "max_tokens"
  72. default:
  73. return reason
  74. }
  75. }
  76. func cohereStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
  77. responseId := helper.GetResponseID(c)
  78. createdTime := common.GetTimestamp()
  79. usage := &dto.Usage{}
  80. responseText := ""
  81. scanner := bufio.NewScanner(resp.Body)
  82. scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
  83. if atEOF && len(data) == 0 {
  84. return 0, nil, nil
  85. }
  86. if i := strings.Index(string(data), "\n"); i >= 0 {
  87. return i + 1, data[0:i], nil
  88. }
  89. if atEOF {
  90. return len(data), data, nil
  91. }
  92. return 0, nil, nil
  93. })
  94. dataChan := make(chan string)
  95. stopChan := make(chan bool)
  96. go func() {
  97. for scanner.Scan() {
  98. data := scanner.Text()
  99. dataChan <- data
  100. }
  101. stopChan <- true
  102. }()
  103. helper.SetEventStreamHeaders(c)
  104. isFirst := true
  105. c.Stream(func(w io.Writer) bool {
  106. select {
  107. case data := <-dataChan:
  108. if isFirst {
  109. isFirst = false
  110. info.FirstResponseTime = time.Now()
  111. }
  112. data = strings.TrimSuffix(data, "\r")
  113. var cohereResp CohereResponse
  114. err := json.Unmarshal([]byte(data), &cohereResp)
  115. if err != nil {
  116. common.SysLog("error unmarshalling stream response: " + err.Error())
  117. return true
  118. }
  119. var openaiResp dto.ChatCompletionsStreamResponse
  120. openaiResp.Id = responseId
  121. openaiResp.Created = createdTime
  122. openaiResp.Object = "chat.completion.chunk"
  123. openaiResp.Model = info.UpstreamModelName
  124. if cohereResp.IsFinished {
  125. finishReason := stopReasonCohere2OpenAI(cohereResp.FinishReason)
  126. openaiResp.Choices = []dto.ChatCompletionsStreamResponseChoice{
  127. {
  128. Delta: dto.ChatCompletionsStreamResponseChoiceDelta{},
  129. Index: 0,
  130. FinishReason: &finishReason,
  131. },
  132. }
  133. if cohereResp.Response != nil {
  134. usage.PromptTokens = cohereResp.Response.Meta.BilledUnits.InputTokens
  135. usage.CompletionTokens = cohereResp.Response.Meta.BilledUnits.OutputTokens
  136. }
  137. } else {
  138. openaiResp.Choices = []dto.ChatCompletionsStreamResponseChoice{
  139. {
  140. Delta: dto.ChatCompletionsStreamResponseChoiceDelta{
  141. Role: "assistant",
  142. Content: &cohereResp.Text,
  143. },
  144. Index: 0,
  145. },
  146. }
  147. responseText += cohereResp.Text
  148. }
  149. jsonStr, err := json.Marshal(openaiResp)
  150. if err != nil {
  151. common.SysLog("error marshalling stream response: " + err.Error())
  152. return true
  153. }
  154. c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)})
  155. return true
  156. case <-stopChan:
  157. c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
  158. return false
  159. }
  160. })
  161. if usage.PromptTokens == 0 {
  162. usage = service.ResponseText2Usage(c, responseText, info.UpstreamModelName, info.GetEstimatePromptTokens())
  163. }
  164. return usage, nil
  165. }
  166. func cohereHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
  167. createdTime := common.GetTimestamp()
  168. responseBody, err := io.ReadAll(resp.Body)
  169. if err != nil {
  170. return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
  171. }
  172. service.CloseResponseBodyGracefully(resp)
  173. var cohereResp CohereResponseResult
  174. err = json.Unmarshal(responseBody, &cohereResp)
  175. if err != nil {
  176. return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
  177. }
  178. usage := dto.Usage{}
  179. usage.PromptTokens = cohereResp.Meta.BilledUnits.InputTokens
  180. usage.CompletionTokens = cohereResp.Meta.BilledUnits.OutputTokens
  181. usage.TotalTokens = cohereResp.Meta.BilledUnits.InputTokens + cohereResp.Meta.BilledUnits.OutputTokens
  182. var openaiResp dto.TextResponse
  183. openaiResp.Id = cohereResp.ResponseId
  184. openaiResp.Created = createdTime
  185. openaiResp.Object = "chat.completion"
  186. openaiResp.Model = info.UpstreamModelName
  187. openaiResp.Usage = usage
  188. openaiResp.Choices = []dto.OpenAITextResponseChoice{
  189. {
  190. Index: 0,
  191. Message: dto.Message{Content: cohereResp.Text, Role: "assistant"},
  192. FinishReason: stopReasonCohere2OpenAI(cohereResp.FinishReason),
  193. },
  194. }
  195. jsonResponse, err := json.Marshal(openaiResp)
  196. if err != nil {
  197. return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
  198. }
  199. c.Writer.Header().Set("Content-Type", "application/json")
  200. c.Writer.WriteHeader(resp.StatusCode)
  201. _, _ = c.Writer.Write(jsonResponse)
  202. return &usage, nil
  203. }
  204. func cohereRerankHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.Usage, *types.NewAPIError) {
  205. responseBody, err := io.ReadAll(resp.Body)
  206. if err != nil {
  207. return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
  208. }
  209. service.CloseResponseBodyGracefully(resp)
  210. var cohereResp CohereRerankResponseResult
  211. err = json.Unmarshal(responseBody, &cohereResp)
  212. if err != nil {
  213. return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
  214. }
  215. usage := dto.Usage{}
  216. if cohereResp.Meta.BilledUnits.InputTokens == 0 {
  217. usage.PromptTokens = info.GetEstimatePromptTokens()
  218. usage.CompletionTokens = 0
  219. usage.TotalTokens = info.GetEstimatePromptTokens()
  220. } else {
  221. usage.PromptTokens = cohereResp.Meta.BilledUnits.InputTokens
  222. usage.CompletionTokens = cohereResp.Meta.BilledUnits.OutputTokens
  223. usage.TotalTokens = cohereResp.Meta.BilledUnits.InputTokens + cohereResp.Meta.BilledUnits.OutputTokens
  224. }
  225. var rerankResp dto.RerankResponse
  226. rerankResp.Results = cohereResp.Results
  227. rerankResp.Usage = usage
  228. jsonResponse, err := json.Marshal(rerankResp)
  229. if err != nil {
  230. return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
  231. }
  232. c.Writer.Header().Set("Content-Type", "application/json")
  233. c.Writer.WriteHeader(resp.StatusCode)
  234. _, err = c.Writer.Write(jsonResponse)
  235. return &usage, nil
  236. }