relay-claude.go 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220
  1. package controller
  2. import (
  3. "bufio"
  4. "encoding/json"
  5. "fmt"
  6. "github.com/gin-gonic/gin"
  7. "io"
  8. "net/http"
  9. "one-api/common"
  10. "strings"
  11. )
  12. type ClaudeMetadata struct {
  13. UserId string `json:"user_id"`
  14. }
  15. type ClaudeRequest struct {
  16. Model string `json:"model"`
  17. Prompt string `json:"prompt"`
  18. MaxTokensToSample int `json:"max_tokens_to_sample"`
  19. StopSequences []string `json:"stop_sequences,omitempty"`
  20. Temperature float64 `json:"temperature,omitempty"`
  21. TopP float64 `json:"top_p,omitempty"`
  22. TopK int `json:"top_k,omitempty"`
  23. //ClaudeMetadata `json:"metadata,omitempty"`
  24. Stream bool `json:"stream,omitempty"`
  25. }
  26. type ClaudeError struct {
  27. Type string `json:"type"`
  28. Message string `json:"message"`
  29. }
  30. type ClaudeResponse struct {
  31. Completion string `json:"completion"`
  32. StopReason string `json:"stop_reason"`
  33. Model string `json:"model"`
  34. Error ClaudeError `json:"error"`
  35. }
  36. func stopReasonClaude2OpenAI(reason string) string {
  37. switch reason {
  38. case "stop_sequence":
  39. return "stop"
  40. case "max_tokens":
  41. return "length"
  42. default:
  43. return reason
  44. }
  45. }
  46. func requestOpenAI2Claude(textRequest GeneralOpenAIRequest) *ClaudeRequest {
  47. claudeRequest := ClaudeRequest{
  48. Model: textRequest.Model,
  49. Prompt: "",
  50. MaxTokensToSample: textRequest.MaxTokens,
  51. StopSequences: nil,
  52. Temperature: textRequest.Temperature,
  53. TopP: textRequest.TopP,
  54. Stream: textRequest.Stream,
  55. }
  56. if claudeRequest.MaxTokensToSample == 0 {
  57. claudeRequest.MaxTokensToSample = 1000000
  58. }
  59. prompt := ""
  60. for _, message := range textRequest.Messages {
  61. if message.Role == "user" {
  62. prompt += fmt.Sprintf("\n\nHuman: %s", message.Content)
  63. } else if message.Role == "assistant" {
  64. prompt += fmt.Sprintf("\n\nAssistant: %s", message.Content)
  65. } else if message.Role == "system" {
  66. prompt += fmt.Sprintf("\n\nSystem: %s", message.Content)
  67. }
  68. }
  69. prompt += "\n\nAssistant:"
  70. claudeRequest.Prompt = prompt
  71. return &claudeRequest
  72. }
  73. func streamResponseClaude2OpenAI(claudeResponse *ClaudeResponse) *ChatCompletionsStreamResponse {
  74. var choice ChatCompletionsStreamResponseChoice
  75. choice.Delta.Content = claudeResponse.Completion
  76. finishReason := stopReasonClaude2OpenAI(claudeResponse.StopReason)
  77. if finishReason != "null" {
  78. choice.FinishReason = &finishReason
  79. }
  80. var response ChatCompletionsStreamResponse
  81. response.Object = "chat.completion.chunk"
  82. response.Model = claudeResponse.Model
  83. response.Choices = []ChatCompletionsStreamResponseChoice{choice}
  84. return &response
  85. }
  86. func responseClaude2OpenAI(claudeResponse *ClaudeResponse) *OpenAITextResponse {
  87. choice := OpenAITextResponseChoice{
  88. Index: 0,
  89. Message: Message{
  90. Role: "assistant",
  91. Content: strings.TrimPrefix(claudeResponse.Completion, " "),
  92. Name: nil,
  93. },
  94. FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason),
  95. }
  96. fullTextResponse := OpenAITextResponse{
  97. Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
  98. Object: "chat.completion",
  99. Created: common.GetTimestamp(),
  100. Choices: []OpenAITextResponseChoice{choice},
  101. }
  102. return &fullTextResponse
  103. }
  104. func claudeStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, string) {
  105. responseText := ""
  106. responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
  107. createdTime := common.GetTimestamp()
  108. scanner := bufio.NewScanner(resp.Body)
  109. scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
  110. if atEOF && len(data) == 0 {
  111. return 0, nil, nil
  112. }
  113. if i := strings.Index(string(data), "\r\n\r\n"); i >= 0 {
  114. return i + 4, data[0:i], nil
  115. }
  116. if atEOF {
  117. return len(data), data, nil
  118. }
  119. return 0, nil, nil
  120. })
  121. dataChan := make(chan string)
  122. stopChan := make(chan bool)
  123. go func() {
  124. for scanner.Scan() {
  125. data := scanner.Text()
  126. if !strings.HasPrefix(data, "event: completion") {
  127. continue
  128. }
  129. data = strings.TrimPrefix(data, "event: completion\r\ndata: ")
  130. dataChan <- data
  131. }
  132. stopChan <- true
  133. }()
  134. setEventStreamHeaders(c)
  135. c.Stream(func(w io.Writer) bool {
  136. select {
  137. case data := <-dataChan:
  138. // some implementations may add \r at the end of data
  139. data = strings.TrimSuffix(data, "\r")
  140. var claudeResponse ClaudeResponse
  141. err := json.Unmarshal([]byte(data), &claudeResponse)
  142. if err != nil {
  143. common.SysError("error unmarshalling stream response: " + err.Error())
  144. return true
  145. }
  146. responseText += claudeResponse.Completion
  147. response := streamResponseClaude2OpenAI(&claudeResponse)
  148. response.Id = responseId
  149. response.Created = createdTime
  150. jsonStr, err := json.Marshal(response)
  151. if err != nil {
  152. common.SysError("error marshalling stream response: " + err.Error())
  153. return true
  154. }
  155. c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)})
  156. return true
  157. case <-stopChan:
  158. c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
  159. return false
  160. }
  161. })
  162. err := resp.Body.Close()
  163. if err != nil {
  164. return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
  165. }
  166. return nil, responseText
  167. }
  168. func claudeHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) {
  169. responseBody, err := io.ReadAll(resp.Body)
  170. if err != nil {
  171. return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
  172. }
  173. err = resp.Body.Close()
  174. if err != nil {
  175. return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
  176. }
  177. var claudeResponse ClaudeResponse
  178. err = json.Unmarshal(responseBody, &claudeResponse)
  179. if err != nil {
  180. return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
  181. }
  182. if claudeResponse.Error.Type != "" {
  183. return &OpenAIErrorWithStatusCode{
  184. OpenAIError: OpenAIError{
  185. Message: claudeResponse.Error.Message,
  186. Type: claudeResponse.Error.Type,
  187. Param: "",
  188. Code: claudeResponse.Error.Type,
  189. },
  190. StatusCode: resp.StatusCode,
  191. }, nil
  192. }
  193. fullTextResponse := responseClaude2OpenAI(&claudeResponse)
  194. completionTokens := countTokenText(claudeResponse.Completion, model)
  195. usage := Usage{
  196. PromptTokens: promptTokens,
  197. CompletionTokens: completionTokens,
  198. TotalTokens: promptTokens + completionTokens,
  199. }
  200. fullTextResponse.Usage = usage
  201. jsonResponse, err := json.Marshal(fullTextResponse)
  202. if err != nil {
  203. return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
  204. }
  205. c.Writer.Header().Set("Content-Type", "application/json")
  206. c.Writer.WriteHeader(resp.StatusCode)
  207. _, err = c.Writer.Write(jsonResponse)
  208. return nil, &usage
  209. }