relay-claude.go 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224
  1. package controller
  2. import (
  3. "bufio"
  4. "encoding/json"
  5. "fmt"
  6. "github.com/gin-gonic/gin"
  7. "io"
  8. "net/http"
  9. "one-api/common"
  10. "strings"
  11. )
  12. type ClaudeMetadata struct {
  13. UserId string `json:"user_id"`
  14. }
  15. type ClaudeRequest struct {
  16. Model string `json:"model"`
  17. Prompt string `json:"prompt"`
  18. MaxTokensToSample int `json:"max_tokens_to_sample"`
  19. StopSequences []string `json:"stop_sequences,omitempty"`
  20. Temperature float64 `json:"temperature,omitempty"`
  21. TopP float64 `json:"top_p,omitempty"`
  22. TopK int `json:"top_k,omitempty"`
  23. //ClaudeMetadata `json:"metadata,omitempty"`
  24. Stream bool `json:"stream,omitempty"`
  25. }
  26. type ClaudeError struct {
  27. Type string `json:"type"`
  28. Message string `json:"message"`
  29. }
  30. type ClaudeResponse struct {
  31. Completion string `json:"completion"`
  32. StopReason string `json:"stop_reason"`
  33. Model string `json:"model"`
  34. Error ClaudeError `json:"error"`
  35. }
  36. func stopReasonClaude2OpenAI(reason string) string {
  37. switch reason {
  38. case "stop_sequence":
  39. return "stop"
  40. case "max_tokens":
  41. return "length"
  42. default:
  43. return reason
  44. }
  45. }
  46. func requestOpenAI2Claude(textRequest GeneralOpenAIRequest) *ClaudeRequest {
  47. claudeRequest := ClaudeRequest{
  48. Model: textRequest.Model,
  49. Prompt: "",
  50. MaxTokensToSample: textRequest.MaxTokens,
  51. StopSequences: nil,
  52. Temperature: textRequest.Temperature,
  53. TopP: textRequest.TopP,
  54. Stream: textRequest.Stream,
  55. }
  56. if claudeRequest.MaxTokensToSample == 0 {
  57. claudeRequest.MaxTokensToSample = 1000000
  58. }
  59. prompt := ""
  60. for _, message := range textRequest.Messages {
  61. if message.Role == "user" {
  62. prompt += fmt.Sprintf("\n\nHuman: %s", message.Content)
  63. } else if message.Role == "assistant" {
  64. prompt += fmt.Sprintf("\n\nAssistant: %s", message.Content)
  65. } else if message.Role == "system" {
  66. prompt += fmt.Sprintf("\n\nSystem: %s", message.Content)
  67. }
  68. }
  69. prompt += "\n\nAssistant:"
  70. claudeRequest.Prompt = prompt
  71. return &claudeRequest
  72. }
  73. func streamResponseClaude2OpenAI(claudeResponse *ClaudeResponse) *ChatCompletionsStreamResponse {
  74. var choice ChatCompletionsStreamResponseChoice
  75. choice.Delta.Content = claudeResponse.Completion
  76. finishReason := stopReasonClaude2OpenAI(claudeResponse.StopReason)
  77. if finishReason != "null" {
  78. choice.FinishReason = &finishReason
  79. }
  80. var response ChatCompletionsStreamResponse
  81. response.Object = "chat.completion.chunk"
  82. response.Model = claudeResponse.Model
  83. response.Choices = []ChatCompletionsStreamResponseChoice{choice}
  84. return &response
  85. }
  86. func responseClaude2OpenAI(claudeResponse *ClaudeResponse) *OpenAITextResponse {
  87. choice := OpenAITextResponseChoice{
  88. Index: 0,
  89. Message: Message{
  90. Role: "assistant",
  91. Content: strings.TrimPrefix(claudeResponse.Completion, " "),
  92. Name: nil,
  93. },
  94. FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason),
  95. }
  96. fullTextResponse := OpenAITextResponse{
  97. Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
  98. Object: "chat.completion",
  99. Created: common.GetTimestamp(),
  100. Choices: []OpenAITextResponseChoice{choice},
  101. }
  102. return &fullTextResponse
  103. }
  104. func claudeStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, string) {
  105. responseText := ""
  106. responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
  107. createdTime := common.GetTimestamp()
  108. scanner := bufio.NewScanner(resp.Body)
  109. scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
  110. if atEOF && len(data) == 0 {
  111. return 0, nil, nil
  112. }
  113. if i := strings.Index(string(data), "\r\n\r\n"); i >= 0 {
  114. return i + 4, data[0:i], nil
  115. }
  116. if atEOF {
  117. return len(data), data, nil
  118. }
  119. return 0, nil, nil
  120. })
  121. dataChan := make(chan string)
  122. stopChan := make(chan bool)
  123. go func() {
  124. for scanner.Scan() {
  125. data := scanner.Text()
  126. if !strings.HasPrefix(data, "event: completion") {
  127. continue
  128. }
  129. data = strings.TrimPrefix(data, "event: completion\r\ndata: ")
  130. dataChan <- data
  131. }
  132. stopChan <- true
  133. }()
  134. c.Writer.Header().Set("Content-Type", "text/event-stream")
  135. c.Writer.Header().Set("Cache-Control", "no-cache")
  136. c.Writer.Header().Set("Connection", "keep-alive")
  137. c.Writer.Header().Set("Transfer-Encoding", "chunked")
  138. c.Writer.Header().Set("X-Accel-Buffering", "no")
  139. c.Stream(func(w io.Writer) bool {
  140. select {
  141. case data := <-dataChan:
  142. // some implementations may add \r at the end of data
  143. data = strings.TrimSuffix(data, "\r")
  144. var claudeResponse ClaudeResponse
  145. err := json.Unmarshal([]byte(data), &claudeResponse)
  146. if err != nil {
  147. common.SysError("error unmarshalling stream response: " + err.Error())
  148. return true
  149. }
  150. responseText += claudeResponse.Completion
  151. response := streamResponseClaude2OpenAI(&claudeResponse)
  152. response.Id = responseId
  153. response.Created = createdTime
  154. jsonStr, err := json.Marshal(response)
  155. if err != nil {
  156. common.SysError("error marshalling stream response: " + err.Error())
  157. return true
  158. }
  159. c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)})
  160. return true
  161. case <-stopChan:
  162. c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
  163. return false
  164. }
  165. })
  166. err := resp.Body.Close()
  167. if err != nil {
  168. return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
  169. }
  170. return nil, responseText
  171. }
  172. func claudeHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) {
  173. responseBody, err := io.ReadAll(resp.Body)
  174. if err != nil {
  175. return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
  176. }
  177. err = resp.Body.Close()
  178. if err != nil {
  179. return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
  180. }
  181. var claudeResponse ClaudeResponse
  182. err = json.Unmarshal(responseBody, &claudeResponse)
  183. if err != nil {
  184. return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
  185. }
  186. if claudeResponse.Error.Type != "" {
  187. return &OpenAIErrorWithStatusCode{
  188. OpenAIError: OpenAIError{
  189. Message: claudeResponse.Error.Message,
  190. Type: claudeResponse.Error.Type,
  191. Param: "",
  192. Code: claudeResponse.Error.Type,
  193. },
  194. StatusCode: resp.StatusCode,
  195. }, nil
  196. }
  197. fullTextResponse := responseClaude2OpenAI(&claudeResponse)
  198. completionTokens := countTokenText(claudeResponse.Completion, model)
  199. usage := Usage{
  200. PromptTokens: promptTokens,
  201. CompletionTokens: completionTokens,
  202. TotalTokens: promptTokens + completionTokens,
  203. }
  204. fullTextResponse.Usage = usage
  205. jsonResponse, err := json.Marshal(fullTextResponse)
  206. if err != nil {
  207. return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
  208. }
  209. c.Writer.Header().Set("Content-Type", "application/json")
  210. c.Writer.WriteHeader(resp.StatusCode)
  211. _, err = c.Writer.Write(jsonResponse)
  212. return nil, &usage
  213. }