Browse Source

Merge pull request #103 from Calcium-Ion/dev

feat: support Claude 3
Calcium-Ion 2 years ago
parent
commit
eca48268b2

+ 4 - 3
common/image.go

@@ -12,7 +12,7 @@ import (
 	"strings"
 	"strings"
 )
 )
 
 
-func DecodeBase64ImageData(base64String string) (image.Config, string, error) {
+func DecodeBase64ImageData(base64String string) (image.Config, string, string, error) {
 	// 去除base64数据的URL前缀(如果有)
 	// 去除base64数据的URL前缀(如果有)
 	if idx := strings.Index(base64String, ","); idx != -1 {
 	if idx := strings.Index(base64String, ","); idx != -1 {
 		base64String = base64String[idx+1:]
 		base64String = base64String[idx+1:]
@@ -22,13 +22,13 @@ func DecodeBase64ImageData(base64String string) (image.Config, string, error) {
 	decodedData, err := base64.StdEncoding.DecodeString(base64String)
 	decodedData, err := base64.StdEncoding.DecodeString(base64String)
 	if err != nil {
 	if err != nil {
 		fmt.Println("Error: Failed to decode base64 string")
 		fmt.Println("Error: Failed to decode base64 string")
-		return image.Config{}, "", err
+		return image.Config{}, "", "", err
 	}
 	}
 
 
 	// 创建一个bytes.Buffer用于存储解码后的数据
 	// 创建一个bytes.Buffer用于存储解码后的数据
 	reader := bytes.NewReader(decodedData)
 	reader := bytes.NewReader(decodedData)
 	config, format, err := getImageConfig(reader)
 	config, format, err := getImageConfig(reader)
-	return config, format, err
+	return config, format, base64String, err
 }
 }
 
 
 func IsImageUrl(url string) (bool, error) {
 func IsImageUrl(url string) (bool, error) {
@@ -42,6 +42,7 @@ func IsImageUrl(url string) (bool, error) {
 	return true, nil
 	return true, nil
 }
 }
 
 
+// GetImageFromUrl 获取图片的类型和base64编码的数据
 func GetImageFromUrl(url string) (mimeType string, data string, err error) {
 func GetImageFromUrl(url string) (mimeType string, data string, err error) {
 	isImage, err := IsImageUrl(url)
 	isImage, err := IsImageUrl(url)
 	if !isImage {
 	if !isImage {

+ 8 - 6
dto/text_request.go

@@ -82,6 +82,14 @@ func (m Message) StringContent() string {
 	return string(m.Content)
 	return string(m.Content)
 }
 }
 
 
+func (m Message) IsStringContent() bool {
+	var stringContent string
+	if err := json.Unmarshal(m.Content, &stringContent); err == nil {
+		return true
+	}
+	return false
+}
+
 func (m Message) ParseContent() []MediaMessage {
 func (m Message) ParseContent() []MediaMessage {
 	var contentList []MediaMessage
 	var contentList []MediaMessage
 	var stringContent string
 	var stringContent string
@@ -130,9 +138,3 @@ func (m Message) ParseContent() []MediaMessage {
 
 
 	return nil
 	return nil
 }
 }
-
-type Usage struct {
-	PromptTokens     int `json:"prompt_tokens"`
-	CompletionTokens int `json:"completion_tokens"`
-	TotalTokens      int `json:"total_tokens"`
-}

+ 6 - 0
dto/text_response.go

@@ -61,3 +61,9 @@ type CompletionsStreamResponse struct {
 		FinishReason string `json:"finish_reason"`
 		FinishReason string `json:"finish_reason"`
 	} `json:"choices"`
 	} `json:"choices"`
 }
 }
+
+type Usage struct {
+	PromptTokens     int `json:"prompt_tokens"`
+	CompletionTokens int `json:"completion_tokens"`
+	TotalTokens      int `json:"total_tokens"`
+}

+ 9 - 11
relay/channel/claude/adaptor.go

@@ -6,10 +6,10 @@ import (
 	"github.com/gin-gonic/gin"
 	"github.com/gin-gonic/gin"
 	"io"
 	"io"
 	"net/http"
 	"net/http"
+	"one-api/common"
 	"one-api/dto"
 	"one-api/dto"
 	"one-api/relay/channel"
 	"one-api/relay/channel"
 	relaycommon "one-api/relay/common"
 	relaycommon "one-api/relay/common"
-	"one-api/service"
 	"strings"
 	"strings"
 )
 )
 
 
@@ -50,15 +50,15 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re
 }
 }
 
 
 func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
 func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
+	common.SysLog(fmt.Sprintf("Request mode: %d", a.RequestMode))
 	if request == nil {
 	if request == nil {
 		return nil, errors.New("request is nil")
 		return nil, errors.New("request is nil")
 	}
 	}
-	//if a.RequestMode == RequestModeCompletion {
-	//	return requestOpenAI2ClaudeComplete(*request), nil
-	//} else {
-	//	return requestOpenAI2ClaudeMessage(*request), nil
-	//}
-	return request, nil
+	if a.RequestMode == RequestModeCompletion {
+		return requestOpenAI2ClaudeComplete(*request), nil
+	} else {
+		return requestOpenAI2ClaudeMessage(*request)
+	}
 }
 }
 
 
 func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
 func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
