Jelajahi Sumber

fix: merge runtime and channel header overrides, skip missing source headers

Seefs 1 Minggu lalu
induk
melakukan
305dbce4ad

+ 1 - 4
relay/channel/api_request.go

@@ -173,10 +173,7 @@ func processHeaderOverride(info *common.RelayInfo, c *gin.Context) (map[string]s
 		return headerOverride, nil
 	}
 
-	headerOverrideSource := info.HeadersOverride
-	if info.UseRuntimeHeadersOverride {
-		headerOverrideSource = info.RuntimeHeadersOverride
-	}
+	headerOverrideSource := common.GetEffectiveHeaderOverride(info)
 
 	passAll := false
 	var passthroughRegex []*regexp.Regexp

+ 2 - 3
relay/channel/api_request_test.go

@@ -80,7 +80,7 @@ func TestProcessHeaderOverride_NonTestKeepsClientHeaderPlaceholder(t *testing.T)
 	require.Equal(t, "trace-123", headers["X-Upstream-Trace"])
 }
 
-func TestProcessHeaderOverride_RuntimeOverrideHasPriority(t *testing.T) {
+func TestProcessHeaderOverride_RuntimeOverrideMergesWithChannelOverride(t *testing.T) {
 	t.Parallel()
 
 	gin.SetMode(gin.TestMode)
@@ -107,8 +107,7 @@ func TestProcessHeaderOverride_RuntimeOverrideHasPriority(t *testing.T) {
 	require.NoError(t, err)
 	require.Equal(t, "runtime-value", headers["X-Static"])
 	require.Equal(t, "runtime-only", headers["X-Runtime"])
-	_, ok := headers["X-Legacy"]
-	require.False(t, ok)
+	require.Equal(t, "legacy-only", headers["X-Legacy"])
 }
 
 func TestProcessHeaderOverride_PassthroughSkipsAcceptEncoding(t *testing.T) {

+ 131 - 6
relay/common/override.go

@@ -22,6 +22,7 @@ const (
 	paramOverrideContextRequestHeadersRaw        = "request_headers_raw"
 	paramOverrideContextHeaderOverride           = "header_override"
 	paramOverrideContextHeaderOverrideNormalized = "header_override_normalized"
+	paramOverrideContextHeaderOverrideDeleted    = "header_override_deleted_normalized"
 )
 
 var errSourceHeaderNotFound = errors.New("source header does not exist")
@@ -160,6 +161,84 @@ func getHeaderOverrideMap(info *RelayInfo) map[string]interface{} {
 	return info.ChannelMeta.HeadersOverride
 }
 
+func cloneHeaderOverrideMap(source map[string]interface{}) map[string]interface{} {
+	if len(source) == 0 {
+		return map[string]interface{}{}
+	}
+	target := make(map[string]interface{}, len(source))
+	for key, value := range source {
+		target[key] = value
+	}
+	return target
+}
+
+func setHeaderOverrideEntry(target map[string]interface{}, key string, value interface{}) {
+	key = strings.TrimSpace(key)
+	if key == "" {
+		return
+	}
+	for existing := range target {
+		if strings.EqualFold(strings.TrimSpace(existing), key) {
+			delete(target, existing)
+		}
+	}
+	target[key] = value
+}
+
+func isHeaderDeletedByRuntime(headerName string, deleted map[string]bool) bool {
+	if len(deleted) == 0 {
+		return false
+	}
+	normalized := normalizeHeaderContextKey(headerName)
+	if normalized == "" {
+		return false
+	}
+	return deleted[normalized]
+}
+
+func mergeHeaderOverrideSource(base, runtime map[string]interface{}, deleted map[string]bool) map[string]interface{} {
+	merged := make(map[string]interface{}, len(base)+len(runtime))
+	for key, value := range base {
+		if isHeaderDeletedByRuntime(key, deleted) {
+			continue
+		}
+		setHeaderOverrideEntry(merged, key, value)
+	}
+	for key, value := range runtime {
+		setHeaderOverrideEntry(merged, key, value)
+	}
+	return merged
+}
+
+func cloneDeletedHeaderKeys(source map[string]bool) map[string]bool {
+	if len(source) == 0 {
+		return map[string]bool{}
+	}
+	target := make(map[string]bool, len(source))
+	for key, value := range source {
+		if !value {
+			continue
+		}
+		normalized := normalizeHeaderContextKey(key)
+		if normalized == "" {
+			continue
+		}
+		target[normalized] = true
+	}
+	return target
+}
+
+func GetEffectiveHeaderOverride(info *RelayInfo) map[string]interface{} {
+	if info == nil {
+		return map[string]interface{}{}
+	}
+	base := getHeaderOverrideMap(info)
+	if !info.UseRuntimeHeadersOverride {
+		return cloneHeaderOverrideMap(base)
+	}
+	return mergeHeaderOverrideSource(base, info.RuntimeHeadersOverride, cloneDeletedHeaderKeys(info.RuntimeHeadersDeletedNormalized))
+}
+
 func tryParseOperations(paramOverride map[string]interface{}) ([]ParamOperation, bool) {
 	// 检查是否包含 "operations" 字段
 	if opsValue, exists := paramOverride["operations"]; exists {
@@ -480,6 +559,9 @@ func applyOperations(jsonStr string, operations []ParamOperation, conditionConte
 				targetHeader = strings.TrimSpace(op.Path)
 			}
 			err = copyHeaderInContext(context, sourceHeader, targetHeader, op.KeepOrigin)
+			if errors.Is(err, errSourceHeaderNotFound) {
+				err = nil
+			}
 			if err == nil {
 				contextJSON, err = marshalContextJSON(context)
 			}
@@ -493,6 +575,9 @@ func applyOperations(jsonStr string, operations []ParamOperation, conditionConte
 				targetHeader = strings.TrimSpace(op.Path)
 			}
 			err = moveHeaderInContext(context, sourceHeader, targetHeader, op.KeepOrigin)
+			if errors.Is(err, errSourceHeaderNotFound) {
+				err = nil
+			}
 			if err == nil {
 				contextJSON, err = marshalContextJSON(context)
 			}
@@ -647,8 +732,13 @@ func setHeaderOverrideInContext(context map[string]interface{}, headerName strin
 	rawHeaders := ensureMapKeyInContext(context, paramOverrideContextHeaderOverride)
 	rawHeaders[headerName] = headerValue
 
+	normalizedHeaderName := normalizeHeaderContextKey(headerName)
 	normalizedHeaders := ensureMapKeyInContext(context, paramOverrideContextHeaderOverrideNormalized)
-	normalizedHeaders[normalizeHeaderContextKey(headerName)] = headerValue
+	normalizedHeaders[normalizedHeaderName] = headerValue
+	if normalizedHeaderName != "" {
+		deletedHeaders := ensureMapKeyInContext(context, paramOverrideContextHeaderOverrideDeleted)
+		delete(deletedHeaders, normalizedHeaderName)
+	}
 	return nil
 }
 
@@ -693,7 +783,12 @@ func deleteHeaderOverrideInContext(context map[string]interface{}, headerName st
 	}
 
 	normalizedHeaders := ensureMapKeyInContext(context, paramOverrideContextHeaderOverrideNormalized)
-	delete(normalizedHeaders, normalizeHeaderContextKey(headerName))
+	normalizedHeaderName := normalizeHeaderContextKey(headerName)
+	delete(normalizedHeaders, normalizedHeaderName)
+	if normalizedHeaderName != "" {
+		deletedHeaders := ensureMapKeyInContext(context, paramOverrideContextHeaderOverrideDeleted)
+		deletedHeaders[normalizedHeaderName] = true
+	}
 	return nil
 }
 
@@ -1062,9 +1157,39 @@ func syncRuntimeHeaderOverrideFromContext(info *RelayInfo, context map[string]in
 	info.RuntimeHeadersOverride = lo.SliceToMap(sanitized, func(item lo.Entry[string, interface{}]) (string, interface{}) {
 		return item.Key, item.Value
 	})
+	info.RuntimeHeadersDeletedNormalized = sanitizeRuntimeDeletedHeadersFromContext(context)
 	info.UseRuntimeHeadersOverride = true
 }
 
+func sanitizeRuntimeDeletedHeadersFromContext(context map[string]interface{}) map[string]bool {
+	deletedRaw, exists := context[paramOverrideContextHeaderOverrideDeleted]
+	if !exists {
+		return nil
+	}
+	deletedMap, ok := deletedRaw.(map[string]interface{})
+	if !ok || len(deletedMap) == 0 {
+		return nil
+	}
+	entries := lo.Entries(deletedMap)
+	sanitized := lo.FilterMap(entries, func(item lo.Entry[string, interface{}], _ int) (string, bool) {
+		if keep, ok := item.Value.(bool); ok && !keep {
+			return "", false
+		}
+		normalized := normalizeHeaderContextKey(item.Key)
+		if normalized == "" {
+			return "", false
+		}
+		return normalized, true
+	})
+	if len(sanitized) == 0 {
+		return nil
+	}
+	keys := lo.Uniq(sanitized)
+	return lo.SliceToMap(keys, func(item string) (string, bool) {
+		return item, true
+	})
+}
+
 func moveValue(jsonStr, fromPath, toPath string) (string, error) {
 	sourceValue := gjson.Get(jsonStr, fromPath)
 	if !sourceValue.Exists() {
@@ -1513,13 +1638,13 @@ func BuildParamOverrideContext(info *RelayInfo) map[string]interface{} {
 	ctx[paramOverrideContextRequestHeaders] = buildNormalizedHeaders(info.RequestHeaders)
 	ctx[paramOverrideContextRequestHeadersRaw] = buildRawHeaders(info.RequestHeaders)
 
-	headerOverrideSource := getHeaderOverrideMap(info)
-	if info.UseRuntimeHeadersOverride {
-		headerOverrideSource = info.RuntimeHeadersOverride
-	}
+	headerOverrideSource := GetEffectiveHeaderOverride(info)
 	rawHeaderOverride, normalizedHeaderOverride := buildHeaderOverrideContext(headerOverrideSource)
 	ctx[paramOverrideContextHeaderOverride] = rawHeaderOverride
 	ctx[paramOverrideContextHeaderOverrideNormalized] = normalizedHeaderOverride
+	ctx[paramOverrideContextHeaderOverrideDeleted] = lo.SliceToMap(lo.Keys(cloneDeletedHeaderKeys(info.RuntimeHeadersDeletedNormalized)), func(item string) (string, interface{}) {
+		return item, true
+	})
 
 	ctx["retry_index"] = info.RetryIndex
 	ctx["is_retry"] = info.RetryIndex > 0

+ 99 - 0
relay/common/override_test.go

@@ -1097,6 +1097,76 @@ func TestApplyParamOverridePassHeadersSkipsMissingHeaders(t *testing.T) {
 	}
 }
 
+func TestApplyParamOverrideCopyHeaderSkipsMissingSource(t *testing.T) {
+	input := []byte(`{"temperature":0.7}`)
+	override := map[string]interface{}{
+		"operations": []interface{}{
+			map[string]interface{}{
+				"mode": "copy_header",
+				"from": "X-Missing-Header",
+				"to":   "X-Upstream-Auth",
+			},
+		},
+	}
+	ctx := map[string]interface{}{
+		"request_headers_raw": map[string]interface{}{
+			"Authorization": "Bearer token-123",
+		},
+		"request_headers": map[string]interface{}{
+			"authorization": "Bearer token-123",
+		},
+	}
+
+	out, err := ApplyParamOverride(input, override, ctx)
+	if err != nil {
+		t.Fatalf("ApplyParamOverride returned error: %v", err)
+	}
+	assertJSONEqual(t, `{"temperature":0.7}`, string(out))
+
+	headers, ok := ctx["header_override"].(map[string]interface{})
+	if !ok {
+		return
+	}
+	if _, exists := headers["X-Upstream-Auth"]; exists {
+		t.Fatalf("expected X-Upstream-Auth to be skipped when source header is missing")
+	}
+}
+
+func TestApplyParamOverrideMoveHeaderSkipsMissingSource(t *testing.T) {
+	input := []byte(`{"temperature":0.7}`)
+	override := map[string]interface{}{
+		"operations": []interface{}{
+			map[string]interface{}{
+				"mode": "move_header",
+				"from": "X-Missing-Header",
+				"to":   "X-Upstream-Auth",
+			},
+		},
+	}
+	ctx := map[string]interface{}{
+		"request_headers_raw": map[string]interface{}{
+			"Authorization": "Bearer token-123",
+		},
+		"request_headers": map[string]interface{}{
+			"authorization": "Bearer token-123",
+		},
+	}
+
+	out, err := ApplyParamOverride(input, override, ctx)
+	if err != nil {
+		t.Fatalf("ApplyParamOverride returned error: %v", err)
+	}
+	assertJSONEqual(t, `{"temperature":0.7}`, string(out))
+
+	headers, ok := ctx["header_override"].(map[string]interface{})
+	if !ok {
+		return
+	}
+	if _, exists := headers["X-Upstream-Auth"]; exists {
+		t.Fatalf("expected X-Upstream-Auth to be skipped when source header is missing")
+	}
+}
+
 func TestApplyParamOverrideSyncFieldsHeaderToJSON(t *testing.T) {
 	input := []byte(`{"model":"gpt-4"}`)
 	override := map[string]interface{}{
@@ -1351,6 +1421,35 @@ func TestApplyParamOverrideWithRelayInfoMoveAndCopyHeaders(t *testing.T) {
 	}
 }
 
+func TestGetEffectiveHeaderOverrideMergesRuntimeAndChannelOverrides(t *testing.T) {
+	info := &RelayInfo{
+		UseRuntimeHeadersOverride: true,
+		RuntimeHeadersOverride: map[string]interface{}{
+			"X-Runtime": "runtime-only",
+		},
+		RuntimeHeadersDeletedNormalized: map[string]bool{
+			"x-deleted": true,
+		},
+		ChannelMeta: &ChannelMeta{
+			HeadersOverride: map[string]interface{}{
+				"X-Static":  "static-value",
+				"X-Deleted": "should-not-exist",
+			},
+		},
+	}
+
+	effective := GetEffectiveHeaderOverride(info)
+	if effective["X-Static"] != "static-value" {
+		t.Fatalf("expected X-Static from channel override, got: %v", effective["X-Static"])
+	}
+	if effective["X-Runtime"] != "runtime-only" {
+		t.Fatalf("expected X-Runtime from runtime override, got: %v", effective["X-Runtime"])
+	}
+	if _, exists := effective["X-Deleted"]; exists {
+		t.Fatalf("expected deleted headers to stay deleted in effective override")
+	}
+}
+
 func TestRemoveDisabledFieldsSkipWhenChannelPassThroughEnabled(t *testing.T) {
 	input := `{
 		"service_tier":"flex",

+ 1 - 0
relay/common/relay_info.go

@@ -148,6 +148,7 @@ type RelayInfo struct {
 	RetryIndex                            int
 	LastError                             *types.NewAPIError
 	RuntimeHeadersOverride                map[string]interface{}
+	RuntimeHeadersDeletedNormalized       map[string]bool
 	UseRuntimeHeadersOverride             bool
 
 	PriceData types.PriceData