relay-cohere.go 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189
  1. package cohere
  2. import (
  3. "bufio"
  4. "encoding/json"
  5. "fmt"
  6. "github.com/gin-gonic/gin"
  7. "io"
  8. "net/http"
  9. "one-api/common"
  10. "one-api/dto"
  11. "one-api/service"
  12. "strings"
  13. )
  14. func requestOpenAI2Cohere(textRequest dto.GeneralOpenAIRequest) *CohereRequest {
  15. cohereReq := CohereRequest{
  16. Model: textRequest.Model,
  17. ChatHistory: []ChatHistory{},
  18. Message: "",
  19. Stream: textRequest.Stream,
  20. MaxTokens: textRequest.GetMaxTokens(),
  21. }
  22. if cohereReq.MaxTokens == 0 {
  23. cohereReq.MaxTokens = 4000
  24. }
  25. for _, msg := range textRequest.Messages {
  26. if msg.Role == "user" {
  27. cohereReq.Message = msg.StringContent()
  28. } else {
  29. var role string
  30. if msg.Role == "assistant" {
  31. role = "CHATBOT"
  32. } else if msg.Role == "system" {
  33. role = "SYSTEM"
  34. } else {
  35. role = "USER"
  36. }
  37. cohereReq.ChatHistory = append(cohereReq.ChatHistory, ChatHistory{
  38. Role: role,
  39. Message: msg.StringContent(),
  40. })
  41. }
  42. }
  43. return &cohereReq
  44. }
  45. func stopReasonCohere2OpenAI(reason string) string {
  46. switch reason {
  47. case "COMPLETE":
  48. return "stop"
  49. case "MAX_TOKENS":
  50. return "max_tokens"
  51. default:
  52. return reason
  53. }
  54. }
  55. func cohereStreamHandler(c *gin.Context, resp *http.Response, modelName string, promptTokens int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
  56. responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
  57. createdTime := common.GetTimestamp()
  58. usage := &dto.Usage{}
  59. responseText := ""
  60. scanner := bufio.NewScanner(resp.Body)
  61. scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
  62. if atEOF && len(data) == 0 {
  63. return 0, nil, nil
  64. }
  65. if i := strings.Index(string(data), "\n"); i >= 0 {
  66. return i + 1, data[0:i], nil
  67. }
  68. if atEOF {
  69. return len(data), data, nil
  70. }
  71. return 0, nil, nil
  72. })
  73. dataChan := make(chan string)
  74. stopChan := make(chan bool)
  75. go func() {
  76. for scanner.Scan() {
  77. data := scanner.Text()
  78. dataChan <- data
  79. }
  80. stopChan <- true
  81. }()
  82. service.SetEventStreamHeaders(c)
  83. c.Stream(func(w io.Writer) bool {
  84. select {
  85. case data := <-dataChan:
  86. data = strings.TrimSuffix(data, "\r")
  87. var cohereResp CohereResponse
  88. err := json.Unmarshal([]byte(data), &cohereResp)
  89. if err != nil {
  90. common.SysError("error unmarshalling stream response: " + err.Error())
  91. return true
  92. }
  93. var openaiResp dto.ChatCompletionsStreamResponse
  94. openaiResp.Id = responseId
  95. openaiResp.Created = createdTime
  96. openaiResp.Object = "chat.completion.chunk"
  97. openaiResp.Model = modelName
  98. if cohereResp.IsFinished {
  99. finishReason := stopReasonCohere2OpenAI(cohereResp.FinishReason)
  100. openaiResp.Choices = []dto.ChatCompletionsStreamResponseChoice{
  101. {
  102. Delta: dto.ChatCompletionsStreamResponseChoiceDelta{},
  103. Index: 0,
  104. FinishReason: &finishReason,
  105. },
  106. }
  107. if cohereResp.Response != nil {
  108. usage.PromptTokens = cohereResp.Response.Meta.BilledUnits.InputTokens
  109. usage.CompletionTokens = cohereResp.Response.Meta.BilledUnits.OutputTokens
  110. }
  111. } else {
  112. openaiResp.Choices = []dto.ChatCompletionsStreamResponseChoice{
  113. {
  114. Delta: dto.ChatCompletionsStreamResponseChoiceDelta{
  115. Role: "assistant",
  116. Content: &cohereResp.Text,
  117. },
  118. Index: 0,
  119. },
  120. }
  121. responseText += cohereResp.Text
  122. }
  123. jsonStr, err := json.Marshal(openaiResp)
  124. if err != nil {
  125. common.SysError("error marshalling stream response: " + err.Error())
  126. return true
  127. }
  128. c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)})
  129. return true
  130. case <-stopChan:
  131. c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
  132. return false
  133. }
  134. })
  135. if usage.PromptTokens == 0 {
  136. usage, _ = service.ResponseText2Usage(responseText, modelName, promptTokens)
  137. }
  138. return nil, usage
  139. }
  140. func cohereHandler(c *gin.Context, resp *http.Response, modelName string, promptTokens int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
  141. createdTime := common.GetTimestamp()
  142. responseBody, err := io.ReadAll(resp.Body)
  143. if err != nil {
  144. return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
  145. }
  146. err = resp.Body.Close()
  147. if err != nil {
  148. return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
  149. }
  150. var cohereResp CohereResponseResult
  151. err = json.Unmarshal(responseBody, &cohereResp)
  152. if err != nil {
  153. return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
  154. }
  155. usage := dto.Usage{}
  156. usage.PromptTokens = cohereResp.Meta.BilledUnits.InputTokens
  157. usage.CompletionTokens = cohereResp.Meta.BilledUnits.OutputTokens
  158. usage.TotalTokens = cohereResp.Meta.BilledUnits.InputTokens + cohereResp.Meta.BilledUnits.OutputTokens
  159. var openaiResp dto.TextResponse
  160. openaiResp.Id = cohereResp.ResponseId
  161. openaiResp.Created = createdTime
  162. openaiResp.Object = "chat.completion"
  163. openaiResp.Model = modelName
  164. openaiResp.Usage = usage
  165. content, _ := json.Marshal(cohereResp.Text)
  166. openaiResp.Choices = []dto.OpenAITextResponseChoice{
  167. {
  168. Index: 0,
  169. Message: dto.Message{Content: content, Role: "assistant"},
  170. FinishReason: stopReasonCohere2OpenAI(cohereResp.FinishReason),
  171. },
  172. }
  173. jsonResponse, err := json.Marshal(openaiResp)
  174. if err != nil {
  175. return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
  176. }
  177. c.Writer.Header().Set("Content-Type", "application/json")
  178. c.Writer.WriteHeader(resp.StatusCode)
  179. _, err = c.Writer.Write(jsonResponse)
  180. return nil, &usage
  181. }