Просмотр исходного кода

feat: support claude3 not stream

CaIon 2 лет назад
Родитель
Сommit
c2965eb835

+ 4 - 3
common/image.go

@@ -12,7 +12,7 @@ import (
 	"strings"
 )
 
-func DecodeBase64ImageData(base64String string) (image.Config, string, error) {
+func DecodeBase64ImageData(base64String string) (image.Config, string, string, error) {
 	// 去除base64数据的URL前缀(如果有)
 	if idx := strings.Index(base64String, ","); idx != -1 {
 		base64String = base64String[idx+1:]
@@ -22,13 +22,13 @@ func DecodeBase64ImageData(base64String string) (image.Config, string, error) {
 	decodedData, err := base64.StdEncoding.DecodeString(base64String)
 	if err != nil {
 		fmt.Println("Error: Failed to decode base64 string")
-		return image.Config{}, "", err
+		return image.Config{}, "", "", err
 	}
 
 	// 创建一个bytes.Buffer用于存储解码后的数据
 	reader := bytes.NewReader(decodedData)
 	config, format, err := getImageConfig(reader)
-	return config, format, err
+	return config, format, base64String, err
 }
 
 func IsImageUrl(url string) (bool, error) {
@@ -42,6 +42,7 @@ func IsImageUrl(url string) (bool, error) {
 	return true, nil
 }
 
+// GetImageFromUrl 获取图片的类型和base64编码的数据
 func GetImageFromUrl(url string) (mimeType string, data string, err error) {
 	isImage, err := IsImageUrl(url)
 	if !isImage {

+ 8 - 6
dto/text_request.go

@@ -82,6 +82,14 @@ func (m Message) StringContent() string {
 	return string(m.Content)
 }
 
+func (m Message) IsStringContent() bool {
+	var stringContent string
+	if err := json.Unmarshal(m.Content, &stringContent); err == nil {
+		return true
+	}
+	return false
+}
+
 func (m Message) ParseContent() []MediaMessage {
 	var contentList []MediaMessage
 	var stringContent string
@@ -130,9 +138,3 @@ func (m Message) ParseContent() []MediaMessage {
 
 	return nil
 }
-
-type Usage struct {
-	PromptTokens     int `json:"prompt_tokens"`
-	CompletionTokens int `json:"completion_tokens"`
-	TotalTokens      int `json:"total_tokens"`
-}

+ 6 - 0
dto/text_response.go

@@ -61,3 +61,9 @@ type CompletionsStreamResponse struct {
 		FinishReason string `json:"finish_reason"`
 	} `json:"choices"`
 }
+
+type Usage struct {
+	PromptTokens     int `json:"prompt_tokens"`
+	CompletionTokens int `json:"completion_tokens"`
+	TotalTokens      int `json:"total_tokens"`
+}

+ 8 - 7
relay/channel/claude/adaptor.go

@@ -6,6 +6,7 @@ import (
 	"github.com/gin-gonic/gin"
 	"io"
 	"net/http"
+	"one-api/common"
 	"one-api/dto"
 	"one-api/relay/channel"
 	relaycommon "one-api/relay/common"
@@ -50,15 +51,15 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re
 }
 
 func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
+	common.SysLog(fmt.Sprintf("Request mode: %d", a.RequestMode))
 	if request == nil {
 		return nil, errors.New("request is nil")
 	}
-	//if a.RequestMode == RequestModeCompletion {
-	//	return requestOpenAI2ClaudeComplete(*request), nil
-	//} else {
-	//	return requestOpenAI2ClaudeMessage(*request), nil
-	//}
-	return request, nil
+	if a.RequestMode == RequestModeCompletion {
+		return requestOpenAI2ClaudeComplete(*request), nil
+	} else {
+		return requestOpenAI2ClaudeMessage(*request)
+	}
 }
 
 func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
@@ -71,7 +72,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
 		err, responseText = claudeStreamHandler(c, resp)
 		usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
 	} else {
-		err, usage = claudeHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
+		err, usage = claudeHandler(a.RequestMode, c, resp, info.PromptTokens, info.UpstreamModelName)
 	}
 	return
 }

+ 40 - 11
relay/channel/claude/dto.go

@@ -4,14 +4,34 @@ type ClaudeMetadata struct {
 	UserId string `json:"user_id"`
 }
 
+type ClaudeMediaMessage struct {
+	Type   string               `json:"type"`
+	Text   string               `json:"text,omitempty"`
+	Source *ClaudeMessageSource `json:"source,omitempty"`
+}
+
+type ClaudeMessageSource struct {
+	Type      string `json:"type"`
+	MediaType string `json:"media_type"`
+	Data      string `json:"data"`
+}
+
+type ClaudeMessage struct {
+	Role    string `json:"role"`
+	Content any    `json:"content"`
+}
+
 type ClaudeRequest struct {
-	Model             string   `json:"model"`
-	Prompt            string   `json:"prompt"`
-	MaxTokensToSample uint     `json:"max_tokens_to_sample"`
-	StopSequences     []string `json:"stop_sequences,omitempty"`
-	Temperature       float64  `json:"temperature,omitempty"`
-	TopP              float64  `json:"top_p,omitempty"`
-	TopK              int      `json:"top_k,omitempty"`
+	Model             string          `json:"model"`
+	Prompt            string          `json:"prompt,omitempty"`
+	System            string          `json:"system,omitempty"`
+	Messages          []ClaudeMessage `json:"messages,omitempty"`
+	MaxTokensToSample uint            `json:"max_tokens_to_sample,omitempty"`
+	MaxTokens         uint            `json:"max_tokens,omitempty"`
+	StopSequences     []string        `json:"stop_sequences,omitempty"`
+	Temperature       float64         `json:"temperature,omitempty"`
+	TopP              float64         `json:"top_p,omitempty"`
+	TopK              int             `json:"top_k,omitempty"`
 	//ClaudeMetadata    `json:"metadata,omitempty"`
 	Stream bool `json:"stream,omitempty"`
 }
@@ -22,8 +42,17 @@ type ClaudeError struct {
 }
 
 type ClaudeResponse struct {
-	Completion string      `json:"completion"`
-	StopReason string      `json:"stop_reason"`
-	Model      string      `json:"model"`
-	Error      ClaudeError `json:"error"`
+	Id         string               `json:"id"`
+	Type       string               `json:"type"`
+	Content    []ClaudeMediaMessage `json:"content"`
+	Completion string               `json:"completion"`
+	StopReason string               `json:"stop_reason"`
+	Model      string               `json:"model"`
+	Error      ClaudeError          `json:"error"`
+	Usage      ClaudeUsage          `json:"usage"`
+}
+
+type ClaudeUsage struct {
+	InputTokens  int `json:"input_tokens"`
+	OutputTokens int `json:"output_tokens"`
 }

+ 107 - 21
relay/channel/claude/relay-claude.go

@@ -54,9 +54,68 @@ func requestOpenAI2ClaudeComplete(textRequest dto.GeneralOpenAIRequest) *ClaudeR
 	return &claudeRequest
 }
 
-//func requestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) *dto.GeneralOpenAIRequest {
-//
-//}
+func requestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeRequest, error) {
+	claudeRequest := ClaudeRequest{
+		Model:         textRequest.Model,
+		MaxTokens:     textRequest.MaxTokens,
+		StopSequences: nil,
+		Temperature:   textRequest.Temperature,
+		TopP:          textRequest.TopP,
+		Stream:        textRequest.Stream,
+	}
+	claudeMessages := make([]ClaudeMessage, 0)
+	for _, message := range textRequest.Messages {
+		if message.Role == "system" {
+			claudeRequest.System = message.StringContent()
+		} else {
+			claudeMessage := ClaudeMessage{
+				Role: message.Role,
+			}
+			if message.IsStringContent() {
+				claudeMessage.Content = message.StringContent()
+			} else {
+				claudeMediaMessages := make([]ClaudeMediaMessage, 0)
+				for _, mediaMessage := range message.ParseContent() {
+					claudeMediaMessage := ClaudeMediaMessage{
+						Type: mediaMessage.Type,
+					}
+					if mediaMessage.Type == "text" {
+						claudeMediaMessage.Text = mediaMessage.Text
+					} else {
+						imageUrl := mediaMessage.ImageUrl.(dto.MessageImageUrl)
+						claudeMediaMessage.Type = "image"
+						claudeMediaMessage.Source = &ClaudeMessageSource{
+							Type: "base64",
+						}
+						// 判断是否是url
+						if strings.HasPrefix(imageUrl.Url, "http") {
+							// 是url,获取图片的类型和base64编码的数据
+							mimeType, data, _ := common.GetImageFromUrl(imageUrl.Url)
+							claudeMediaMessage.Source.MediaType = mimeType
+							claudeMediaMessage.Source.Data = data
+						} else {
+							_, format, base64String, err := common.DecodeBase64ImageData(imageUrl.Url)
+							if err != nil {
+								return nil, err
+							}
+							claudeMediaMessage.Source.MediaType = "image/" + format
+							claudeMediaMessage.Source.Data = base64String
+						}
+					}
+					claudeMediaMessages = append(claudeMediaMessages, claudeMediaMessage)
+				}
+				claudeMessage.Content = claudeMediaMessages
+			}
+			claudeMessages = append(claudeMessages, claudeMessage)
+		}
+	}
+	claudeRequest.Prompt = ""
+	claudeRequest.Messages = claudeMessages
+	reqJson, _ := json.Marshal(claudeRequest)
+	common.SysLog(fmt.Sprintf("claude request: %s", reqJson))
+
+	return &claudeRequest, nil
+}
 
 func streamResponseClaude2OpenAI(claudeResponse *ClaudeResponse) *dto.ChatCompletionsStreamResponse {
 	var choice dto.ChatCompletionsStreamResponseChoice
@@ -72,23 +131,42 @@ func streamResponseClaude2OpenAI(claudeResponse *ClaudeResponse) *dto.ChatComple
 	return &response
 }
 
-func responseClaude2OpenAI(claudeResponse *ClaudeResponse) *dto.OpenAITextResponse {
-	content, _ := json.Marshal(strings.TrimPrefix(claudeResponse.Completion, " "))
-	choice := dto.OpenAITextResponseChoice{
-		Index: 0,
-		Message: dto.Message{
-			Role:    "assistant",
-			Content: content,
-			Name:    nil,
-		},
-		FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason),
-	}
+func responseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *dto.OpenAITextResponse {
+	choices := make([]dto.OpenAITextResponseChoice, 0)
 	fullTextResponse := dto.OpenAITextResponse{
 		Id:      fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
 		Object:  "chat.completion",
 		Created: common.GetTimestamp(),
-		Choices: []dto.OpenAITextResponseChoice{choice},
 	}
+	if reqMode == RequestModeCompletion {
+		content, _ := json.Marshal(strings.TrimPrefix(claudeResponse.Completion, " "))
+		choice := dto.OpenAITextResponseChoice{
+			Index: 0,
+			Message: dto.Message{
+				Role:    "assistant",
+				Content: content,
+				Name:    nil,
+			},
+			FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason),
+		}
+		choices = append(choices, choice)
+	} else {
+		fullTextResponse.Id = claudeResponse.Id
+		for i, message := range claudeResponse.Content {
+			content, _ := json.Marshal(message.Text)
+			choice := dto.OpenAITextResponseChoice{
+				Index: i,
+				Message: dto.Message{
+					Role:    "assistant",
+					Content: content,
+				},
+				FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason),
+			}
+			choices = append(choices, choice)
+		}
+	}
+
+	fullTextResponse.Choices = choices
 	return &fullTextResponse
 }
 
