Просмотр исходного кода

fix: 支持aws 通过全局参数透传或者渠道参数透传来 调用 (#2423)

* fix: 支持aws 通过全局参数透传或者渠道参数透传来 调用

* fix(aws): replace json.Unmarshal with common.Unmarshal for request body processing

---------

Co-authored-by: r0 <liangchunlei@01.ai>
Co-authored-by: CaIon <i@caion.me>
zdwy5 2 месяцев назад
Родитель
Сommit
e1bee48152
1 измененных файлов с 21 добавлено и 2 удалено
  1. 21 2
      relay/channel/aws/relay-aws.go

+ 21 - 2
relay/channel/aws/relay-aws.go

@@ -18,6 +18,7 @@ import (
 	"github.com/gin-gonic/gin"
 	"github.com/pkg/errors"
 
+	"github.com/QuantumNous/new-api/setting/model_setting"
 	"github.com/aws/aws-sdk-go-v2/aws"
 	"github.com/aws/aws-sdk-go-v2/credentials"
 	"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
@@ -129,7 +130,7 @@ func doAwsClientRequest(c *gin.Context, info *relaycommon.RelayInfo, a *Adaptor,
 				Accept:      aws.String("application/json"),
 				ContentType: aws.String("application/json"),
 			}
-			awsReq.Body, err = common.Marshal(awsClaudeReq)
+			awsReq.Body, err = buildAwsRequestBody(c, info, awsClaudeReq)
 			if err != nil {
 				return nil, types.NewError(errors.Wrap(err, "marshal aws request fail"), types.ErrorCodeBadRequestBody)
 			}
@@ -141,7 +142,7 @@ func doAwsClientRequest(c *gin.Context, info *relaycommon.RelayInfo, a *Adaptor,
 				Accept:      aws.String("application/json"),
 				ContentType: aws.String("application/json"),
 			}
-			awsReq.Body, err = common.Marshal(awsClaudeReq)
+			awsReq.Body, err = buildAwsRequestBody(c, info, awsClaudeReq)
 			if err != nil {
 				return nil, types.NewError(errors.Wrap(err, "marshal aws request fail"), types.ErrorCodeBadRequestBody)
 			}
@@ -151,6 +152,24 @@ func doAwsClientRequest(c *gin.Context, info *relaycommon.RelayInfo, a *Adaptor,
 	}
 }
 
+// buildAwsRequestBody prepares the payload for AWS requests, applying passthrough rules when enabled.
+func buildAwsRequestBody(c *gin.Context, info *relaycommon.RelayInfo, awsClaudeReq any) ([]byte, error) {
+	if model_setting.GetGlobalSettings().PassThroughRequestEnabled || info.ChannelSetting.PassThroughBodyEnabled {
+		body, err := common.GetRequestBody(c)
+		if err != nil {
+			return nil, errors.Wrap(err, "get request body for pass-through fail")
+		}
+		var data map[string]interface{}
+		if err := common.Unmarshal(body, &data); err != nil {
+			return nil, errors.Wrap(err, "pass-through unmarshal request body fail")
+		}
+		delete(data, "model")
+		delete(data, "stream")
+		return common.Marshal(data)
+	}
+	return common.Marshal(awsClaudeReq)
+}
+
 func getAwsRegionPrefix(awsRegionId string) string {
 	parts := strings.Split(awsRegionId, "-")
 	regionPrefix := ""