|
|
@@ -6,6 +6,7 @@ import (
|
|
|
"fmt"
|
|
|
"io"
|
|
|
"net/http"
|
|
|
+ "regexp"
|
|
|
"strings"
|
|
|
"sync"
|
|
|
"time"
|
|
|
@@ -40,6 +41,86 @@ func SetupApiRequestHeader(info *common.RelayInfo, c *gin.Context, req *http.Hea
|
|
|
|
|
|
const clientHeaderPlaceholderPrefix = "{client_header:"
|
|
|
|
|
|
+const (
|
|
|
+ headerPassthroughAllKey = "*"
|
|
|
+ headerPassthroughRegexPrefix = "re:"
|
|
|
+ headerPassthroughRegexPrefixV2 = "regex:"
|
|
|
+)
|
|
|
+
|
|
|
+var passthroughSkipHeaderNamesLower = map[string]struct{}{
|
|
|
+ // RFC 7230 hop-by-hop headers.
|
|
|
+ "connection": {},
|
|
|
+ "keep-alive": {},
|
|
|
+ "proxy-authenticate": {},
|
|
|
+ "proxy-authorization": {},
|
|
|
+ "te": {},
|
|
|
+ "trailer": {},
|
|
|
+ "transfer-encoding": {},
|
|
|
+ "upgrade": {},
|
|
|
+
|
|
|
+ // Additional headers that should not be forwarded by name-matching passthrough rules.
|
|
|
+ "host": {},
|
|
|
+ "content-length": {},
|
|
|
+
|
|
|
+ // Do not passthrough credentials by wildcard/regex.
|
|
|
+ "authorization": {},
|
|
|
+ "x-api-key": {},
|
|
|
+ "x-goog-api-key": {},
|
|
|
+
|
|
|
+ // WebSocket handshake headers are generated by the client/dialer.
|
|
|
+ "sec-websocket-key": {},
|
|
|
+ "sec-websocket-version": {},
|
|
|
+ "sec-websocket-extensions": {},
|
|
|
+}
|
|
|
+
|
|
|
+var headerPassthroughRegexCache sync.Map // map[string]*regexp.Regexp
|
|
|
+
|
|
|
+func getHeaderPassthroughRegex(pattern string) (*regexp.Regexp, error) {
|
|
|
+ pattern = strings.TrimSpace(pattern)
|
|
|
+ if pattern == "" {
|
|
|
+ return nil, errors.New("empty regex pattern")
|
|
|
+ }
|
|
|
+ if v, ok := headerPassthroughRegexCache.Load(pattern); ok {
|
|
|
+ if re, ok := v.(*regexp.Regexp); ok {
|
|
|
+ return re, nil
|
|
|
+ }
|
|
|
+ headerPassthroughRegexCache.Delete(pattern)
|
|
|
+ }
|
|
|
+ compiled, err := regexp.Compile(pattern)
|
|
|
+ if err != nil {
|
|
|
+ return nil, err
|
|
|
+ }
|
|
|
+ actual, _ := headerPassthroughRegexCache.LoadOrStore(pattern, compiled)
|
|
|
+ if re, ok := actual.(*regexp.Regexp); ok {
|
|
|
+ return re, nil
|
|
|
+ }
|
|
|
+ return compiled, nil
|
|
|
+}
|
|
|
+
|
|
|
+func isHeaderPassthroughRuleKey(key string) bool {
|
|
|
+ key = strings.TrimSpace(key)
|
|
|
+ if key == "" {
|
|
|
+ return false
|
|
|
+ }
|
|
|
+ if key == headerPassthroughAllKey {
|
|
|
+ return true
|
|
|
+ }
|
|
|
+ lower := strings.ToLower(key)
|
|
|
+ return strings.HasPrefix(lower, headerPassthroughRegexPrefix) || strings.HasPrefix(lower, headerPassthroughRegexPrefixV2)
|
|
|
+}
|
|
|
+
|
|
|
+func shouldSkipPassthroughHeader(name string) bool {
|
|
|
+ name = strings.TrimSpace(name)
|
|
|
+ if name == "" {
|
|
|
+ return true
|
|
|
+ }
|
|
|
+ lower := strings.ToLower(name)
|
|
|
+ if _, ok := passthroughSkipHeaderNamesLower[lower]; ok {
|
|
|
+ return true
|
|
|
+ }
|
|
|
+ return false
|
|
|
+}
|
|
|
+
|
|
|
func applyHeaderOverridePlaceholders(template string, c *gin.Context, apiKey string) (string, bool, error) {
|
|
|
trimmed := strings.TrimSpace(template)
|
|
|
if strings.HasPrefix(trimmed, clientHeaderPlaceholderPrefix) {
|
|
|
@@ -77,9 +158,85 @@ func applyHeaderOverridePlaceholders(template string, c *gin.Context, apiKey str
|
|
|
// Supported placeholders:
|
|
|
// - {api_key}: resolved to the channel API key
|
|
|
// - {client_header:<name>}: resolved to the incoming request header value
|
|
|
+//
|
|
|
+// Header passthrough rules (keys only; values are ignored):
|
|
|
+// - "*": passthrough all incoming headers by name (excluding unsafe headers)
|
|
|
+// - "re:<regex>" / "regex:<regex>": passthrough headers whose names match the regex (Go regexp)
|
|
|
+//
|
|
|
+// Passthrough rules are applied first, then normal overrides are applied, so explicit overrides win.
|
|
|
func processHeaderOverride(info *common.RelayInfo, c *gin.Context) (map[string]string, error) {
|
|
|
headerOverride := make(map[string]string)
|
|
|
+
|
|
|
+ passAll := false
|
|
|
+ var passthroughRegex []*regexp.Regexp
|
|
|
+ for k := range info.HeadersOverride {
|
|
|
+ key := strings.TrimSpace(k)
|
|
|
+ if key == "" {
|
|
|
+ continue
|
|
|
+ }
|
|
|
+ if key == headerPassthroughAllKey {
|
|
|
+ passAll = true
|
|
|
+ continue
|
|
|
+ }
|
|
|
+
|
|
|
+ lower := strings.ToLower(key)
|
|
|
+ var pattern string
|
|
|
+ switch {
|
|
|
+ case strings.HasPrefix(lower, headerPassthroughRegexPrefix):
|
|
|
+ pattern = strings.TrimSpace(key[len(headerPassthroughRegexPrefix):])
|
|
|
+ case strings.HasPrefix(lower, headerPassthroughRegexPrefixV2):
|
|
|
+ pattern = strings.TrimSpace(key[len(headerPassthroughRegexPrefixV2):])
|
|
|
+ default:
|
|
|
+ continue
|
|
|
+ }
|
|
|
+
|
|
|
+ if pattern == "" {
|
|
|
+ return nil, types.NewError(fmt.Errorf("header passthrough regex pattern is empty: %q", k), types.ErrorCodeChannelHeaderOverrideInvalid)
|
|
|
+ }
|
|
|
+ compiled, err := getHeaderPassthroughRegex(pattern)
|
|
|
+ if err != nil {
|
|
|
+ return nil, types.NewError(err, types.ErrorCodeChannelHeaderOverrideInvalid)
|
|
|
+ }
|
|
|
+ passthroughRegex = append(passthroughRegex, compiled)
|
|
|
+ }
|
|
|
+
|
|
|
+ if passAll || len(passthroughRegex) > 0 {
|
|
|
+ if c == nil || c.Request == nil {
|
|
|
+ return nil, types.NewError(fmt.Errorf("missing request context for header passthrough"), types.ErrorCodeChannelHeaderOverrideInvalid)
|
|
|
+ }
|
|
|
+ for name := range c.Request.Header {
|
|
|
+ if shouldSkipPassthroughHeader(name) {
|
|
|
+ continue
|
|
|
+ }
|
|
|
+ if !passAll {
|
|
|
+ matched := false
|
|
|
+ for _, re := range passthroughRegex {
|
|
|
+ if re.MatchString(name) {
|
|
|
+ matched = true
|
|
|
+ break
|
|
|
+ }
|
|
|
+ }
|
|
|
+ if !matched {
|
|
|
+ continue
|
|
|
+ }
|
|
|
+ }
|
|
|
+ value := strings.TrimSpace(c.Request.Header.Get(name))
|
|
|
+ if value == "" {
|
|
|
+ continue
|
|
|
+ }
|
|
|
+ headerOverride[name] = value
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
for k, v := range info.HeadersOverride {
|
|
|
+ if isHeaderPassthroughRuleKey(k) {
|
|
|
+ continue
|
|
|
+ }
|
|
|
+ key := strings.TrimSpace(k)
|
|
|
+ if key == "" {
|
|
|
+ continue
|
|
|
+ }
|
|
|
+
|
|
|
str, ok := v.(string)
|
|
|
if !ok {
|
|
|
return nil, types.NewError(nil, types.ErrorCodeChannelHeaderOverrideInvalid)
|
|
|
@@ -93,7 +250,7 @@ func processHeaderOverride(info *common.RelayInfo, c *gin.Context) (map[string]s
|
|
|
continue
|
|
|
}
|
|
|
|
|
|
- headerOverride[k] = value
|
|
|
+ headerOverride[key] = value
|
|
|
}
|
|
|
return headerOverride, nil
|
|
|
}
|