relay-claude.go 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221
  1. package controller
  2. import (
  3. "bufio"
  4. "encoding/json"
  5. "fmt"
  6. "github.com/gin-gonic/gin"
  7. "io"
  8. "net/http"
  9. "one-api/common"
  10. "strings"
  11. )
  12. type ClaudeMetadata struct {
  13. UserId string `json:"user_id"`
  14. }
  15. type ClaudeRequest struct {
  16. Model string `json:"model"`
  17. Prompt string `json:"prompt"`
  18. MaxTokensToSample uint `json:"max_tokens_to_sample"`
  19. StopSequences []string `json:"stop_sequences,omitempty"`
  20. Temperature float64 `json:"temperature,omitempty"`
  21. TopP float64 `json:"top_p,omitempty"`
  22. TopK int `json:"top_k,omitempty"`
  23. //ClaudeMetadata `json:"metadata,omitempty"`
  24. Stream bool `json:"stream,omitempty"`
  25. }
  26. type ClaudeError struct {
  27. Type string `json:"type"`
  28. Message string `json:"message"`
  29. }
  30. type ClaudeResponse struct {
  31. Completion string `json:"completion"`
  32. StopReason string `json:"stop_reason"`
  33. Model string `json:"model"`
  34. Error ClaudeError `json:"error"`
  35. }
  36. func stopReasonClaude2OpenAI(reason string) string {
  37. switch reason {
  38. case "stop_sequence":
  39. return "stop"
  40. case "max_tokens":
  41. return "length"
  42. default:
  43. return reason
  44. }
  45. }
  46. func requestOpenAI2Claude(textRequest GeneralOpenAIRequest) *ClaudeRequest {
  47. claudeRequest := ClaudeRequest{
  48. Model: textRequest.Model,
  49. Prompt: "",
  50. MaxTokensToSample: textRequest.MaxTokens,
  51. StopSequences: nil,
  52. Temperature: textRequest.Temperature,
  53. TopP: textRequest.TopP,
  54. Stream: textRequest.Stream,
  55. }
  56. if claudeRequest.MaxTokensToSample == 0 {
  57. claudeRequest.MaxTokensToSample = 1000000
  58. }
  59. prompt := ""
  60. for _, message := range textRequest.Messages {
  61. if message.Role == "user" {
  62. prompt += fmt.Sprintf("\n\nHuman: %s", message.Content)
  63. } else if message.Role == "assistant" {
  64. prompt += fmt.Sprintf("\n\nAssistant: %s", message.Content)
  65. } else if message.Role == "system" {
  66. prompt += fmt.Sprintf("\n\nSystem: %s", message.Content)
  67. }
  68. }
  69. prompt += "\n\nAssistant:"
  70. claudeRequest.Prompt = prompt
  71. return &claudeRequest
  72. }
  73. func streamResponseClaude2OpenAI(claudeResponse *ClaudeResponse) *ChatCompletionsStreamResponse {
  74. var choice ChatCompletionsStreamResponseChoice
  75. choice.Delta.Content = claudeResponse.Completion
  76. finishReason := stopReasonClaude2OpenAI(claudeResponse.StopReason)
  77. if finishReason != "null" {
  78. choice.FinishReason = &finishReason
  79. }
  80. var response ChatCompletionsStreamResponse
  81. response.Object = "chat.completion.chunk"
  82. response.Model = claudeResponse.Model
  83. response.Choices = []ChatCompletionsStreamResponseChoice{choice}
  84. return &response
  85. }
  86. func responseClaude2OpenAI(claudeResponse *ClaudeResponse) *OpenAITextResponse {
  87. content, _ := json.Marshal(strings.TrimPrefix(claudeResponse.Completion, " "))
  88. choice := OpenAITextResponseChoice{
  89. Index: 0,
  90. Message: Message{
  91. Role: "assistant",
  92. Content: content,
  93. Name: nil,
  94. },
  95. FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason),
  96. }
  97. fullTextResponse := OpenAITextResponse{
  98. Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
  99. Object: "chat.completion",
  100. Created: common.GetTimestamp(),
  101. Choices: []OpenAITextResponseChoice{choice},
  102. }
  103. return &fullTextResponse
  104. }
  105. func claudeStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, string) {
  106. responseText := ""
  107. responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
  108. createdTime := common.GetTimestamp()
  109. scanner := bufio.NewScanner(resp.Body)
  110. scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
  111. if atEOF && len(data) == 0 {
  112. return 0, nil, nil
  113. }
  114. if i := strings.Index(string(data), "\r\n\r\n"); i >= 0 {
  115. return i + 4, data[0:i], nil
  116. }
  117. if atEOF {
  118. return len(data), data, nil
  119. }
  120. return 0, nil, nil
  121. })
  122. dataChan := make(chan string)
  123. stopChan := make(chan bool)
  124. go func() {
  125. for scanner.Scan() {
  126. data := scanner.Text()
  127. if !strings.HasPrefix(data, "event: completion") {
  128. continue
  129. }
  130. data = strings.TrimPrefix(data, "event: completion\r\ndata: ")
  131. dataChan <- data
  132. }
  133. stopChan <- true
  134. }()
  135. setEventStreamHeaders(c)
  136. c.Stream(func(w io.Writer) bool {
  137. select {
  138. case data := <-dataChan:
  139. // some implementations may add \r at the end of data
  140. data = strings.TrimSuffix(data, "\r")
  141. var claudeResponse ClaudeResponse
  142. err := json.Unmarshal([]byte(data), &claudeResponse)
  143. if err != nil {
  144. common.SysError("error unmarshalling stream response: " + err.Error())
  145. return true
  146. }
  147. responseText += claudeResponse.Completion
  148. response := streamResponseClaude2OpenAI(&claudeResponse)
  149. response.Id = responseId
  150. response.Created = createdTime
  151. jsonStr, err := json.Marshal(response)
  152. if err != nil {
  153. common.SysError("error marshalling stream response: " + err.Error())
  154. return true
  155. }
  156. c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)})
  157. return true
  158. case <-stopChan:
  159. c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
  160. return false
  161. }
  162. })
  163. err := resp.Body.Close()
  164. if err != nil {
  165. return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
  166. }
  167. return nil, responseText
  168. }
  169. func claudeHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) {
  170. responseBody, err := io.ReadAll(resp.Body)
  171. if err != nil {
  172. return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
  173. }
  174. err = resp.Body.Close()
  175. if err != nil {
  176. return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
  177. }
  178. var claudeResponse ClaudeResponse
  179. err = json.Unmarshal(responseBody, &claudeResponse)
  180. if err != nil {
  181. return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
  182. }
  183. if claudeResponse.Error.Type != "" {
  184. return &OpenAIErrorWithStatusCode{
  185. OpenAIError: OpenAIError{
  186. Message: claudeResponse.Error.Message,
  187. Type: claudeResponse.Error.Type,
  188. Param: "",
  189. Code: claudeResponse.Error.Type,
  190. },
  191. StatusCode: resp.StatusCode,
  192. }, nil
  193. }
  194. fullTextResponse := responseClaude2OpenAI(&claudeResponse)
  195. completionTokens := countTokenText(claudeResponse.Completion, model)
  196. usage := Usage{
  197. PromptTokens: promptTokens,
  198. CompletionTokens: completionTokens,
  199. TotalTokens: promptTokens + completionTokens,
  200. }
  201. fullTextResponse.Usage = usage
  202. jsonResponse, err := json.Marshal(fullTextResponse)
  203. if err != nil {
  204. return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
  205. }
  206. c.Writer.Header().Set("Content-Type", "application/json")
  207. c.Writer.WriteHeader(resp.StatusCode)
  208. _, err = c.Writer.Write(jsonResponse)
  209. return nil, &usage
  210. }