Bladeren bron

fix: tool responses (#3080)

Calcium-Ion 4 dagen geleden
bovenliggende
commit
5dcbcd9cad
2 gewijzigde bestanden met toevoegingen van 131 en 24 verwijderingen
  1. 66 16
      relay/channel/claude/relay-claude.go
  2. 65 8
      relay/channel/claude/relay_claude_test.go

+ 66 - 16
relay/channel/claude/relay-claude.go

@@ -404,12 +404,15 @@ func RequestOpenAI2ClaudeMessage(c *gin.Context, textRequest dto.GeneralOpenAIRe
 	return &claudeRequest, nil
 }
 
-func StreamResponseClaude2OpenAI(claudeResponse *dto.ClaudeResponse) *dto.ChatCompletionsStreamResponse {
+func StreamResponseClaude2OpenAI(claudeResponse *dto.ClaudeResponse, claudeInfo *ClaudeResponseInfo) *dto.ChatCompletionsStreamResponse {
 	var response dto.ChatCompletionsStreamResponse
 	response.Object = "chat.completion.chunk"
 	response.Model = claudeResponse.Model
 	response.Choices = make([]dto.ChatCompletionsStreamResponseChoice, 0)
 	tools := make([]dto.ToolCallResponse, 0)
+	if claudeInfo != nil && claudeInfo.ToolCallStreamStates == nil {
+		claudeInfo.ToolCallStreamStates = make(map[int]*ToolCallStreamState)
+	}
 	fcIdx := 0
 	if claudeResponse.Index != nil {
 		fcIdx = *claudeResponse.Index - 1
@@ -433,6 +436,13 @@ func StreamResponseClaude2OpenAI(claudeResponse *dto.ClaudeResponse) *dto.ChatCo
 				choice.Delta.SetContentString(*claudeResponse.ContentBlock.Text)
 			}
 			if claudeResponse.ContentBlock.Type == "tool_use" {
+				if claudeInfo != nil {
+					claudeInfo.ToolCallStreamStates[fcIdx] = &ToolCallStreamState{
+						ID:   claudeResponse.ContentBlock.Id,
+						Name: claudeResponse.ContentBlock.Name,
+					}
+					return nil
+				}
 				tools = append(tools, dto.ToolCallResponse{
 					Index: common.GetPointer(fcIdx),
 					ID:    claudeResponse.ContentBlock.Id,
@@ -451,19 +461,28 @@ func StreamResponseClaude2OpenAI(claudeResponse *dto.ClaudeResponse) *dto.ChatCo
 			choice.Delta.Content = claudeResponse.Delta.Text
 			switch claudeResponse.Delta.Type {
 			case "input_json_delta":
-				arguments := "{}"
-				if claudeResponse.Delta.PartialJson != nil {
-					if partial := strings.TrimSpace(*claudeResponse.Delta.PartialJson); partial != "" {
-						arguments = partial
-					}
+				if claudeResponse.Delta.PartialJson == nil {
+					return nil
 				}
-				tools = append(tools, dto.ToolCallResponse{
+				arguments := *claudeResponse.Delta.PartialJson
+				if strings.TrimSpace(arguments) == "" {
+					return nil
+				}
+				toolCall := dto.ToolCallResponse{
 					Type:  "function",
 					Index: common.GetPointer(fcIdx),
 					Function: dto.FunctionResponse{
 						Arguments: arguments,
 					},
-				})
+				}
+				if claudeInfo != nil {
+					if state, ok := claudeInfo.ToolCallStreamStates[fcIdx]; ok {
+						state.Emitted = true
+						toolCall.ID = state.ID
+						toolCall.Function.Name = state.Name
+					}
+				}
+				tools = append(tools, toolCall)
 			case "signature_delta":
 				// 加密的不处理
 				signatureContent := "\n"
@@ -472,6 +491,27 @@ func StreamResponseClaude2OpenAI(claudeResponse *dto.ClaudeResponse) *dto.ChatCo
 				choice.Delta.ReasoningContent = claudeResponse.Delta.Thinking
 			}
 		}
+	} else if claudeResponse.Type == "content_block_stop" {
+		if claudeInfo == nil {
+			return nil
+		}
+		state, ok := claudeInfo.ToolCallStreamStates[fcIdx]
+		if !ok {
+			return nil
+		}
+		delete(claudeInfo.ToolCallStreamStates, fcIdx)
+		if state.Emitted {
+			return nil
+		}
+		tools = append(tools, dto.ToolCallResponse{
+			ID:    state.ID,
+			Type:  "function",
+			Index: common.GetPointer(fcIdx),
+			Function: dto.FunctionResponse{
+				Name:      state.Name,
+				Arguments: "{}",
+			},
+		})
 	} else if claudeResponse.Type == "message_delta" {
 		if claudeResponse.Delta != nil && claudeResponse.Delta.StopReason != nil {
 			finishReason := stopReasonClaude2OpenAI(*claudeResponse.Delta.StopReason)
@@ -556,12 +596,19 @@ func ResponseClaude2OpenAI(claudeResponse *dto.ClaudeResponse) *dto.OpenAITextRe
 }
 
 type ClaudeResponseInfo struct {
-	ResponseId   string
-	Created      int64
-	Model        string
-	ResponseText strings.Builder
-	Usage        *dto.Usage
-	Done         bool
+	ResponseId           string
+	Created              int64
+	Model                string
+	ResponseText         strings.Builder
+	Usage                *dto.Usage
+	Done                 bool
+	ToolCallStreamStates map[int]*ToolCallStreamState
+}
+
+type ToolCallStreamState struct {
+	ID      string
+	Name    string
+	Emitted bool
 }
 
 func buildMessageDeltaPatchUsage(claudeResponse *dto.ClaudeResponse, claudeInfo *ClaudeResponseInfo) *dto.ClaudeUsage {
@@ -694,7 +741,7 @@ func FormatClaudeResponseInfo(claudeResponse *dto.ClaudeResponse, oaiResponse *d
 
 		// 判断是否完整
 		claudeInfo.Done = true
-	} else if claudeResponse.Type == "content_block_start" {
+	} else if claudeResponse.Type == "content_block_start" || claudeResponse.Type == "content_block_stop" {
 	} else {
 		return false
 	}
@@ -739,7 +786,10 @@ func HandleStreamResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud
 		}
 		helper.ClaudeChunkData(c, claudeResponse, data)
 	} else if info.RelayFormat == types.RelayFormatOpenAI {
-		response := StreamResponseClaude2OpenAI(&claudeResponse)
+		response := StreamResponseClaude2OpenAI(&claudeResponse, claudeInfo)
+		if response == nil {
+			return nil
+		}
 
 		if !FormatClaudeResponseInfo(&claudeResponse, response, claudeInfo) {
 			return nil

+ 65 - 8
relay/channel/claude/relay_claude_test.go

@@ -216,7 +216,7 @@ func TestRequestOpenAI2ClaudeMessage_AssistantToolCallWithMalformedArguments(t *
 	assert.Empty(t, inputObj)
 }
 
-func TestStreamResponseClaude2OpenAI_EmptyInputJSONDeltaFallback(t *testing.T) {
+func TestStreamResponseClaude2OpenAI_EmptyInputJSONDeltaIgnored(t *testing.T) {
 	empty := ""
 	resp := &dto.ClaudeResponse{
 		Type:  "content_block_delta",
@@ -227,12 +227,8 @@ func TestStreamResponseClaude2OpenAI_EmptyInputJSONDeltaFallback(t *testing.T) {
 		},
 	}
 
-	chunk := StreamResponseClaude2OpenAI(resp)
-	require.NotNil(t, chunk)
-	require.Len(t, chunk.Choices, 1)
-	require.NotNil(t, chunk.Choices[0].Delta.ToolCalls)
-	require.Len(t, chunk.Choices[0].Delta.ToolCalls, 1)
-	assert.Equal(t, "{}", chunk.Choices[0].Delta.ToolCalls[0].Function.Arguments)
+	chunk := StreamResponseClaude2OpenAI(resp, &ClaudeResponseInfo{})
+	require.Nil(t, chunk)
 }
 
 func TestStreamResponseClaude2OpenAI_NonEmptyInputJSONDeltaPreserved(t *testing.T) {
@@ -246,10 +242,71 @@ func TestStreamResponseClaude2OpenAI_NonEmptyInputJSONDeltaPreserved(t *testing.
 		},
 	}
 
-	chunk := StreamResponseClaude2OpenAI(resp)
+	chunk := StreamResponseClaude2OpenAI(resp, &ClaudeResponseInfo{})
 	require.NotNil(t, chunk)
 	require.Len(t, chunk.Choices, 1)
 	require.NotNil(t, chunk.Choices[0].Delta.ToolCalls)
 	require.Len(t, chunk.Choices[0].Delta.ToolCalls, 1)
 	assert.Equal(t, partial, chunk.Choices[0].Delta.ToolCalls[0].Function.Arguments)
 }
+
+func TestStreamResponseClaude2OpenAI_NoArgToolEmitsObjectAtStop(t *testing.T) {
+	claudeInfo := &ClaudeResponseInfo{}
+	start := &dto.ClaudeResponse{
+		Type:  "content_block_start",
+		Index: func() *int { v := 1; return &v }(),
+		ContentBlock: &dto.ClaudeMediaMessage{
+			Type: "tool_use",
+			Id:   "toolu_1",
+			Name: "get_current_time",
+		},
+	}
+	stop := &dto.ClaudeResponse{
+		Type:  "content_block_stop",
+		Index: func() *int { v := 1; return &v }(),
+	}
+
+	startChunk := StreamResponseClaude2OpenAI(start, claudeInfo)
+	require.Nil(t, startChunk)
+
+	stopChunk := StreamResponseClaude2OpenAI(stop, claudeInfo)
+	require.NotNil(t, stopChunk)
+	require.Len(t, stopChunk.Choices, 1)
+	require.Len(t, stopChunk.Choices[0].Delta.ToolCalls, 1)
+	assert.Equal(t, "toolu_1", stopChunk.Choices[0].Delta.ToolCalls[0].ID)
+	assert.Equal(t, "get_current_time", stopChunk.Choices[0].Delta.ToolCalls[0].Function.Name)
+	assert.Equal(t, "{}", stopChunk.Choices[0].Delta.ToolCalls[0].Function.Arguments)
+}
+
+func TestStreamResponseClaude2OpenAI_ArgToolKeepsIDNameOnDelta(t *testing.T) {
+	claudeInfo := &ClaudeResponseInfo{}
+	start := &dto.ClaudeResponse{
+		Type:  "content_block_start",
+		Index: func() *int { v := 1; return &v }(),
+		ContentBlock: &dto.ClaudeMediaMessage{
+			Type: "tool_use",
+			Id:   "toolu_2",
+			Name: "search_notes",
+		},
+	}
+	partial := `{"query":"today"}`
+	delta := &dto.ClaudeResponse{
+		Type:  "content_block_delta",
+		Index: func() *int { v := 1; return &v }(),
+		Delta: &dto.ClaudeMediaMessage{
+			Type:        "input_json_delta",
+			PartialJson: &partial,
+		},
+	}
+
+	startChunk := StreamResponseClaude2OpenAI(start, claudeInfo)
+	require.Nil(t, startChunk)
+
+	deltaChunk := StreamResponseClaude2OpenAI(delta, claudeInfo)
+	require.NotNil(t, deltaChunk)
+	require.Len(t, deltaChunk.Choices, 1)
+	require.Len(t, deltaChunk.Choices[0].Delta.ToolCalls, 1)
+	assert.Equal(t, "toolu_2", deltaChunk.Choices[0].Delta.ToolCalls[0].ID)
+	assert.Equal(t, "search_notes", deltaChunk.Choices[0].Delta.ToolCalls[0].Function.Name)
+	assert.Equal(t, partial, deltaChunk.Choices[0].Delta.ToolCalls[0].Function.Arguments)
+}