Bladeren bron

feat: RequestOpenAI2ClaudeMessage add more parms map

creamlike1024 7 maanden geleden
bovenliggende
commit
77da33de4f
2 gewijzigde bestanden met toevoegingen van 211 en 2 verwijderingen
  1. 74 0
      dto/claude.go
  2. 137 2
      relay/channel/claude/relay-claude.go

+ 74 - 0
dto/claude.go

@@ -158,6 +158,27 @@ type InputSchema struct {
 	Required   any    `json:"required,omitempty"`
 }
 
+type ClaudeWebSearchTool struct {
+	Type         string                       `json:"type"`
+	Name         string                       `json:"name"`
+	MaxUses      int                          `json:"max_uses,omitempty"`
+	UserLocation *ClaudeWebSearchUserLocation `json:"user_location,omitempty"`
+}
+
+type ClaudeWebSearchUserLocation struct {
+	Type     string `json:"type"`
+	Timezone string `json:"timezone,omitempty"`
+	Country  string `json:"country,omitempty"`
+	Region   string `json:"region,omitempty"`
+	City     string `json:"city,omitempty"`
+}
+
+type ClaudeToolChoice struct {
+	Type                   string `json:"type"`
+	Name                   string `json:"name,omitempty"`
+	DisableParallelToolUse bool   `json:"disable_parallel_tool_use,omitempty"`
+}
+
 type ClaudeRequest struct {
 	Model             string          `json:"model"`
 	Prompt            string          `json:"prompt,omitempty"`
@@ -176,6 +197,59 @@ type ClaudeRequest struct {
 	Thinking   *Thinking `json:"thinking,omitempty"`
 }
 
+// AddTool 添加工具到请求中
+func (c *ClaudeRequest) AddTool(tool any) {
+	if c.Tools == nil {
+		c.Tools = make([]any, 0)
+	}
+
+	switch tools := c.Tools.(type) {
+	case []any:
+		c.Tools = append(tools, tool)
+	default:
+		// 如果Tools不是[]any类型,重新初始化为[]any
+		c.Tools = []any{tool}
+	}
+}
+
+// GetTools 获取工具列表
+func (c *ClaudeRequest) GetTools() []any {
+	if c.Tools == nil {
+		return nil
+	}
+
+	switch tools := c.Tools.(type) {
+	case []any:
+		return tools
+	default:
+		return nil
+	}
+}
+
+// ProcessTools 处理工具列表,支持类型断言
+func ProcessTools(tools []any) ([]*Tool, []*ClaudeWebSearchTool) {
+	var normalTools []*Tool
+	var webSearchTools []*ClaudeWebSearchTool
+
+	for _, tool := range tools {
+		switch t := tool.(type) {
+		case *Tool:
+			normalTools = append(normalTools, t)
+		case *ClaudeWebSearchTool:
+			webSearchTools = append(webSearchTools, t)
+		case Tool:
+			normalTools = append(normalTools, &t)
+		case ClaudeWebSearchTool:
+			webSearchTools = append(webSearchTools, &t)
+		default:
+			// 未知类型,跳过
+			continue
+		}
+	}
+
+	return normalTools, webSearchTools
+}
+
 type Thinking struct {
 	Type         string `json:"type"`
 	BudgetTokens *int   `json:"budget_tokens,omitempty"`

+ 137 - 2
relay/channel/claude/relay-claude.go

@@ -17,6 +17,12 @@ import (
 	"github.com/gin-gonic/gin"
 )
 
+const (
+	WebSearchMaxUsesLow    = 1
+	WebSearchMaxUsesMedium = 5
+	WebSearchMaxUsesHigh   = 10
+)
+
 func stopReasonClaude2OpenAI(reason string) string {
 	switch reason {
 	case "stop_sequence":
@@ -64,7 +70,7 @@ func RequestOpenAI2ClaudeComplete(textRequest dto.GeneralOpenAIRequest) *dto.Cla
 }
 
 func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*dto.ClaudeRequest, error) {
-	claudeTools := make([]dto.Tool, 0, len(textRequest.Tools))
+	claudeTools := make([]any, 0, len(textRequest.Tools))
 
 	for _, tool := range textRequest.Tools {
 		if params, ok := tool.Function.Parameters.(map[string]any); ok {
@@ -84,8 +90,60 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*dto.Cla
 				}
 				claudeTool.InputSchema[s] = a
 			}
-			claudeTools = append(claudeTools, claudeTool)
+			claudeTools = append(claudeTools, &claudeTool)
+		}
+	}
+
+	// Web search tool
+	// https://docs.anthropic.com/en/docs/agents-and-tools/tool-use/web-search-tool
+	if textRequest.WebSearchOptions != nil {
+		webSearchTool := dto.ClaudeWebSearchTool{
+			Type: "web_search_20250305",
+			Name: "web_search",
+		}
+
+		// 处理 user_location
+		if textRequest.WebSearchOptions.UserLocation != nil {
+			anthropicUserLocation := &dto.ClaudeWebSearchUserLocation{
+				Type: "approximate", // 固定为 "approximate"
+			}
+
+			// 解析 UserLocation JSON
+			var userLocationMap map[string]interface{}
+			if err := json.Unmarshal(textRequest.WebSearchOptions.UserLocation, &userLocationMap); err == nil {
+				// 检查是否有 approximate 字段
+				if approximateData, ok := userLocationMap["approximate"].(map[string]interface{}); ok {
+					if timezone, ok := approximateData["timezone"].(string); ok && timezone != "" {
+						anthropicUserLocation.Timezone = timezone
+					}
+					if country, ok := approximateData["country"].(string); ok && country != "" {
+						anthropicUserLocation.Country = country
+					}
+					if region, ok := approximateData["region"].(string); ok && region != "" {
+						anthropicUserLocation.Region = region
+					}
+					if city, ok := approximateData["city"].(string); ok && city != "" {
+						anthropicUserLocation.City = city
+					}
+				}
+			}
+
+			webSearchTool.UserLocation = anthropicUserLocation
 		}
+
+		// 处理 search_context_size 转换为 max_uses
+		if textRequest.WebSearchOptions.SearchContextSize != "" {
+			switch textRequest.WebSearchOptions.SearchContextSize {
+			case "low":
+				webSearchTool.MaxUses = WebSearchMaxUsesLow
+			case "medium":
+				webSearchTool.MaxUses = WebSearchMaxUsesMedium
+			case "high":
+				webSearchTool.MaxUses = WebSearchMaxUsesHigh
+			}
+		}
+
+		claudeTools = append(claudeTools, &webSearchTool)
 	}
 
 	claudeRequest := dto.ClaudeRequest{
@@ -99,6 +157,14 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*dto.Cla
 		Tools:         claudeTools,
 	}
 
+	// 处理 tool_choice 和 parallel_tool_calls
+	if textRequest.ToolChoice != nil || textRequest.ParallelTooCalls != nil {
+		claudeToolChoice := mapToolChoice(textRequest.ToolChoice, textRequest.ParallelTooCalls)
+		if claudeToolChoice != nil {
+			claudeRequest.ToolChoice = claudeToolChoice
+		}
+	}
+
 	if claudeRequest.MaxTokens == 0 {
 		claudeRequest.MaxTokens = uint(model_setting.GetClaudeSettings().GetDefaultMaxTokens(textRequest.Model))
 	}
@@ -123,6 +189,27 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*dto.Cla
 		claudeRequest.Model = strings.TrimSuffix(textRequest.Model, "-thinking")
 	}
 
