Преглед изворни кода

Merge pull request #2397 from seefs001/fix/tool-call-claude

fix: try to fix tool call issues
Calcium-Ion пре 2 месеци
родитељ
комит
fca015c6c4
1 измењених фајлова са 198 додато и 103 уклоњено
  1. 198 103
      service/convert.go

+ 198 - 103
service/convert.go

@@ -201,6 +201,10 @@ func generateStopBlock(index int) *dto.ClaudeResponse {
 }
 
 func StreamResponseOpenAI2Claude(openAIResponse *dto.ChatCompletionsStreamResponse, info *relaycommon.RelayInfo) []*dto.ClaudeResponse {
+	if info.ClaudeConvertInfo.Done {
+		return nil
+	}
+
 	var claudeResponses []*dto.ClaudeResponse
 	if info.SendResponseCount == 1 {
 		msg := &dto.ClaudeMediaMessage{
@@ -218,45 +222,117 @@ func StreamResponseOpenAI2Claude(openAIResponse *dto.ChatCompletionsStreamRespon
 			Type:    "message_start",
 			Message: msg,
 		})
-		claudeResponses = append(claudeResponses)
 		//claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
 		//	Type: "ping",
 		//})
 		if openAIResponse.IsToolCall() {
 			info.ClaudeConvertInfo.LastMessagesType = relaycommon.LastMessageTypeTools
+			var toolCall dto.ToolCallResponse
+			if len(openAIResponse.Choices) > 0 && len(openAIResponse.Choices[0].Delta.ToolCalls) > 0 {
+				toolCall = openAIResponse.Choices[0].Delta.ToolCalls[0]
+			} else {
+				first := openAIResponse.GetFirstToolCall()
+				if first != nil {
+					toolCall = *first
+				} else {
+					toolCall = dto.ToolCallResponse{}
+				}
+			}
 			resp := &dto.ClaudeResponse{
 				Type: "content_block_start",
 				ContentBlock: &dto.ClaudeMediaMessage{
-					Id:    openAIResponse.GetFirstToolCall().ID,
+					Id:    toolCall.ID,
 					Type:  "tool_use",
-					Name:  openAIResponse.GetFirstToolCall().Function.Name,
+					Name:  toolCall.Function.Name,
 					Input: map[string]interface{}{},
 				},
 			}
 			resp.SetIndex(0)
 			claudeResponses = append(claudeResponses, resp)
+			// 首块包含工具 delta,则追加 input_json_delta
+			if toolCall.Function.Arguments != "" {
+				claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
+					Index: &info.ClaudeConvertInfo.Index,
+					Type:  "content_block_delta",
+					Delta: &dto.ClaudeMediaMessage{
+						Type:        "input_json_delta",
+						PartialJson: &toolCall.Function.Arguments,
+					},
+				})
+			}
 		} else {
 
 		}
 		// 判断首个响应是否存在内容(非标准的 OpenAI 响应)
