Просмотр исходного кода

Merge pull request #2738 from Li-Xingyu/main

Seefs 1 месяц назад
Родитель
Сommit
df43193600
2 измененных файлов с 35 добавлено и 14 удалено
  1. 18 12
      relay/channel/api_request.go
  2. 17 2
      relay/channel/openai/adaptor.go

+ 18 - 12
relay/channel/api_request.go

@@ -71,6 +71,12 @@ func DoApiRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBody
 		return nil, fmt.Errorf("new request failed: %w", err)
 	}
 	headers := req.Header
+	err = a.SetupRequestHeader(c, &headers, info)
+	if err != nil {
+		return nil, fmt.Errorf("setup request header failed: %w", err)
+	}
+	// 在 SetupRequestHeader 之后应用 Header Override,确保用户设置优先级最高
+	// 这样可以覆盖默认的 Authorization header 设置
 	headerOverride, err := processHeaderOverride(info)
 	if err != nil {
 		return nil, err
@@ -78,10 +84,6 @@ func DoApiRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBody
 	for key, value := range headerOverride {
 		headers.Set(key, value)
 	}
-	err = a.SetupRequestHeader(c, &headers, info)
-	if err != nil {
-		return nil, fmt.Errorf("setup request header failed: %w", err)
-	}
 	resp, err := doRequest(c, req, info)
 	if err != nil {
 		return nil, fmt.Errorf("do request failed: %w", err)
@@ -104,6 +106,12 @@ func DoFormRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBod
 	// set form data
 	req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
 	headers := req.Header
+	err = a.SetupRequestHeader(c, &headers, info)
+	if err != nil {
+		return nil, fmt.Errorf("setup request header failed: %w", err)
+	}
+	// 在 SetupRequestHeader 之后应用 Header Override,确保用户设置优先级最高
+	// 这样可以覆盖默认的 Authorization header 设置
 	headerOverride, err := processHeaderOverride(info)
 	if err != nil {
 		return nil, err
@@ -111,10 +119,6 @@ func DoFormRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBod
 	for key, value := range headerOverride {
 		headers.Set(key, value)
 	}
-	err = a.SetupRequestHeader(c, &headers, info)
-	if err != nil {
-		return nil, fmt.Errorf("setup request header failed: %w", err)
-	}
 	resp, err := doRequest(c, req, info)
 	if err != nil {
 		return nil, fmt.Errorf("do request failed: %w", err)
@@ -128,6 +132,12 @@ func DoWssRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBody
 		return nil, fmt.Errorf("get request url failed: %w", err)
 	}
 	targetHeader := http.Header{}
+	err = a.SetupRequestHeader(c, &targetHeader, info)
+	if err != nil {
+		return nil, fmt.Errorf("setup request header failed: %w", err)
+	}
+	// 在 SetupRequestHeader 之后应用 Header Override,确保用户设置优先级最高
+	// 这样可以覆盖默认的 Authorization header 设置
 	headerOverride, err := processHeaderOverride(info)
 	if err != nil {
 		return nil, err
@@ -135,10 +145,6 @@ func DoWssRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBody
 	for key, value := range headerOverride {
 		targetHeader.Set(key, value)
 	}
-	err = a.SetupRequestHeader(c, &targetHeader, info)
-	if err != nil {
-		return nil, fmt.Errorf("setup request header failed: %w", err)
-	}
 	targetHeader.Set("Content-Type", c.Request.Header.Get("Content-Type"))
 	targetConn, _, err := websocket.DefaultDialer.Dial(fullRequestURL, targetHeader)
 	if err != nil {

+ 17 - 2
relay/channel/openai/adaptor.go

@@ -187,6 +187,17 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, header *http.Header, info *
 	if info.ChannelType == constant.ChannelTypeOpenAI && "" != info.Organization {
 		header.Set("OpenAI-Organization", info.Organization)
 	}
+	// 检查 Header Override 是否已设置 Authorization,如果已设置则跳过默认设置
+	// 这样可以避免在 Header Override 应用时被覆盖(虽然 Header Override 会在之后应用,但这里作为额外保护)
+	hasAuthOverride := false
+	if len(info.HeadersOverride) > 0 {
+		for k := range info.HeadersOverride {
+			if strings.EqualFold(k, "Authorization") {
+				hasAuthOverride = true
+				break
+			}
+		}
+	}
 	if info.RelayMode == relayconstant.RelayModeRealtime {
 		swp := c.Request.Header.Get("Sec-WebSocket-Protocol")
 		if swp != "" {
@@ -201,10 +212,14 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, header *http.Header, info *
 			//req.Header.Set("Sec-Websocket-Version", c.Request.Header.Get("Sec-Websocket-Version"))
 		} else {
 			header.Set("openai-beta", "realtime=v1")
-			header.Set("Authorization", "Bearer "+info.ApiKey)
+			if !hasAuthOverride {
+				header.Set("Authorization", "Bearer "+info.ApiKey)
+			}
 		}
 	} else {
-		header.Set("Authorization", "Bearer "+info.ApiKey)
+		if !hasAuthOverride {
+			header.Set("Authorization", "Bearer "+info.ApiKey)
+		}
 	}
 	if info.ChannelType == constant.ChannelTypeOpenRouter {
 		header.Set("HTTP-Referer", "https://www.newapi.ai")