Quellcode durchsuchen

✨ feat: Refactor Gemini tools handling to support JSON raw message format

CaIon vor 6 Monaten
Ursprung
Commit
d3170310ff
3 geänderte Dateien mit 44 neuen und 9 gelöschten Zeilen
  1. 37 1
      dto/gemini.go
  2. 5 6
      relay/channel/gemini/relay-gemini.go
  3. 2 2
      service/convert.go

+ 37 - 1
dto/gemini.go

@@ -3,16 +3,52 @@ package dto
 import (
 	"encoding/json"
 	"one-api/common"
+	"strings"
 )
 
 type GeminiChatRequest struct {
 	Contents           []GeminiChatContent        `json:"contents"`
 	SafetySettings     []GeminiChatSafetySettings `json:"safetySettings,omitempty"`
 	GenerationConfig   GeminiChatGenerationConfig `json:"generationConfig,omitempty"`
-	Tools              []GeminiChatTool           `json:"tools,omitempty"`
+	Tools              json.RawMessage            `json:"tools,omitempty"`
 	SystemInstructions *GeminiChatContent         `json:"systemInstruction,omitempty"`
 }
 
+func (r *GeminiChatRequest) GetTools() []GeminiChatTool {
+	var tools []GeminiChatTool
+	if strings.HasSuffix(string(r.Tools), "[") {
+		// is array
+		if err := common.Unmarshal(r.Tools, &tools); err != nil {
+			common.LogError(nil, "error_unmarshalling_tools: "+err.Error())
+			return nil
+		}
+	} else if strings.HasPrefix(string(r.Tools), "{") {
+		// is object
+		singleTool := GeminiChatTool{}
+		if err := common.Unmarshal(r.Tools, &singleTool); err != nil {
+			common.LogError(nil, "error_unmarshalling_single_tool: "+err.Error())
+			return nil
+		}
+		tools = []GeminiChatTool{singleTool}
+	}
+	return tools
+}
+
+func (r *GeminiChatRequest) SetTools(tools []GeminiChatTool) {
+	if len(tools) == 0 {
+		r.Tools = json.RawMessage("[]")
+		return
+	}
+
+	// Marshal the tools to JSON
+	data, err := common.Marshal(tools)
+	if err != nil {
+		common.LogError(nil, "error_marshalling_tools: "+err.Error())
+		return
+	}
+	r.Tools = data
+}
+
 type GeminiThinkingConfig struct {
 	IncludeThoughts bool `json:"includeThoughts,omitempty"`
 	ThinkingBudget  *int `json:"thinkingBudget,omitempty"`

+ 5 - 6
relay/channel/gemini/relay-gemini.go

@@ -267,24 +267,23 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
 			tool.Function.Parameters = cleanedParams
 			functions = append(functions, tool.Function)
 		}
+		geminiTools := geminiRequest.GetTools()
 		if codeExecution {
-			geminiRequest.Tools = append(geminiRequest.Tools, dto.GeminiChatTool{
+			geminiTools = append(geminiTools, dto.GeminiChatTool{
 				CodeExecution: make(map[string]string),
 			})
 		}
 		if googleSearch {
-			geminiRequest.Tools = append(geminiRequest.Tools, dto.GeminiChatTool{
+			geminiTools = append(geminiTools, dto.GeminiChatTool{
 				GoogleSearch: make(map[string]string),
 			})
 		}
 		if len(functions) > 0 {
-			geminiRequest.Tools = append(geminiRequest.Tools, dto.GeminiChatTool{
+			geminiTools = append(geminiTools, dto.GeminiChatTool{
 				FunctionDeclarations: functions,
 			})
 		}
-		// common.SysLog("tools: " + fmt.Sprintf("%+v", geminiRequest.Tools))
-		// json_data, _ := json.Marshal(geminiRequest.Tools)
-		// common.SysLog("tools_json: " + string(json_data))
+		geminiRequest.SetTools(geminiTools)
 	}
 
 	if textRequest.ResponseFormat != nil && (textRequest.ResponseFormat.Type == "json_schema" || textRequest.ResponseFormat.Type == "json_object") {

+ 2 - 2
service/convert.go

@@ -569,9 +569,9 @@ func GeminiToOpenAIRequest(geminiRequest *dto.GeminiChatRequest, info *relaycomm
 	}
 
 	// 转换工具调用
-	if len(geminiRequest.Tools) > 0 {
+	if len(geminiRequest.GetTools()) > 0 {
 		var tools []dto.ToolCallRequest
-		for _, tool := range geminiRequest.Tools {
+		for _, tool := range geminiRequest.GetTools() {
 			if tool.FunctionDeclarations != nil {
 				// 将 Gemini 的 FunctionDeclarations 转换为 OpenAI 的 ToolCallRequest
 				functionDeclarations, ok := tool.FunctionDeclarations.([]dto.FunctionRequest)