relay-cohere.go 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246
  1. package cohere
  2. import (
  3. "bufio"
  4. "encoding/json"
  5. "github.com/gin-gonic/gin"
  6. "io"
  7. "net/http"
  8. "one-api/common"
  9. "one-api/dto"
  10. relaycommon "one-api/relay/common"
  11. "one-api/relay/helper"
  12. "one-api/service"
  13. "strings"
  14. "time"
  15. )
  16. func requestOpenAI2Cohere(textRequest dto.GeneralOpenAIRequest) *CohereRequest {
  17. cohereReq := CohereRequest{
  18. Model: textRequest.Model,
  19. ChatHistory: []ChatHistory{},
  20. Message: "",
  21. Stream: textRequest.Stream,
  22. MaxTokens: textRequest.GetMaxTokens(),
  23. }
  24. if common.CohereSafetySetting != "NONE" {
  25. cohereReq.SafetyMode = common.CohereSafetySetting
  26. }
  27. if cohereReq.MaxTokens == 0 {
  28. cohereReq.MaxTokens = 4000
  29. }
  30. for _, msg := range textRequest.Messages {
  31. if msg.Role == "user" {
  32. cohereReq.Message = msg.StringContent()
  33. } else {
  34. var role string
  35. if msg.Role == "assistant" {
  36. role = "CHATBOT"
  37. } else if msg.Role == "system" {
  38. role = "SYSTEM"
  39. } else {
  40. role = "USER"
  41. }
  42. cohereReq.ChatHistory = append(cohereReq.ChatHistory, ChatHistory{
  43. Role: role,
  44. Message: msg.StringContent(),
  45. })
  46. }
  47. }
  48. return &cohereReq
  49. }
  50. func requestConvertRerank2Cohere(rerankRequest dto.RerankRequest) *CohereRerankRequest {
  51. if rerankRequest.TopN == 0 {
  52. rerankRequest.TopN = 1
  53. }
  54. cohereReq := CohereRerankRequest{
  55. Query: rerankRequest.Query,
  56. Documents: rerankRequest.Documents,
  57. Model: rerankRequest.Model,
  58. TopN: rerankRequest.TopN,
  59. ReturnDocuments: true,
  60. }
  61. return &cohereReq
  62. }
  63. func stopReasonCohere2OpenAI(reason string) string {
  64. switch reason {
  65. case "COMPLETE":
  66. return "stop"
  67. case "MAX_TOKENS":
  68. return "max_tokens"
  69. default:
  70. return reason
  71. }
  72. }
  73. func cohereStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
  74. responseId := helper.GetResponseID(c)
  75. createdTime := common.GetTimestamp()
  76. usage := &dto.Usage{}
  77. responseText := ""
  78. scanner := bufio.NewScanner(resp.Body)
  79. scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
  80. if atEOF && len(data) == 0 {
  81. return 0, nil, nil
  82. }
  83. if i := strings.Index(string(data), "\n"); i >= 0 {
  84. return i + 1, data[0:i], nil
  85. }
  86. if atEOF {
  87. return len(data), data, nil
  88. }
  89. return 0, nil, nil
  90. })
  91. dataChan := make(chan string)
  92. stopChan := make(chan bool)
  93. go func() {
  94. for scanner.Scan() {
  95. data := scanner.Text()
  96. dataChan <- data
  97. }
  98. stopChan <- true
  99. }()
  100. helper.SetEventStreamHeaders(c)
  101. isFirst := true
  102. c.Stream(func(w io.Writer) bool {
  103. select {
  104. case data := <-dataChan:
  105. if isFirst {
  106. isFirst = false
  107. info.FirstResponseTime = time.Now()
  108. }
  109. data = strings.TrimSuffix(data, "\r")
  110. var cohereResp CohereResponse
  111. err := json.Unmarshal([]byte(data), &cohereResp)
  112. if err != nil {
  113. common.SysError("error unmarshalling stream response: " + err.Error())
  114. return true
  115. }
  116. var openaiResp dto.ChatCompletionsStreamResponse
  117. openaiResp.Id = responseId
  118. openaiResp.Created = createdTime
  119. openaiResp.Object = "chat.completion.chunk"
  120. openaiResp.Model = info.UpstreamModelName
  121. if cohereResp.IsFinished {
  122. finishReason := stopReasonCohere2OpenAI(cohereResp.FinishReason)
  123. openaiResp.Choices = []dto.ChatCompletionsStreamResponseChoice{
  124. {
  125. Delta: dto.ChatCompletionsStreamResponseChoiceDelta{},
  126. Index: 0,
  127. FinishReason: &finishReason,
  128. },
  129. }
  130. if cohereResp.Response != nil {
  131. usage.PromptTokens = cohereResp.Response.Meta.BilledUnits.InputTokens
  132. usage.CompletionTokens = cohereResp.Response.Meta.BilledUnits.OutputTokens
  133. }
  134. } else {
  135. openaiResp.Choices = []dto.ChatCompletionsStreamResponseChoice{
  136. {
  137. Delta: dto.ChatCompletionsStreamResponseChoiceDelta{
  138. Role: "assistant",
  139. Content: &cohereResp.Text,
  140. },
  141. Index: 0,
  142. },
  143. }
  144. responseText += cohereResp.Text
  145. }
  146. jsonStr, err := json.Marshal(openaiResp)
  147. if err != nil {
  148. common.SysError("error marshalling stream response: " + err.Error())
  149. return true
  150. }
  151. c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)})
  152. return true
  153. case <-stopChan:
  154. c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
  155. return false
  156. }
  157. })
  158. if usage.PromptTokens == 0 {
  159. usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
  160. }
  161. return nil, usage
  162. }
  163. func cohereHandler(c *gin.Context, resp *http.Response, modelName string, promptTokens int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
  164. createdTime := common.GetTimestamp()
  165. responseBody, err := io.ReadAll(resp.Body)
  166. if err != nil {
  167. return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
  168. }
  169. common.CloseResponseBodyGracefully(resp)
  170. var cohereResp CohereResponseResult
  171. err = json.Unmarshal(responseBody, &cohereResp)
  172. if err != nil {
  173. return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
  174. }
  175. usage := dto.Usage{}
  176. usage.PromptTokens = cohereResp.Meta.BilledUnits.InputTokens
  177. usage.CompletionTokens = cohereResp.Meta.BilledUnits.OutputTokens
  178. usage.TotalTokens = cohereResp.Meta.BilledUnits.InputTokens + cohereResp.Meta.BilledUnits.OutputTokens
  179. var openaiResp dto.TextResponse
  180. openaiResp.Id = cohereResp.ResponseId
  181. openaiResp.Created = createdTime
  182. openaiResp.Object = "chat.completion"
  183. openaiResp.Model = modelName
  184. openaiResp.Usage = usage
  185. openaiResp.Choices = []dto.OpenAITextResponseChoice{
  186. {
  187. Index: 0,
  188. Message: dto.Message{Content: cohereResp.Text, Role: "assistant"},
  189. FinishReason: stopReasonCohere2OpenAI(cohereResp.FinishReason),
  190. },
  191. }
  192. jsonResponse, err := json.Marshal(openaiResp)
  193. if err != nil {
  194. return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
  195. }
  196. c.Writer.Header().Set("Content-Type", "application/json")
  197. c.Writer.WriteHeader(resp.StatusCode)
  198. _, err = c.Writer.Write(jsonResponse)
  199. return nil, &usage
  200. }
  201. func cohereRerankHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
  202. responseBody, err := io.ReadAll(resp.Body)
  203. if err != nil {
  204. return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
  205. }
  206. common.CloseResponseBodyGracefully(resp)
  207. var cohereResp CohereRerankResponseResult
  208. err = json.Unmarshal(responseBody, &cohereResp)
  209. if err != nil {
  210. return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
  211. }
  212. usage := dto.Usage{}
  213. if cohereResp.Meta.BilledUnits.InputTokens == 0 {
  214. usage.PromptTokens = info.PromptTokens
  215. usage.CompletionTokens = 0
  216. usage.TotalTokens = info.PromptTokens
  217. } else {
  218. usage.PromptTokens = cohereResp.Meta.BilledUnits.InputTokens
  219. usage.CompletionTokens = cohereResp.Meta.BilledUnits.OutputTokens
  220. usage.TotalTokens = cohereResp.Meta.BilledUnits.InputTokens + cohereResp.Meta.BilledUnits.OutputTokens
  221. }
  222. var rerankResp dto.RerankResponse
  223. rerankResp.Results = cohereResp.Results
  224. rerankResp.Usage = usage
  225. jsonResponse, err := json.Marshal(rerankResp)
  226. if err != nil {
  227. return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
  228. }
  229. c.Writer.Header().Set("Content-Type", "application/json")
  230. c.Writer.WriteHeader(resp.StatusCode)
  231. _, err = c.Writer.Write(jsonResponse)
  232. return nil, &usage
  233. }