Просмотр исходного кода

Merge pull request #3066 from seefs001/fix/aws-header-override

Fix/aws header override
Seefs 5 дней назад
Родитель
Сommit
0689600103

+ 1 - 1
dto/claude.go

@@ -434,7 +434,7 @@ func ProcessTools(tools []any) ([]*Tool, []*ClaudeWebSearchTool) {
 }
 
 type Thinking struct {
-	Type         string `json:"type"`
+	Type         string `json:"type,omitempty"`
 	BudgetTokens *int   `json:"budget_tokens,omitempty"`
 }
 

+ 7 - 7
go.mod

@@ -8,10 +8,10 @@ require (
 	github.com/abema/go-mp4 v1.4.1
 	github.com/andybalholm/brotli v1.1.1
 	github.com/anknown/ahocorasick v0.0.0-20190904063843-d75dbd5169c0
-	github.com/aws/aws-sdk-go-v2 v1.37.2
-	github.com/aws/aws-sdk-go-v2/credentials v1.17.11
-	github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.33.0
-	github.com/aws/smithy-go v1.22.5
+	github.com/aws/aws-sdk-go-v2 v1.41.2
+	github.com/aws/aws-sdk-go-v2/credentials v1.19.10
+	github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.50.0
+	github.com/aws/smithy-go v1.24.2
 	github.com/bytedance/gopkg v0.1.3
 	github.com/gin-contrib/cors v1.7.2
 	github.com/gin-contrib/gzip v0.0.6
@@ -62,9 +62,9 @@ require (
 require (
 	github.com/DmitriyVTitov/size v1.5.0 // indirect
 	github.com/anknown/darts v0.0.0-20151216065714-83ff685239e6 // indirect
-	github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.0 // indirect
-	github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.2 // indirect
-	github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.2 // indirect
+	github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.5 // indirect
+	github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.18 // indirect
+	github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.18 // indirect
 	github.com/beorn7/perks v1.0.1 // indirect
 	github.com/boombuler/barcode v1.1.0 // indirect
 	github.com/bytedance/sonic v1.14.1 // indirect

+ 16 - 0
go.sum

@@ -12,18 +12,34 @@ github.com/anknown/darts v0.0.0-20151216065714-83ff685239e6 h1:HblK3eJHq54yET63q
 github.com/anknown/darts v0.0.0-20151216065714-83ff685239e6/go.mod h1:pbiaLIeYLUbgMY1kwEAdwO6UKD5ZNwdPGQlwokS9fe8=
 github.com/aws/aws-sdk-go-v2 v1.37.2 h1:xkW1iMYawzcmYFYEV0UCMxc8gSsjCGEhBXQkdQywVbo=
 github.com/aws/aws-sdk-go-v2 v1.37.2/go.mod h1:9Q0OoGQoboYIAJyslFyF1f5K1Ryddop8gqMhWx/n4Wg=
+github.com/aws/aws-sdk-go-v2 v1.41.2 h1:LuT2rzqNQsauaGkPK/7813XxcZ3o3yePY0Iy891T2ls=
+github.com/aws/aws-sdk-go-v2 v1.41.2/go.mod h1:IvvlAZQXvTXznUPfRVfryiG1fbzE2NGK6m9u39YQ+S4=
 github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.0 h1:6GMWV6CNpA/6fbFHnoAjrv4+LGfyTqZz2LtCHnspgDg=
 github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.0/go.mod h1:/mXlTIVG9jbxkqDnr5UQNQxW1HRYxeGklkM9vAFeabg=
+github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.5 h1:zWFmPmgw4sveAYi1mRqG+E/g0461cJ5M4bJ8/nc6d3Q=
+github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.5/go.mod h1:nVUlMLVV8ycXSb7mSkcNu9e3v/1TJq2RTlrPwhYWr5c=
 github.com/aws/aws-sdk-go-v2/credentials v1.17.11 h1:YuIB1dJNf1Re822rriUOTxopaHHvIq0l/pX3fwO+Tzs=
 github.com/aws/aws-sdk-go-v2/credentials v1.17.11/go.mod h1:AQtFPsDH9bI2O+71anW6EKL+NcD7LG3dpKGMV4SShgo=
+github.com/aws/aws-sdk-go-v2/credentials v1.19.10 h1:EEhmEUFCE1Yhl7vDhNOI5OCL/iKMdkkYFTRpZXNw7m8=
+github.com/aws/aws-sdk-go-v2/credentials v1.19.10/go.mod h1:RnnlFCAlxQCkN2Q379B67USkBMu1PipEEiibzYN5UTE=
 github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.2 h1:sPiRHLVUIIQcoVZTNwqQcdtjkqkPopyYmIX0M5ElRf4=
 github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.2/go.mod h1:ik86P3sgV+Bk7c1tBFCwI3VxMoSEwl4YkRB9xn1s340=
+github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.18 h1:F43zk1vemYIqPAwhjTjYIz0irU2EY7sOb/F5eJ3HuyM=
+github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.18/go.mod h1:w1jdlZXrGKaJcNoL+Nnrj+k5wlpGXqnNrKoP22HvAug=
 github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.2 h1:ZdzDAg075H6stMZtbD2o+PyB933M/f20e9WmCBC17wA=
 github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.2/go.mod h1:eE1IIzXG9sdZCB0pNNpMpsYTLl4YdOQD3njiVN1e/E4=
+github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.18 h1:xCeWVjj0ki0l3nruoyP2slHsGArMxeiiaoPN5QZH6YQ=
+github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.18/go.mod h1:r/eLGuGCBw6l36ZRWiw6PaZwPXb6YOj+i/7MizNl5/k=
 github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.33.0 h1:JzidOz4Hcn2RbP5fvIS1iAP+DcRv5VJtgixbEYDsI5g=
 github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.33.0/go.mod h1:9A4/PJYlWjvjEzzoOLGQjkLt4bYK9fRWi7uz1GSsAcA=
+github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.50.0 h1:TDKR8ACRw7G+GFaQlhoy6biu+8q6ZtSddQCy9avMdMI=
+github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.50.0/go.mod h1:XlhOh5Ax/lesqN4aZCUgj9vVJed5VoXYHHFYGAlJEwU=
 github.com/aws/smithy-go v1.22.5 h1:P9ATCXPMb2mPjYBgueqJNCA5S9UfktsW0tTxi+a7eqw=
 github.com/aws/smithy-go v1.22.5/go.mod h1:t1ufH5HMublsJYulve2RKmHDC15xu1f26kHCp/HgceI=
+github.com/aws/smithy-go v1.24.1 h1:VbyeNfmYkWoxMVpGUAbQumkODcYmfMRfZ8yQiH30SK0=
+github.com/aws/smithy-go v1.24.1/go.mod h1:LEj2LM3rBRQJxPZTB4KuzZkaZYnZPnvgIhb4pu07mx0=
+github.com/aws/smithy-go v1.24.2 h1:FzA3bu/nt/vDvmnkg+R8Xl46gmzEDam6mZ1hzmwXFng=
+github.com/aws/smithy-go v1.24.2/go.mod h1:YE2RhdIuDbA5E5bTdciG9KrW3+TiEONeUWCqxX9i1Fc=
 github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
 github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
 github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8=

+ 4 - 0
relay/channel/api_request.go

@@ -267,6 +267,10 @@ func processHeaderOverride(info *common.RelayInfo, c *gin.Context) (map[string]s
 	return headerOverride, nil
 }
 
+func ResolveHeaderOverride(info *common.RelayInfo, c *gin.Context) (map[string]string, error) {
+	return processHeaderOverride(info, c)
+}
+
 func applyHeaderOverrideToRequest(req *http.Request, headerOverride map[string]string) {
 	if req == nil {
 		return

+ 1 - 0
relay/channel/aws/dto.go

@@ -27,6 +27,7 @@ type AwsClaudeRequest struct {
 	ToolChoice       any                 `json:"tool_choice,omitempty"`
 	Thinking         *dto.Thinking       `json:"thinking,omitempty"`
 	OutputConfig     json.RawMessage     `json:"output_config,omitempty"`
+	//Metadata         json.RawMessage     `json:"metadata,omitempty"`
 }
 
 func formatRequest(requestBody io.Reader, requestHeader http.Header) (*AwsClaudeRequest, error) {

+ 8 - 0
relay/channel/aws/relay-aws.go

@@ -11,6 +11,7 @@ import (
 
 	"github.com/QuantumNous/new-api/common"
 	"github.com/QuantumNous/new-api/dto"
+	"github.com/QuantumNous/new-api/relay/channel"
 	"github.com/QuantumNous/new-api/relay/channel/claude"
 	relaycommon "github.com/QuantumNous/new-api/relay/common"
 	"github.com/QuantumNous/new-api/relay/helper"
@@ -106,6 +107,13 @@ func doAwsClientRequest(c *gin.Context, info *relaycommon.RelayInfo, a *Adaptor,
 	// init empty request.header
 	requestHeader := http.Header{}
 	a.SetupRequestHeader(c, &requestHeader, info)
+	headerOverride, err := channel.ResolveHeaderOverride(info, c)
+	if err != nil {
+		return nil, err
+	}
+	for key, value := range headerOverride {
+		requestHeader.Set(key, value)
+	}
 
 	if isNovaModel(awsModelId) {
 		var novaReq *NovaRequest

+ 55 - 0
relay/channel/aws/relay_aws_test.go

@@ -0,0 +1,55 @@
+package aws
+
+import (
+	"bytes"
+	"net/http"
+	"net/http/httptest"
+	"testing"
+
+	"github.com/QuantumNous/new-api/common"
+	relaycommon "github.com/QuantumNous/new-api/relay/common"
+	"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
+	"github.com/gin-gonic/gin"
+	"github.com/stretchr/testify/require"
+)
+
+func TestDoAwsClientRequest_AppliesRuntimeHeaderOverrideToAnthropicBeta(t *testing.T) {
+	t.Parallel()
+
+	gin.SetMode(gin.TestMode)
+	recorder := httptest.NewRecorder()
+	ctx, _ := gin.CreateTestContext(recorder)
+	ctx.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
+
+	info := &relaycommon.RelayInfo{
+		OriginModelName:           "claude-3-5-sonnet-20240620",
+		IsStream:                  false,
+		UseRuntimeHeadersOverride: true,
+		RuntimeHeadersOverride: map[string]any{
+			"anthropic-beta": "computer-use-2025-01-24",
+		},
+		ChannelMeta: &relaycommon.ChannelMeta{
+			ApiKey:            "access-key|secret-key|us-east-1",
+			UpstreamModelName: "claude-3-5-sonnet-20240620",
+		},
+	}
+
+	requestBody := bytes.NewBufferString(`{"messages":[{"role":"user","content":"hello"}],"max_tokens":128}`)
+	adaptor := &Adaptor{}
+
+	_, err := doAwsClientRequest(ctx, info, adaptor, requestBody)
+	require.NoError(t, err)
+
+	awsReq, ok := adaptor.AwsReq.(*bedrockruntime.InvokeModelInput)
+	require.True(t, ok)
+
+	var payload map[string]any
+	require.NoError(t, common.Unmarshal(awsReq.Body, &payload))
+
+	anthropicBeta, exists := payload["anthropic_beta"]
+	require.True(t, exists)
+
+	values, ok := anthropicBeta.([]any)
+	require.True(t, ok)
+	require.Equal(t, []any{"computer-use-2025-01-24"}, values)
+}

+ 1 - 0
relay/channel/vertex/dto.go

@@ -20,6 +20,7 @@ type VertexAIClaudeRequest struct {
 	ToolChoice       any                 `json:"tool_choice,omitempty"`
 	Thinking         *dto.Thinking       `json:"thinking,omitempty"`
 	OutputConfig     json.RawMessage     `json:"output_config,omitempty"`
+	//Metadata         json.RawMessage     `json:"metadata,omitempty"`
 }
 
 func copyRequest(req *dto.ClaudeRequest, version string) *VertexAIClaudeRequest {

+ 25 - 1
relay/common/override.go

@@ -120,8 +120,18 @@ func ApplyParamOverride(jsonData []byte, paramOverride map[string]interface{}, c
 
 	// 尝试断言为操作格式
 	if operations, ok := tryParseOperations(paramOverride); ok {
+		legacyOverride := buildLegacyParamOverride(paramOverride)
+		workingJSON := jsonData
+		var err error
+		if len(legacyOverride) > 0 {
+			workingJSON, err = applyOperationsLegacy(workingJSON, legacyOverride)
+			if err != nil {
+				return nil, err
+			}
+		}
+
 		// 使用新方法
-		result, err := applyOperations(string(jsonData), operations, conditionContext)
+		result, err := applyOperations(string(workingJSON), operations, conditionContext)
 		return []byte(result), err
 	}
 
@@ -129,6 +139,20 @@ func ApplyParamOverride(jsonData []byte, paramOverride map[string]interface{}, c
 	return applyOperationsLegacy(jsonData, paramOverride)
 }
 
+func buildLegacyParamOverride(paramOverride map[string]interface{}) map[string]interface{} {
+	if len(paramOverride) == 0 {
+		return nil
+	}
+	legacy := make(map[string]interface{}, len(paramOverride))
+	for key, value := range paramOverride {
+		if strings.EqualFold(strings.TrimSpace(key), "operations") {
+			continue
+		}
+		legacy[key] = value
+	}
+	return legacy
+}
+
 func ApplyParamOverrideWithRelayInfo(jsonData []byte, info *RelayInfo) ([]byte, error) {
 	paramOverride := getParamOverrideMap(info)
 	if len(paramOverride) == 0 {

+ 80 - 0
relay/common/override_test.go

@@ -74,6 +74,48 @@ func TestApplyParamOverrideTrimNoop(t *testing.T) {
 	assertJSONEqual(t, `{"model":"gpt-4","temperature":0.7}`, string(out))
 }
 
+func TestApplyParamOverrideMixedLegacyAndOperations(t *testing.T) {
+	input := []byte(`{"model":"openai/gpt-4","temperature":0.7}`)
+	override := map[string]interface{}{
+		"temperature": 0.2,
+		"top_p":       0.95,
+		"operations": []interface{}{
+			map[string]interface{}{
+				"path":  "model",
+				"mode":  "trim_prefix",
+				"value": "openai/",
+			},
+		},
+	}
+
+	out, err := ApplyParamOverride(input, override, nil)
+	if err != nil {
+		t.Fatalf("ApplyParamOverride returned error: %v", err)
+	}
+	assertJSONEqual(t, `{"model":"gpt-4","temperature":0.2,"top_p":0.95}`, string(out))
+}
+
+func TestApplyParamOverrideMixedLegacyAndOperationsConflictPrefersOperations(t *testing.T) {
+	input := []byte(`{"model":"openai/gpt-4","temperature":0.7}`)
+	override := map[string]interface{}{
+		"model":       "legacy-model",
+		"temperature": 0.2,
+		"operations": []interface{}{
+			map[string]interface{}{
+				"path":  "model",
+				"mode":  "set",
+				"value": "op-model",
+			},
+		},
+	}
+
+	out, err := ApplyParamOverride(input, override, nil)
+	if err != nil {
+		t.Fatalf("ApplyParamOverride returned error: %v", err)
+	}
+	assertJSONEqual(t, `{"model":"op-model","temperature":0.2}`, string(out))
+}
+
 func TestApplyParamOverrideTrimRequiresValue(t *testing.T) {
 	// trim_prefix requires value example:
 	// {"operations":[{"path":"model","mode":"trim_prefix"}]}
@@ -1429,6 +1471,44 @@ func TestApplyParamOverrideWithRelayInfoSyncRuntimeHeaders(t *testing.T) {
 	}
 }
 
+func TestApplyParamOverrideWithRelayInfoMixedLegacyAndOperations(t *testing.T) {
+	info := &RelayInfo{
+		RequestHeaders: map[string]string{
+			"Originator": "Codex CLI",
+		},
+		ChannelMeta: &ChannelMeta{
+			ParamOverride: map[string]interface{}{
+				"temperature": 0.2,
+				"operations": []interface{}{
+					map[string]interface{}{
+						"mode":  "pass_headers",
+						"value": []interface{}{"Originator"},
+					},
+				},
+			},
+			HeadersOverride: map[string]interface{}{
+				"X-Static": "legacy-static",
+			},
+		},
+	}
+
+	out, err := ApplyParamOverrideWithRelayInfo([]byte(`{"model":"gpt-5","temperature":0.7}`), info)
+	if err != nil {
+		t.Fatalf("ApplyParamOverrideWithRelayInfo returned error: %v", err)
+	}
+	assertJSONEqual(t, `{"model":"gpt-5","temperature":0.2}`, string(out))
+
+	if !info.UseRuntimeHeadersOverride {
+		t.Fatalf("expected runtime header override to be enabled")
+	}
+	if info.RuntimeHeadersOverride["x-static"] != "legacy-static" {
+		t.Fatalf("expected x-static to be preserved, got: %v", info.RuntimeHeadersOverride["x-static"])
+	}
+	if info.RuntimeHeadersOverride["originator"] != "Codex CLI" {
+		t.Fatalf("expected originator header to be passed, got: %v", info.RuntimeHeadersOverride["originator"])
+	}
+}
+
 func TestApplyParamOverrideWithRelayInfoMoveAndCopyHeaders(t *testing.T) {
 	info := &RelayInfo{
 		ChannelMeta: &ChannelMeta{

+ 1 - 1
relay/helper/valid_request.go

@@ -229,7 +229,7 @@ func GetAndValidOpenAIImageRequest(c *gin.Context, relayMode int) (*dto.ImageReq
 
 func GetAndValidateClaudeRequest(c *gin.Context) (textRequest *dto.ClaudeRequest, err error) {
 	textRequest = &dto.ClaudeRequest{}
-	err = c.ShouldBindJSON(textRequest)
+	err = common.UnmarshalBodyReusable(c, textRequest)
 	if err != nil {
 		return nil, err
 	}

+ 35 - 0
service/channel_affinity.go

@@ -436,11 +436,46 @@ func mergeChannelOverride(base map[string]interface{}, tpl map[string]interface{
 	}
 	out := cloneStringAnyMap(base)
 	for k, v := range tpl {
+		if strings.EqualFold(strings.TrimSpace(k), "operations") {
+			baseOps, hasBaseOps := extractParamOperations(out[k])
+			tplOps, hasTplOps := extractParamOperations(v)
+			if hasTplOps {
+				if hasBaseOps {
+					out[k] = append(tplOps, baseOps...)
+				} else {
+					out[k] = tplOps
+				}
+				continue
+			}
+		}
+		if _, exists := out[k]; exists {
+			continue
+		}
 		out[k] = v
 	}
 	return out
 }
 
+func extractParamOperations(value interface{}) ([]interface{}, bool) {
+	switch ops := value.(type) {
+	case []interface{}:
+		if len(ops) == 0 {
+			return []interface{}{}, true
+		}
+		cloned := make([]interface{}, 0, len(ops))
+		cloned = append(cloned, ops...)
+		return cloned, true
+	case []map[string]interface{}:
+		cloned := make([]interface{}, 0, len(ops))
+		for _, op := range ops {
+			cloned = append(cloned, op)
+		}
+		return cloned, true
+	default:
+		return nil, false
+	}
+}
+
 func appendChannelAffinityTemplateAdminInfo(c *gin.Context, meta channelAffinityMeta) {
 	if c == nil {
 		return

+ 43 - 1
service/channel_affinity_template_test.go

@@ -56,7 +56,7 @@ func TestApplyChannelAffinityOverrideTemplate_MergeTemplate(t *testing.T) {
 
 	merged, applied := ApplyChannelAffinityOverrideTemplate(ctx, base)
 	require.True(t, applied)
-	require.Equal(t, 0.2, merged["temperature"])
+	require.Equal(t, 0.7, merged["temperature"])
 	require.Equal(t, 0.95, merged["top_p"])
 	require.Equal(t, 2000, merged["max_tokens"])
 	require.Equal(t, 0.7, base["temperature"])
@@ -74,6 +74,48 @@ func TestApplyChannelAffinityOverrideTemplate_MergeTemplate(t *testing.T) {
 	require.EqualValues(t, 2, overrideInfo["param_override_keys"])
 }
 
+func TestApplyChannelAffinityOverrideTemplate_MergeOperations(t *testing.T) {
+	ctx := buildChannelAffinityTemplateContextForTest(channelAffinityMeta{
+		RuleName: "rule-with-ops-template",
+		ParamTemplate: map[string]interface{}{
+			"operations": []map[string]interface{}{
+				{
+					"mode":  "pass_headers",
+					"value": []string{"Originator"},
+				},
+			},
+		},
+	})
+	base := map[string]interface{}{
+		"temperature": 0.7,
+		"operations": []map[string]interface{}{
+			{
+				"path":  "model",
+				"mode":  "trim_prefix",
+				"value": "openai/",
+			},
+		},
+	}
+
+	merged, applied := ApplyChannelAffinityOverrideTemplate(ctx, base)
+	require.True(t, applied)
+	require.Equal(t, 0.7, merged["temperature"])
+
+	opsAny, ok := merged["operations"]
+	require.True(t, ok)
+	ops, ok := opsAny.([]interface{})
+	require.True(t, ok)
+	require.Len(t, ops, 2)
+
+	firstOp, ok := ops[0].(map[string]interface{})
+	require.True(t, ok)
+	require.Equal(t, "pass_headers", firstOp["mode"])
+
+	secondOp, ok := ops[1].(map[string]interface{})
+	require.True(t, ok)
+	require.Equal(t, "trim_prefix", secondOp["mode"])
+}
+
 func TestChannelAffinityHitCodexTemplatePassHeadersEffective(t *testing.T) {
 	gin.SetMode(gin.TestMode)
 

+ 11 - 0
web/src/components/table/channels/modals/EditChannelModal.jsx

@@ -759,6 +759,10 @@ const EditChannelModal = (props) => {
     }
   };
 
+  const clearParamOverride = () => {
+    handleInputChange('param_override', '');
+  };
+
   const loadChannel = async () => {
     setLoading(true);
     let res = await API.get(`/api/channel/${channelId}`);
@@ -3356,6 +3360,13 @@ const EditChannelModal = (props) => {
                           >
                             {t('填充旧模板')}
                           </Button>
+                          <Button
+                            size='small'
+                            type='tertiary'
+                            onClick={clearParamOverride}
+                          >
+                            {t('清空')}
+                          </Button>
                         </Space>
                       </div>
                       <Text type='tertiary' size='small'>