|
|
@@ -165,6 +165,30 @@ func GetAllChannels(c *gin.Context) {
|
|
|
return
|
|
|
}
|
|
|
|
|
|
+func buildFetchModelsHeaders(channel *model.Channel, key string) (http.Header, error) {
|
|
|
+ var headers http.Header
|
|
|
+ switch channel.Type {
|
|
|
+ case constant.ChannelTypeAnthropic:
|
|
|
+ headers = GetClaudeAuthHeader(key)
|
|
|
+ default:
|
|
|
+ headers = GetAuthHeader(key)
|
|
|
+ }
|
|
|
+
|
|
|
+ headerOverride := channel.GetHeaderOverride()
|
|
|
+ for k, v := range headerOverride {
|
|
|
+ str, ok := v.(string)
|
|
|
+ if !ok {
|
|
|
+ return nil, fmt.Errorf("invalid header override for key %s", k)
|
|
|
+ }
|
|
|
+ if strings.Contains(str, "{api_key}") {
|
|
|
+ str = strings.ReplaceAll(str, "{api_key}", key)
|
|
|
+ }
|
|
|
+ headers.Set(k, str)
|
|
|
+ }
|
|
|
+
|
|
|
+ return headers, nil
|
|
|
+}
|
|
|
+
|
|
|
func FetchUpstreamModels(c *gin.Context) {
|
|
|
id, err := strconv.Atoi(c.Param("id"))
|
|
|
if err != nil {
|
|
|
@@ -223,14 +247,13 @@ func FetchUpstreamModels(c *gin.Context) {
|
|
|
}
|
|
|
key = strings.TrimSpace(key)
|
|
|
|
|
|
- // 获取响应体 - 根据渠道类型决定是否添加 AuthHeader
|
|
|
- var body []byte
|
|
|
- switch channel.Type {
|
|
|
- case constant.ChannelTypeAnthropic:
|
|
|
- body, err = GetResponseBody("GET", url, channel, GetClaudeAuthHeader(key))
|
|
|
- default:
|
|
|
- body, err = GetResponseBody("GET", url, channel, GetAuthHeader(key))
|
|
|
+ headers, err := buildFetchModelsHeaders(channel, key)
|
|
|
+ if err != nil {
|
|
|
+ common.ApiError(c, err)
|
|
|
+ return
|
|
|
}
|
|
|
+
|
|
|
+ body, err := GetResponseBody("GET", url, channel, headers)
|
|
|
if err != nil {
|
|
|
common.ApiError(c, err)
|
|
|
return
|