Просмотр исходного кода

feat: support aws stream_options

CalciumIon 1 год назад
Родитель
Сommit
9896ba0a64
3 измененных файлов с 14 добавлено и 5 удалено
  1. 1 1
      relay/channel/aws/adaptor.go
  2. 12 3
      relay/channel/aws/relay-aws.go
  3. 1 1
      relay/channel/claude/relay-claude.go

+ 1 - 1
relay/channel/aws/adaptor.go

@@ -68,7 +68,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
 
 
 func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
 func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
 	if info.IsStream {
 	if info.IsStream {
-		err, usage = awsStreamHandler(c, info, a.RequestMode)
+		err, usage = awsStreamHandler(c, resp, info, a.RequestMode)
 	} else {
 	} else {
 		err, usage = awsHandler(c, info, a.RequestMode)
 		err, usage = awsHandler(c, info, a.RequestMode)
 	}
 	}

+ 12 - 3
relay/channel/aws/relay-aws.go

@@ -13,6 +13,7 @@ import (
 	relaymodel "one-api/dto"
 	relaymodel "one-api/dto"
 	"one-api/relay/channel/claude"
 	"one-api/relay/channel/claude"
 	relaycommon "one-api/relay/common"
 	relaycommon "one-api/relay/common"
+	"one-api/service"
 	"strings"
 	"strings"
 	"time"
 	"time"
 
 
@@ -112,7 +113,7 @@ func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode int) (*
 	return nil, &usage
 	return nil, &usage
 }
 }
 
 
-func awsStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode int) (*relaymodel.OpenAIErrorWithStatusCode, *relaymodel.Usage) {
+func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*relaymodel.OpenAIErrorWithStatusCode, *relaymodel.Usage) {
 	awsCli, err := newAwsClient(c, info)
 	awsCli, err := newAwsClient(c, info)
 	if err != nil {
 	if err != nil {
 		return wrapErr(errors.Wrap(err, "newAwsClient")), nil
 		return wrapErr(errors.Wrap(err, "newAwsClient")), nil
@@ -162,7 +163,6 @@ func awsStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode i
 	c.Stream(func(w io.Writer) bool {
 	c.Stream(func(w io.Writer) bool {
 		event, ok := <-stream.Events()
 		event, ok := <-stream.Events()
 		if !ok {
 		if !ok {
-			c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
 			return false
 			return false
 		}
 		}
 
 
@@ -214,6 +214,15 @@ func awsStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode i
 			return false
 			return false
 		}
 		}
 	})
 	})
-
+	response := service.GenerateFinalUsageResponse(id, createdTime, model, usage)
+	err = service.ObjectData(c, response)
+	if err != nil {
+		common.SysError("send final response failed: " + err.Error())
+	}
+	service.Done(c)
+	err = resp.Body.Close()
+	if err != nil {
+		return service.OpenAIErrorWrapperLocal(err, "close_response_body_failed", http.StatusInternalServerError), nil
+	}
 	return nil, &usage
 	return nil, &usage
 }
 }

+ 1 - 1
relay/channel/claude/relay-claude.go

@@ -352,7 +352,7 @@ func claudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
 	response := service.GenerateFinalUsageResponse(responseId, createdTime, info.UpstreamModelName, *usage)
 	response := service.GenerateFinalUsageResponse(responseId, createdTime, info.UpstreamModelName, *usage)
 	err := service.ObjectData(c, response)
 	err := service.ObjectData(c, response)
 	if err != nil {
 	if err != nil {
-		common.SysError(err.Error())
+		common.SysError("send final response failed: " + err.Error())
 	}
 	}
 	service.Done(c)
 	service.Done(c)
 	err = resp.Body.Close()
 	err = resp.Body.Close()