@@ -157,7 +235,7 @@ func claudeStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorW
 	return nil, responseText
 }
 
-func claudeHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
+func claudeHandler(requestMode int, c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
 	responseBody, err := io.ReadAll(resp.Body)
 	if err != nil {
 		return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
@@ -167,10 +245,13 @@ func claudeHandler(c *gin.Context, resp *http.Response, promptTokens int, model
 		return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
 	}
 	var claudeResponse ClaudeResponse
+	common.SysLog(fmt.Sprintf("claude response: %s", responseBody))
 	err = json.Unmarshal(responseBody, &claudeResponse)
 	if err != nil {
 		return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
 	}
+	respJson, _ := json.Marshal(claudeResponse)
+	common.SysLog(fmt.Sprintf("claude response json: %s", respJson))
 	if claudeResponse.Error.Type != "" {
 		return &dto.OpenAIErrorWithStatusCode{
 			Error: dto.OpenAIError{
@@ -182,12 +263,17 @@ func claudeHandler(c *gin.Context, resp *http.Response, promptTokens int, model
 			StatusCode: resp.StatusCode,
 		}, nil
 	}
-	fullTextResponse := responseClaude2OpenAI(&claudeResponse)
+	fullTextResponse := responseClaude2OpenAI(requestMode, &claudeResponse)
 	completionTokens := service.CountTokenText(claudeResponse.Completion, model)
-	usage := dto.Usage{
-		PromptTokens:     promptTokens,
-		CompletionTokens: completionTokens,
-		TotalTokens:      promptTokens + completionTokens,
+	usage := dto.Usage{}
+	if requestMode == RequestModeCompletion {
+		usage.PromptTokens = promptTokens
+		usage.CompletionTokens = completionTokens
+		usage.TotalTokens = promptTokens + completionTokens
+	} else {
+		usage.PromptTokens = claudeResponse.Usage.InputTokens
+		usage.CompletionTokens = claudeResponse.Usage.OutputTokens
+		usage.TotalTokens = claudeResponse.Usage.InputTokens + claudeResponse.Usage.OutputTokens
 	}
 	fullTextResponse.Usage = usage
 	jsonResponse, err := json.Marshal(fullTextResponse)

+ 1 - 1
service/token_counter.go

@@ -74,7 +74,7 @@ func getImageToken(imageUrl *dto.MessageImageUrl) (int, error) {
 		config, format, err = common.DecodeUrlImageData(imageUrl.Url)
 	} else {
 		common.SysLog(fmt.Sprintf("decoding image"))
-		config, format, err = common.DecodeBase64ImageData(imageUrl.Url)
+		config, format, _, err = common.DecodeBase64ImageData(imageUrl.Url)
 	}
 	if err != nil {
 		return 0, err