Przeglądaj źródła

feat: 完善stream_options

CalciumIon 1 rok temu
rodzic
commit
df6502733c

+ 6 - 4
relay/channel/aws/relay-aws.go

@@ -214,10 +214,12 @@ func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
 			return false
 			return false
 		}
 		}
 	})
 	})
-	response := service.GenerateFinalUsageResponse(id, createdTime, model, usage)
-	err = service.ObjectData(c, response)
-	if err != nil {
-		common.SysError("send final response failed: " + err.Error())
+	if info.ShouldIncludeUsage {
+		response := service.GenerateFinalUsageResponse(id, createdTime, info.UpstreamModelName, usage)
+		err := service.ObjectData(c, response)
+		if err != nil {
+			common.SysError("send final response failed: " + err.Error())
+		}
 	}
 	}
 	service.Done(c)
 	service.Done(c)
 	err = resp.Body.Close()
 	err = resp.Body.Close()

+ 7 - 5
relay/channel/claude/relay-claude.go

@@ -349,13 +349,15 @@ func claudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
 			usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, usage.PromptTokens)
 			usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, usage.PromptTokens)
 		}
 		}
 	}
 	}
-	response := service.GenerateFinalUsageResponse(responseId, createdTime, info.UpstreamModelName, *usage)
-	err := service.ObjectData(c, response)
-	if err != nil {
-		common.SysError("send final response failed: " + err.Error())
+	if info.ShouldIncludeUsage {
+		response := service.GenerateFinalUsageResponse(responseId, createdTime, info.UpstreamModelName, *usage)
+		err := service.ObjectData(c, response)
+		if err != nil {
+			common.SysError("send final response failed: " + err.Error())
+		}
 	}
 	}
 	service.Done(c)
 	service.Done(c)
-	err = resp.Body.Close()
+	err := resp.Body.Close()
 	if err != nil {
 	if err != nil {
 		return service.OpenAIErrorWrapperLocal(err, "close_response_body_failed", http.StatusInternalServerError), nil
 		return service.OpenAIErrorWrapperLocal(err, "close_response_body_failed", http.StatusInternalServerError), nil
 	}
 	}

+ 1 - 15
relay/channel/openai/adaptor.go

@@ -7,7 +7,6 @@ import (
 	"io"
 	"io"
 	"net/http"
 	"net/http"
 	"one-api/common"
 	"one-api/common"
-	"one-api/constant"
 	"one-api/dto"
 	"one-api/dto"
 	"one-api/relay/channel"
 	"one-api/relay/channel"
 	"one-api/relay/channel/ai360"
 	"one-api/relay/channel/ai360"
@@ -20,8 +19,7 @@ import (
 )
 )
 
 
 type Adaptor struct {
 type Adaptor struct {
-	ChannelType          int
-	SupportStreamOptions bool
+	ChannelType int
 }
 }
 
 
 func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
 func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
@@ -33,7 +31,6 @@ func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequ
 
 
 func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
 func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
 	a.ChannelType = info.ChannelType
 	a.ChannelType = info.ChannelType
-	a.SupportStreamOptions = info.SupportStreamOptions
 }
 }
 
 
 func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
 func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
@@ -81,17 +78,6 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
 	if request == nil {
 	if request == nil {
 		return nil, errors.New("request is nil")
 		return nil, errors.New("request is nil")
 	}
 	}
-	// 如果不支持StreamOptions,将StreamOptions设置为nil
-	if !a.SupportStreamOptions || !request.Stream {
-		request.StreamOptions = nil
-	} else {
-		// 如果支持StreamOptions,且请求中没有设置StreamOptions,根据配置文件设置StreamOptions
-		if constant.ForceStreamOption {
-			request.StreamOptions = &dto.StreamOptions{
-				IncludeUsage: true,
-			}
-		}
-	}
 	return request, nil
 	return request, nil
 }
 }
 
 

+ 1 - 0
relay/common/relay_info.go

@@ -28,6 +28,7 @@ type RelayInfo struct {
 	Organization         string
 	Organization         string
 	BaseUrl              string
 	BaseUrl              string
 	SupportStreamOptions bool
 	SupportStreamOptions bool
+	ShouldIncludeUsage   bool
 }
 }
 
 
 func GenRelayInfo(c *gin.Context) *RelayInfo {
 func GenRelayInfo(c *gin.Context) *RelayInfo {

+ 16 - 0
relay/relay-text.go

@@ -130,6 +130,22 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
 		return openaiErr
 		return openaiErr
 	}
 	}
 
 
+	// 如果不支持StreamOptions,将StreamOptions设置为nil
+	if !relayInfo.SupportStreamOptions || !textRequest.Stream {
+		textRequest.StreamOptions = nil
+	} else {
+		// 如果支持StreamOptions,且请求中没有设置StreamOptions,根据配置文件设置StreamOptions
+		if constant.ForceStreamOption {
+			textRequest.StreamOptions = &dto.StreamOptions{
+				IncludeUsage: true,
+			}
+		}
+	}
+
+	if textRequest.StreamOptions != nil && textRequest.StreamOptions.IncludeUsage {
+		relayInfo.ShouldIncludeUsage = textRequest.StreamOptions.IncludeUsage
+	}
+
 	adaptor := GetAdaptor(relayInfo.ApiType)
 	adaptor := GetAdaptor(relayInfo.ApiType)
 	if adaptor == nil {
 	if adaptor == nil {
 		return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest)
 		return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest)