فهرست منبع

Merge pull request #464 from Yan-Zero/main

fix: tool use in claude and add gemini mapping
Calcium-Ion 1 سال پیش
والد
کامیت
1675679be9
4فایلهای تغییر یافته به همراه78 افزوده شده و 20 حذف شده
  1. 3 1
      common/model-ratio.go
  2. 10 8
      constant/env.go
  3. 64 10
      relay/channel/claude/relay-claude.go
  4. 1 1
      relay/channel/gemini/constant.go

+ 3 - 1
common/model-ratio.go

@@ -106,8 +106,10 @@ var defaultModelRatio = map[string]float64{
 	"gemini-pro-vision":              1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
 	"gemini-pro-vision":              1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
 	"gemini-1.0-pro-vision-001":      1,
 	"gemini-1.0-pro-vision-001":      1,
 	"gemini-1.0-pro-001":             1,
 	"gemini-1.0-pro-001":             1,
-	"gemini-1.5-pro-latest":          1,
+	"gemini-1.5-pro-latest":          1.75, // $3.5 / 1M tokens
+	"gemini-1.5-pro-exp-0827":        1.75, // $3.5 / 1M tokens
 	"gemini-1.5-flash-latest":        1,
 	"gemini-1.5-flash-latest":        1,
+	"gemini-1.5-flash-exp-0827":      1,
 	"gemini-1.0-pro-latest":          1,
 	"gemini-1.0-pro-latest":          1,
 	"gemini-1.0-pro-vision-latest":   1,
 	"gemini-1.0-pro-vision-latest":   1,
 	"gemini-ultra":                   1,
 	"gemini-ultra":                   1,

+ 10 - 8
constant/env.go

@@ -20,14 +20,16 @@ var GetMediaTokenNotStream = common.GetEnvOrDefaultBool("GET_MEDIA_TOKEN_NOT_STR
 var UpdateTask = common.GetEnvOrDefaultBool("UPDATE_TASK", true)
 var UpdateTask = common.GetEnvOrDefaultBool("UPDATE_TASK", true)
 
 
 var GeminiModelMap = map[string]string{
 var GeminiModelMap = map[string]string{
-	"gemini-1.5-pro-latest":   "v1beta",
-	"gemini-1.5-pro-001":      "v1beta",
-	"gemini-1.5-pro":          "v1beta",
-	"gemini-1.5-pro-exp-0801": "v1beta",
-	"gemini-1.5-flash-latest": "v1beta",
-	"gemini-1.5-flash-001":    "v1beta",
-	"gemini-1.5-flash":        "v1beta",
-	"gemini-ultra":            "v1beta",
+	"gemini-1.5-pro-latest":     "v1beta",
+	"gemini-1.5-pro-001":        "v1beta",
+	"gemini-1.5-pro":            "v1beta",
+	"gemini-1.5-pro-exp-0801":   "v1beta",
+	"gemini-1.5-pro-exp-0827":   "v1beta",
+	"gemini-1.5-flash-latest":   "v1beta",
+	"gemini-1.5-flash-exp-0827": "v1beta",
+	"gemini-1.5-flash-001":      "v1beta",
+	"gemini-1.5-flash":          "v1beta",
+	"gemini-ultra":              "v1beta",
 }
 }
 
 
 func InitEnv() {
 func InitEnv() {

+ 64 - 10
relay/channel/claude/relay-claude.go

@@ -4,7 +4,6 @@ import (
 	"bufio"
 	"bufio"
 	"encoding/json"
 	"encoding/json"
 	"fmt"
 	"fmt"
-	"github.com/gin-gonic/gin"
 	"io"
 	"io"
 	"net/http"
 	"net/http"
 	"one-api/common"
 	"one-api/common"
@@ -12,6 +11,8 @@ import (
 	relaycommon "one-api/relay/common"
 	relaycommon "one-api/relay/common"
 	"one-api/service"
 	"one-api/service"
 	"strings"
 	"strings"
+
+	"github.com/gin-gonic/gin"
 )
 )
 
 
 func stopReasonClaude2OpenAI(reason string) string {
 func stopReasonClaude2OpenAI(reason string) string {
@@ -108,13 +109,10 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR
 		}
 		}
 	}
 	}
 	formatMessages := make([]dto.Message, 0)
 	formatMessages := make([]dto.Message, 0)
-	var lastMessage *dto.Message
+	lastMessage := dto.Message{
+		Role: "tool",
+	}
 	for i, message := range textRequest.Messages {
 	for i, message := range textRequest.Messages {
-		//if message.Role == "system" {
-		//	if i != 0 {
-		//		message.Role = "user"
-		//	}
-		//}
 		if message.Role == "" {
 		if message.Role == "" {
 			textRequest.Messages[i].Role = "user"
 			textRequest.Messages[i].Role = "user"
 		}
 		}
@@ -122,7 +120,13 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR
 			Role:    message.Role,
 			Role:    message.Role,
 			Content: message.Content,
 			Content: message.Content,
 		}
 		}
-		if lastMessage != nil && lastMessage.Role == message.Role {
+		if message.Role == "tool" {
+			fmtMessage.ToolCallId = message.ToolCallId
+		}
+		if message.Role == "assistant" && message.ToolCalls != nil {
+			fmtMessage.ToolCalls = message.ToolCalls
+		}
+		if lastMessage.Role == message.Role && lastMessage.Role != "tool" {
 			if lastMessage.IsStringContent() && message.IsStringContent() {
 			if lastMessage.IsStringContent() && message.IsStringContent() {
 				content, _ := json.Marshal(strings.Trim(fmt.Sprintf("%s %s", lastMessage.StringContent(), message.StringContent()), "\""))
 				content, _ := json.Marshal(strings.Trim(fmt.Sprintf("%s %s", lastMessage.StringContent(), message.StringContent()), "\""))
 				fmtMessage.Content = content
 				fmtMessage.Content = content
@@ -135,7 +139,7 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR
 			fmtMessage.Content = content
 			fmtMessage.Content = content
 		}
 		}
 		formatMessages = append(formatMessages, fmtMessage)
 		formatMessages = append(formatMessages, fmtMessage)
-		lastMessage = &textRequest.Messages[i]
+		lastMessage = fmtMessage
 	}
 	}
 
 
 	claudeMessages := make([]ClaudeMessage, 0)
 	claudeMessages := make([]ClaudeMessage, 0)
@@ -174,7 +178,35 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR
 			claudeMessage := ClaudeMessage{
 			claudeMessage := ClaudeMessage{
 				Role: message.Role,
 				Role: message.Role,
 			}
 			}
-			if message.IsStringContent() {
+			if message.Role == "tool" {
+				if len(claudeMessages) > 0 && claudeMessages[len(claudeMessages)-1].Role == "user" {
+					lastMessage := claudeMessages[len(claudeMessages)-1]
+					if content, ok := lastMessage.Content.(string); ok {
+						lastMessage.Content = []ClaudeMediaMessage{
+							{
+								Type: "text",
+								Text: content,
+							},
+						}
+					}
+					lastMessage.Content = append(lastMessage.Content.([]ClaudeMediaMessage), ClaudeMediaMessage{
+						Type:      "tool_result",
+						ToolUseId: message.ToolCallId,
+						Content:   message.StringContent(),
+					})
+					claudeMessages[len(claudeMessages)-1] = lastMessage
+					continue
+				} else {
+					claudeMessage.Role = "user"
+					claudeMessage.Content = []ClaudeMediaMessage{
+						{
+							Type:      "tool_result",
+							ToolUseId: message.ToolCallId,
+							Content:   message.StringContent(),
+						},
+					}
+				}
+			} else if message.IsStringContent() && message.ToolCalls == nil {
 				claudeMessage.Content = message.StringContent()
 				claudeMessage.Content = message.StringContent()
 			} else {
 			} else {
 				claudeMediaMessages := make([]ClaudeMediaMessage, 0)
 				claudeMediaMessages := make([]ClaudeMediaMessage, 0)
@@ -207,6 +239,28 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR
 					}
 					}
 					claudeMediaMessages = append(claudeMediaMessages, claudeMediaMessage)
 					claudeMediaMessages = append(claudeMediaMessages, claudeMediaMessage)
 				}
 				}
