Преглед изворни кода

fix: fix the proxyURL is empty, not using the default HTTP client configuration && the AWS calling side did not apply the relay timeout.

Seefs пре 2 месеци
родитељ
комит
5f37a1e97c
2 измењених фајлова са 25 додато и 3 уклоњено
  1. 22 3
      relay/channel/aws/relay-aws.go
  2. 3 0
      service/http_client.go

+ 22 - 3
relay/channel/aws/relay-aws.go

@@ -1,11 +1,13 @@
 package aws
 
 import (
+	"context"
 	"encoding/json"
 	"fmt"
 	"io"
 	"net/http"
 	"strings"
+	"time"
 
 	"github.com/QuantumNous/new-api/common"
 	"github.com/QuantumNous/new-api/dto"
@@ -37,6 +39,13 @@ func getAwsErrorStatusCode(err error) int {
 	return http.StatusInternalServerError
 }
 
+func newAwsInvokeContext() (context.Context, context.CancelFunc) {
+	if common.RelayTimeout <= 0 {
+		return context.Background(), func() {}
+	}
+	return context.WithTimeout(context.Background(), time.Duration(common.RelayTimeout)*time.Second)
+}
+
 func newAwsClient(c *gin.Context, info *relaycommon.RelayInfo) (*bedrockruntime.Client, error) {
 	var (
 		httpClient *http.Client
@@ -117,6 +126,7 @@ func doAwsClientRequest(c *gin.Context, info *relaycommon.RelayInfo, a *Adaptor,
 			return nil, types.NewError(errors.Wrap(err, "marshal nova request"), types.ErrorCodeBadResponseBody)
 		}
 		awsReq.Body = reqBody
+		a.AwsReq = awsReq
 		return nil, nil
 	} else {
 		awsClaudeReq, err := formatRequest(requestBody, requestHeader)
@@ -201,7 +211,10 @@ func getAwsModelID(requestModel string) string {
 
 func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, a *Adaptor) (*types.NewAPIError, *dto.Usage) {
 
-	awsResp, err := a.AwsClient.InvokeModel(c.Request.Context(), a.AwsReq.(*bedrockruntime.InvokeModelInput))
+	ctx, cancel := newAwsInvokeContext()
+	defer cancel()
+
+	awsResp, err := a.AwsClient.InvokeModel(ctx, a.AwsReq.(*bedrockruntime.InvokeModelInput))
 	if err != nil {
 		statusCode := getAwsErrorStatusCode(err)
 		return types.NewOpenAIError(errors.Wrap(err, "InvokeModel"), types.ErrorCodeAwsInvokeError, statusCode), nil
@@ -228,7 +241,10 @@ func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, a *Adaptor) (*types
 }
 
 func awsStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, a *Adaptor) (*types.NewAPIError, *dto.Usage) {
-	awsResp, err := a.AwsClient.InvokeModelWithResponseStream(c.Request.Context(), a.AwsReq.(*bedrockruntime.InvokeModelWithResponseStreamInput))
+	ctx, cancel := newAwsInvokeContext()
+	defer cancel()
+
+	awsResp, err := a.AwsClient.InvokeModelWithResponseStream(ctx, a.AwsReq.(*bedrockruntime.InvokeModelWithResponseStreamInput))
 	if err != nil {
 		statusCode := getAwsErrorStatusCode(err)
 		return types.NewOpenAIError(errors.Wrap(err, "InvokeModelWithResponseStream"), types.ErrorCodeAwsInvokeError, statusCode), nil
@@ -268,7 +284,10 @@ func awsStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, a *Adaptor) (
 // Nova模型处理函数
 func handleNovaRequest(c *gin.Context, info *relaycommon.RelayInfo, a *Adaptor) (*types.NewAPIError, *dto.Usage) {
 
-	awsResp, err := a.AwsClient.InvokeModel(c.Request.Context(), a.AwsReq.(*bedrockruntime.InvokeModelInput))
+	ctx, cancel := newAwsInvokeContext()
+	defer cancel()
+
+	awsResp, err := a.AwsClient.InvokeModel(ctx, a.AwsReq.(*bedrockruntime.InvokeModelInput))
 	if err != nil {
 		statusCode := getAwsErrorStatusCode(err)
 		return types.NewOpenAIError(errors.Wrap(err, "InvokeModel"), types.ErrorCodeAwsInvokeError, statusCode), nil

+ 3 - 0
service/http_client.go

@@ -82,6 +82,9 @@ func ResetProxyClientCache() {
 // NewProxyHttpClient 创建支持代理的 HTTP 客户端
 func NewProxyHttpClient(proxyURL string) (*http.Client, error) {
 	if proxyURL == "" {
+		if client := GetHttpClient(); client != nil {
+			return client, nil
+		}
 		return http.DefaultClient, nil
 	}