瀏覽代碼

feat: unify param/header overrides with retry-aware conditions and flexible header operations

Seefs 1 周之前
父節點
當前提交
91b300f522

+ 1 - 1
controller/channel-test.go

@@ -385,7 +385,7 @@ func testChannel(channel *model.Channel, testModel string, endpointType string,
 	//}
 
 	if len(info.ParamOverride) > 0 {
-		jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info))
+		jsonData, err = relaycommon.ApplyParamOverrideWithRelayInfo(jsonData, info)
 		if err != nil {
 			if fixedErr, ok := relaycommon.AsParamOverrideReturnError(err); ok {
 				return testResult{

+ 10 - 2
relay/channel/api_request.go

@@ -168,11 +168,19 @@ func applyHeaderOverridePlaceholders(template string, c *gin.Context, apiKey str
 // Passthrough rules are applied first, then normal overrides are applied, so explicit overrides win.
 func processHeaderOverride(info *common.RelayInfo, c *gin.Context) (map[string]string, error) {
 	headerOverride := make(map[string]string)
+	if info == nil {
+		return headerOverride, nil
+	}
+
+	headerOverrideSource := info.HeadersOverride
+	if info.UseRuntimeHeadersOverride {
+		headerOverrideSource = info.RuntimeHeadersOverride
+	}
 
 	passAll := false
 	var passthroughRegex []*regexp.Regexp
 	if !info.IsChannelTest {
-		for k := range info.HeadersOverride {
+		for k := range headerOverrideSource {
 			key := strings.TrimSpace(k)
 			if key == "" {
 				continue
@@ -232,7 +240,7 @@ func processHeaderOverride(info *common.RelayInfo, c *gin.Context) (map[string]s
 		}
 	}
 
-	for k, v := range info.HeadersOverride {
+	for k, v := range headerOverrideSource {
 		if isHeaderPassthroughRuleKey(k) {
 			continue
 		}

+ 31 - 0
relay/channel/api_request_test.go

@@ -79,3 +79,34 @@ func TestProcessHeaderOverride_NonTestKeepsClientHeaderPlaceholder(t *testing.T)
 	require.NoError(t, err)
 	require.Equal(t, "trace-123", headers["X-Upstream-Trace"])
 }
+
+func TestProcessHeaderOverride_RuntimeOverrideHasPriority(t *testing.T) {
+	t.Parallel()
+
+	gin.SetMode(gin.TestMode)
+	recorder := httptest.NewRecorder()
+	ctx, _ := gin.CreateTestContext(recorder)
+	ctx.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
+
+	info := &relaycommon.RelayInfo{
+		IsChannelTest:             false,
+		UseRuntimeHeadersOverride: true,
+		RuntimeHeadersOverride: map[string]any{
+			"X-Static":  "runtime-value",
+			"X-Runtime": "runtime-only",
+		},
+		ChannelMeta: &relaycommon.ChannelMeta{
+			HeadersOverride: map[string]any{
+				"X-Static": "legacy-value",
+				"X-Legacy": "legacy-only",
+			},
+		},
+	}
+
+	headers, err := processHeaderOverride(info, ctx)
+	require.NoError(t, err)
+	require.Equal(t, "runtime-value", headers["X-Static"])
+	require.Equal(t, "runtime-only", headers["X-Runtime"])
+	_, ok := headers["X-Legacy"]
+	require.False(t, ok)
+}

+ 1 - 2
relay/chat_completions_via_responses.go

@@ -70,7 +70,6 @@ func applySystemPromptIfNeeded(c *gin.Context, info *relaycommon.RelayInfo, requ
 }
 
 func chatCompletionsViaResponses(c *gin.Context, info *relaycommon.RelayInfo, adaptor channel.Adaptor, request *dto.GeneralOpenAIRequest) (*dto.Usage, *types.NewAPIError) {
-	overrideCtx := relaycommon.BuildParamOverrideContext(info)
 	chatJSON, err := common.Marshal(request)
 	if err != nil {
 		return nil, types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
@@ -82,7 +81,7 @@ func chatCompletionsViaResponses(c *gin.Context, info *relaycommon.RelayInfo, ad
 	}
 
 	if len(info.ParamOverride) > 0 {
-		chatJSON, err = relaycommon.ApplyParamOverride(chatJSON, info.ParamOverride, overrideCtx)
+		chatJSON, err = relaycommon.ApplyParamOverrideWithRelayInfo(chatJSON, info)
 		if err != nil {
 			return nil, newAPIErrorFromParamOverride(err)
 		}

+ 1 - 1
relay/claude_handler.go

@@ -153,7 +153,7 @@ func ClaudeHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
 
 		// apply param override
 		if len(info.ParamOverride) > 0 {
-			jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info))
+			jsonData, err = relaycommon.ApplyParamOverrideWithRelayInfo(jsonData, info)
 			if err != nil {
 				return newAPIErrorFromParamOverride(err)
 			}

+ 412 - 72
relay/common/override.go

@@ -10,12 +10,20 @@ import (
 
 	"github.com/QuantumNous/new-api/common"
 	"github.com/QuantumNous/new-api/types"
+	"github.com/samber/lo"
 	"github.com/tidwall/gjson"
 	"github.com/tidwall/sjson"
 )
 
 var negativeIndexRegexp = regexp.MustCompile(`\.(-\d+)`)
 
+const (
+	paramOverrideContextRequestHeaders           = "request_headers"
+	paramOverrideContextRequestHeadersRaw        = "request_headers_raw"
+	paramOverrideContextHeaderOverride           = "header_override"
+	paramOverrideContextHeaderOverrideNormalized = "header_override_normalized"
+)
+
 type ConditionOperation struct {
 	Path           string      `json:"path"`             // JSON路径
 	Mode           string      `json:"mode"`             // full, prefix, suffix, contains, gt, gte, lt, lte
@@ -26,7 +34,7 @@ type ConditionOperation struct {
 
 type ParamOperation struct {
 	Path       string               `json:"path"`
-	Mode       string               `json:"mode"` // delete, set, move, copy, prepend, append, trim_prefix, trim_suffix, ensure_prefix, ensure_suffix, trim_space, to_lower, to_upper, replace, regex_replace, return_error, prune_objects
+	Mode       string               `json:"mode"` // delete, set, move, copy, prepend, append, trim_prefix, trim_suffix, ensure_prefix, ensure_suffix, trim_space, to_lower, to_upper, replace, regex_replace, return_error, prune_objects, set_header, delete_header, copy_header, move_header
 	Value      interface{}          `json:"value"`
 	KeepOrigin bool                 `json:"keep_origin"`
 	From       string               `json:"from,omitempty"`
@@ -121,6 +129,35 @@ func ApplyParamOverride(jsonData []byte, paramOverride map[string]interface{}, c
 	return applyOperationsLegacy(jsonData, paramOverride)
 }
 
+func ApplyParamOverrideWithRelayInfo(jsonData []byte, info *RelayInfo) ([]byte, error) {
+	paramOverride := getParamOverrideMap(info)
+	if len(paramOverride) == 0 {
+		return jsonData, nil
+	}
+
+	overrideCtx := BuildParamOverrideContext(info)
+	result, err := ApplyParamOverride(jsonData, paramOverride, overrideCtx)
+	if err != nil {
+		return nil, err
+	}
+	syncRuntimeHeaderOverrideFromContext(info, overrideCtx)
+	return result, nil
+}
+
+func getParamOverrideMap(info *RelayInfo) map[string]interface{} {
+	if info == nil || info.ChannelMeta == nil {
+		return nil
+	}
+	return info.ChannelMeta.ParamOverride
+}
+
+func getHeaderOverrideMap(info *RelayInfo) map[string]interface{} {
+	if info == nil || info.ChannelMeta == nil {
+		return nil
+	}
+	return info.ChannelMeta.HeadersOverride
+}
+
 func tryParseOperations(paramOverride map[string]interface{}) ([]ParamOperation, bool) {
 	// 检查是否包含 "operations" 字段
 	if opsValue, exists := paramOverride["operations"]; exists {
@@ -161,29 +198,11 @@ func tryParseOperations(paramOverride map[string]interface{}) ([]ParamOperation,
 
 					// 解析条件
 					if conditions, exists := opMap["conditions"]; exists {
-						if condSlice, ok := conditions.([]interface{}); ok {
-							for _, cond := range condSlice {
-								if condMap, ok := cond.(map[string]interface{}); ok {
-									condition := ConditionOperation{}
-									if path, ok := condMap["path"].(string); ok {
-										condition.Path = path
-									}
-									if mode, ok := condMap["mode"].(string); ok {
-										condition.Mode = mode
-									}
-									if value, ok := condMap["value"]; ok {
-										condition.Value = value
-									}
-									if invert, ok := condMap["invert"].(bool); ok {
-										condition.Invert = invert
-									}
-									if passMissingKey, ok := condMap["pass_missing_key"].(bool); ok {
-										condition.PassMissingKey = passMissingKey
-									}
-									operation.Conditions = append(operation.Conditions, condition)
-								}
-							}
+						parsedConditions, err := parseConditionOperations(conditions)
+						if err != nil {
+							return nil, false
 						}
+						operation.Conditions = append(operation.Conditions, parsedConditions...)
 					}
 
 					operations = append(operations, operation)
@@ -212,20 +231,9 @@ func checkConditions(jsonStr, contextJSON string, conditions []ConditionOperatio
 	}
 
 	if strings.ToUpper(logic) == "AND" {
-		for _, result := range results {
-			if !result {
-				return false, nil
-			}
-		}
-		return true, nil
-	} else {
-		for _, result := range results {
-			if result {
-				return true, nil
-			}
-		}
-		return false, nil
+		return lo.EveryBy(results, func(item bool) bool { return item }), nil
 	}
+	return lo.SomeBy(results, func(item bool) bool { return item }), nil
 }
 
 func checkSingleCondition(jsonStr, contextJSON string, condition ConditionOperation) (bool, error) {
@@ -382,13 +390,10 @@ func applyOperationsLegacy(jsonData []byte, paramOverride map[string]interface{}
 }
 
 func applyOperations(jsonStr string, operations []ParamOperation, conditionContext map[string]interface{}) (string, error) {
-	var contextJSON string
-	if conditionContext != nil && len(conditionContext) > 0 {
-		ctxBytes, err := common.Marshal(conditionContext)
-		if err != nil {
-			return "", fmt.Errorf("failed to marshal condition context: %v", err)
-		}
-		contextJSON = string(ctxBytes)
+	context := ensureContextMap(conditionContext)
+	contextJSON, err := marshalContextJSON(context)
+	if err != nil {
+		return "", fmt.Errorf("failed to marshal condition context: %v", err)
 	}
 
 	result := jsonStr
@@ -453,6 +458,42 @@ func applyOperations(jsonStr string, operations []ParamOperation, conditionConte
 			return "", returnErr
 		case "prune_objects":
 			result, err = pruneObjects(result, opPath, contextJSON, op.Value)
+		case "set_header":
+			err = setHeaderOverrideInContext(context, op.Path, op.Value, op.KeepOrigin)
+			if err == nil {
+				contextJSON, err = marshalContextJSON(context)
+			}
+		case "delete_header":
+			err = deleteHeaderOverrideInContext(context, op.Path)
+			if err == nil {
+				contextJSON, err = marshalContextJSON(context)
+			}
+		case "copy_header":
+			sourceHeader := strings.TrimSpace(op.From)
+			targetHeader := strings.TrimSpace(op.To)
+			if sourceHeader == "" {
+				sourceHeader = strings.TrimSpace(op.Path)
+			}
+			if targetHeader == "" {
+				targetHeader = strings.TrimSpace(op.Path)
+			}
+			err = copyHeaderInContext(context, sourceHeader, targetHeader, op.KeepOrigin)
+			if err == nil {
+				contextJSON, err = marshalContextJSON(context)
+			}
+		case "move_header":
+			sourceHeader := strings.TrimSpace(op.From)
+			targetHeader := strings.TrimSpace(op.To)
+			if sourceHeader == "" {
+				sourceHeader = strings.TrimSpace(op.Path)
+			}
+			if targetHeader == "" {
+				targetHeader = strings.TrimSpace(op.Path)
+			}
+			err = moveHeaderInContext(context, sourceHeader, targetHeader, op.KeepOrigin)
+			if err == nil {
+				contextJSON, err = marshalContextJSON(context)
+			}
 		default:
 			return "", fmt.Errorf("unknown operation: %s", op.Mode)
 		}
@@ -543,6 +584,276 @@ func parseOverrideInt(v interface{}) (int, bool) {
 	}
 }
 
+func ensureContextMap(conditionContext map[string]interface{}) map[string]interface{} {
+	if conditionContext != nil {
+		return conditionContext
+	}
+	return make(map[string]interface{})
+}
+
+func marshalContextJSON(context map[string]interface{}) (string, error) {
+	if context == nil || len(context) == 0 {
+		return "", nil
+	}
+	ctxBytes, err := common.Marshal(context)
+	if err != nil {
+		return "", err
+	}
+	return string(ctxBytes), nil
+}
+
+func setHeaderOverrideInContext(context map[string]interface{}, headerName string, value interface{}, keepOrigin bool) error {
+	headerName = strings.TrimSpace(headerName)
+	if headerName == "" {
+		return fmt.Errorf("header name is required")
+	}
+	if keepOrigin {
+		if _, exists := getHeaderValueFromContext(context, headerName); exists {
+			return nil
+		}
+	}
+	if value == nil {
+		return fmt.Errorf("header value is required")
+	}
+	headerValue := strings.TrimSpace(fmt.Sprintf("%v", value))
+	if headerValue == "" {
+		return fmt.Errorf("header value is required")
+	}
+
+	rawHeaders := ensureMapKeyInContext(context, paramOverrideContextHeaderOverride)
+	rawHeaders[headerName] = headerValue
+
+	normalizedHeaders := ensureMapKeyInContext(context, paramOverrideContextHeaderOverrideNormalized)
+	normalizedHeaders[normalizeHeaderContextKey(headerName)] = headerValue
+	return nil
+}
+
+func copyHeaderInContext(context map[string]interface{}, fromHeader, toHeader string, keepOrigin bool) error {
+	fromHeader = strings.TrimSpace(fromHeader)
+	toHeader = strings.TrimSpace(toHeader)
+	if fromHeader == "" || toHeader == "" {
+		return fmt.Errorf("copy_header from/to is required")
+	}
+	value, exists := getHeaderValueFromContext(context, fromHeader)
+	if !exists {
+		return fmt.Errorf("source header does not exist: %s", fromHeader)
+	}
+	return setHeaderOverrideInContext(context, toHeader, value, keepOrigin)
+}
+
+func moveHeaderInContext(context map[string]interface{}, fromHeader, toHeader string, keepOrigin bool) error {
+	fromHeader = strings.TrimSpace(fromHeader)
+	toHeader = strings.TrimSpace(toHeader)
+	if fromHeader == "" || toHeader == "" {
+		return fmt.Errorf("move_header from/to is required")
+	}
+	if err := copyHeaderInContext(context, fromHeader, toHeader, keepOrigin); err != nil {
+		return err
+	}
+	if strings.EqualFold(fromHeader, toHeader) {
+		return nil
+	}
+	return deleteHeaderOverrideInContext(context, fromHeader)
+}
+
+func deleteHeaderOverrideInContext(context map[string]interface{}, headerName string) error {
+	headerName = strings.TrimSpace(headerName)
+	if headerName == "" {
+		return fmt.Errorf("header name is required")
+	}
+	rawHeaders := ensureMapKeyInContext(context, paramOverrideContextHeaderOverride)
+	for key := range rawHeaders {
+		if strings.EqualFold(strings.TrimSpace(key), headerName) {
+			delete(rawHeaders, key)
+		}
+	}
+
+	normalizedHeaders := ensureMapKeyInContext(context, paramOverrideContextHeaderOverrideNormalized)
+	delete(normalizedHeaders, normalizeHeaderContextKey(headerName))
+	return nil
+}
+
+func ensureMapKeyInContext(context map[string]interface{}, key string) map[string]interface{} {
+	if context == nil {
+		return map[string]interface{}{}
+	}
+	if existing, ok := context[key]; ok {
+		if mapVal, ok := existing.(map[string]interface{}); ok {
+			return mapVal
+		}
+	}
+	result := make(map[string]interface{})
+	context[key] = result
+	return result
+}
+
+func getHeaderValueFromContext(context map[string]interface{}, headerName string) (string, bool) {
+	headerName = strings.TrimSpace(headerName)
+	if headerName == "" {
+		return "", false
+	}
+	if value, ok := findHeaderValueInMap(ensureMapKeyInContext(context, paramOverrideContextHeaderOverride), headerName); ok {
+		return value, true
+	}
+	if value, ok := findHeaderValueInMap(ensureMapKeyInContext(context, paramOverrideContextRequestHeadersRaw), headerName); ok {
+		return value, true
+	}
+
+	normalizedName := normalizeHeaderContextKey(headerName)
+	if normalizedName == "" {
+		return "", false
+	}
+	if value, ok := findHeaderValueInMap(ensureMapKeyInContext(context, paramOverrideContextHeaderOverrideNormalized), normalizedName); ok {
+		return value, true
+	}
+	if value, ok := findHeaderValueInMap(ensureMapKeyInContext(context, paramOverrideContextRequestHeaders), normalizedName); ok {
+		return value, true
+	}
+	return "", false
+}
+
+func findHeaderValueInMap(source map[string]interface{}, key string) (string, bool) {
+	if len(source) == 0 {
+		return "", false
+	}
+	entries := lo.Entries(source)
+	entry, ok := lo.Find(entries, func(item lo.Entry[string, interface{}]) bool {
+		return strings.EqualFold(strings.TrimSpace(item.Key), key)
+	})
+	if !ok {
+		return "", false
+	}
+	value := strings.TrimSpace(fmt.Sprintf("%v", entry.Value))
+	if value == "" {
+		return "", false
+	}
+	return value, true
+}
+
+func normalizeHeaderContextKey(key string) string {
+	key = strings.TrimSpace(strings.ToLower(key))
+	if key == "" {
+		return ""
+	}
+	var b strings.Builder
+	b.Grow(len(key))
+	previousUnderscore := false
+	for _, r := range key {
+		switch {
+		case r >= 'a' && r <= 'z':
+			b.WriteRune(r)
+			previousUnderscore = false
+		case r >= '0' && r <= '9':
+			b.WriteRune(r)
+			previousUnderscore = false
+		default:
+			if !previousUnderscore {
+				b.WriteByte('_')
+				previousUnderscore = true
+			}
+		}
+	}
+	result := strings.Trim(b.String(), "_")
+	return result
+}
+
+func buildNormalizedHeaders(headers map[string]string) map[string]interface{} {
+	if len(headers) == 0 {
+		return map[string]interface{}{}
+	}
+	entries := lo.Entries(headers)
+	normalizedEntries := lo.FilterMap(entries, func(item lo.Entry[string, string], _ int) (lo.Entry[string, string], bool) {
+		normalized := normalizeHeaderContextKey(item.Key)
+		value := strings.TrimSpace(item.Value)
+		if normalized == "" || value == "" {
+			return lo.Entry[string, string]{}, false
+		}
+		return lo.Entry[string, string]{Key: normalized, Value: value}, true
+	})
+	return lo.SliceToMap(normalizedEntries, func(item lo.Entry[string, string]) (string, interface{}) {
+		return item.Key, item.Value
+	})
+}
+
+func buildRawHeaders(headers map[string]string) map[string]interface{} {
+	if len(headers) == 0 {
+		return map[string]interface{}{}
+	}
+	entries := lo.Entries(headers)
+	rawEntries := lo.FilterMap(entries, func(item lo.Entry[string, string], _ int) (lo.Entry[string, string], bool) {
+		key := strings.TrimSpace(item.Key)
+		value := strings.TrimSpace(item.Value)
+		if key == "" || value == "" {
+			return lo.Entry[string, string]{}, false
+		}
+		return lo.Entry[string, string]{Key: key, Value: value}, true
+	})
+	return lo.SliceToMap(rawEntries, func(item lo.Entry[string, string]) (string, interface{}) {
+		return item.Key, item.Value
+	})
+}
+
+func buildHeaderOverrideContext(headers map[string]interface{}) (map[string]interface{}, map[string]interface{}) {
+	if len(headers) == 0 {
+		return map[string]interface{}{}, map[string]interface{}{}
+	}
+	entries := lo.Entries(headers)
+	rawEntries := lo.FilterMap(entries, func(item lo.Entry[string, interface{}], _ int) (lo.Entry[string, string], bool) {
+		key := strings.TrimSpace(item.Key)
+		value := strings.TrimSpace(fmt.Sprintf("%v", item.Value))
+		if key == "" || value == "" {
+			return lo.Entry[string, string]{}, false
+		}
+		return lo.Entry[string, string]{Key: key, Value: value}, true
+	})
+
+	raw := lo.SliceToMap(rawEntries, func(item lo.Entry[string, string]) (string, interface{}) {
+		return item.Key, item.Value
+	})
+	normalizedEntries := lo.FilterMap(rawEntries, func(item lo.Entry[string, string], _ int) (lo.Entry[string, string], bool) {
+		normalized := normalizeHeaderContextKey(item.Key)
+		if normalized == "" {
+			return lo.Entry[string, string]{}, false
+		}
+		return lo.Entry[string, string]{Key: normalized, Value: item.Value}, true
+	})
+	normalized := lo.SliceToMap(normalizedEntries, func(item lo.Entry[string, string]) (string, interface{}) {
+		return item.Key, item.Value
+	})
+	return raw, normalized
+}
+
+func syncRuntimeHeaderOverrideFromContext(info *RelayInfo, context map[string]interface{}) {
+	if info == nil || context == nil {
+		return
+	}
+	raw, exists := context[paramOverrideContextHeaderOverride]
+	if !exists {
+		return
+	}
+	rawMap, ok := raw.(map[string]interface{})
+	if !ok {
+		return
+	}
+
+	entries := lo.Entries(rawMap)
+	sanitized := lo.FilterMap(entries, func(item lo.Entry[string, interface{}], _ int) (lo.Entry[string, interface{}], bool) {
+		key := strings.TrimSpace(item.Key)
+		if key == "" {
+			return lo.Entry[string, interface{}]{}, false
+		}
+		value := strings.TrimSpace(fmt.Sprintf("%v", item.Value))
+		if value == "" {
+			return lo.Entry[string, interface{}]{}, false
+		}
+		return lo.Entry[string, interface{}]{Key: key, Value: value}, true
+	})
+	info.RuntimeHeadersOverride = lo.SliceToMap(sanitized, func(item lo.Entry[string, interface{}]) (string, interface{}) {
+		return item.Key, item.Value
+	})
+	info.UseRuntimeHeadersOverride = true
+}
+
 func moveValue(jsonStr, fromPath, toPath string) (string, error) {
 	sourceValue := gjson.Get(jsonStr, fromPath)
 	if !sourceValue.Exists() {
@@ -824,38 +1135,56 @@ func parsePruneObjectsOptions(value interface{}) (pruneObjectsOptions, error) {
 }
 
 func parseConditionOperations(raw interface{}) ([]ConditionOperation, error) {
-	items, ok := raw.([]interface{})
-	if !ok {
-		return nil, fmt.Errorf("conditions must be an array")
-	}
-
-	result := make([]ConditionOperation, 0, len(items))
-	for _, item := range items {
-		itemMap, ok := item.(map[string]interface{})
-		if !ok {
-			return nil, fmt.Errorf("condition must be object")
-		}
-		path, _ := itemMap["path"].(string)
-		mode, _ := itemMap["mode"].(string)
-		if strings.TrimSpace(path) == "" || strings.TrimSpace(mode) == "" {
-			return nil, fmt.Errorf("condition path/mode is required")
-		}
-		condition := ConditionOperation{
-			Path: path,
-			Mode: mode,
-		}
-		if value, exists := itemMap["value"]; exists {
-			condition.Value = value
-		}
-		if invert, ok := itemMap["invert"].(bool); ok {
-			condition.Invert = invert
+	switch typed := raw.(type) {
+	case map[string]interface{}:
+		entries := lo.Entries(typed)
+		conditions := lo.FilterMap(entries, func(item lo.Entry[string, interface{}], _ int) (ConditionOperation, bool) {
+			path := strings.TrimSpace(item.Key)
+			if path == "" {
+				return ConditionOperation{}, false
+			}
+			return ConditionOperation{
+				Path:  path,
+				Mode:  "full",
+				Value: item.Value,
+			}, true
+		})
+		if len(conditions) == 0 {
+			return nil, fmt.Errorf("conditions object must contain at least one key")
 		}
-		if passMissingKey, ok := itemMap["pass_missing_key"].(bool); ok {
-			condition.PassMissingKey = passMissingKey
+		return conditions, nil
+	case []interface{}:
+		items := typed
+		result := make([]ConditionOperation, 0, len(items))
+		for _, item := range items {
+			itemMap, ok := item.(map[string]interface{})
+			if !ok {
+				return nil, fmt.Errorf("condition must be object")
+			}
+			path, _ := itemMap["path"].(string)
+			mode, _ := itemMap["mode"].(string)
+			if strings.TrimSpace(path) == "" || strings.TrimSpace(mode) == "" {
+				return nil, fmt.Errorf("condition path/mode is required")
+			}
+			condition := ConditionOperation{
+				Path: path,
+				Mode: mode,
+			}
+			if value, exists := itemMap["value"]; exists {
+				condition.Value = value
+			}
+			if invert, ok := itemMap["invert"].(bool); ok {
+				condition.Invert = invert
+			}
+			if passMissingKey, ok := itemMap["pass_missing_key"].(bool); ok {
+				condition.PassMissingKey = passMissingKey
+			}
+			result = append(result, condition)
 		}
-		result = append(result, condition)
+		return result, nil
+	default:
+		return nil, fmt.Errorf("conditions must be an array or object")
 	}
-	return result, nil
 }
 
 func pruneObjectsNode(node interface{}, options pruneObjectsOptions, contextJSON string, isRoot bool) (interface{}, bool, error) {
@@ -970,6 +1299,17 @@ func BuildParamOverrideContext(info *RelayInfo) map[string]interface{} {
 		}
 	}
 
+	ctx[paramOverrideContextRequestHeaders] = buildNormalizedHeaders(info.RequestHeaders)
+	ctx[paramOverrideContextRequestHeadersRaw] = buildRawHeaders(info.RequestHeaders)
+
+	headerOverrideSource := getHeaderOverrideMap(info)
+	if info.UseRuntimeHeadersOverride {
+		headerOverrideSource = info.RuntimeHeadersOverride
+	}
+	rawHeaderOverride, normalizedHeaderOverride := buildHeaderOverrideContext(headerOverrideSource)
+	ctx[paramOverrideContextHeaderOverride] = rawHeaderOverride
+	ctx[paramOverrideContextHeaderOverrideNormalized] = normalizedHeaderOverride
+
 	ctx["retry_index"] = info.RetryIndex
 	ctx["is_retry"] = info.RetryIndex > 0
 	ctx["retry"] = map[string]interface{}{

+ 248 - 0
relay/common/override_test.go

@@ -956,6 +956,254 @@ func TestApplyParamOverrideConditionFromRetryAndLastErrorContext(t *testing.T) {
 	assertJSONEqual(t, `{"temperature":0.1}`, string(out))
 }
 
+func TestApplyParamOverrideConditionFromRequestHeaders(t *testing.T) {
+	input := []byte(`{"temperature":0.7}`)
+	override := map[string]interface{}{
+		"operations": []interface{}{
+			map[string]interface{}{
+				"path":  "temperature",
+				"mode":  "set",
+				"value": 0.1,
+				"conditions": []interface{}{
+					map[string]interface{}{
+						"path":  "request_headers.authorization",
+						"mode":  "contains",
+						"value": "Bearer ",
+					},
+				},
+			},
+		},
+	}
+	ctx := map[string]interface{}{
+		"request_headers": map[string]interface{}{
+			"authorization": "Bearer token-123",
+		},
+	}
+
+	out, err := ApplyParamOverride(input, override, ctx)
+	if err != nil {
+		t.Fatalf("ApplyParamOverride returned error: %v", err)
+	}
+	assertJSONEqual(t, `{"temperature":0.1}`, string(out))
+}
+
+func TestApplyParamOverrideSetHeaderAndUseInLaterCondition(t *testing.T) {
+	input := []byte(`{"temperature":0.7}`)
+	override := map[string]interface{}{
+		"operations": []interface{}{
+			map[string]interface{}{
+				"mode":  "set_header",
+				"path":  "X-Debug-Mode",
+				"value": "enabled",
+			},
+			map[string]interface{}{
+				"path":  "temperature",
+				"mode":  "set",
+				"value": 0.1,
+				"conditions": []interface{}{
+					map[string]interface{}{
+						"path":  "header_override_normalized.x_debug_mode",
+						"mode":  "full",
+						"value": "enabled",
+					},
+				},
+			},
+		},
+	}
+
+	out, err := ApplyParamOverride(input, override, nil)
+	if err != nil {
+		t.Fatalf("ApplyParamOverride returned error: %v", err)
+	}
+	assertJSONEqual(t, `{"temperature":0.1}`, string(out))
+}
+
+func TestApplyParamOverrideCopyHeaderFromRequestHeaders(t *testing.T) {
+	input := []byte(`{"temperature":0.7}`)
+	override := map[string]interface{}{
+		"operations": []interface{}{
+			map[string]interface{}{
+				"mode": "copy_header",
+				"from": "Authorization",
+				"to":   "X-Upstream-Auth",
+			},
+			map[string]interface{}{
+				"path":  "temperature",
+				"mode":  "set",
+				"value": 0.1,
+				"conditions": []interface{}{
+					map[string]interface{}{
+						"path":  "header_override_normalized.x_upstream_auth",
+						"mode":  "contains",
+						"value": "Bearer ",
+					},
+				},
+			},
+		},
+	}
+	ctx := map[string]interface{}{
+		"request_headers_raw": map[string]interface{}{
+			"Authorization": "Bearer token-123",
+		},
+		"request_headers": map[string]interface{}{
+			"authorization": "Bearer token-123",
+		},
+	}
+
+	out, err := ApplyParamOverride(input, override, ctx)
+	if err != nil {
+		t.Fatalf("ApplyParamOverride returned error: %v", err)
+	}
+	assertJSONEqual(t, `{"temperature":0.1}`, string(out))
+}
+
+func TestApplyParamOverrideSetHeaderKeepOrigin(t *testing.T) {
+	input := []byte(`{"temperature":0.7}`)
+	override := map[string]interface{}{
+		"operations": []interface{}{
+			map[string]interface{}{
+				"mode":        "set_header",
+				"path":        "X-Feature-Flag",
+				"value":       "new-value",
+				"keep_origin": true,
+			},
+		},
+	}
+	ctx := map[string]interface{}{
+		"header_override": map[string]interface{}{
+			"X-Feature-Flag": "legacy-value",
+		},
+		"header_override_normalized": map[string]interface{}{
+			"x_feature_flag": "legacy-value",
+		},
+	}
+
+	_, err := ApplyParamOverride(input, override, ctx)
+	if err != nil {
+		t.Fatalf("ApplyParamOverride returned error: %v", err)
+	}
+	headers, ok := ctx["header_override"].(map[string]interface{})
+	if !ok {
+		t.Fatalf("expected header_override context map")
+	}
+	if headers["X-Feature-Flag"] != "legacy-value" {
+		t.Fatalf("expected keep_origin to preserve old value, got: %v", headers["X-Feature-Flag"])
+	}
+}
+
+func TestApplyParamOverrideConditionsObjectShorthand(t *testing.T) {
+	input := []byte(`{"temperature":0.7}`)
+	override := map[string]interface{}{
+		"operations": []interface{}{
+			map[string]interface{}{
+				"path":  "temperature",
+				"mode":  "set",
+				"value": 0.1,
+				"logic": "AND",
+				"conditions": map[string]interface{}{
+					"is_retry":               true,
+					"last_error.status_code": 400.0,
+				},
+			},
+		},
+	}
+	ctx := map[string]interface{}{
+		"is_retry": true,
+		"last_error": map[string]interface{}{
+			"status_code": 400.0,
+		},
+	}
+
+	out, err := ApplyParamOverride(input, override, ctx)
+	if err != nil {
+		t.Fatalf("ApplyParamOverride returned error: %v", err)
+	}
+	assertJSONEqual(t, `{"temperature":0.1}`, string(out))
+}
+
+func TestApplyParamOverrideWithRelayInfoSyncRuntimeHeaders(t *testing.T) {
+	info := &RelayInfo{
+		ChannelMeta: &ChannelMeta{
+			ParamOverride: map[string]interface{}{
+				"operations": []interface{}{
+					map[string]interface{}{
+						"mode":  "set_header",
+						"path":  "X-Injected-By-Param-Override",
+						"value": "enabled",
+					},
+					map[string]interface{}{
+						"mode": "delete_header",
+						"path": "X-Delete-Me",
+					},
+				},
+			},
+			HeadersOverride: map[string]interface{}{
+				"X-Delete-Me": "legacy",
+				"X-Keep-Me":   "keep",
+			},
+		},
+	}
+
+	input := []byte(`{"temperature":0.7}`)
+	out, err := ApplyParamOverrideWithRelayInfo(input, info)
+	if err != nil {
+		t.Fatalf("ApplyParamOverrideWithRelayInfo returned error: %v", err)
+	}
+	assertJSONEqual(t, `{"temperature":0.7}`, string(out))
+
+	if !info.UseRuntimeHeadersOverride {
+		t.Fatalf("expected runtime header override to be enabled")
+	}
+	if info.RuntimeHeadersOverride["X-Keep-Me"] != "keep" {
+		t.Fatalf("expected X-Keep-Me header to be preserved, got: %v", info.RuntimeHeadersOverride["X-Keep-Me"])
+	}
+	if info.RuntimeHeadersOverride["X-Injected-By-Param-Override"] != "enabled" {
+		t.Fatalf("expected X-Injected-By-Param-Override header to be set, got: %v", info.RuntimeHeadersOverride["X-Injected-By-Param-Override"])
+	}
+	if _, exists := info.RuntimeHeadersOverride["X-Delete-Me"]; exists {
+		t.Fatalf("expected X-Delete-Me header to be deleted")
+	}
+}
+
+func TestApplyParamOverrideWithRelayInfoMoveAndCopyHeaders(t *testing.T) {
+	info := &RelayInfo{
+		ChannelMeta: &ChannelMeta{
+			ParamOverride: map[string]interface{}{
+				"operations": []interface{}{
+					map[string]interface{}{
+						"mode": "move_header",
+						"from": "X-Legacy-Trace",
+						"to":   "X-Trace",
+					},
+					map[string]interface{}{
+						"mode": "copy_header",
+						"from": "X-Trace",
+						"to":   "X-Trace-Backup",
+					},
+				},
+			},
+			HeadersOverride: map[string]interface{}{
+				"X-Legacy-Trace": "trace-123",
+			},
+		},
+	}
+
+	input := []byte(`{"temperature":0.7}`)
+	_, err := ApplyParamOverrideWithRelayInfo(input, info)
+	if err != nil {
+		t.Fatalf("ApplyParamOverrideWithRelayInfo returned error: %v", err)
+	}
+	if _, exists := info.RuntimeHeadersOverride["X-Legacy-Trace"]; exists {
+		t.Fatalf("expected source header to be removed after move")
+	}
+	if info.RuntimeHeadersOverride["X-Trace"] != "trace-123" {
+		t.Fatalf("expected X-Trace to be set, got: %v", info.RuntimeHeadersOverride["X-Trace"])
+	}
+	if info.RuntimeHeadersOverride["X-Trace-Backup"] != "trace-123" {
+		t.Fatalf("expected X-Trace-Backup to be copied, got: %v", info.RuntimeHeadersOverride["X-Trace-Backup"])
+	}
+}
+
 func assertJSONEqual(t *testing.T, want, got string) {
 	t.Helper()
 

+ 25 - 0
relay/common/relay_info.go

@@ -101,6 +101,7 @@ type RelayInfo struct {
 	RelayMode              int
 	OriginModelName        string
 	RequestURLPath         string
+	RequestHeaders         map[string]string
 	ShouldIncludeUsage     bool
 	DisablePing            bool // 是否禁止向下游发送自定义 Ping
 	ClientWs               *websocket.Conn
@@ -142,6 +143,8 @@ type RelayInfo struct {
 	IsChannelTest                         bool // channel test request
 	RetryIndex                            int
 	LastError                             *types.NewAPIError
+	RuntimeHeadersOverride                map[string]interface{}
+	UseRuntimeHeadersOverride             bool
 
 	PriceData types.PriceData
 
@@ -458,6 +461,7 @@ func genBaseRelayInfo(c *gin.Context, request dto.Request) *RelayInfo {
 		isFirstResponse: true,
 		RelayMode:       relayconstant.Path2RelayMode(c.Request.URL.Path),
 		RequestURLPath:  c.Request.URL.String(),
+		RequestHeaders:  cloneRequestHeaders(c),
 		IsStream:        isStream,
 
 		StartTime:         startTime,
@@ -490,6 +494,27 @@ func genBaseRelayInfo(c *gin.Context, request dto.Request) *RelayInfo {
 	return info
 }
 
+func cloneRequestHeaders(c *gin.Context) map[string]string {
+	if c == nil || c.Request == nil {
+		return nil
+	}
+	if len(c.Request.Header) == 0 {
+		return nil
+	}
+	headers := make(map[string]string, len(c.Request.Header))
+	for key := range c.Request.Header {
+		value := strings.TrimSpace(c.Request.Header.Get(key))
+		if value == "" {
+			continue
+		}
+		headers[key] = value
+	}
+	if len(headers) == 0 {
+		return nil
+	}
+	return headers
+}
+
 func GenRelayInfo(c *gin.Context, relayFormat types.RelayFormat, request dto.Request, ws *websocket.Conn) (*RelayInfo, error) {
 	var info *RelayInfo
 	var err error

+ 1 - 1
relay/compatible_handler.go

@@ -172,7 +172,7 @@ func TextHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types
 
 		// apply param override
 		if len(info.ParamOverride) > 0 {
-			jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info))
+			jsonData, err = relaycommon.ApplyParamOverrideWithRelayInfo(jsonData, info)
 			if err != nil {
 				return newAPIErrorFromParamOverride(err)
 			}

+ 1 - 1
relay/embedding_handler.go

@@ -51,7 +51,7 @@ func EmbeddingHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *
 	}
 
 	if len(info.ParamOverride) > 0 {
-		jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info))
+		jsonData, err = relaycommon.ApplyParamOverrideWithRelayInfo(jsonData, info)
 		if err != nil {
 			return newAPIErrorFromParamOverride(err)
 		}

+ 2 - 2
relay/gemini_handler.go

@@ -157,7 +157,7 @@ func GeminiHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
 
 		// apply param override
 		if len(info.ParamOverride) > 0 {
-			jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info))
+			jsonData, err = relaycommon.ApplyParamOverrideWithRelayInfo(jsonData, info)
 			if err != nil {
 				return newAPIErrorFromParamOverride(err)
 			}
@@ -257,7 +257,7 @@ func GeminiEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo) (newAPI
 
 	// apply param override
 	if len(info.ParamOverride) > 0 {
-		jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info))
+		jsonData, err = relaycommon.ApplyParamOverrideWithRelayInfo(jsonData, info)
 		if err != nil {
 			return newAPIErrorFromParamOverride(err)
 		}

+ 1 - 1
relay/image_handler.go

@@ -70,7 +70,7 @@ func ImageHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *type
 
 			// apply param override
 			if len(info.ParamOverride) > 0 {
-				jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info))
+				jsonData, err = relaycommon.ApplyParamOverrideWithRelayInfo(jsonData, info)
 				if err != nil {
 					return newAPIErrorFromParamOverride(err)
 				}

+ 1 - 1
relay/rerank_handler.go

@@ -61,7 +61,7 @@ func RerankHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
 
 		// apply param override
 		if len(info.ParamOverride) > 0 {
-			jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info))
+			jsonData, err = relaycommon.ApplyParamOverrideWithRelayInfo(jsonData, info)
 			if err != nil {
 				return newAPIErrorFromParamOverride(err)
 			}

+ 1 - 1
relay/responses_handler.go

@@ -96,7 +96,7 @@ func ResponsesHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *
 
 		// apply param override
 		if len(info.ParamOverride) > 0 {
-			jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info))
+			jsonData, err = relaycommon.ApplyParamOverrideWithRelayInfo(jsonData, info)
 			if err != nil {
 				return newAPIErrorFromParamOverride(err)
 			}