+				if message.ToolCalls != nil {
+					for _, tc := range message.ToolCalls.([]interface{}) {
+						toolCallJSON, _ := json.Marshal(tc)
+						var toolCall dto.ToolCall
+						err := json.Unmarshal(toolCallJSON, &toolCall)
+						if err != nil {
+							common.SysError("tool call is not a dto.ToolCall: " + fmt.Sprintf("%v", tc))
+							continue
+						}
+						inputObj := make(map[string]any)
+						if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &inputObj); err != nil {
+							common.SysError("tool call function arguments is not a map[string]any: " + fmt.Sprintf("%v", toolCall.Function.Arguments))
+							continue
+						}
+						claudeMediaMessages = append(claudeMediaMessages, ClaudeMediaMessage{
+							Type:  "tool_use",
+							Id:    toolCall.ID,
+							Name:  toolCall.Function.Name,
+							Input: inputObj,
+						})
+					}
+				}
 				claudeMessage.Content = claudeMediaMessages
 				claudeMessage.Content = claudeMediaMessages
 			}
 			}
 			claudeMessages = append(claudeMessages, claudeMessage)
 			claudeMessages = append(claudeMessages, claudeMessage)

+ 1 - 1
relay/channel/gemini/constant.go

@@ -6,7 +6,7 @@ const (
 
 
 var ModelList = []string{
 var ModelList = []string{
 	"gemini-1.0-pro-latest", "gemini-1.0-pro-001", "gemini-1.5-pro-latest", "gemini-1.5-flash-latest", "gemini-ultra",
 	"gemini-1.0-pro-latest", "gemini-1.0-pro-001", "gemini-1.5-pro-latest", "gemini-1.5-flash-latest", "gemini-ultra",
-	"gemini-1.0-pro-vision-latest", "gemini-1.0-pro-vision-001",
+	"gemini-1.0-pro-vision-latest", "gemini-1.0-pro-vision-001", "gemini-1.5-pro-exp-0827", "gemini-1.5-flash-exp-0827",
 }
 }
 
 
 var ChannelName = "google gemini"
 var ChannelName = "google gemini"