+	if textRequest.ReasoningEffort != "" {
+		switch textRequest.ReasoningEffort {
+		case "low":
+			claudeRequest.Thinking = &dto.Thinking{
+				Type:         "enabled",
+				BudgetTokens: common.GetPointer[int](1280),
+			}
+		case "medium":
+			claudeRequest.Thinking = &dto.Thinking{
+				Type:         "enabled",
+				BudgetTokens: common.GetPointer[int](2048),
+			}
+		case "high":
+			claudeRequest.Thinking = &dto.Thinking{
+				Type:         "enabled",
+				BudgetTokens: common.GetPointer[int](4096),
+			}
+		}
+	}
+
+	// 指定了 reasoning 参数,覆盖 budgetTokens
 	if textRequest.Reasoning != nil {
 		var reasoning openrouter.RequestReasoning
 		if err := common.UnmarshalJson(textRequest.Reasoning, &reasoning); err != nil {
@@ -685,3 +772,51 @@ func ClaudeHandler(c *gin.Context, resp *http.Response, requestMode int, info *r
 	}
 	return nil, claudeInfo.Usage
 }
+
+func mapToolChoice(toolChoice any, parallelToolCalls *bool) *dto.ClaudeToolChoice {
+	var claudeToolChoice *dto.ClaudeToolChoice
+
+	// 处理 tool_choice 字符串值
+	if toolChoiceStr, ok := toolChoice.(string); ok {
+		switch toolChoiceStr {
+		case "auto":
+			claudeToolChoice = &dto.ClaudeToolChoice{
+				Type: "auto",
+			}
+		case "required":
+			claudeToolChoice = &dto.ClaudeToolChoice{
+				Type: "any",
+			}
+		case "none":
+			claudeToolChoice = &dto.ClaudeToolChoice{
+				Type: "none",
+			}
+		}
+	} else if toolChoiceMap, ok := toolChoice.(map[string]interface{}); ok {
+		// 处理 tool_choice 对象值
+		if function, ok := toolChoiceMap["function"].(map[string]interface{}); ok {
+			if toolName, ok := function["name"].(string); ok {
+				claudeToolChoice = &dto.ClaudeToolChoice{
+					Type: "tool",
+					Name: toolName,
+				}
+			}
+		}
+	}
+
+	// 处理 parallel_tool_calls
+	if parallelToolCalls != nil {
+		if claudeToolChoice == nil {
+			// 如果没有 tool_choice,但有 parallel_tool_calls,创建默认的 auto 类型
+			claudeToolChoice = &dto.ClaudeToolChoice{
+				Type: "auto",
+			}
+		}
+
+		// 设置 disable_parallel_tool_use
+		// 如果 parallel_tool_calls 为 true,则 disable_parallel_tool_use 为 false
+		claudeToolChoice.DisableParallelToolUse = !*parallelToolCalls
+	}
+
+	return claudeToolChoice
+}