@@ -67,11 +67,9 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
 
 
 func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
 func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
 	if info.IsStream {
 	if info.IsStream {
-		var responseText string
-		err, responseText = claudeStreamHandler(c, resp)
-		usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
+		err, usage = claudeStreamHandler(a.RequestMode, info.UpstreamModelName, info.PromptTokens, c, resp)
 	} else {
 	} else {
-		err, usage = claudeHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
+		err, usage = claudeHandler(a.RequestMode, c, resp, info.PromptTokens, info.UpstreamModelName)
 	}
 	}
 	return
 	return
 }
 }

+ 1 - 1
relay/channel/claude/constants.go

@@ -1,7 +1,7 @@
 package claude
 package claude
 
 
 var ModelList = []string{
 var ModelList = []string{
-	"claude-instant-1", "claude-2", "claude-2.0", "claude-2.1",
+	"claude-instant-1", "claude-2", "claude-2.0", "claude-2.1", "claude-3-sonnet-20240229", "claude-3-opus-20240229",
 }
 }
 
 
 var ChannelName = "claude"
 var ChannelName = "claude"

+ 50 - 11
relay/channel/claude/dto.go

@@ -4,14 +4,36 @@ type ClaudeMetadata struct {
 	UserId string `json:"user_id"`
 	UserId string `json:"user_id"`
 }
 }
 
 
+type ClaudeMediaMessage struct {
+	Type       string               `json:"type"`
+	Text       string               `json:"text,omitempty"`
+	Source     *ClaudeMessageSource `json:"source,omitempty"`
+	Usage      *ClaudeUsage         `json:"usage,omitempty"`
+	StopReason *string              `json:"stop_reason,omitempty"`
+}
+
+type ClaudeMessageSource struct {
+	Type      string `json:"type"`
+	MediaType string `json:"media_type"`
+	Data      string `json:"data"`
+}
+
+type ClaudeMessage struct {
+	Role    string `json:"role"`
+	Content any    `json:"content"`
+}
+
 type ClaudeRequest struct {
 type ClaudeRequest struct {
-	Model             string   `json:"model"`
-	Prompt            string   `json:"prompt"`
-	MaxTokensToSample uint     `json:"max_tokens_to_sample"`
-	StopSequences     []string `json:"stop_sequences,omitempty"`
-	Temperature       float64  `json:"temperature,omitempty"`
-	TopP              float64  `json:"top_p,omitempty"`
-	TopK              int      `json:"top_k,omitempty"`
+	Model             string          `json:"model"`
+	Prompt            string          `json:"prompt,omitempty"`
+	System            string          `json:"system,omitempty"`
+	Messages          []ClaudeMessage `json:"messages,omitempty"`
+	MaxTokensToSample uint            `json:"max_tokens_to_sample,omitempty"`
+	MaxTokens         uint            `json:"max_tokens,omitempty"`
+	StopSequences     []string        `json:"stop_sequences,omitempty"`
+	Temperature       float64         `json:"temperature,omitempty"`
+	TopP              float64         `json:"top_p,omitempty"`
+	TopK              int             `json:"top_k,omitempty"`
 	//ClaudeMetadata    `json:"metadata,omitempty"`
 	//ClaudeMetadata    `json:"metadata,omitempty"`
 	Stream bool `json:"stream,omitempty"`
 	Stream bool `json:"stream,omitempty"`
 }
 }
