relay-cohere.go 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252
  1. package cohere
  2. import (
  3. "bufio"
  4. "encoding/json"
  5. "github.com/gin-gonic/gin"
  6. "io"
  7. "net/http"
  8. "one-api/common"
  9. "one-api/dto"
  10. relaycommon "one-api/relay/common"
  11. "one-api/relay/helper"
  12. "one-api/service"
  13. "strings"
  14. "time"
  15. )
  16. func requestOpenAI2Cohere(textRequest dto.GeneralOpenAIRequest) *CohereRequest {
  17. cohereReq := CohereRequest{
  18. Model: textRequest.Model,
  19. ChatHistory: []ChatHistory{},
  20. Message: "",
  21. Stream: textRequest.Stream,
  22. MaxTokens: textRequest.GetMaxTokens(),
  23. }
  24. if common.CohereSafetySetting != "NONE" {
  25. cohereReq.SafetyMode = common.CohereSafetySetting
  26. }
  27. if cohereReq.MaxTokens == 0 {
  28. cohereReq.MaxTokens = 4000
  29. }
  30. for _, msg := range textRequest.Messages {
  31. if msg.Role == "user" {
  32. cohereReq.Message = msg.StringContent()
  33. } else {
  34. var role string
  35. if msg.Role == "assistant" {
  36. role = "CHATBOT"
  37. } else if msg.Role == "system" {
  38. role = "SYSTEM"
  39. } else {
  40. role = "USER"
  41. }
  42. cohereReq.ChatHistory = append(cohereReq.ChatHistory, ChatHistory{
  43. Role: role,
  44. Message: msg.StringContent(),
  45. })
  46. }
  47. }
  48. return &cohereReq
  49. }
  50. func requestConvertRerank2Cohere(rerankRequest dto.RerankRequest) *CohereRerankRequest {
  51. if rerankRequest.TopN == 0 {
  52. rerankRequest.TopN = 1
  53. }
  54. cohereReq := CohereRerankRequest{
  55. Query: rerankRequest.Query,
  56. Documents: rerankRequest.Documents,
  57. Model: rerankRequest.Model,
  58. TopN: rerankRequest.TopN,
  59. ReturnDocuments: true,
  60. }
  61. return &cohereReq
  62. }
  63. func stopReasonCohere2OpenAI(reason string) string {
  64. switch reason {
  65. case "COMPLETE":
  66. return "stop"
  67. case "MAX_TOKENS":
  68. return "max_tokens"
  69. default:
  70. return reason
  71. }
  72. }
  73. func cohereStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
  74. responseId := helper.GetResponseID(c)
  75. createdTime := common.GetTimestamp()
  76. usage := &dto.Usage{}
  77. responseText := ""
  78. scanner := bufio.NewScanner(resp.Body)
  79. scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
  80. if atEOF && len(data) == 0 {
  81. return 0, nil, nil
  82. }
  83. if i := strings.Index(string(data), "\n"); i >= 0 {
  84. return i + 1, data[0:i], nil
  85. }
  86. if atEOF {
  87. return len(data), data, nil
  88. }
  89. return 0, nil, nil
  90. })
  91. dataChan := make(chan string)
  92. stopChan := make(chan bool)
  93. go func() {
  94. for scanner.Scan() {
  95. data := scanner.Text()
  96. dataChan <- data
  97. }
  98. stopChan <- true
  99. }()
  100. helper.SetEventStreamHeaders(c)
  101. isFirst := true
  102. c.Stream(func(w io.Writer) bool {
  103. select {
  104. case data := <-dataChan:
  105. if isFirst {
  106. isFirst = false
  107. info.FirstResponseTime = time.Now()
  108. }
  109. data = strings.TrimSuffix(data, "\r")
  110. var cohereResp CohereResponse
  111. err := json.Unmarshal([]byte(data), &cohereResp)
  112. if err != nil {
  113. common.SysError("error unmarshalling stream response: " + err.Error())
  114. return true
  115. }
  116. var openaiResp dto.ChatCompletionsStreamResponse
  117. openaiResp.Id = responseId
  118. openaiResp.Created = createdTime
  119. openaiResp.Object = "chat.completion.chunk"
  120. openaiResp.Model = info.UpstreamModelName
  121. if cohereResp.IsFinished {
  122. finishReason := stopReasonCohere2OpenAI(cohereResp.FinishReason)
  123. openaiResp.Choices = []dto.ChatCompletionsStreamResponseChoice{
  124. {
  125. Delta: dto.ChatCompletionsStreamResponseChoiceDelta{},
  126. Index: 0,
  127. FinishReason: &finishReason,
  128. },
  129. }
  130. if cohereResp.Response != nil {
  131. usage.PromptTokens = cohereResp.Response.Meta.BilledUnits.InputTokens
  132. usage.CompletionTokens = cohereResp.Response.Meta.BilledUnits.OutputTokens
  133. }
  134. } else {
  135. openaiResp.Choices = []dto.ChatCompletionsStreamResponseChoice{
  136. {
  137. Delta: dto.ChatCompletionsStreamResponseChoiceDelta{
  138. Role: "assistant",
  139. Content: &cohereResp.Text,
  140. },
  141. Index: 0,
  142. },
  143. }
  144. responseText += cohereResp.Text
  145. }
  146. jsonStr, err := json.Marshal(openaiResp)
  147. if err != nil {
  148. common.SysError("error marshalling stream response: " + err.Error())
  149. return true
  150. }
  151. c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)})
  152. return true
  153. case <-stopChan:
  154. c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
  155. return false
  156. }
  157. })
  158. if usage.PromptTokens == 0 {
  159. usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
  160. }
  161. return nil, usage
  162. }
  163. func cohereHandler(c *gin.Context, resp *http.Response, modelName string, promptTokens int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
  164. createdTime := common.GetTimestamp()
  165. responseBody, err := io.ReadAll(resp.Body)
  166. if err != nil {
  167. return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
  168. }
  169. err = resp.Body.Close()
  170. if err != nil {
  171. return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
  172. }
  173. var cohereResp CohereResponseResult
  174. err = json.Unmarshal(responseBody, &cohereResp)
  175. if err != nil {
  176. return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
  177. }
  178. usage := dto.Usage{}
  179. usage.PromptTokens = cohereResp.Meta.BilledUnits.InputTokens
  180. usage.CompletionTokens = cohereResp.Meta.BilledUnits.OutputTokens
  181. usage.TotalTokens = cohereResp.Meta.BilledUnits.InputTokens + cohereResp.Meta.BilledUnits.OutputTokens
  182. var openaiResp dto.TextResponse
  183. openaiResp.Id = cohereResp.ResponseId
  184. openaiResp.Created = createdTime
  185. openaiResp.Object = "chat.completion"
  186. openaiResp.Model = modelName
  187. openaiResp.Usage = usage
  188. openaiResp.Choices = []dto.OpenAITextResponseChoice{
  189. {
  190. Index: 0,
  191. Message: dto.Message{Content: cohereResp.Text, Role: "assistant"},
  192. FinishReason: stopReasonCohere2OpenAI(cohereResp.FinishReason),
  193. },
  194. }
  195. jsonResponse, err := json.Marshal(openaiResp)
  196. if err != nil {
  197. return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
  198. }
  199. c.Writer.Header().Set("Content-Type", "application/json")
  200. c.Writer.WriteHeader(resp.StatusCode)
  201. _, err = c.Writer.Write(jsonResponse)
  202. return nil, &usage
  203. }
  204. func cohereRerankHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
  205. responseBody, err := io.ReadAll(resp.Body)
  206. if err != nil {
  207. return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
  208. }
  209. err = resp.Body.Close()
  210. if err != nil {
  211. return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
  212. }
  213. var cohereResp CohereRerankResponseResult
  214. err = json.Unmarshal(responseBody, &cohereResp)
  215. if err != nil {
  216. return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
  217. }
  218. usage := dto.Usage{}
  219. if cohereResp.Meta.BilledUnits.InputTokens == 0 {
  220. usage.PromptTokens = info.PromptTokens
  221. usage.CompletionTokens = 0
  222. usage.TotalTokens = info.PromptTokens
  223. } else {
  224. usage.PromptTokens = cohereResp.Meta.BilledUnits.InputTokens
  225. usage.CompletionTokens = cohereResp.Meta.BilledUnits.OutputTokens
  226. usage.TotalTokens = cohereResp.Meta.BilledUnits.InputTokens + cohereResp.Meta.BilledUnits.OutputTokens
  227. }
  228. var rerankResp dto.RerankResponse
  229. rerankResp.Results = cohereResp.Results
  230. rerankResp.Usage = usage
  231. jsonResponse, err := json.Marshal(rerankResp)
  232. if err != nil {
  233. return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
  234. }
  235. c.Writer.Header().Set("Content-Type", "application/json")
  236. c.Writer.WriteHeader(resp.StatusCode)
  237. _, err = c.Writer.Write(jsonResponse)
  238. return nil, &usage
  239. }