Explorar el Código

feat(aws): Add support for anthropic-beta header in AwsClaudeRequest

CaIon hace 3 meses
padre
commit
e1a52f1d5a
Se han modificado 2 ficheros con 19 adiciones y 19 borrados
  1. 14 17
      relay/channel/aws/dto.go
  2. 5 2
      relay/channel/aws/relay-aws.go

+ 14 - 17
relay/channel/aws/dto.go

@@ -1,7 +1,9 @@
 package aws
 
 import (
+	"encoding/json"
 	"io"
+	"net/http"
 
 	"github.com/QuantumNous/new-api/common"
 	"github.com/QuantumNous/new-api/dto"
@@ -10,6 +12,7 @@ import (
 type AwsClaudeRequest struct {
 	// AnthropicVersion should be "bedrock-2023-05-31"
 	AnthropicVersion string              `json:"anthropic_version"`
+	AnthropicBeta    json.RawMessage     `json:"anthropic_beta"`
 	System           any                 `json:"system,omitempty"`
 	Messages         []dto.ClaudeMessage `json:"messages"`
 	MaxTokens        uint                `json:"max_tokens,omitempty"`
@@ -22,29 +25,23 @@ type AwsClaudeRequest struct {
 	Thinking         *dto.Thinking       `json:"thinking,omitempty"`
 }
 
-func copyRequest(req *dto.ClaudeRequest) *AwsClaudeRequest {
-	return &AwsClaudeRequest{
-		AnthropicVersion: "bedrock-2023-05-31",
-		System:           req.System,
-		Messages:         req.Messages,
-		MaxTokens:        req.MaxTokens,
-		Temperature:      req.Temperature,
-		TopP:             req.TopP,
-		TopK:             req.TopK,
-		StopSequences:    req.StopSequences,
-		Tools:            req.Tools,
-		ToolChoice:       req.ToolChoice,
-		Thinking:         req.Thinking,
-	}
-}
-
-func formatRequest(requestBody io.Reader) (*AwsClaudeRequest, error) {
+func formatRequest(requestBody io.Reader, requestHeader http.Header) (*AwsClaudeRequest, error) {
 	var awsClaudeRequest AwsClaudeRequest
 	err := common.DecodeJson(requestBody, &awsClaudeRequest)
 	if err != nil {
 		return nil, err
 	}
 	awsClaudeRequest.AnthropicVersion = "bedrock-2023-05-31"
+
+	// check header anthropic-beta
+	anthropicBetaValues := requestHeader.Values("anthropic-beta")
+	if len(anthropicBetaValues) > 0 {
+		betaJson, err := json.Marshal(anthropicBetaValues)
+		if err != nil {
+			return nil, err
+		}
+		awsClaudeRequest.AnthropicBeta = json.RawMessage(betaJson)
+	}
 	return &awsClaudeRequest, nil
 }
 

+ 5 - 2
relay/channel/aws/relay-aws.go

@@ -73,7 +73,6 @@ func doAwsClientRequest(c *gin.Context, info *relaycommon.RelayInfo, a *Adaptor,
 	}
 	a.AwsClient = awsCli
 
-	println(info.UpstreamModelName)
 	// 获取对应的AWS模型ID
 	awsModelId := getAwsModelID(info.UpstreamModelName)
 
@@ -83,6 +82,10 @@ func doAwsClientRequest(c *gin.Context, info *relaycommon.RelayInfo, a *Adaptor,
 		awsModelId = awsModelCrossRegion(awsModelId, awsRegionPrefix)
 	}
 
+	// init empty request.header
+	requestHeader := http.Header{}
+	a.SetupRequestHeader(c, &requestHeader, info)
+
 	if isNovaModel(awsModelId) {
 		var novaReq *NovaRequest
 		err = common.DecodeJson(requestBody, &novaReq)
@@ -104,7 +107,7 @@ func doAwsClientRequest(c *gin.Context, info *relaycommon.RelayInfo, a *Adaptor,
 		awsReq.Body = reqBody
 		return nil, nil
 	} else {
-		awsClaudeReq, err := formatRequest(requestBody)
+		awsClaudeReq, err := formatRequest(requestBody, requestHeader)
 		if err != nil {
 			return nil, types.NewError(errors.Wrap(err, "format aws request fail"), types.ErrorCodeBadRequestBody)
 		}