@@ -22,8 +44,25 @@ type ClaudeError struct {
 }
 }
 
 
 type ClaudeResponse struct {
 type ClaudeResponse struct {
-	Completion string      `json:"completion"`
-	StopReason string      `json:"stop_reason"`
-	Model      string      `json:"model"`
-	Error      ClaudeError `json:"error"`
+	Id         string               `json:"id"`
+	Type       string               `json:"type"`
+	Content    []ClaudeMediaMessage `json:"content"`
+	Completion string               `json:"completion"`
+	StopReason string               `json:"stop_reason"`
+	Model      string               `json:"model"`
+	Error      ClaudeError          `json:"error"`
+	Usage      ClaudeUsage          `json:"usage"`
+	Index      int                  `json:"index"`   // stream only
+	Delta      *ClaudeMediaMessage  `json:"delta"`   // stream only
+	Message    *ClaudeResponse      `json:"message"` // stream only: message_start
+}
+
+//type ClaudeResponseChoice struct {
+//	Index   int                `json:"index"`
+//	Type    string             `json:"type"`
+//}
+
+type ClaudeUsage struct {
+	InputTokens  int `json:"input_tokens"`
+	OutputTokens int `json:"output_tokens"`
 }
 }

+ 167 - 40
relay/channel/claude/relay-claude.go

@@ -17,6 +17,8 @@ func stopReasonClaude2OpenAI(reason string) string {
 	switch reason {
 	switch reason {
 	case "stop_sequence":
 	case "stop_sequence":
 		return "stop"
 		return "stop"
+	case "end_turn":
+		return "stop"
 	case "max_tokens":
 	case "max_tokens":
 		return "length"
 		return "length"
 	default:
 	default:
@@ -54,55 +56,151 @@ func requestOpenAI2ClaudeComplete(textRequest dto.GeneralOpenAIRequest) *ClaudeR
 	return &claudeRequest
 	return &claudeRequest
 }
 }
 
 
-//func requestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) *dto.GeneralOpenAIRequest {
-//
-//}
-
-func streamResponseClaude2OpenAI(claudeResponse *ClaudeResponse) *dto.ChatCompletionsStreamResponse {
-	var choice dto.ChatCompletionsStreamResponseChoice
-	choice.Delta.Content = claudeResponse.Completion
-	finishReason := stopReasonClaude2OpenAI(claudeResponse.StopReason)
-	if finishReason != "null" {
-		choice.FinishReason = &finishReason
+func requestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeRequest, error) {
+	claudeRequest := ClaudeRequest{
+		Model:         textRequest.Model,
+		MaxTokens:     textRequest.MaxTokens,
+		StopSequences: nil,
+		Temperature:   textRequest.Temperature,
+		TopP:          textRequest.TopP,
+		Stream:        textRequest.Stream,
+	}
+	claudeMessages := make([]ClaudeMessage, 0)
+	for _, message := range textRequest.Messages {
+		if message.Role == "system" {
+			claudeRequest.System = message.StringContent()
+		} else {
+			claudeMessage := ClaudeMessage{
+				Role: message.Role,
+			}
+			if message.IsStringContent() {
+				claudeMessage.Content = message.StringContent()
+			} else {
+				claudeMediaMessages := make([]ClaudeMediaMessage, 0)
+				for _, mediaMessage := range message.ParseContent() {
+					claudeMediaMessage := ClaudeMediaMessage{
+						Type: mediaMessage.Type,
+					}
+					if mediaMessage.Type == "text" {
+						claudeMediaMessage.Text = mediaMessage.Text
+					} else {
+						imageUrl := mediaMessage.ImageUrl.(dto.MessageImageUrl)
+						claudeMediaMessage.Type = "image"
+						claudeMediaMessage.Source = &ClaudeMessageSource{
+							Type: "base64",
+						}
+						// 判断是否是url
+						if strings.HasPrefix(imageUrl.Url, "http") {
+							// 是url,获取图片的类型和base64编码的数据
+							mimeType, data, _ := common.GetImageFromUrl(imageUrl.Url)
+							claudeMediaMessage.Source.MediaType = mimeType
+							claudeMediaMessage.Source.Data = data
+						} else {
+							_, format, base64String, err := common.DecodeBase64ImageData(imageUrl.Url)
+							if err != nil {
+								return nil, err
+							}
+							claudeMediaMessage.Source.MediaType = "image/" + format
+							claudeMediaMessage.Source.Data = base64String
+						}
+					}
+					claudeMediaMessages = append(claudeMediaMessages, claudeMediaMessage)
+				}
+				claudeMessage.Content = claudeMediaMessages
+			}
+			claudeMessages = append(claudeMessages, claudeMessage)
+		}
 	}
 	}
+	claudeRequest.Prompt = ""
+	claudeRequest.Messages = claudeMessages
+
+	return &claudeRequest, nil
+}
+
+func streamResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) (*dto.ChatCompletionsStreamResponse, *ClaudeUsage) {
 	var response dto.ChatCompletionsStreamResponse
 	var response dto.ChatCompletionsStreamResponse
+	var claudeUsage *ClaudeUsage
 	response.Object = "chat.completion.chunk"
 	response.Object = "chat.completion.chunk"
 	response.Model = claudeResponse.Model
 	response.Model = claudeResponse.Model
-	response.Choices = []dto.ChatCompletionsStreamResponseChoice{choice}
-	return &response
+	response.Choices = make([]dto.ChatCompletionsStreamResponseChoice, 0)
+	var choice dto.ChatCompletionsStreamResponseChoice
+	if reqMode == RequestModeCompletion {
+		choice.Delta.Content = claudeResponse.Completion
+		finishReason := stopReasonClaude2OpenAI(claudeResponse.StopReason)
+		if finishReason != "null" {
+			choice.FinishReason = &finishReason
+		}
+	} else {
+		if claudeResponse.Type == "message_start" {
+			response.Id = claudeResponse.Message.Id
+			response.Model = claudeResponse.Message.Model
+			claudeUsage = &claudeResponse.Message.Usage
+		} else if claudeResponse.Type == "content_block_delta" {
+			choice.Index = claudeResponse.Index
+			choice.Delta.Content = claudeResponse.Delta.Text
+		} else if claudeResponse.Type == "message_delta" {
+			finishReason := stopReasonClaude2OpenAI(*claudeResponse.Delta.StopReason)
+			if finishReason != "null" {
+				choice.FinishReason = &finishReason
+			}
+			claudeUsage = &claudeResponse.Usage
+		}
+	}
+	response.Choices = append(response.Choices, choice)
+	return &response, claudeUsage
 }
 }
 
 
