瀏覽代碼

fix pass_headers

Seefs 1 周之前
父節點
當前提交
bb0c663dbe
共有 3 個文件被更改,包括 97 次插入1 次删除
  1. 53 0
      relay/channel/api_request_test.go
  2. 7 1
      relay/common/override.go
  3. 37 0
      relay/common/override_test.go

+ 53 - 0
relay/channel/api_request_test.go

@@ -137,3 +137,56 @@ func TestProcessHeaderOverride_PassthroughSkipsAcceptEncoding(t *testing.T) {
 	_, hasAcceptEncoding := headers["Accept-Encoding"]
 	require.False(t, hasAcceptEncoding)
 }
+
+func TestProcessHeaderOverride_PassHeadersTemplateSetsRuntimeHeaders(t *testing.T) {
+	t.Parallel()
+
+	gin.SetMode(gin.TestMode)
+	recorder := httptest.NewRecorder()
+	ctx, _ := gin.CreateTestContext(recorder)
+	ctx.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
+	ctx.Request.Header.Set("Originator", "Codex CLI")
+	ctx.Request.Header.Set("Session_id", "sess-123")
+
+	info := &relaycommon.RelayInfo{
+		IsChannelTest: false,
+		RequestHeaders: map[string]string{
+			"Originator": "Codex CLI",
+			"Session_id": "sess-123",
+		},
+		ChannelMeta: &relaycommon.ChannelMeta{
+			ParamOverride: map[string]any{
+				"operations": []any{
+					map[string]any{
+						"mode":  "pass_headers",
+						"value": []any{"Originator", "Session_id", "X-Codex-Beta-Features"},
+					},
+				},
+			},
+			HeadersOverride: map[string]any{
+				"X-Static": "legacy-value",
+			},
+		},
+	}
+
+	_, err := relaycommon.ApplyParamOverrideWithRelayInfo([]byte(`{"model":"gpt-4.1"}`), info)
+	require.NoError(t, err)
+	require.True(t, info.UseRuntimeHeadersOverride)
+	require.Equal(t, "Codex CLI", info.RuntimeHeadersOverride["Originator"])
+	require.Equal(t, "sess-123", info.RuntimeHeadersOverride["Session_id"])
+	_, exists := info.RuntimeHeadersOverride["X-Codex-Beta-Features"]
+	require.False(t, exists)
+
+	headers, err := processHeaderOverride(info, ctx)
+	require.NoError(t, err)
+	require.Equal(t, "Codex CLI", headers["Originator"])
+	require.Equal(t, "sess-123", headers["Session_id"])
+	_, exists = headers["X-Codex-Beta-Features"]
+	require.False(t, exists)
+
+	upstreamReq := httptest.NewRequest(http.MethodPost, "https://example.com/v1/responses", nil)
+	applyHeaderOverrideToRequest(upstreamReq, headers)
+	require.Equal(t, "Codex CLI", upstreamReq.Header.Get("Originator"))
+	require.Equal(t, "sess-123", upstreamReq.Header.Get("Session_id"))
+	require.Empty(t, upstreamReq.Header.Get("X-Codex-Beta-Features"))
+}

+ 7 - 1
relay/common/override.go

@@ -24,6 +24,8 @@ const (
 	paramOverrideContextHeaderOverrideNormalized = "header_override_normalized"
 )
 
+var errSourceHeaderNotFound = errors.New("source header does not exist")
+
 type ConditionOperation struct {
 	Path           string      `json:"path"`             // JSON路径
 	Mode           string      `json:"mode"`             // full, prefix, suffix, contains, gt, gte, lt, lte
@@ -501,6 +503,10 @@ func applyOperations(jsonStr string, operations []ParamOperation, conditionConte
 			}
 			for _, headerName := range headerNames {
 				if err = copyHeaderInContext(context, headerName, headerName, op.KeepOrigin); err != nil {
+					if errors.Is(err, errSourceHeaderNotFound) {
+						err = nil
+						continue
+					}
 					break
 				}
 			}
@@ -654,7 +660,7 @@ func copyHeaderInContext(context map[string]interface{}, fromHeader, toHeader st
 	}
 	value, exists := getHeaderValueFromContext(context, fromHeader)
 	if !exists {
-		return fmt.Errorf("source header does not exist: %s", fromHeader)
+		return fmt.Errorf("%w: %s", errSourceHeaderNotFound, fromHeader)
 	}
 	return setHeaderOverrideInContext(context, toHeader, value, keepOrigin)
 }

+ 37 - 0
relay/common/override_test.go

@@ -1060,6 +1060,43 @@ func TestApplyParamOverrideCopyHeaderFromRequestHeaders(t *testing.T) {
 	assertJSONEqual(t, `{"temperature":0.1}`, string(out))
 }
 
+func TestApplyParamOverridePassHeadersSkipsMissingHeaders(t *testing.T) {
+	input := []byte(`{"temperature":0.7}`)
+	override := map[string]interface{}{
+		"operations": []interface{}{
+			map[string]interface{}{
+				"mode":  "pass_headers",
+				"value": []interface{}{"X-Codex-Beta-Features", "Session_id"},
+			},
+		},
+	}
+	ctx := map[string]interface{}{
+		"request_headers_raw": map[string]interface{}{
+			"Session_id": "sess-123",
+		},
+		"request_headers": map[string]interface{}{
+			"session_id": "sess-123",
+		},
+	}
+
+	out, err := ApplyParamOverride(input, override, ctx)
+	if err != nil {
+		t.Fatalf("ApplyParamOverride returned error: %v", err)
+	}
+	assertJSONEqual(t, `{"temperature":0.7}`, string(out))
+
+	headers, ok := ctx["header_override"].(map[string]interface{})
+	if !ok {
+		t.Fatalf("expected header_override context map")
+	}
+	if headers["Session_id"] != "sess-123" {
+		t.Fatalf("expected Session_id to be passed, got: %v", headers["Session_id"])
+	}
+	if _, exists := headers["X-Codex-Beta-Features"]; exists {
+		t.Fatalf("expected missing header to be skipped")
+	}
+}
+
 func TestApplyParamOverrideSyncFieldsHeaderToJSON(t *testing.T) {
 	input := []byte(`{"model":"gpt-4"}`)
 	override := map[string]interface{}{