-		if len(openAIResponse.Choices) > 0 && len(openAIResponse.Choices[0].Delta.GetContentString()) > 0 {
-			claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
-				Index: &info.ClaudeConvertInfo.Index,
-				Type:  "content_block_start",
-				ContentBlock: &dto.ClaudeMediaMessage{
-					Type: "text",
-					Text: common.GetPointer[string](""),
-				},
-			})
+		if len(openAIResponse.Choices) > 0 {
+			reasoning := openAIResponse.Choices[0].Delta.GetReasoningContent()
+			content := openAIResponse.Choices[0].Delta.GetContentString()
+
+			if reasoning != "" {
+				claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
+					Index: &info.ClaudeConvertInfo.Index,
+					Type:  "content_block_start",
+					ContentBlock: &dto.ClaudeMediaMessage{
+						Type:     "thinking",
+						Thinking: common.GetPointer[string](""),
+					},
+				})
+				claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
+					Index: &info.ClaudeConvertInfo.Index,
+					Type:  "content_block_delta",
+					Delta: &dto.ClaudeMediaMessage{
+						Type:     "thinking_delta",
+						Thinking: &reasoning,
+					},
+				})
+				info.ClaudeConvertInfo.LastMessagesType = relaycommon.LastMessageTypeThinking
+			} else if content != "" {
+				claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
+					Index: &info.ClaudeConvertInfo.Index,
+					Type:  "content_block_start",
+					ContentBlock: &dto.ClaudeMediaMessage{
+						Type: "text",
+						Text: common.GetPointer[string](""),
+					},
+				})
+				claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
+					Index: &info.ClaudeConvertInfo.Index,
+					Type:  "content_block_delta",
+					Delta: &dto.ClaudeMediaMessage{
+						Type: "text_delta",
+						Text: common.GetPointer[string](content),
+					},
+				})
+				info.ClaudeConvertInfo.LastMessagesType = relaycommon.LastMessageTypeText
+			}
+		}
+
+		// 如果首块就带 finish_reason,需要立即发送停止块
+		if len(openAIResponse.Choices) > 0 && openAIResponse.Choices[0].FinishReason != nil && *openAIResponse.Choices[0].FinishReason != "" {
+			info.FinishReason = *openAIResponse.Choices[0].FinishReason
+			claudeResponses = append(claudeResponses, generateStopBlock(info.ClaudeConvertInfo.Index))
+			oaiUsage := openAIResponse.Usage
+			if oaiUsage == nil {
+				oaiUsage = info.ClaudeConvertInfo.Usage
+			}
+			if oaiUsage != nil {
+				claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
+					Type: "message_delta",
+					Usage: &dto.ClaudeUsage{
+						InputTokens:              oaiUsage.PromptTokens,
+						OutputTokens:             oaiUsage.CompletionTokens,
+						CacheCreationInputTokens: oaiUsage.PromptTokensDetails.CachedCreationTokens,
+						CacheReadInputTokens:     oaiUsage.PromptTokensDetails.CachedTokens,
+					},
+					Delta: &dto.ClaudeMediaMessage{
+						StopReason: common.GetPointer[string](stopReasonOpenAI2Claude(info.FinishReason)),
+					},
+				})
+			}
 			claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
-				Index: &info.ClaudeConvertInfo.Index,
-				Type:  "content_block_delta",
-				Delta: &dto.ClaudeMediaMessage{
-					Type: "text_delta",
-					Text: common.GetPointer[string](openAIResponse.Choices[0].Delta.GetContentString()),
-				},
+				Type: "message_stop",
 			})
-			info.ClaudeConvertInfo.LastMessagesType = relaycommon.LastMessageTypeText
+			info.ClaudeConvertInfo.Done = true
 		}
 		return claudeResponses
 	}