-func responseClaude2OpenAI(claudeResponse *ClaudeResponse) *dto.OpenAITextResponse {
-	content, _ := json.Marshal(strings.TrimPrefix(claudeResponse.Completion, " "))
-	choice := dto.OpenAITextResponseChoice{
-		Index: 0,
-		Message: dto.Message{
-			Role:    "assistant",
-			Content: content,
-			Name:    nil,
-		},
-		FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason),
-	}
+func responseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *dto.OpenAITextResponse {
+	choices := make([]dto.OpenAITextResponseChoice, 0)
 	fullTextResponse := dto.OpenAITextResponse{
 	fullTextResponse := dto.OpenAITextResponse{
 		Id:      fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
 		Id:      fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
 		Object:  "chat.completion",
 		Object:  "chat.completion",
 		Created: common.GetTimestamp(),
 		Created: common.GetTimestamp(),
-		Choices: []dto.OpenAITextResponseChoice{choice},
 	}
 	}
+	if reqMode == RequestModeCompletion {
+		content, _ := json.Marshal(strings.TrimPrefix(claudeResponse.Completion, " "))
+		choice := dto.OpenAITextResponseChoice{
+			Index: 0,
+			Message: dto.Message{
+				Role:    "assistant",
+				Content: content,
+				Name:    nil,
+			},
+			FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason),
+		}
+		choices = append(choices, choice)
+	} else {
+		fullTextResponse.Id = claudeResponse.Id
+		for i, message := range claudeResponse.Content {
+			content, _ := json.Marshal(message.Text)
+			choice := dto.OpenAITextResponseChoice{
+				Index: i,
+				Message: dto.Message{
+					Role:    "assistant",
+					Content: content,
+				},
+				FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason),
+			}
+			choices = append(choices, choice)
+		}
+	}
+
+	fullTextResponse.Choices = choices
 	return &fullTextResponse
 	return &fullTextResponse
 }
 }
 
 
