瀏覽代碼

feats:add custom headers override

Nekohy 6 月之前
父節點
當前提交
abcb353793

+ 1 - 0
constant/context_key.go

@@ -27,6 +27,7 @@ const (
 	ContextKeyChannelSetting           ContextKey = "channel_setting"
 	ContextKeyChannelOtherSetting      ContextKey = "channel_other_setting"
 	ContextKeyChannelParamOverride     ContextKey = "param_override"
+	ContextKeyChannelHeaderOverride    ContextKey = "header_override"
 	ContextKeyChannelOrganization      ContextKey = "channel_organization"
 	ContextKeyChannelAutoBan           ContextKey = "auto_ban"
 	ContextKeyChannelModelMapping      ContextKey = "model_mapping"

+ 1 - 0
middleware/distributor.go

@@ -248,6 +248,7 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode
 	common.SetContextKey(c, constant.ContextKeyChannelSetting, channel.GetSetting())
 	common.SetContextKey(c, constant.ContextKeyChannelOtherSetting, channel.GetOtherSettings())
 	common.SetContextKey(c, constant.ContextKeyChannelParamOverride, channel.GetParamOverride())
+	common.SetContextKey(c, constant.ContextKeyChannelHeaderOverride, channel.GetHeaderOverride())
 	if nil != channel.OpenAIOrganization && *channel.OpenAIOrganization != "" {
 		common.SetContextKey(c, constant.ContextKeyChannelOrganization, *channel.OpenAIOrganization)
 	}

+ 12 - 0
model/channel.go

@@ -46,6 +46,7 @@ type Channel struct {
 	Tag               *string `json:"tag" gorm:"index"`
 	Setting           *string `json:"setting" gorm:"type:text"` // 渠道额外设置
 	ParamOverride     *string `json:"param_override" gorm:"type:text"`
+	HeaderOverride    *string `json:"header_override" gorm:"type:text"`
 	// add after v0.8.5
 	ChannelInfo ChannelInfo `json:"channel_info" gorm:"type:json"`
 
@@ -875,6 +876,17 @@ func (channel *Channel) GetParamOverride() map[string]interface{} {
 	return paramOverride
 }
 
+func (channel *Channel) GetHeaderOverride() map[string]interface{} {
+	headerOverride := make(map[string]interface{})
+	if channel.HeaderOverride != nil && *channel.HeaderOverride != "" {
+		err := common.Unmarshal([]byte(*channel.HeaderOverride), &headerOverride)
+		if err != nil {
+			common.SysLog(fmt.Sprintf("failed to unmarshal param override: channel_id=%d, error=%v", channel.Id, err))
+		}
+	}
+	return headerOverride
+}
+
 func GetChannelsByIds(ids []int) ([]*Channel, error) {
 	var channels []*Channel
 	err := DB.Where("id in (?)", ids).Find(&channels).Error

+ 27 - 3
relay/channel/api_request.go

@@ -13,6 +13,7 @@ import (
 	"one-api/relay/helper"
 	"one-api/service"
 	"one-api/setting/operation_setting"
+	"one-api/types"
 	"sync"
 	"time"
 
@@ -47,7 +48,19 @@ func DoApiRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBody
 	if err != nil {
 		return nil, fmt.Errorf("new request failed: %w", err)
 	}
-	err = a.SetupRequestHeader(c, &req.Header, info)
+	headers := req.Header
+	headerOverride := make(map[string]string)
+	for k, v := range info.HeadersOverride {
+		if str, ok := v.(string); ok {
+			headerOverride[k] = str
+		} else {
+			return nil, types.NewError(err, types.ErrorCodeChannelHeaderOverrideInvalid)
+		}
+	}
+	for key, value := range headerOverride {
+		headers.Set(key, value)
+	}
+	err = a.SetupRequestHeader(c, &headers, info)
 	if err != nil {
 		return nil, fmt.Errorf("setup request header failed: %w", err)
 	}
@@ -72,8 +85,19 @@ func DoFormRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBod
 	}
 	// set form data
 	req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
-
-	err = a.SetupRequestHeader(c, &req.Header, info)
+	headers := req.Header
+	headerOverride := make(map[string]string)
+	for k, v := range info.HeadersOverride {
+		if str, ok := v.(string); ok {
+			headerOverride[k] = str
+		} else {
+			return nil, types.NewError(err, types.ErrorCodeChannelHeaderOverrideInvalid)
+		}
+	}
+	for key, value := range headerOverride {
+		headers.Set(key, value)
+	}
+	err = a.SetupRequestHeader(c, &headers, info)
 	if err != nil {
 		return nil, fmt.Errorf("setup request header failed: %w", err)
 	}

+ 3 - 0
relay/common/relay_info.go

@@ -63,6 +63,7 @@ type ChannelMeta struct {
 	Organization         string
 	ChannelCreateTime    int64
 	ParamOverride        map[string]interface{}
+	HeadersOverride      map[string]interface{}
 	ChannelSetting       dto.ChannelSettings
 	ChannelOtherSettings dto.ChannelOtherSettings
 	UpstreamModelName    string
@@ -120,6 +121,7 @@ type RelayInfo struct {
 func (info *RelayInfo) InitChannelMeta(c *gin.Context) {
 	channelType := common.GetContextKeyInt(c, constant.ContextKeyChannelType)
 	paramOverride := common.GetContextKeyStringMap(c, constant.ContextKeyChannelParamOverride)
+	headerOverride := common.GetContextKeyStringMap(c, constant.ContextKeyChannelHeaderOverride)
 	apiType, _ := common.ChannelType2APIType(channelType)
 	channelMeta := &ChannelMeta{
 		ChannelType:          channelType,
@@ -133,6 +135,7 @@ func (info *RelayInfo) InitChannelMeta(c *gin.Context) {
 		Organization:         c.GetString("channel_organization"),
 		ChannelCreateTime:    c.GetInt64("channel_create_time"),
 		ParamOverride:        paramOverride,
+		HeadersOverride:      headerOverride,
 		UpstreamModelName:    common.GetContextKeyString(c, constant.ContextKeyOriginalModel),
 		IsModelMapped:        false,
 		SupportStreamOptions: false,

+ 7 - 6
types/error.go

@@ -48,12 +48,13 @@ const (
 	ErrorCodeGenRelayInfoFailed ErrorCode = "gen_relay_info_failed"
 
 	// channel error
-	ErrorCodeChannelNoAvailableKey       ErrorCode = "channel:no_available_key"
-	ErrorCodeChannelParamOverrideInvalid ErrorCode = "channel:param_override_invalid"
-	ErrorCodeChannelModelMappedError     ErrorCode = "channel:model_mapped_error"
-	ErrorCodeChannelAwsClientError       ErrorCode = "channel:aws_client_error"
-	ErrorCodeChannelInvalidKey           ErrorCode = "channel:invalid_key"
-	ErrorCodeChannelResponseTimeExceeded ErrorCode = "channel:response_time_exceeded"
+	ErrorCodeChannelNoAvailableKey        ErrorCode = "channel:no_available_key"
+	ErrorCodeChannelParamOverrideInvalid  ErrorCode = "channel:param_override_invalid"
+	ErrorCodeChannelHeaderOverrideInvalid ErrorCode = "channel:header_override_invalid"
+	ErrorCodeChannelModelMappedError      ErrorCode = "channel:model_mapped_error"
+	ErrorCodeChannelAwsClientError        ErrorCode = "channel:aws_client_error"
+	ErrorCodeChannelInvalidKey            ErrorCode = "channel:invalid_key"
+	ErrorCodeChannelResponseTimeExceeded  ErrorCode = "channel:response_time_exceeded"
 
 	// client request error
 	ErrorCodeReadRequestBodyFailed ErrorCode = "read_request_body_failed"

+ 25 - 0
web/src/components/table/channels/modals/EditChannelModal.jsx

@@ -1699,6 +1699,31 @@ const EditChannelModal = (props) => {
                     showClear
                   />
 
+                  <Form.TextArea
+                      field='header_override'
+                      label={t('请求头覆盖')}
+                      placeholder={
+                          t('此项可选,用于覆盖请求头参数') +
+                          '\n' + t('格式示例:') +
+                          '\n{\n  "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/139.0.0.0 Safari/537.36 Edg/139.0.0.0"\n}'
+                      }
+                      autosize
+                      onChange={(value) => handleInputChange('header_override', value)}
+                      extraText={
+                        <div className="flex gap-2 flex-wrap">
+                          <Text
+                              className="!text-semi-color-primary cursor-pointer"
+                              onClick={() => handleInputChange('header_override', JSON.stringify({
+                                "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/139.0.0.0 Safari/537.36 Edg/139.0.0.0"
+                              }, null, 2))}
+                          >
+                            {t('格式模板')}
+                          </Text>
+                        </div>
+                      }
+                      showClear
+                  />
+
 
                   <JSONEditor
                     key={`status_code_mapping-${isEdit ? channelId : 'new'}`}