@@ -264,7 +340,7 @@ func StreamResponseOpenAI2Claude(openAIResponse *dto.ChatCompletionsStreamRespon
 	if len(openAIResponse.Choices) == 0 {
 		// no choices
 		// 可能为非标准的 OpenAI 响应,判断是否已经完成
-		if info.Done {
+		if info.ClaudeConvertInfo.Done {
 			claudeResponses = append(claudeResponses, generateStopBlock(info.ClaudeConvertInfo.Index))
 			oaiUsage := info.ClaudeConvertInfo.Usage
 			if oaiUsage != nil {
@@ -288,16 +364,110 @@ func StreamResponseOpenAI2Claude(openAIResponse *dto.ChatCompletionsStreamRespon
 		return claudeResponses
 	} else {
 		chosenChoice := openAIResponse.Choices[0]
-		if chosenChoice.FinishReason != nil && *chosenChoice.FinishReason != "" {
-			// should be done
+		doneChunk := chosenChoice.FinishReason != nil && *chosenChoice.FinishReason != ""
+		if doneChunk {
 			info.FinishReason = *chosenChoice.FinishReason
-			if !info.Done {
-				return claudeResponses
+		}
+
+		var claudeResponse dto.ClaudeResponse
+		var isEmpty bool
+		claudeResponse.Type = "content_block_delta"
+		if len(chosenChoice.Delta.ToolCalls) > 0 {
+			toolCalls := chosenChoice.Delta.ToolCalls
+			if info.ClaudeConvertInfo.LastMessagesType != relaycommon.LastMessageTypeTools {
+				claudeResponses = append(claudeResponses, generateStopBlock(info.ClaudeConvertInfo.Index))
+				info.ClaudeConvertInfo.Index++
+			}
+			info.ClaudeConvertInfo.LastMessagesType = relaycommon.LastMessageTypeTools
+
+			for i, toolCall := range toolCalls {
+				blockIndex := info.ClaudeConvertInfo.Index
+				if toolCall.Index != nil {
+					blockIndex = *toolCall.Index
+				} else if len(toolCalls) > 1 {
+					blockIndex = info.ClaudeConvertInfo.Index + i
+				}
+
+				idx := blockIndex
+				claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
+					Index: &idx,
+					Type:  "content_block_start",
+					ContentBlock: &dto.ClaudeMediaMessage{
+						Id:    toolCall.ID,
+						Type:  "tool_use",
+						Name:  toolCall.Function.Name,
+						Input: map[string]interface{}{},
+					},
+				})
+
+				claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
+					Index: &idx,
+					Type:  "content_block_delta",
+					Delta: &dto.ClaudeMediaMessage{
+						Type:        "input_json_delta",
+						PartialJson: &toolCall.Function.Arguments,
+					},
+				})
+
+				info.ClaudeConvertInfo.Index = blockIndex
+			}
+		} else {
+			reasoning := chosenChoice.Delta.GetReasoningContent()
+			textContent := chosenChoice.Delta.GetContentString()
+			if reasoning != "" || textContent != "" {
+				if reasoning != "" {
+					if info.ClaudeConvertInfo.LastMessagesType != relaycommon.LastMessageTypeThinking {
+						claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
+							Index: &info.ClaudeConvertInfo.Index,
+							Type:  "content_block_start",
+							ContentBlock: &dto.ClaudeMediaMessage{
+								Type:     "thinking",
+								Thinking: common.GetPointer[string](""),
+							},
+						})
+					}
+					info.ClaudeConvertInfo.LastMessagesType = relaycommon.LastMessageTypeThinking
+					claudeResponse.Delta = &dto.ClaudeMediaMessage{
+						Type:     "thinking_delta",
+						Thinking: &reasoning,
+					}
+				} else {
+					if info.ClaudeConvertInfo.LastMessagesType != relaycommon.LastMessageTypeText {
+						if info.ClaudeConvertInfo.LastMessagesType == relaycommon.LastMessageTypeThinking || info.ClaudeConvertInfo.LastMessagesType == relaycommon.LastMessageTypeTools {
+							claudeResponses = append(claudeResponses, generateStopBlock(info.ClaudeConvertInfo.Index))
+							info.ClaudeConvertInfo.Index++
+						}
+						claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
+							Index: &info.ClaudeConvertInfo.Index,
+							Type:  "content_block_start",
+							ContentBlock: &dto.ClaudeMediaMessage{
+								Type: "text",
+								Text: common.GetPointer[string](""),
+							},
+						})
+					}
+					info.ClaudeConvertInfo.LastMessagesType = relaycommon.LastMessageTypeText
+					claudeResponse.Delta = &dto.ClaudeMediaMessage{
+						Type: "text_delta",
+						Text: common.GetPointer[string](textContent),
+					}
+				}
+			} else {
+				isEmpty = true
 			}
 		}