-func claudeStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, string) {
-	responseText := ""
+func claudeStreamHandler(requestMode int, modelName string, promptTokens int, c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
 	responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
 	responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
+	var usage dto.Usage
+	responseText := ""
 	createdTime := common.GetTimestamp()
 	createdTime := common.GetTimestamp()
 	scanner := bufio.NewScanner(resp.Body)
 	scanner := bufio.NewScanner(resp.Body)
 	scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
 	scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
 		if atEOF && len(data) == 0 {
 		if atEOF && len(data) == 0 {
 			return 0, nil, nil
 			return 0, nil, nil
 		}
 		}
-		if i := strings.Index(string(data), "\r\n\r\n"); i >= 0 {
-			return i + 4, data[0:i], nil
+		if i := strings.Index(string(data), "\n"); i >= 0 {
+			return i + 1, data[0:i], nil
 		}
 		}
 		if atEOF {
 		if atEOF {
 			return len(data), data, nil
 			return len(data), data, nil
@@ -114,10 +212,10 @@ func claudeStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorW
 	go func() {
 	go func() {
 		for scanner.Scan() {
 		for scanner.Scan() {
 			data := scanner.Text()
 			data := scanner.Text()
-			if !strings.HasPrefix(data, "event: completion") {
+			if !strings.HasPrefix(data, "data: ") {
 				continue
 				continue
 			}
 			}
-			data = strings.TrimPrefix(data, "event: completion\r\ndata: ")
+			data = strings.TrimPrefix(data, "data: ")
 			dataChan <- data
 			dataChan <- data
 		}
 		}
 		stopChan <- true
 		stopChan <- true
@@ -134,10 +232,31 @@ func claudeStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorW
 				common.SysError("error unmarshalling stream response: " + err.Error())
 				common.SysError("error unmarshalling stream response: " + err.Error())
 				return true
 				return true
 			}
 			}
-			responseText += claudeResponse.Completion
-			response := streamResponseClaude2OpenAI(&claudeResponse)
+
+			response, claudeUsage := streamResponseClaude2OpenAI(requestMode, &claudeResponse)
+			if requestMode == RequestModeCompletion {
+				responseText += claudeResponse.Completion
+				responseId = response.Id
+			} else {
+				if claudeResponse.Type == "message_start" {
+					// message_start, 获取usage
+					responseId = claudeResponse.Message.Id
+					modelName = claudeResponse.Message.Model
+					usage.PromptTokens = claudeUsage.InputTokens
+				} else if claudeResponse.Type == "content_block_delta" {
+					responseText += claudeResponse.Delta.Text
+				} else if claudeResponse.Type == "message_delta" {
+					usage.CompletionTokens = claudeUsage.OutputTokens
+					usage.TotalTokens = claudeUsage.InputTokens + claudeUsage.OutputTokens
+				} else {
+					return true
+				}
+			}
+			//response.Id = responseId
 			response.Id = responseId
 			response.Id = responseId
 			response.Created = createdTime
 			response.Created = createdTime
+			response.Model = modelName
+
 			jsonStr, err := json.Marshal(response)
 			jsonStr, err := json.Marshal(response)
 			if err != nil {
 			if err != nil {
 				common.SysError("error marshalling stream response: " + err.Error())
 				common.SysError("error marshalling stream response: " + err.Error())
@@ -152,12 +271,15 @@ func claudeStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorW
 	})
 	})
 	err := resp.Body.Close()
 	err := resp.Body.Close()
 	if err != nil {
 	if err != nil {
-		return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
+		return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
 	}
 	}
-	return nil, responseText
+	if requestMode == RequestModeCompletion {
+		usage = *service.ResponseText2Usage(responseText, modelName, promptTokens)
+	}
+	return nil, &usage
 }
 }
 
 
-func claudeHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
+func claudeHandler(requestMode int, c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
 	responseBody, err := io.ReadAll(resp.Body)
 	responseBody, err := io.ReadAll(resp.Body)
 	if err != nil {
 	if err != nil {
 		return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
 		return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
@@ -182,12 +304,17 @@ func claudeHandler(c *gin.Context, resp *http.Response, promptTokens int, model
 			StatusCode: resp.StatusCode,
 			StatusCode: resp.StatusCode,
 		}, nil
 		}, nil
 	}
 	}
