relay-cohere.go 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249
  1. package cohere
  2. import (
  3. "bufio"
  4. "encoding/json"
  5. "fmt"
  6. "github.com/gin-gonic/gin"
  7. "io"
  8. "net/http"
  9. "one-api/common"
  10. "one-api/dto"
  11. relaycommon "one-api/relay/common"
  12. "one-api/service"
  13. "strings"
  14. "time"
  15. )
  16. func requestOpenAI2Cohere(textRequest dto.GeneralOpenAIRequest) *CohereRequest {
  17. cohereReq := CohereRequest{
  18. Model: textRequest.Model,
  19. ChatHistory: []ChatHistory{},
  20. Message: "",
  21. Stream: textRequest.Stream,
  22. MaxTokens: textRequest.GetMaxTokens(),
  23. }
  24. if cohereReq.MaxTokens == 0 {
  25. cohereReq.MaxTokens = 4000
  26. }
  27. for _, msg := range textRequest.Messages {
  28. if msg.Role == "user" {
  29. cohereReq.Message = msg.StringContent()
  30. } else {
  31. var role string
  32. if msg.Role == "assistant" {
  33. role = "CHATBOT"
  34. } else if msg.Role == "system" {
  35. role = "SYSTEM"
  36. } else {
  37. role = "USER"
  38. }
  39. cohereReq.ChatHistory = append(cohereReq.ChatHistory, ChatHistory{
  40. Role: role,
  41. Message: msg.StringContent(),
  42. })
  43. }
  44. }
  45. return &cohereReq
  46. }
  47. func requestConvertRerank2Cohere(rerankRequest dto.RerankRequest) *CohereRerankRequest {
  48. cohereReq := CohereRerankRequest{
  49. Query: rerankRequest.Query,
  50. Documents: rerankRequest.Documents,
  51. Model: rerankRequest.Model,
  52. TopN: rerankRequest.TopN,
  53. ReturnDocuments: true,
  54. }
  55. for _, doc := range rerankRequest.Documents {
  56. cohereReq.Documents = append(cohereReq.Documents, doc)
  57. }
  58. return &cohereReq
  59. }
  60. func stopReasonCohere2OpenAI(reason string) string {
  61. switch reason {
  62. case "COMPLETE":
  63. return "stop"
  64. case "MAX_TOKENS":
  65. return "max_tokens"
  66. default:
  67. return reason
  68. }
  69. }
  70. func cohereStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
  71. responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
  72. createdTime := common.GetTimestamp()
  73. usage := &dto.Usage{}
  74. responseText := ""
  75. scanner := bufio.NewScanner(resp.Body)
  76. scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
  77. if atEOF && len(data) == 0 {
  78. return 0, nil, nil
  79. }
  80. if i := strings.Index(string(data), "\n"); i >= 0 {
  81. return i + 1, data[0:i], nil
  82. }
  83. if atEOF {
  84. return len(data), data, nil
  85. }
  86. return 0, nil, nil
  87. })
  88. dataChan := make(chan string)
  89. stopChan := make(chan bool)
  90. go func() {
  91. for scanner.Scan() {
  92. data := scanner.Text()
  93. dataChan <- data
  94. }
  95. stopChan <- true
  96. }()
  97. service.SetEventStreamHeaders(c)
  98. isFirst := true
  99. c.Stream(func(w io.Writer) bool {
  100. select {
  101. case data := <-dataChan:
  102. if isFirst {
  103. isFirst = false
  104. info.FirstResponseTime = time.Now()
  105. }
  106. data = strings.TrimSuffix(data, "\r")
  107. var cohereResp CohereResponse
  108. err := json.Unmarshal([]byte(data), &cohereResp)
  109. if err != nil {
  110. common.SysError("error unmarshalling stream response: " + err.Error())
  111. return true
  112. }
  113. var openaiResp dto.ChatCompletionsStreamResponse
  114. openaiResp.Id = responseId
  115. openaiResp.Created = createdTime
  116. openaiResp.Object = "chat.completion.chunk"
  117. openaiResp.Model = info.UpstreamModelName
  118. if cohereResp.IsFinished {
  119. finishReason := stopReasonCohere2OpenAI(cohereResp.FinishReason)
  120. openaiResp.Choices = []dto.ChatCompletionsStreamResponseChoice{
  121. {
  122. Delta: dto.ChatCompletionsStreamResponseChoiceDelta{},
  123. Index: 0,
  124. FinishReason: &finishReason,
  125. },
  126. }
  127. if cohereResp.Response != nil {
  128. usage.PromptTokens = cohereResp.Response.Meta.BilledUnits.InputTokens
  129. usage.CompletionTokens = cohereResp.Response.Meta.BilledUnits.OutputTokens
  130. }
  131. } else {
  132. openaiResp.Choices = []dto.ChatCompletionsStreamResponseChoice{
  133. {
  134. Delta: dto.ChatCompletionsStreamResponseChoiceDelta{
  135. Role: "assistant",
  136. Content: &cohereResp.Text,
  137. },
  138. Index: 0,
  139. },
  140. }
  141. responseText += cohereResp.Text
  142. }
  143. jsonStr, err := json.Marshal(openaiResp)
  144. if err != nil {
  145. common.SysError("error marshalling stream response: " + err.Error())
  146. return true
  147. }
  148. c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)})
  149. return true
  150. case <-stopChan:
  151. c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
  152. return false
  153. }
  154. })
  155. if usage.PromptTokens == 0 {
  156. usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
  157. }
  158. return nil, usage
  159. }
  160. func cohereHandler(c *gin.Context, resp *http.Response, modelName string, promptTokens int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
  161. createdTime := common.GetTimestamp()
  162. responseBody, err := io.ReadAll(resp.Body)
  163. if err != nil {
  164. return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
  165. }
  166. err = resp.Body.Close()
  167. if err != nil {
  168. return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
  169. }
  170. var cohereResp CohereResponseResult
  171. err = json.Unmarshal(responseBody, &cohereResp)
  172. if err != nil {
  173. return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
  174. }
  175. usage := dto.Usage{}
  176. usage.PromptTokens = cohereResp.Meta.BilledUnits.InputTokens
  177. usage.CompletionTokens = cohereResp.Meta.BilledUnits.OutputTokens
  178. usage.TotalTokens = cohereResp.Meta.BilledUnits.InputTokens + cohereResp.Meta.BilledUnits.OutputTokens
  179. var openaiResp dto.TextResponse
  180. openaiResp.Id = cohereResp.ResponseId
  181. openaiResp.Created = createdTime
  182. openaiResp.Object = "chat.completion"
  183. openaiResp.Model = modelName
  184. openaiResp.Usage = usage
  185. content, _ := json.Marshal(cohereResp.Text)
  186. openaiResp.Choices = []dto.OpenAITextResponseChoice{
  187. {
  188. Index: 0,
  189. Message: dto.Message{Content: content, Role: "assistant"},
  190. FinishReason: stopReasonCohere2OpenAI(cohereResp.FinishReason),
  191. },
  192. }
  193. jsonResponse, err := json.Marshal(openaiResp)
  194. if err != nil {
  195. return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
  196. }
  197. c.Writer.Header().Set("Content-Type", "application/json")
  198. c.Writer.WriteHeader(resp.StatusCode)
  199. _, err = c.Writer.Write(jsonResponse)
  200. return nil, &usage
  201. }
  202. func cohereRerankHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
  203. responseBody, err := io.ReadAll(resp.Body)
  204. if err != nil {
  205. return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
  206. }
  207. err = resp.Body.Close()
  208. if err != nil {
  209. return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
  210. }
  211. var cohereResp CohereRerankResponseResult
  212. err = json.Unmarshal(responseBody, &cohereResp)
  213. if err != nil {
  214. return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
  215. }
  216. usage := dto.Usage{}
  217. if cohereResp.Meta.BilledUnits.InputTokens == 0 {
  218. usage.PromptTokens = info.PromptTokens
  219. usage.CompletionTokens = 0
  220. usage.TotalTokens = info.PromptTokens
  221. } else {
  222. usage.PromptTokens = cohereResp.Meta.BilledUnits.InputTokens
  223. usage.CompletionTokens = cohereResp.Meta.BilledUnits.OutputTokens
  224. usage.TotalTokens = cohereResp.Meta.BilledUnits.InputTokens + cohereResp.Meta.BilledUnits.OutputTokens
  225. }
  226. var rerankResp dto.RerankResponse
  227. rerankResp.Results = cohereResp.Results
  228. rerankResp.Usage = usage
  229. jsonResponse, err := json.Marshal(rerankResp)
  230. if err != nil {
  231. return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
  232. }
  233. c.Writer.Header().Set("Content-Type", "application/json")
  234. c.Writer.WriteHeader(resp.StatusCode)
  235. _, err = c.Writer.Write(jsonResponse)
  236. return nil, &usage
  237. }