Yan 1 год назад
Родитель
Сommit
0ada2371b6
1 измененных файлов с 64 добавлено и 10 удалено
  1. 64 10
      relay/channel/claude/relay-claude.go

+ 64 - 10
relay/channel/claude/relay-claude.go

@@ -4,7 +4,6 @@ import (
 	"bufio"
 	"encoding/json"
 	"fmt"
-	"github.com/gin-gonic/gin"
 	"io"
 	"net/http"
 	"one-api/common"
@@ -12,6 +11,8 @@ import (
 	relaycommon "one-api/relay/common"
 	"one-api/service"
 	"strings"
+
+	"github.com/gin-gonic/gin"
 )
 
 func stopReasonClaude2OpenAI(reason string) string {
@@ -108,13 +109,10 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR
 		}
 	}
 	formatMessages := make([]dto.Message, 0)
-	var lastMessage *dto.Message
+	lastMessage := dto.Message{
+		Role: "tool",
+	}
 	for i, message := range textRequest.Messages {
-		//if message.Role == "system" {
-		//	if i != 0 {
-		//		message.Role = "user"
-		//	}
-		//}
 		if message.Role == "" {
 			textRequest.Messages[i].Role = "user"
 		}
@@ -122,7 +120,13 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR
 			Role:    message.Role,
 			Content: message.Content,
 		}
-		if lastMessage != nil && lastMessage.Role == message.Role {
+		if message.Role == "tool" {
+			fmtMessage.ToolCallId = message.ToolCallId
+		}
+		if message.Role == "assistant" && message.ToolCalls != nil {
+			fmtMessage.ToolCalls = message.ToolCalls
+		}
+		if lastMessage.Role == message.Role && lastMessage.Role != "tool" {
 			if lastMessage.IsStringContent() && message.IsStringContent() {
 				content, _ := json.Marshal(strings.Trim(fmt.Sprintf("%s %s", lastMessage.StringContent(), message.StringContent()), "\""))
 				fmtMessage.Content = content
@@ -135,7 +139,7 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR
 			fmtMessage.Content = content
 		}
 		formatMessages = append(formatMessages, fmtMessage)
-		lastMessage = &textRequest.Messages[i]
+		lastMessage = fmtMessage
 	}
 
 	claudeMessages := make([]ClaudeMessage, 0)
@@ -174,7 +178,35 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR
 			claudeMessage := ClaudeMessage{
 				Role: message.Role,
 			}
-			if message.IsStringContent() {
+			if message.Role == "tool" {
+				if len(claudeMessages) > 0 && claudeMessages[len(claudeMessages)-1].Role == "user" {
+					lastMessage := claudeMessages[len(claudeMessages)-1]
+					if content, ok := lastMessage.Content.(string); ok {
+						lastMessage.Content = []ClaudeMediaMessage{
+							{
+								Type: "text",
+								Text: content,
+							},
+						}
+					}
+					lastMessage.Content = append(lastMessage.Content.([]ClaudeMediaMessage), ClaudeMediaMessage{
+						Type:      "tool_result",
+						ToolUseId: message.ToolCallId,
+						Content:   message.StringContent(),
+					})
+					claudeMessages[len(claudeMessages)-1] = lastMessage
+					continue
+				} else {
+					claudeMessage.Role = "user"
+					claudeMessage.Content = []ClaudeMediaMessage{
+						{
+							Type:      "tool_result",
+							ToolUseId: message.ToolCallId,
+							Content:   message.StringContent(),
+						},
+					}
+				}
+			} else if message.IsStringContent() && message.ToolCalls == nil {
 				claudeMessage.Content = message.StringContent()
 			} else {
 				claudeMediaMessages := make([]ClaudeMediaMessage, 0)
@@ -207,6 +239,28 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR
 					}
 					claudeMediaMessages = append(claudeMediaMessages, claudeMediaMessage)
 				}
+				if message.ToolCalls != nil {
+					for _, tc := range message.ToolCalls.([]interface{}) {
+						toolCallJSON, _ := json.Marshal(tc)
+						var toolCall dto.ToolCall
+						err := json.Unmarshal(toolCallJSON, &toolCall)
+						if err != nil {
+							common.SysError("tool call is not a dto.ToolCall: " + fmt.Sprintf("%v", tc))
+							continue
+						}
+						inputObj := make(map[string]any)
+						if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &inputObj); err != nil {
+							common.SysError("tool call function arguments is not a map[string]any: " + fmt.Sprintf("%v", toolCall.Function.Arguments))
+							continue
+						}
+						claudeMediaMessages = append(claudeMediaMessages, ClaudeMediaMessage{
+							Type:  "tool_use",
+							Id:    toolCall.ID,
+							Name:  toolCall.Function.Name,
+							Input: inputObj,
+						})
+					}
+				}
 				claudeMessage.Content = claudeMediaMessages
 			}
 			claudeMessages = append(claudeMessages, claudeMessage)