Browse Source

feat: support Claude 3

CaIon 2 years ago
parent
commit
4a0af1ea3c

+ 1 - 4
relay/channel/claude/adaptor.go

@@ -10,7 +10,6 @@ import (
 	"one-api/dto"
 	"one-api/relay/channel"
 	relaycommon "one-api/relay/common"
-	"one-api/service"
 	"strings"
 )
 
@@ -68,9 +67,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
 
 func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
 	if info.IsStream {
-		var responseText string
-		err, responseText = claudeStreamHandler(c, resp)
-		usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
+		err, usage = claudeStreamHandler(a.RequestMode, info.UpstreamModelName, info.PromptTokens, c, resp)
 	} else {
 		err, usage = claudeHandler(a.RequestMode, c, resp, info.PromptTokens, info.UpstreamModelName)
 	}

+ 1 - 1
relay/channel/claude/constants.go

@@ -1,7 +1,7 @@
 package claude
 
 var ModelList = []string{
-	"claude-instant-1", "claude-2", "claude-2.0", "claude-2.1",
+	"claude-instant-1", "claude-2", "claude-2.0", "claude-2.1", "claude-3-sonnet-20240229", "claude-3-opus-20240229",
 }
 
 var ChannelName = "claude"

+ 13 - 3
relay/channel/claude/dto.go

