Просмотр исходного кода

fix(aws): extract HTTP status code from AWS SDK errors

jason.mei 3 месяцев назад
Родитель
Сommit
f2e51963dc
1 измененных файлов с 21 добавлено и 3 удалено
  1. 21 3
      relay/channel/aws/relay-aws.go

+ 21 - 3
relay/channel/aws/relay-aws.go

@@ -22,9 +22,24 @@ import (
 	"github.com/aws/aws-sdk-go-v2/credentials"
 	"github.com/aws/aws-sdk-go-v2/credentials"
 	"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
 	"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
 	bedrockruntimeTypes "github.com/aws/aws-sdk-go-v2/service/bedrockruntime/types"
 	bedrockruntimeTypes "github.com/aws/aws-sdk-go-v2/service/bedrockruntime/types"
+	"github.com/aws/smithy-go"
 	"github.com/aws/smithy-go/auth/bearer"
 	"github.com/aws/smithy-go/auth/bearer"
 )
 )
 
 
+// getAwsErrorStatusCode extracts HTTP status code from AWS SDK error
+func getAwsErrorStatusCode(err error) int {
+	var apiErr smithy.APIError
+	if errors.As(err, &apiErr) {
+		// Check for HTTP response error which contains status code
+		var httpErr interface{ HTTPStatusCode() int }
+		if errors.As(err, &httpErr) {
+			return httpErr.HTTPStatusCode()
+		}
+	}
+	// Default to 500 if we can't determine the status code
+	return http.StatusInternalServerError
+}
+
 func newAwsClient(c *gin.Context, info *relaycommon.RelayInfo) (*bedrockruntime.Client, error) {
 func newAwsClient(c *gin.Context, info *relaycommon.RelayInfo) (*bedrockruntime.Client, error) {
 	var (
 	var (
 		httpClient *http.Client
 		httpClient *http.Client
@@ -173,7 +188,8 @@ func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, a *Adaptor) (*types
 
 
 	awsResp, err := a.AwsClient.InvokeModel(c.Request.Context(), a.AwsReq.(*bedrockruntime.InvokeModelInput))
 	awsResp, err := a.AwsClient.InvokeModel(c.Request.Context(), a.AwsReq.(*bedrockruntime.InvokeModelInput))
 	if err != nil {
 	if err != nil {
-		return types.NewOpenAIError(errors.Wrap(err, "InvokeModel"), types.ErrorCodeAwsInvokeError, http.StatusInternalServerError), nil
+		statusCode := getAwsErrorStatusCode(err)
+		return types.NewOpenAIError(errors.Wrap(err, "InvokeModel"), types.ErrorCodeAwsInvokeError, statusCode), nil
 	}
 	}
 
 
 	claudeInfo := &claude.ClaudeResponseInfo{
 	claudeInfo := &claude.ClaudeResponseInfo{
@@ -199,7 +215,8 @@ func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, a *Adaptor) (*types
 func awsStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, a *Adaptor) (*types.NewAPIError, *dto.Usage) {
 func awsStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, a *Adaptor) (*types.NewAPIError, *dto.Usage) {
 	awsResp, err := a.AwsClient.InvokeModelWithResponseStream(c.Request.Context(), a.AwsReq.(*bedrockruntime.InvokeModelWithResponseStreamInput))
 	awsResp, err := a.AwsClient.InvokeModelWithResponseStream(c.Request.Context(), a.AwsReq.(*bedrockruntime.InvokeModelWithResponseStreamInput))
 	if err != nil {
 	if err != nil {
-		return types.NewOpenAIError(errors.Wrap(err, "InvokeModelWithResponseStream"), types.ErrorCodeAwsInvokeError, http.StatusInternalServerError), nil
+		statusCode := getAwsErrorStatusCode(err)
+		return types.NewOpenAIError(errors.Wrap(err, "InvokeModelWithResponseStream"), types.ErrorCodeAwsInvokeError, statusCode), nil
 	}
 	}
 	stream := awsResp.GetStream()
 	stream := awsResp.GetStream()
 	defer stream.Close()
 	defer stream.Close()
@@ -238,7 +255,8 @@ func handleNovaRequest(c *gin.Context, info *relaycommon.RelayInfo, a *Adaptor)
 
 
 	awsResp, err := a.AwsClient.InvokeModel(c.Request.Context(), a.AwsReq.(*bedrockruntime.InvokeModelInput))
 	awsResp, err := a.AwsClient.InvokeModel(c.Request.Context(), a.AwsReq.(*bedrockruntime.InvokeModelInput))
 	if err != nil {
 	if err != nil {
-		return types.NewError(errors.Wrap(err, "InvokeModel"), types.ErrorCodeChannelAwsClientError), nil
+		statusCode := getAwsErrorStatusCode(err)
+		return types.NewOpenAIError(errors.Wrap(err, "InvokeModel"), types.ErrorCodeAwsInvokeError, statusCode), nil
 	}
 	}
 
 
 	// 解析Nova响应
 	// 解析Nova响应