|
@@ -1,11 +1,13 @@
|
|
|
package aws
|
|
package aws
|
|
|
|
|
|
|
|
import (
|
|
import (
|
|
|
|
|
+ "context"
|
|
|
"encoding/json"
|
|
"encoding/json"
|
|
|
"fmt"
|
|
"fmt"
|
|
|
"io"
|
|
"io"
|
|
|
"net/http"
|
|
"net/http"
|
|
|
"strings"
|
|
"strings"
|
|
|
|
|
+ "time"
|
|
|
|
|
|
|
|
"github.com/QuantumNous/new-api/common"
|
|
"github.com/QuantumNous/new-api/common"
|
|
|
"github.com/QuantumNous/new-api/dto"
|
|
"github.com/QuantumNous/new-api/dto"
|
|
@@ -37,6 +39,13 @@ func getAwsErrorStatusCode(err error) int {
|
|
|
return http.StatusInternalServerError
|
|
return http.StatusInternalServerError
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
+func newAwsInvokeContext() (context.Context, context.CancelFunc) {
|
|
|
|
|
+ if common.RelayTimeout <= 0 {
|
|
|
|
|
+ return context.Background(), func() {}
|
|
|
|
|
+ }
|
|
|
|
|
+ return context.WithTimeout(context.Background(), time.Duration(common.RelayTimeout)*time.Second)
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
func newAwsClient(c *gin.Context, info *relaycommon.RelayInfo) (*bedrockruntime.Client, error) {
|
|
func newAwsClient(c *gin.Context, info *relaycommon.RelayInfo) (*bedrockruntime.Client, error) {
|
|
|
var (
|
|
var (
|
|
|
httpClient *http.Client
|
|
httpClient *http.Client
|
|
@@ -117,6 +126,7 @@ func doAwsClientRequest(c *gin.Context, info *relaycommon.RelayInfo, a *Adaptor,
|
|
|
return nil, types.NewError(errors.Wrap(err, "marshal nova request"), types.ErrorCodeBadResponseBody)
|
|
return nil, types.NewError(errors.Wrap(err, "marshal nova request"), types.ErrorCodeBadResponseBody)
|
|
|
}
|
|
}
|
|
|
awsReq.Body = reqBody
|
|
awsReq.Body = reqBody
|
|
|
|
|
+ a.AwsReq = awsReq
|
|
|
return nil, nil
|
|
return nil, nil
|
|
|
} else {
|
|
} else {
|
|
|
awsClaudeReq, err := formatRequest(requestBody, requestHeader)
|
|
awsClaudeReq, err := formatRequest(requestBody, requestHeader)
|
|
@@ -201,7 +211,10 @@ func getAwsModelID(requestModel string) string {
|
|
|
|
|
|
|
|
func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, a *Adaptor) (*types.NewAPIError, *dto.Usage) {
|
|
func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, a *Adaptor) (*types.NewAPIError, *dto.Usage) {
|
|
|
|
|
|
|
|
- awsResp, err := a.AwsClient.InvokeModel(c.Request.Context(), a.AwsReq.(*bedrockruntime.InvokeModelInput))
|
|
|
|
|
|
|
+ ctx, cancel := newAwsInvokeContext()
|
|
|
|
|
+ defer cancel()
|
|
|
|
|
+
|
|
|
|
|
+ awsResp, err := a.AwsClient.InvokeModel(ctx, a.AwsReq.(*bedrockruntime.InvokeModelInput))
|
|
|
if err != nil {
|
|
if err != nil {
|
|
|
statusCode := getAwsErrorStatusCode(err)
|
|
statusCode := getAwsErrorStatusCode(err)
|
|
|
return types.NewOpenAIError(errors.Wrap(err, "InvokeModel"), types.ErrorCodeAwsInvokeError, statusCode), nil
|
|
return types.NewOpenAIError(errors.Wrap(err, "InvokeModel"), types.ErrorCodeAwsInvokeError, statusCode), nil
|
|
@@ -228,7 +241,10 @@ func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, a *Adaptor) (*types
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
func awsStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, a *Adaptor) (*types.NewAPIError, *dto.Usage) {
|
|
func awsStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, a *Adaptor) (*types.NewAPIError, *dto.Usage) {
|
|
|
- awsResp, err := a.AwsClient.InvokeModelWithResponseStream(c.Request.Context(), a.AwsReq.(*bedrockruntime.InvokeModelWithResponseStreamInput))
|
|
|
|
|
|
|
+ ctx, cancel := newAwsInvokeContext()
|
|
|
|
|
+ defer cancel()
|
|
|
|
|
+
|
|
|
|
|
+ awsResp, err := a.AwsClient.InvokeModelWithResponseStream(ctx, a.AwsReq.(*bedrockruntime.InvokeModelWithResponseStreamInput))
|
|
|
if err != nil {
|
|
if err != nil {
|
|
|
statusCode := getAwsErrorStatusCode(err)
|
|
statusCode := getAwsErrorStatusCode(err)
|
|
|
return types.NewOpenAIError(errors.Wrap(err, "InvokeModelWithResponseStream"), types.ErrorCodeAwsInvokeError, statusCode), nil
|
|
return types.NewOpenAIError(errors.Wrap(err, "InvokeModelWithResponseStream"), types.ErrorCodeAwsInvokeError, statusCode), nil
|
|
@@ -268,7 +284,10 @@ func awsStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, a *Adaptor) (
|
|
|
// Nova模型处理函数
|
|
// Nova模型处理函数
|
|
|
func handleNovaRequest(c *gin.Context, info *relaycommon.RelayInfo, a *Adaptor) (*types.NewAPIError, *dto.Usage) {
|
|
func handleNovaRequest(c *gin.Context, info *relaycommon.RelayInfo, a *Adaptor) (*types.NewAPIError, *dto.Usage) {
|
|
|
|
|
|
|
|
- awsResp, err := a.AwsClient.InvokeModel(c.Request.Context(), a.AwsReq.(*bedrockruntime.InvokeModelInput))
|
|
|
|
|
|
|
+ ctx, cancel := newAwsInvokeContext()
|
|
|
|
|
+ defer cancel()
|
|
|
|
|
+
|
|
|
|
|
+ awsResp, err := a.AwsClient.InvokeModel(ctx, a.AwsReq.(*bedrockruntime.InvokeModelInput))
|
|
|
if err != nil {
|
|
if err != nil {
|
|
|
statusCode := getAwsErrorStatusCode(err)
|
|
statusCode := getAwsErrorStatusCode(err)
|
|
|
return types.NewOpenAIError(errors.Wrap(err, "InvokeModel"), types.ErrorCodeAwsInvokeError, statusCode), nil
|
|
return types.NewOpenAIError(errors.Wrap(err, "InvokeModel"), types.ErrorCodeAwsInvokeError, statusCode), nil
|