@@ -5,9 +5,11 @@ type ClaudeMetadata struct {
 }
 
 type ClaudeMediaMessage struct {
-	Type   string               `json:"type"`
-	Text   string               `json:"text,omitempty"`
-	Source *ClaudeMessageSource `json:"source,omitempty"`
+	Type       string               `json:"type"`
+	Text       string               `json:"text,omitempty"`
+	Source     *ClaudeMessageSource `json:"source,omitempty"`
+	Usage      *ClaudeUsage         `json:"usage,omitempty"`
+	StopReason *string              `json:"stop_reason,omitempty"`
 }
 
 type ClaudeMessageSource struct {
@@ -50,8 +52,16 @@ type ClaudeResponse struct {
 	Model      string               `json:"model"`
 	Error      ClaudeError          `json:"error"`
 	Usage      ClaudeUsage          `json:"usage"`
+	Index      int                  `json:"index"`   // stream only
+	Delta      *ClaudeMediaMessage  `json:"delta"`   // stream only
+	Message    *ClaudeResponse      `json:"message"` // stream only: message_start
 }
 
+//type ClaudeResponseChoice struct {
+//	Index   int                `json:"index"`
+//	Type    string             `json:"type"`
+//}
+
 type ClaudeUsage struct {
 	InputTokens  int `json:"input_tokens"`
 	OutputTokens int `json:"output_tokens"`

+ 65 - 24
relay/channel/claude/relay-claude.go

@@ -17,6 +17,8 @@ func stopReasonClaude2OpenAI(reason string) string {
 	switch reason {
 	case "stop_sequence":
 		return "stop"
+	case "end_turn":
+		return "stop"
 	case "max_tokens":
 		return "length"
 	default:
@@ -111,24 +113,41 @@ func requestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR
 	}
 	claudeRequest.Prompt = ""
 	claudeRequest.Messages = claudeMessages
-	reqJson, _ := json.Marshal(claudeRequest)
-	common.SysLog(fmt.Sprintf("claude request: %s", reqJson))
 
 	return &claudeRequest, nil
 }
 
-func streamResponseClaude2OpenAI(claudeResponse *ClaudeResponse) *dto.ChatCompletionsStreamResponse {
-	var choice dto.ChatCompletionsStreamResponseChoice
-	choice.Delta.Content = claudeResponse.Completion
-	finishReason := stopReasonClaude2OpenAI(claudeResponse.StopReason)
-	if finishReason != "null" {
-		choice.FinishReason = &finishReason
-	}
+func streamResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) (*dto.ChatCompletionsStreamResponse, *ClaudeUsage) {
 	var response dto.ChatCompletionsStreamResponse
+	var claudeUsage *ClaudeUsage
 	response.Object = "chat.completion.chunk"
 	response.Model = claudeResponse.Model
-	response.Choices = []dto.ChatCompletionsStreamResponseChoice{choice}
-	return &response
+	response.Choices = make([]dto.ChatCompletionsStreamResponseChoice, 0)
+	var choice dto.ChatCompletionsStreamResponseChoice
+	if reqMode == RequestModeCompletion {
+		choice.Delta.Content = claudeResponse.Completion
+		finishReason := stopReasonClaude2OpenAI(claudeResponse.StopReason)
+		if finishReason != "null" {
+			choice.FinishReason = &finishReason
+		}
+	} else {
+		if claudeResponse.Type == "message_start" {
+			response.Id = claudeResponse.Message.Id
+			response.Model = claudeResponse.Message.Model
+			claudeUsage = &claudeResponse.Message.Usage
+		} else if claudeResponse.Type == "content_block_delta" {
+			choice.Index = claudeResponse.Index
+			choice.Delta.Content = claudeResponse.Delta.Text
+		} else if claudeResponse.Type == "message_delta" {
+			finishReason := stopReasonClaude2OpenAI(*claudeResponse.Delta.StopReason)
+			if finishReason != "null" {
+				choice.FinishReason = &finishReason
+			}
+			claudeUsage = &claudeResponse.Usage
+		}
+	}
+	response.Choices = append(response.Choices, choice)
+	return &response, claudeUsage
 }
 
 func responseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *dto.OpenAITextResponse {
@@ -170,17 +189,18 @@ func responseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *dto.Ope
 	return &fullTextResponse
 }
 
-func claudeStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, string) {
-	responseText := ""
+func claudeStreamHandler(requestMode int, modelName string, promptTokens int, c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
 	responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
+	var usage dto.Usage
+	responseText := ""
 	createdTime := common.GetTimestamp()
 	scanner := bufio.NewScanner(resp.Body)
 	scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
 		if atEOF && len(data) == 0 {
 			return 0, nil, nil
 		}
-		if i := strings.Index(string(data), "\r\n\r\n"); i >= 0 {
-			return i + 4, data[0:i], nil
+		if i := strings.Index(string(data), "\n"); i >= 0 {
+			return i + 1, data[0:i], nil
 		}
 		if atEOF {
 			return len(data), data, nil
@@ -192,10 +212,10 @@ func claudeStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorW
 	go func() {
 		for scanner.Scan() {
 			data := scanner.Text()
-			if !strings.HasPrefix(data, "event: completion") {
+			if !strings.HasPrefix(data, "data: ") {
 				continue
 			}
-			data = strings.TrimPrefix(data, "event: completion\r\ndata: ")
+			data = strings.TrimPrefix(data, "data: ")
 			dataChan <- data
 		}
 		stopChan <- true
@@ -212,10 +232,31 @@ func claudeStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorW
 				common.SysError("error unmarshalling stream response: " + err.Error())
 				return true
 			}
-			responseText += claudeResponse.Completion
-			response := streamResponseClaude2OpenAI(&claudeResponse)
+
+			response, claudeUsage := streamResponseClaude2OpenAI(requestMode, &claudeResponse)
+			if requestMode == RequestModeCompletion {
+				responseText += claudeResponse.Completion
+				responseId = response.Id
+			} else {
+				if claudeResponse.Type == "message_start" {
+					// message_start, 获取usage
+					responseId = claudeResponse.Message.Id
+					modelName = claudeResponse.Message.Model
+					usage.PromptTokens = claudeUsage.InputTokens
+				} else if claudeResponse.Type == "content_block_delta" {
+					responseText += claudeResponse.Delta.Text
+				} else if claudeResponse.Type == "message_delta" {
+					usage.CompletionTokens = claudeUsage.OutputTokens
+					usage.TotalTokens = claudeUsage.InputTokens + claudeUsage.OutputTokens
+				} else {
+					return true
+				}
+			}
+			//response.Id = responseId
 			response.Id = responseId
 			response.Created = createdTime
+			response.Model = modelName
+
 			jsonStr, err := json.Marshal(response)
 			if err != nil {
 				common.SysError("error marshalling stream response: " + err.Error())
@@ -230,9 +271,12 @@ func claudeStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorW
 	})
 	err := resp.Body.Close()
 	if err != nil {
-		return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
+		return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
 	}
-	return nil, responseText
+	if requestMode == RequestModeCompletion {
+		usage = *service.ResponseText2Usage(responseText, modelName, promptTokens)
+	}
+	return nil, &usage
 }
 
 func claudeHandler(requestMode int, c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
@@ -245,13 +289,10 @@ func claudeHandler(requestMode int, c *gin.Context, resp *http.Response, promptT
 		return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
 	}
 	var claudeResponse ClaudeResponse
-	common.SysLog(fmt.Sprintf("claude response: %s", responseBody))
 	err = json.Unmarshal(responseBody, &claudeResponse)
 	if err != nil {
 		return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
 	}
-	respJson, _ := json.Marshal(claudeResponse)
-	common.SysLog(fmt.Sprintf("claude response json: %s", respJson))
 	if claudeResponse.Error.Type != "" {
 		return &dto.OpenAIErrorWithStatusCode{
 			Error: dto.OpenAIError{

+ 1 - 1
web/src/pages/Channel/EditChannel.js

@@ -63,7 +63,7 @@ const EditChannel = (props) => {
             let localModels = [];
             switch (value) {
                 case 14:
-                    localModels = ['claude-instant-1', 'claude-2'];
+                    localModels = ["claude-instant-1", "claude-2", "claude-2.0", "claude-2.1", "claude-3-sonnet-20240229", "claude-3-opus-20240229"];
                     break;
                 case 11:
                     localModels = ['PaLM-2'];