|
|
@@ -38,9 +38,46 @@ func SetupApiRequestHeader(info *common.RelayInfo, c *gin.Context, req *http.Hea
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-// processHeaderOverride 处理请求头覆盖,支持变量替换
|
|
|
-// 支持的变量:{api_key}
|
|
|
-func processHeaderOverride(info *common.RelayInfo) (map[string]string, error) {
|
|
|
+const clientHeaderPlaceholderPrefix = "{client_header:"
|
|
|
+
|
|
|
+func applyHeaderOverridePlaceholders(template string, c *gin.Context, apiKey string) (string, bool, error) {
|
|
|
+ trimmed := strings.TrimSpace(template)
|
|
|
+ if strings.HasPrefix(trimmed, clientHeaderPlaceholderPrefix) {
|
|
|
+ afterPrefix := trimmed[len(clientHeaderPlaceholderPrefix):]
|
|
|
+ end := strings.Index(afterPrefix, "}")
|
|
|
+ if end < 0 || end != len(afterPrefix)-1 {
|
|
|
+ return "", false, fmt.Errorf("client_header placeholder must be the full value: %q", template)
|
|
|
+ }
|
|
|
+
|
|
|
+ name := strings.TrimSpace(afterPrefix[:end])
|
|
|
+ if name == "" {
|
|
|
+ return "", false, fmt.Errorf("client_header placeholder name is empty: %q", template)
|
|
|
+ }
|
|
|
+ if c == nil || c.Request == nil {
|
|
|
+ return "", false, fmt.Errorf("missing request context for client_header placeholder")
|
|
|
+ }
|
|
|
+ clientHeaderValue := c.Request.Header.Get(name)
|
|
|
+ if strings.TrimSpace(clientHeaderValue) == "" {
|
|
|
+ return "", false, nil
|
|
|
+ }
|
|
|
+ // Do not interpolate {api_key} inside client-supplied content.
|
|
|
+ return clientHeaderValue, true, nil
|
|
|
+ }
|
|
|
+
|
|
|
+ if strings.Contains(template, "{api_key}") {
|
|
|
+ template = strings.ReplaceAll(template, "{api_key}", apiKey)
|
|
|
+ }
|
|
|
+ if strings.TrimSpace(template) == "" {
|
|
|
+ return "", false, nil
|
|
|
+ }
|
|
|
+ return template, true, nil
|
|
|
+}
|
|
|
+
|
|
|
+// processHeaderOverride applies channel header overrides, with placeholder substitution.
|
|
|
+// Supported placeholders:
|
|
|
+// - {api_key}: resolved to the channel API key
|
|
|
+// - {client_header:<name>}: resolved to the incoming request header value
|
|
|
+func processHeaderOverride(info *common.RelayInfo, c *gin.Context) (map[string]string, error) {
|
|
|
headerOverride := make(map[string]string)
|
|
|
for k, v := range info.HeadersOverride {
|
|
|
str, ok := v.(string)
|
|
|
@@ -48,12 +85,15 @@ func processHeaderOverride(info *common.RelayInfo) (map[string]string, error) {
|
|
|
return nil, types.NewError(nil, types.ErrorCodeChannelHeaderOverrideInvalid)
|
|
|
}
|
|
|
|
|
|
- // 替换支持的变量
|
|
|
- if strings.Contains(str, "{api_key}") {
|
|
|
- str = strings.ReplaceAll(str, "{api_key}", info.ApiKey)
|
|
|
+ value, include, err := applyHeaderOverridePlaceholders(str, c, info.ApiKey)
|
|
|
+ if err != nil {
|
|
|
+ return nil, types.NewError(err, types.ErrorCodeChannelHeaderOverrideInvalid)
|
|
|
+ }
|
|
|
+ if !include {
|
|
|
+ continue
|
|
|
}
|
|
|
|
|
|
- headerOverride[k] = str
|
|
|
+ headerOverride[k] = value
|
|
|
}
|
|
|
return headerOverride, nil
|
|
|
}
|
|
|
@@ -77,7 +117,7 @@ func DoApiRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBody
|
|
|
}
|
|
|
// 在 SetupRequestHeader 之后应用 Header Override,确保用户设置优先级最高
|
|
|
// 这样可以覆盖默认的 Authorization header 设置
|
|
|
- headerOverride, err := processHeaderOverride(info)
|
|
|
+ headerOverride, err := processHeaderOverride(info, c)
|
|
|
if err != nil {
|
|
|
return nil, err
|
|
|
}
|
|
|
@@ -112,7 +152,7 @@ func DoFormRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBod
|
|
|
}
|
|
|
// 在 SetupRequestHeader 之后应用 Header Override,确保用户设置优先级最高
|
|
|
// 这样可以覆盖默认的 Authorization header 设置
|
|
|
- headerOverride, err := processHeaderOverride(info)
|
|
|
+ headerOverride, err := processHeaderOverride(info, c)
|
|
|
if err != nil {
|
|
|
return nil, err
|
|
|
}
|
|
|
@@ -138,7 +178,7 @@ func DoWssRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBody
|
|
|
}
|
|
|
// 在 SetupRequestHeader 之后应用 Header Override,确保用户设置优先级最高
|
|
|
// 这样可以覆盖默认的 Authorization header 设置
|
|
|
- headerOverride, err := processHeaderOverride(info)
|
|
|
+ headerOverride, err := processHeaderOverride(info, c)
|
|
|
if err != nil {
|
|
|
return nil, err
|
|
|
}
|