Explorar o código

fix: azure stream options

CalciumIon hai 1 ano
pai
achega
0f687aab9a

+ 1 - 1
controller/channel-test.go

@@ -87,7 +87,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
 
 	adaptor.Init(meta, *request)
 
-	convertedRequest, err := adaptor.ConvertRequest(c, constant.RelayModeChatCompletions, request)
+	convertedRequest, err := adaptor.ConvertRequest(c, meta, request)
 	if err != nil {
 		return err, nil
 	}

+ 1 - 1
relay/channel/adapter.go

@@ -14,7 +14,7 @@ type Adaptor interface {
 	InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest)
 	GetRequestURL(info *relaycommon.RelayInfo) (string, error)
 	SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error
-	ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error)
+	ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error)
 	ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error)
 	DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error)
 	DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode)

+ 2 - 2
relay/channel/ali/adaptor.go

@@ -42,11 +42,11 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re
 	return nil
 }
 
-func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
+func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
 	if request == nil {
 		return nil, errors.New("request is nil")
 	}
-	switch relayMode {
+	switch info.RelayMode {
 	case constant.RelayModeEmbeddings:
 		baiduEmbeddingRequest := embeddingRequestOpenAI2Ali(*request)
 		return baiduEmbeddingRequest, nil

+ 1 - 1
relay/channel/aws/adaptor.go

@@ -41,7 +41,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re
 	return nil
 }
 
-func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
+func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
 	if request == nil {
 		return nil, errors.New("request is nil")
 	}

+ 2 - 2
relay/channel/baidu/adaptor.go

@@ -99,11 +99,11 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re
 	return nil
 }
 
-func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
+func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
 	if request == nil {
 		return nil, errors.New("request is nil")
 	}
-	switch relayMode {
+	switch info.RelayMode {
 	case constant.RelayModeEmbeddings:
 		baiduEmbeddingRequest := embeddingRequestOpenAI2Baidu(*request)
 		return baiduEmbeddingRequest, nil

+ 1 - 1
relay/channel/claude/adaptor.go

@@ -53,7 +53,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re
 	return nil
 }
 
-func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
+func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
 	if request == nil {
 		return nil, errors.New("request is nil")
 	}

+ 2 - 2
relay/channel/cloudflare/adaptor.go

@@ -38,11 +38,11 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re
 	return nil
 }
 
-func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
+func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
 	if request == nil {
 		return nil, errors.New("request is nil")
 	}
-	switch relayMode {
+	switch info.RelayMode {
 	case constant.RelayModeCompletions:
 		return convertCf2CompletionsRequest(*request), nil
 	default:

+ 6 - 0
relay/channel/cloudflare/relay_cloudflare.go

@@ -11,6 +11,7 @@ import (
 	relaycommon "one-api/relay/common"
 	"one-api/service"
 	"strings"
+	"time"
 )
 
 func convertCf2CompletionsRequest(textRequest dto.GeneralOpenAIRequest) *CfRequest {
@@ -30,6 +31,7 @@ func cfStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rela
 	service.SetEventStreamHeaders(c)
 	id := service.GetResponseID(c)
 	var responseText string
+	isFirst := true
 
 	for scanner.Scan() {
 		data := scanner.Text()
@@ -56,6 +58,10 @@ func cfStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rela
 		response.Id = id
 		response.Model = info.UpstreamModelName
 		err = service.ObjectData(c, response)
+		if isFirst {
+			isFirst = false
+			info.FirstResponseTime = time.Now()
+		}
 		if err != nil {
 			common.LogError(c, "error_rendering_stream_response: "+err.Error())
 		}

+ 1 - 1
relay/channel/cohere/adaptor.go

@@ -34,7 +34,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re
 	return nil
 }
 
-func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
+func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
 	return requestOpenAI2Cohere(*request), nil
 }
 

+ 1 - 1
relay/channel/dify/adaptor.go

@@ -32,7 +32,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re
 	return nil
 }
 
-func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
+func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
 	if request == nil {
 		return nil, errors.New("request is nil")
 	}

+ 1 - 1
relay/channel/gemini/adaptor.go

@@ -51,7 +51,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re
 	return nil
 }
 
-func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
+func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
 	if request == nil {
 		return nil, errors.New("request is nil")
 	}

+ 1 - 1
relay/channel/jina/adaptor.go

@@ -36,7 +36,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re
 	return nil
 }
 
-func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
+func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
 	return request, nil
 }
 

+ 2 - 2
relay/channel/ollama/adaptor.go

@@ -36,11 +36,11 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re
 	return nil
 }
 
-func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
+func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
 	if request == nil {
 		return nil, errors.New("request is nil")
 	}
-	switch relayMode {
+	switch info.RelayMode {
 	case relayconstant.RelayModeEmbeddings:
 		return requestOpenAI2Embeddings(*request), nil
 	default:

+ 4 - 1
relay/channel/openai/adaptor.go

@@ -74,10 +74,13 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re
 	return nil
 }
 
-func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
+func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
 	if request == nil {
 		return nil, errors.New("request is nil")
 	}
+	if info.ChannelType != common.ChannelTypeOpenAI {
+		request.StreamOptions = nil
+	}
 	return request, nil
 }
 

+ 1 - 1
relay/channel/palm/adaptor.go

@@ -33,7 +33,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re
 	return nil
 }
 
-func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
+func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
 	if request == nil {
 		return nil, errors.New("request is nil")
 	}

+ 1 - 1
relay/channel/perplexity/adaptor.go

@@ -34,7 +34,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re
 	return nil
 }
 
-func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
+func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
 	if request == nil {
 		return nil, errors.New("request is nil")
 	}

+ 1 - 1
relay/channel/tencent/adaptor.go

@@ -47,7 +47,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re
 	return nil
 }
 
-func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
+func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
 	if request == nil {
 		return nil, errors.New("request is nil")
 	}

+ 1 - 1
relay/channel/xunfei/adaptor.go

@@ -33,7 +33,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re
 	return nil
 }
 
-func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
+func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
 	if request == nil {
 		return nil, errors.New("request is nil")
 	}

+ 1 - 1
relay/channel/zhipu/adaptor.go

@@ -37,7 +37,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re
 	return nil
 }
 
-func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
+func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
 	if request == nil {
 		return nil, errors.New("request is nil")
 	}

+ 1 - 1
relay/channel/zhipu_4v/adaptor.go

@@ -35,7 +35,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re
 	return nil
 }
 
-func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
+func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
 	if request == nil {
 		return nil, errors.New("request is nil")
 	}

+ 1 - 1
relay/relay-text.go

@@ -153,7 +153,7 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
 	adaptor.Init(relayInfo, *textRequest)
 	var requestBody io.Reader
 
-	convertedRequest, err := adaptor.ConvertRequest(c, relayInfo.RelayMode, textRequest)
+	convertedRequest, err := adaptor.ConvertRequest(c, relayInfo, textRequest)
 	if err != nil {
 		return service.OpenAIErrorWrapperLocal(err, "convert_request_failed", http.StatusInternalServerError)
 	}