Parcourir la source

Revert "Fix/aws non empty text"

Seefs il y a 5 jours
Parent
commit
9be9943224
2 fichiers modifiés avec 24 ajouts et 131 suppressions
  1. 16 66
      relay/channel/claude/relay-claude.go
  2. 8 65
      relay/channel/claude/relay_claude_test.go

+ 16 - 66
relay/channel/claude/relay-claude.go

@@ -404,15 +404,12 @@ func RequestOpenAI2ClaudeMessage(c *gin.Context, textRequest dto.GeneralOpenAIRe
 	return &claudeRequest, nil
 }
 
-func StreamResponseClaude2OpenAI(claudeResponse *dto.ClaudeResponse, claudeInfo *ClaudeResponseInfo) *dto.ChatCompletionsStreamResponse {
+func StreamResponseClaude2OpenAI(claudeResponse *dto.ClaudeResponse) *dto.ChatCompletionsStreamResponse {
 	var response dto.ChatCompletionsStreamResponse
 	response.Object = "chat.completion.chunk"
 	response.Model = claudeResponse.Model
 	response.Choices = make([]dto.ChatCompletionsStreamResponseChoice, 0)
 	tools := make([]dto.ToolCallResponse, 0)
-	if claudeInfo != nil && claudeInfo.ToolCallStreamStates == nil {
-		claudeInfo.ToolCallStreamStates = make(map[int]*ToolCallStreamState)
-	}
 	fcIdx := 0
 	if claudeResponse.Index != nil {
 		fcIdx = *claudeResponse.Index - 1
@@ -436,13 +433,6 @@ func StreamResponseClaude2OpenAI(claudeResponse *dto.ClaudeResponse, claudeInfo
 				choice.Delta.SetContentString(*claudeResponse.ContentBlock.Text)
 			}
 			if claudeResponse.ContentBlock.Type == "tool_use" {
-				if claudeInfo != nil {
-					claudeInfo.ToolCallStreamStates[fcIdx] = &ToolCallStreamState{
-						ID:   claudeResponse.ContentBlock.Id,
-						Name: claudeResponse.ContentBlock.Name,
-					}
-					return nil
-				}
 				tools = append(tools, dto.ToolCallResponse{
 					Index: common.GetPointer(fcIdx),
 					ID:    claudeResponse.ContentBlock.Id,
@@ -461,28 +451,19 @@ func StreamResponseClaude2OpenAI(claudeResponse *dto.ClaudeResponse, claudeInfo
 			choice.Delta.Content = claudeResponse.Delta.Text
 			switch claudeResponse.Delta.Type {
 			case "input_json_delta":
-				if claudeResponse.Delta.PartialJson == nil {
-					return nil
-				}
-				arguments := *claudeResponse.Delta.PartialJson
-				if strings.TrimSpace(arguments) == "" {
-					return nil
+				arguments := "{}"
+				if claudeResponse.Delta.PartialJson != nil {
+					if partial := strings.TrimSpace(*claudeResponse.Delta.PartialJson); partial != "" {
+						arguments = partial
+					}
 				}
-				toolCall := dto.ToolCallResponse{
+				tools = append(tools, dto.ToolCallResponse{
 					Type:  "function",
 					Index: common.GetPointer(fcIdx),
 					Function: dto.FunctionResponse{
 						Arguments: arguments,
 					},
-				}
-				if claudeInfo != nil {
-					if state, ok := claudeInfo.ToolCallStreamStates[fcIdx]; ok {
-						state.Emitted = true
-						toolCall.ID = state.ID
-						toolCall.Function.Name = state.Name
-					}
-				}
-				tools = append(tools, toolCall)
+				})
 			case "signature_delta":
 				// 加密的不处理
 				signatureContent := "\n"
@@ -491,27 +472,6 @@ func StreamResponseClaude2OpenAI(claudeResponse *dto.ClaudeResponse, claudeInfo
 				choice.Delta.ReasoningContent = claudeResponse.Delta.Thinking
 			}
 		}
-	} else if claudeResponse.Type == "content_block_stop" {
-		if claudeInfo == nil {
-			return nil
-		}
-		state, ok := claudeInfo.ToolCallStreamStates[fcIdx]
-		if !ok {
-			return nil
-		}
-		delete(claudeInfo.ToolCallStreamStates, fcIdx)
-		if state.Emitted {
-			return nil
-		}
-		tools = append(tools, dto.ToolCallResponse{
-			ID:    state.ID,
-			Type:  "function",
-			Index: common.GetPointer(fcIdx),
-			Function: dto.FunctionResponse{
-				Name:      state.Name,
-				Arguments: "{}",
-			},
-		})
 	} else if claudeResponse.Type == "message_delta" {
 		if claudeResponse.Delta != nil && claudeResponse.Delta.StopReason != nil {
 			finishReason := stopReasonClaude2OpenAI(*claudeResponse.Delta.StopReason)
@@ -596,19 +556,12 @@ func ResponseClaude2OpenAI(claudeResponse *dto.ClaudeResponse) *dto.OpenAITextRe
 }
 
 type ClaudeResponseInfo struct {
-	ResponseId           string
-	Created              int64
-	Model                string
-	ResponseText         strings.Builder
-	Usage                *dto.Usage
-	Done                 bool
-	ToolCallStreamStates map[int]*ToolCallStreamState
-}
-
-type ToolCallStreamState struct {
-	ID      string
-	Name    string
-	Emitted bool
+	ResponseId   string
+	Created      int64
+	Model        string
+	ResponseText strings.Builder
+	Usage        *dto.Usage
+	Done         bool
 }
 
 func buildMessageDeltaPatchUsage(claudeResponse *dto.ClaudeResponse, claudeInfo *ClaudeResponseInfo) *dto.ClaudeUsage {
@@ -741,7 +694,7 @@ func FormatClaudeResponseInfo(claudeResponse *dto.ClaudeResponse, oaiResponse *d
 
 		// 判断是否完整
 		claudeInfo.Done = true
-	} else if claudeResponse.Type == "content_block_start" || claudeResponse.Type == "content_block_stop" {
+	} else if claudeResponse.Type == "content_block_start" {
 	} else {
 		return false
 	}
@@ -786,10 +739,7 @@ func HandleStreamResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud
 		}
 		helper.ClaudeChunkData(c, claudeResponse, data)
 	} else if info.RelayFormat == types.RelayFormatOpenAI {
-		response := StreamResponseClaude2OpenAI(&claudeResponse, claudeInfo)
-		if response == nil {
-			return nil
-		}
+		response := StreamResponseClaude2OpenAI(&claudeResponse)
 
 		if !FormatClaudeResponseInfo(&claudeResponse, response, claudeInfo) {
 			return nil

+ 8 - 65
relay/channel/claude/relay_claude_test.go

@@ -216,7 +216,7 @@ func TestRequestOpenAI2ClaudeMessage_AssistantToolCallWithMalformedArguments(t *
 	assert.Empty(t, inputObj)
 }
 
-func TestStreamResponseClaude2OpenAI_EmptyInputJSONDeltaIgnored(t *testing.T) {
+func TestStreamResponseClaude2OpenAI_EmptyInputJSONDeltaFallback(t *testing.T) {
 	empty := ""
 	resp := &dto.ClaudeResponse{
 		Type:  "content_block_delta",
@@ -227,8 +227,12 @@ func TestStreamResponseClaude2OpenAI_EmptyInputJSONDeltaIgnored(t *testing.T) {
 		},
 	}
 
-	chunk := StreamResponseClaude2OpenAI(resp, &ClaudeResponseInfo{})
-	require.Nil(t, chunk)
+	chunk := StreamResponseClaude2OpenAI(resp)
+	require.NotNil(t, chunk)
+	require.Len(t, chunk.Choices, 1)
+	require.NotNil(t, chunk.Choices[0].Delta.ToolCalls)
+	require.Len(t, chunk.Choices[0].Delta.ToolCalls, 1)
+	assert.Equal(t, "{}", chunk.Choices[0].Delta.ToolCalls[0].Function.Arguments)
 }
 
 func TestStreamResponseClaude2OpenAI_NonEmptyInputJSONDeltaPreserved(t *testing.T) {
@@ -242,71 +246,10 @@ func TestStreamResponseClaude2OpenAI_NonEmptyInputJSONDeltaPreserved(t *testing.
 		},
 	}
 
-	chunk := StreamResponseClaude2OpenAI(resp, &ClaudeResponseInfo{})
+	chunk := StreamResponseClaude2OpenAI(resp)
 	require.NotNil(t, chunk)
 	require.Len(t, chunk.Choices, 1)
 	require.NotNil(t, chunk.Choices[0].Delta.ToolCalls)
 	require.Len(t, chunk.Choices[0].Delta.ToolCalls, 1)
 	assert.Equal(t, partial, chunk.Choices[0].Delta.ToolCalls[0].Function.Arguments)
 }
-
-func TestStreamResponseClaude2OpenAI_NoArgToolEmitsObjectAtStop(t *testing.T) {
-	claudeInfo := &ClaudeResponseInfo{}
-	start := &dto.ClaudeResponse{
-		Type:  "content_block_start",
-		Index: func() *int { v := 1; return &v }(),
-		ContentBlock: &dto.ClaudeMediaMessage{
-			Type: "tool_use",
-			Id:   "toolu_1",
-			Name: "get_current_time",
-		},
-	}
-	stop := &dto.ClaudeResponse{
-		Type:  "content_block_stop",
-		Index: func() *int { v := 1; return &v }(),
-	}
-
-	startChunk := StreamResponseClaude2OpenAI(start, claudeInfo)
-	require.Nil(t, startChunk)
-
-	stopChunk := StreamResponseClaude2OpenAI(stop, claudeInfo)
-	require.NotNil(t, stopChunk)
-	require.Len(t, stopChunk.Choices, 1)
-	require.Len(t, stopChunk.Choices[0].Delta.ToolCalls, 1)
-	assert.Equal(t, "toolu_1", stopChunk.Choices[0].Delta.ToolCalls[0].ID)
-	assert.Equal(t, "get_current_time", stopChunk.Choices[0].Delta.ToolCalls[0].Function.Name)
-	assert.Equal(t, "{}", stopChunk.Choices[0].Delta.ToolCalls[0].Function.Arguments)
-}
-
-func TestStreamResponseClaude2OpenAI_ArgToolKeepsIDNameOnDelta(t *testing.T) {
-	claudeInfo := &ClaudeResponseInfo{}
-	start := &dto.ClaudeResponse{
-		Type:  "content_block_start",
-		Index: func() *int { v := 1; return &v }(),
-		ContentBlock: &dto.ClaudeMediaMessage{
-			Type: "tool_use",
-			Id:   "toolu_2",
-			Name: "search_notes",
-		},
-	}
-	partial := `{"query":"today"}`
-	delta := &dto.ClaudeResponse{
-		Type:  "content_block_delta",
-		Index: func() *int { v := 1; return &v }(),
-		Delta: &dto.ClaudeMediaMessage{
-			Type:        "input_json_delta",
-			PartialJson: &partial,
-		},
-	}
-
-	startChunk := StreamResponseClaude2OpenAI(start, claudeInfo)
-	require.Nil(t, startChunk)
-
-	deltaChunk := StreamResponseClaude2OpenAI(delta, claudeInfo)
-	require.NotNil(t, deltaChunk)
-	require.Len(t, deltaChunk.Choices, 1)
-	require.Len(t, deltaChunk.Choices[0].Delta.ToolCalls, 1)
-	assert.Equal(t, "toolu_2", deltaChunk.Choices[0].Delta.ToolCalls[0].ID)
-	assert.Equal(t, "search_notes", deltaChunk.Choices[0].Delta.ToolCalls[0].Function.Name)
-	assert.Equal(t, partial, deltaChunk.Choices[0].Delta.ToolCalls[0].Function.Arguments)
-}