Quellcode durchsuchen

Merge pull request #2735 from seefs001/feature/header-throughpass

feat: header passthrough
Seefs vor 1 Monat
Ursprung
Commit
ac8f17c827
1 geänderte Dateien mit 50 neuen und 10 gelöschten Zeilen
  1. 50 10
      relay/channel/api_request.go

+ 50 - 10
relay/channel/api_request.go

@@ -38,9 +38,46 @@ func SetupApiRequestHeader(info *common.RelayInfo, c *gin.Context, req *http.Hea
 	}
 }
 
-// processHeaderOverride 处理请求头覆盖,支持变量替换
-// 支持的变量:{api_key}
-func processHeaderOverride(info *common.RelayInfo) (map[string]string, error) {
+const clientHeaderPlaceholderPrefix = "{client_header:"
+
+func applyHeaderOverridePlaceholders(template string, c *gin.Context, apiKey string) (string, bool, error) {
+	trimmed := strings.TrimSpace(template)
+	if strings.HasPrefix(trimmed, clientHeaderPlaceholderPrefix) {
+		afterPrefix := trimmed[len(clientHeaderPlaceholderPrefix):]
+		end := strings.Index(afterPrefix, "}")
+		if end < 0 || end != len(afterPrefix)-1 {
+			return "", false, fmt.Errorf("client_header placeholder must be the full value: %q", template)
+		}
+
+		name := strings.TrimSpace(afterPrefix[:end])
+		if name == "" {
+			return "", false, fmt.Errorf("client_header placeholder name is empty: %q", template)
+		}
+		if c == nil || c.Request == nil {
+			return "", false, fmt.Errorf("missing request context for client_header placeholder")
+		}
+		clientHeaderValue := c.Request.Header.Get(name)
+		if strings.TrimSpace(clientHeaderValue) == "" {
+			return "", false, nil
+		}
+		// Do not interpolate {api_key} inside client-supplied content.
+		return clientHeaderValue, true, nil
+	}
+
+	if strings.Contains(template, "{api_key}") {
+		template = strings.ReplaceAll(template, "{api_key}", apiKey)
+	}
+	if strings.TrimSpace(template) == "" {
+		return "", false, nil
+	}
+	return template, true, nil
+}
+
+// processHeaderOverride applies channel header overrides, with placeholder substitution.
+// Supported placeholders:
+//   - {api_key}: resolved to the channel API key
+//   - {client_header:<name>}: resolved to the incoming request header value
+func processHeaderOverride(info *common.RelayInfo, c *gin.Context) (map[string]string, error) {
 	headerOverride := make(map[string]string)
 	for k, v := range info.HeadersOverride {
 		str, ok := v.(string)
@@ -48,12 +85,15 @@ func processHeaderOverride(info *common.RelayInfo) (map[string]string, error) {
 			return nil, types.NewError(nil, types.ErrorCodeChannelHeaderOverrideInvalid)
 		}
 
-		// 替换支持的变量
-		if strings.Contains(str, "{api_key}") {
-			str = strings.ReplaceAll(str, "{api_key}", info.ApiKey)
+		value, include, err := applyHeaderOverridePlaceholders(str, c, info.ApiKey)
+		if err != nil {
+			return nil, types.NewError(err, types.ErrorCodeChannelHeaderOverrideInvalid)
+		}
+		if !include {
+			continue
 		}
 
-		headerOverride[k] = str
+		headerOverride[k] = value
 	}
 	return headerOverride, nil
 }
@@ -77,7 +117,7 @@ func DoApiRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBody
 	}
 	// 在 SetupRequestHeader 之后应用 Header Override,确保用户设置优先级最高
 	// 这样可以覆盖默认的 Authorization header 设置
-	headerOverride, err := processHeaderOverride(info)
+	headerOverride, err := processHeaderOverride(info, c)
 	if err != nil {
 		return nil, err
 	}
@@ -112,7 +152,7 @@ func DoFormRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBod
 	}
 	// 在 SetupRequestHeader 之后应用 Header Override,确保用户设置优先级最高
 	// 这样可以覆盖默认的 Authorization header 设置
-	headerOverride, err := processHeaderOverride(info)
+	headerOverride, err := processHeaderOverride(info, c)
 	if err != nil {
 		return nil, err
 	}
@@ -138,7 +178,7 @@ func DoWssRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBody
 	}
 	// 在 SetupRequestHeader 之后应用 Header Override,确保用户设置优先级最高
 	// 这样可以覆盖默认的 Authorization header 设置
-	headerOverride, err := processHeaderOverride(info)
+	headerOverride, err := processHeaderOverride(info, c)
 	if err != nil {
 		return nil, err
 	}