Browse Source

feat: 修复智谱GLM-4V流模式异常

1808837298@qq.com 2 years ago
parent
commit
84cac72a45

+ 2 - 2
relay/channel/openai/adaptor.go

@@ -71,10 +71,10 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
 func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
 func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
 	if info.IsStream {
 	if info.IsStream {
 		var responseText string
 		var responseText string
-		err, responseText = openaiStreamHandler(c, resp, info.RelayMode)
+		err, responseText = OpenaiStreamHandler(c, resp, info.RelayMode)
 		usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
 		usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
 	} else {
 	} else {
-		err, usage = openaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
+		err, usage = OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
 	}
 	}
 	return
 	return
 }
 }

+ 2 - 2
relay/channel/openai/relay-openai.go

@@ -16,7 +16,7 @@ import (
 	"time"
 	"time"
 )
 )
 
 
-func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*dto.OpenAIErrorWithStatusCode, string) {
+func OpenaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*dto.OpenAIErrorWithStatusCode, string) {
 	var responseTextBuilder strings.Builder
 	var responseTextBuilder strings.Builder
 	scanner := bufio.NewScanner(resp.Body)
 	scanner := bufio.NewScanner(resp.Body)
 	scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
 	scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
@@ -111,7 +111,7 @@ func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*d
 	return nil, responseTextBuilder.String()
 	return nil, responseTextBuilder.String()
 }
 }
 
 
-func openaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
+func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
 	var textResponse dto.TextResponse
 	var textResponse dto.TextResponse
 	responseBody, err := io.ReadAll(resp.Body)
 	responseBody, err := io.ReadAll(resp.Body)
 	if err != nil {
 	if err != nil {

+ 7 - 3
relay/channel/zhipu_v4/adaptor.go → relay/channel/zhipu_4v/adaptor.go

@@ -1,4 +1,4 @@
-package zhipu_v4
+package zhipu_4v
 
 
 import (
 import (
 	"errors"
 	"errors"
@@ -8,7 +8,9 @@ import (
 	"net/http"
 	"net/http"
 	"one-api/dto"
 	"one-api/dto"
 	"one-api/relay/channel"
 	"one-api/relay/channel"
+	"one-api/relay/channel/openai"
 	relaycommon "one-api/relay/common"
 	relaycommon "one-api/relay/common"
+	"one-api/service"
 )
 )
 
 
 type Adaptor struct {
 type Adaptor struct {
@@ -41,9 +43,11 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
 
 
 func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
 func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
 	if info.IsStream {
 	if info.IsStream {
-		err, usage = zhipuStreamHandler(c, resp)
+		var responseText string
+		err, responseText = openai.OpenaiStreamHandler(c, resp, info.RelayMode)
+		usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
 	} else {
 	} else {
-		err, usage = zhipuHandler(c, resp)
+		err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
 	}
 	}
 	return
 	return
 }
 }

+ 2 - 2
relay/channel/zhipu_v4/constants.go → relay/channel/zhipu_4v/constants.go

@@ -1,7 +1,7 @@
-package zhipu_v4
+package zhipu_4v
 
 
 var ModelList = []string{
 var ModelList = []string{
 	"glm-4", "glm-4v", "glm-3-turbo",
 	"glm-4", "glm-4v", "glm-3-turbo",
 }
 }
 
 
-var ChannelName = "zhipu_v4"
+var ChannelName = "zhipu_4v"

+ 1 - 1
relay/channel/zhipu_v4/dto.go → relay/channel/zhipu_4v/dto.go

@@ -1,4 +1,4 @@
-package zhipu_v4
+package zhipu_4v
 
 
 import (
 import (
 	"one-api/dto"
 	"one-api/dto"

+ 1 - 1
relay/channel/zhipu_v4/relay-zhipu_v4.go → relay/channel/zhipu_4v/relay-zhipu_v4.go

@@ -1,4 +1,4 @@
-package zhipu_v4
+package zhipu_4v
 
 
 import (
 import (
 	"bufio"
 	"bufio"

+ 2 - 2
relay/relay_adaptor.go

@@ -11,7 +11,7 @@ import (
 	"one-api/relay/channel/tencent"
 	"one-api/relay/channel/tencent"
 	"one-api/relay/channel/xunfei"
 	"one-api/relay/channel/xunfei"
 	"one-api/relay/channel/zhipu"
 	"one-api/relay/channel/zhipu"
-	"one-api/relay/channel/zhipu_v4"
+	"one-api/relay/channel/zhipu_4v"
 	"one-api/relay/constant"
 	"one-api/relay/constant"
 )
 )
 
 
@@ -38,7 +38,7 @@ func GetAdaptor(apiType int) channel.Adaptor {
 	case constant.APITypeZhipu:
 	case constant.APITypeZhipu:
 		return &zhipu.Adaptor{}
 		return &zhipu.Adaptor{}
 	case constant.APITypeZhipu_v4:
 	case constant.APITypeZhipu_v4:
-		return &zhipu_v4.Adaptor{}
+		return &zhipu_4v.Adaptor{}
 	}
 	}
 	return nil
 	return nil
 }
 }