Browse Source

feat: support claude tool calling

CalciumIon 1 year ago
parent
commit
11fd993574
4 changed files with 128 additions and 39 deletions
  1. 6 1
      dto/text_request.go
  2. 4 2
      dto/text_response.go
  3. 39 22
      relay/channel/claude/dto.go
  4. 79 14
      relay/channel/claude/relay-claude.go

+ 6 - 1
dto/text_request.go

@@ -26,7 +26,7 @@ type GeneralOpenAIRequest struct {
 	PresencePenalty  float64         `json:"presence_penalty,omitempty"`
 	ResponseFormat   *ResponseFormat `json:"response_format,omitempty"`
 	Seed             float64         `json:"seed,omitempty"`
-	Tools            any             `json:"tools,omitempty"`
+	Tools            []ToolCall      `json:"tools,omitempty"`
 	ToolChoice       any             `json:"tool_choice,omitempty"`
 	User             string          `json:"user,omitempty"`
 	LogProbs         bool            `json:"logprobs,omitempty"`
@@ -104,6 +104,11 @@ func (m Message) StringContent() string {
 	return string(m.Content)
 }
 
+func (m *Message) SetStringContent(content string) {
+	jsonContent, _ := json.Marshal(content)
+	m.Content = jsonContent
+}
+
 func (m Message) IsStringContent() bool {
 	var stringContent string
 	if err := json.Unmarshal(m.Content, &stringContent); err == nil {

+ 4 - 2
dto/text_response.go

@@ -86,9 +86,11 @@ type ToolCall struct {
 }
 
 type FunctionCall struct {
-	Name string `json:"name,omitempty"`
+	Description string `json:"description,omitempty"`
+	Name        string `json:"name,omitempty"`
 	// call function with arguments in JSON format
-	Arguments string `json:"arguments,omitempty"`
+	Parameters any    `json:"parameters,omitempty"` // request
+	Arguments  string `json:"arguments,omitempty"`
 }
 
 type ChatCompletionsStreamResponse struct {

+ 39 - 22
relay/channel/claude/dto.go

@@ -5,11 +5,18 @@ type ClaudeMetadata struct {
 }
 
 type ClaudeMediaMessage struct {
-	Type       string               `json:"type"`
-	Text       string               `json:"text,omitempty"`
-	Source     *ClaudeMessageSource `json:"source,omitempty"`
-	Usage      *ClaudeUsage         `json:"usage,omitempty"`
-	StopReason *string              `json:"stop_reason,omitempty"`
+	Type        string               `json:"type"`
+	Text        string               `json:"text,omitempty"`
+	Source      *ClaudeMessageSource `json:"source,omitempty"`
+	Usage       *ClaudeUsage         `json:"usage,omitempty"`
+	StopReason  *string              `json:"stop_reason,omitempty"`
+	PartialJson string               `json:"partial_json,omitempty"`
+	// tool_calls
+	Id        string `json:"id,omitempty"`
+	Name      string `json:"name,omitempty"`
+	Input     any    `json:"input,omitempty"`
+	Content   string `json:"content,omitempty"`
+	ToolUseId string `json:"tool_use_id,omitempty"`
 }
 
 type ClaudeMessageSource struct {
@@ -23,6 +30,18 @@ type ClaudeMessage struct {
 	Content any    `json:"content"`
 }
 
+type Tool struct {
+	Name        string      `json:"name"`
+	Description string      `json:"description,omitempty"`
+	InputSchema InputSchema `json:"input_schema"`
+}
+
+type InputSchema struct {
+	Type       string `json:"type"`
+	Properties any    `json:"properties,omitempty"`
+	Required   any    `json:"required,omitempty"`
+}
+
 type ClaudeRequest struct {
 	Model             string          `json:"model"`
 	Prompt            string          `json:"prompt,omitempty"`
@@ -35,7 +54,9 @@ type ClaudeRequest struct {
 	TopP              float64         `json:"top_p,omitempty"`
 	TopK              int             `json:"top_k,omitempty"`
 	//ClaudeMetadata    `json:"metadata,omitempty"`
-	Stream bool `json:"stream,omitempty"`
+	Stream     bool   `json:"stream,omitempty"`
+	Tools      []Tool `json:"tools,omitempty"`
+	ToolChoice any    `json:"tool_choice,omitempty"`
 }
 
 type ClaudeError struct {
@@ -44,24 +65,20 @@ type ClaudeError struct {
 }
 
 type ClaudeResponse struct {
-	Id         string               `json:"id"`
-	Type       string               `json:"type"`
-	Content    []ClaudeMediaMessage `json:"content"`
-	Completion string               `json:"completion"`
-	StopReason string               `json:"stop_reason"`
-	Model      string               `json:"model"`
-	Error      ClaudeError          `json:"error"`
-	Usage      ClaudeUsage          `json:"usage"`
-	Index      int                  `json:"index"`   // stream only
-	Delta      *ClaudeMediaMessage  `json:"delta"`   // stream only
-	Message    *ClaudeResponse      `json:"message"` // stream only: message_start
+	Id           string               `json:"id"`
+	Type         string               `json:"type"`
+	Content      []ClaudeMediaMessage `json:"content"`
+	Completion   string               `json:"completion"`
+	StopReason   string               `json:"stop_reason"`
+	Model        string               `json:"model"`
+	Error        ClaudeError          `json:"error"`
+	Usage        ClaudeUsage          `json:"usage"`
+	Index        int                  `json:"index"` // stream only
+	ContentBlock *ClaudeMediaMessage  `json:"content_block"`
+	Delta        *ClaudeMediaMessage  `json:"delta"`   // stream only
+	Message      *ClaudeResponse      `json:"message"` // stream only: message_start
 }
 
-//type ClaudeResponseChoice struct {
-//	Index   int                `json:"index"`
-//	Type    string             `json:"type"`
-//}
-
 type ClaudeUsage struct {
 	InputTokens  int `json:"input_tokens"`
 	OutputTokens int `json:"output_tokens"`

+ 79 - 14
relay/channel/claude/relay-claude.go

@@ -30,6 +30,7 @@ func stopReasonClaude2OpenAI(reason string) string {
 }
 
 func RequestOpenAI2ClaudeComplete(textRequest dto.GeneralOpenAIRequest) *ClaudeRequest {
+
 	claudeRequest := ClaudeRequest{
 		Model:         textRequest.Model,
 		Prompt:        "",
@@ -60,6 +61,22 @@ func RequestOpenAI2ClaudeComplete(textRequest dto.GeneralOpenAIRequest) *ClaudeR
 }
 
 func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeRequest, error) {
+	claudeTools := make([]Tool, 0, len(textRequest.Tools))
+
+	for _, tool := range textRequest.Tools {
+		if params, ok := tool.Function.Parameters.(map[string]any); ok {
+			claudeTools = append(claudeTools, Tool{
+				Name:        tool.Function.Name,
+				Description: tool.Function.Description,
+				InputSchema: InputSchema{
+					Type:       params["type"].(string),
+					Properties: params["properties"],
+					Required:   params["required"],
+				},
+			})
+		}
+	}
+
 	claudeRequest := ClaudeRequest{
 		Model:         textRequest.Model,
 		MaxTokens:     textRequest.MaxTokens,
@@ -68,6 +85,7 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR
 		TopP:          textRequest.TopP,
 		TopK:          textRequest.TopK,
 		Stream:        textRequest.Stream,
+		Tools:         claudeTools,
 	}
 	if claudeRequest.MaxTokens == 0 {
 		claudeRequest.MaxTokens = 4096
@@ -184,6 +202,7 @@ func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) (*
 	response.Object = "chat.completion.chunk"
 	response.Model = claudeResponse.Model
 	response.Choices = make([]dto.ChatCompletionsStreamResponseChoice, 0)
+	tools := make([]dto.ToolCall, 0)
 	var choice dto.ChatCompletionsStreamResponseChoice
 	if reqMode == RequestModeCompletion {
 		choice.Delta.SetContentString(claudeResponse.Completion)
@@ -199,10 +218,33 @@ func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) (*
 			choice.Delta.SetContentString("")
 			choice.Delta.Role = "assistant"
 		} else if claudeResponse.Type == "content_block_start" {
-			return nil, nil
+			if claudeResponse.ContentBlock != nil {
+				//choice.Delta.SetContentString(claudeResponse.ContentBlock.Text)
+				if claudeResponse.ContentBlock.Type == "tool_use" {
+					tools = append(tools, dto.ToolCall{
+						ID:   claudeResponse.ContentBlock.Id,
+						Type: "function",
+						Function: dto.FunctionCall{
+							Name:      claudeResponse.ContentBlock.Name,
+							Arguments: "",
+						},
+					})
+				}
+			} else {
+				return nil, nil
+			}
 		} else if claudeResponse.Type == "content_block_delta" {
-			choice.Index = claudeResponse.Index
-			choice.Delta.SetContentString(claudeResponse.Delta.Text)
+			if claudeResponse.Delta != nil {
+				choice.Index = claudeResponse.Index
+				choice.Delta.SetContentString(claudeResponse.Delta.Text)
+				if claudeResponse.Delta.Type == "input_json_delta" {
+					tools = append(tools, dto.ToolCall{
+						Function: dto.FunctionCall{
+							Arguments: claudeResponse.Delta.PartialJson,
+						},
+					})
+				}
+			}
 		} else if claudeResponse.Type == "message_delta" {
 			finishReason := stopReasonClaude2OpenAI(*claudeResponse.Delta.StopReason)
 			if finishReason != "null" {
@@ -218,6 +260,10 @@ func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) (*
 	if claudeUsage == nil {
 		claudeUsage = &ClaudeUsage{}
 	}
+	if len(tools) > 0 {
+		choice.Delta.Content = nil // compatible with other OpenAI derivative applications, like LobeOpenAICompatibleFactory ...
+		choice.Delta.ToolCalls = tools
+	}
 	response.Choices = append(response.Choices, choice)
 
 	return &response, claudeUsage
@@ -230,6 +276,11 @@ func ResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *dto.Ope
 		Object:  "chat.completion",
 		Created: common.GetTimestamp(),
 	}
+	var responseText string
+	if len(claudeResponse.Content) > 0 {
+		responseText = claudeResponse.Content[0].Text
+	}
+	tools := make([]dto.ToolCall, 0)
 	if reqMode == RequestModeCompletion {
 		content, _ := json.Marshal(strings.TrimPrefix(claudeResponse.Completion, " "))
 		choice := dto.OpenAITextResponseChoice{
@@ -244,20 +295,32 @@ func ResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *dto.Ope
 		choices = append(choices, choice)
 	} else {
 		fullTextResponse.Id = claudeResponse.Id
-		for i, message := range claudeResponse.Content {
-			content, _ := json.Marshal(message.Text)
-			choice := dto.OpenAITextResponseChoice{
-				Index: i,
-				Message: dto.Message{
-					Role:    "assistant",
-					Content: content,
-				},
-				FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason),
+		for _, message := range claudeResponse.Content {
+			if message.Type == "tool_use" {
+				args, _ := json.Marshal(message.Input)
+				tools = append(tools, dto.ToolCall{
+					ID:   message.Id,
+					Type: "function", // compatible with other OpenAI derivative applications
+					Function: dto.FunctionCall{
+						Name:      message.Name,
+						Arguments: string(args),
+					},
+				})
 			}
-			choices = append(choices, choice)
 		}
 	}
-
+	choice := dto.OpenAITextResponseChoice{
+		Index: 0,
+		Message: dto.Message{
+			Role: "assistant",
+		},
+		FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason),
+	}
+	choice.SetStringContent(responseText)
+	if len(tools) > 0 {
+		choice.Message.ToolCalls = tools
+	}
+	choices = append(choices, choice)
 	fullTextResponse.Choices = choices
 	return &fullTextResponse
 }
@@ -334,6 +397,8 @@ func claudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
 				} else if claudeResponse.Type == "message_delta" {
 					usage.CompletionTokens = claudeUsage.OutputTokens
 					usage.TotalTokens = claudeUsage.InputTokens + claudeUsage.OutputTokens
+				} else if claudeResponse.Type == "content_block_start" {
+
 				} else {
 					return true
 				}