-	fullTextResponse := responseClaude2OpenAI(&claudeResponse)
+	fullTextResponse := responseClaude2OpenAI(requestMode, &claudeResponse)
 	completionTokens := service.CountTokenText(claudeResponse.Completion, model)
 	completionTokens := service.CountTokenText(claudeResponse.Completion, model)
-	usage := dto.Usage{
-		PromptTokens:     promptTokens,
-		CompletionTokens: completionTokens,
-		TotalTokens:      promptTokens + completionTokens,
+	usage := dto.Usage{}
+	if requestMode == RequestModeCompletion {
+		usage.PromptTokens = promptTokens
+		usage.CompletionTokens = completionTokens
+		usage.TotalTokens = promptTokens + completionTokens
+	} else {
+		usage.PromptTokens = claudeResponse.Usage.InputTokens
+		usage.CompletionTokens = claudeResponse.Usage.OutputTokens
+		usage.TotalTokens = claudeResponse.Usage.InputTokens + claudeResponse.Usage.OutputTokens
 	}
 	}
 	fullTextResponse.Usage = usage
 	fullTextResponse.Usage = usage
 	jsonResponse, err := json.Marshal(fullTextResponse)
 	jsonResponse, err := json.Marshal(fullTextResponse)

+ 1 - 1
service/token_counter.go

@@ -74,7 +74,7 @@ func getImageToken(imageUrl *dto.MessageImageUrl) (int, error) {
 		config, format, err = common.DecodeUrlImageData(imageUrl.Url)
 		config, format, err = common.DecodeUrlImageData(imageUrl.Url)
 	} else {
 	} else {
 		common.SysLog(fmt.Sprintf("decoding image"))
 		common.SysLog(fmt.Sprintf("decoding image"))
-		config, format, err = common.DecodeBase64ImageData(imageUrl.Url)
+		config, format, _, err = common.DecodeBase64ImageData(imageUrl.Url)
 	}
 	}
 	if err != nil {
 	if err != nil {
 		return 0, err
 		return 0, err

+ 1 - 1
web/src/pages/Channel/EditChannel.js

@@ -63,7 +63,7 @@ const EditChannel = (props) => {
             let localModels = [];
             let localModels = [];
             switch (value) {
             switch (value) {
                 case 14:
                 case 14:
-                    localModels = ['claude-instant-1', 'claude-2'];
+                    localModels = ["claude-instant-1", "claude-2", "claude-2.0", "claude-2.1", "claude-3-sonnet-20240229", "claude-3-opus-20240229"];
                     break;
                     break;
                 case 11:
                 case 11:
                     localModels = ['PaLM-2'];
                     localModels = ['PaLM-2'];

+ 11 - 0
web/src/pages/TopUp/index.js

@@ -3,6 +3,8 @@ import {API, isMobile, showError, showInfo, showSuccess} from '../../helpers';
 import {renderNumber, renderQuota} from '../../helpers/render';
 import {renderNumber, renderQuota} from '../../helpers/render';
 import {Col, Layout, Row, Typography, Card, Button, Form, Divider, Space, Modal} from "@douyinfe/semi-ui";
 import {Col, Layout, Row, Typography, Card, Button, Form, Divider, Space, Modal} from "@douyinfe/semi-ui";
 import Title from "@douyinfe/semi-ui/lib/es/typography/title";
 import Title from "@douyinfe/semi-ui/lib/es/typography/title";
+import Text from '@douyinfe/semi-ui/lib/es/typography/text';
+import { Link } from 'react-router-dom';
 
 
 const TopUp = () => {
 const TopUp = () => {
     const [redemptionCode, setRedemptionCode] = useState('');
     const [redemptionCode, setRedemptionCode] = useState('');
@@ -290,6 +292,15 @@ const TopUp = () => {
                                     </Space>
                                     </Space>
                                 </Form>
                                 </Form>
                             </div>
                             </div>
+                            <div style={{ display: 'flex', justifyContent: 'right' }}>
+                                <Text>
+                                    <Link onClick={
+                                        async () => {
+                                            window.location.href = '/topup/history'
+                                        }
+                                    }>充值记录</Link>
+                                </Text>
+                            </div>
                         </Card>
                         </Card>
                     </div>
                     </div>