-		if info.Done {
+
+		claudeResponse.Index = &info.ClaudeConvertInfo.Index
+		if !isEmpty && claudeResponse.Delta != nil {
+			claudeResponses = append(claudeResponses, &claudeResponse)
+		}
+
+		if doneChunk || info.ClaudeConvertInfo.Done {
 			claudeResponses = append(claudeResponses, generateStopBlock(info.ClaudeConvertInfo.Index))
-			oaiUsage := info.ClaudeConvertInfo.Usage
+			oaiUsage := openAIResponse.Usage
+			if oaiUsage == nil {
+				oaiUsage = info.ClaudeConvertInfo.Usage
+			}
 			if oaiUsage != nil {
 				claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
 					Type: "message_delta",
@@ -315,83 +485,8 @@ func StreamResponseOpenAI2Claude(openAIResponse *dto.ChatCompletionsStreamRespon
 			claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
 				Type: "message_stop",
 			})
-		} else {
-			var claudeResponse dto.ClaudeResponse
-			var isEmpty bool
-			claudeResponse.Type = "content_block_delta"
-			if len(chosenChoice.Delta.ToolCalls) > 0 {
-				if info.ClaudeConvertInfo.LastMessagesType != relaycommon.LastMessageTypeTools {
-					claudeResponses = append(claudeResponses, generateStopBlock(info.ClaudeConvertInfo.Index))
-					info.ClaudeConvertInfo.Index++
-					claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
-						Index: &info.ClaudeConvertInfo.Index,
-						Type:  "content_block_start",
-						ContentBlock: &dto.ClaudeMediaMessage{
-							Id:    openAIResponse.GetFirstToolCall().ID,
-							Type:  "tool_use",
-							Name:  openAIResponse.GetFirstToolCall().Function.Name,
-							Input: map[string]interface{}{},
-						},
-					})
-				}
-				info.ClaudeConvertInfo.LastMessagesType = relaycommon.LastMessageTypeTools
-				// tools delta
-				claudeResponse.Delta = &dto.ClaudeMediaMessage{
-					Type:        "input_json_delta",
-					PartialJson: &chosenChoice.Delta.ToolCalls[0].Function.Arguments,
-				}
-			} else {
-				reasoning := chosenChoice.Delta.GetReasoningContent()
-				textContent := chosenChoice.Delta.GetContentString()
-				if reasoning != "" || textContent != "" {
-					if reasoning != "" {
-						if info.ClaudeConvertInfo.LastMessagesType != relaycommon.LastMessageTypeThinking {
-							//info.ClaudeConvertInfo.Index++
-							claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
-								Index: &info.ClaudeConvertInfo.Index,
-								Type:  "content_block_start",
-								ContentBlock: &dto.ClaudeMediaMessage{
-									Type:     "thinking",
-									Thinking: common.GetPointer[string](""),
-								},
-							})
-						}
-						info.ClaudeConvertInfo.LastMessagesType = relaycommon.LastMessageTypeThinking
-						// text delta
-						claudeResponse.Delta = &dto.ClaudeMediaMessage{
-							Type:     "thinking_delta",
-							Thinking: &reasoning,
-						}
-					} else {
-						if info.ClaudeConvertInfo.LastMessagesType != relaycommon.LastMessageTypeText {
-							if info.LastMessagesType == relaycommon.LastMessageTypeThinking || info.LastMessagesType == relaycommon.LastMessageTypeTools {
-								claudeResponses = append(claudeResponses, generateStopBlock(info.ClaudeConvertInfo.Index))
-								info.ClaudeConvertInfo.Index++
-							}
-							claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
-								Index: &info.ClaudeConvertInfo.Index,
-								Type:  "content_block_start",
-								ContentBlock: &dto.ClaudeMediaMessage{
-									Type: "text",
-									Text: common.GetPointer[string](""),
-								},
-							})
-						}
-						info.ClaudeConvertInfo.LastMessagesType = relaycommon.LastMessageTypeText
-						// text delta
-						claudeResponse.Delta = &dto.ClaudeMediaMessage{
-							Type: "text_delta",
-							Text: common.GetPointer[string](textContent),
-						}
-					}
-				} else {
-					isEmpty = true
-				}
-			}
-			claudeResponse.Index = &info.ClaudeConvertInfo.Index
-			if !isEmpty {
-				claudeResponses = append(claudeResponses, &claudeResponse)
-			}
+			info.ClaudeConvertInfo.Done = true
+			return claudeResponses
 		}
 	}