Explorar el Código

fix: baidu max_output_tokens (close #353)

CalciumIon hace 1 año
padre
commit
02545e4856
Se han modificado 2 ficheros con 12 adiciones y 9 borrados
  1. 1 1
      relay/channel/baidu/dto.go
  2. 11 8
      relay/channel/baidu/relay-baidu.go

+ 1 - 1
relay/channel/baidu/dto.go

@@ -19,7 +19,7 @@ type BaiduChatRequest struct {
 	System          string         `json:"system,omitempty"`
 	System          string         `json:"system,omitempty"`
 	DisableSearch   bool           `json:"disable_search,omitempty"`
 	DisableSearch   bool           `json:"disable_search,omitempty"`
 	EnableCitation  bool           `json:"enable_citation,omitempty"`
 	EnableCitation  bool           `json:"enable_citation,omitempty"`
-	MaxOutputTokens int            `json:"max_output_tokens,omitempty"`
+	MaxOutputTokens *int           `json:"max_output_tokens,omitempty"`
 	UserId          string         `json:"user_id,omitempty"`
 	UserId          string         `json:"user_id,omitempty"`
 }
 }
 
 

+ 11 - 8
relay/channel/baidu/relay-baidu.go

@@ -23,14 +23,17 @@ var baiduTokenStore sync.Map
 
 
 func requestOpenAI2Baidu(request dto.GeneralOpenAIRequest) *BaiduChatRequest {
 func requestOpenAI2Baidu(request dto.GeneralOpenAIRequest) *BaiduChatRequest {
 	baiduRequest := BaiduChatRequest{
 	baiduRequest := BaiduChatRequest{
-		Temperature:     request.Temperature,
-		TopP:            request.TopP,
-		PenaltyScore:    request.FrequencyPenalty,
-		Stream:          request.Stream,
-		DisableSearch:   false,
-		EnableCitation:  false,
-		MaxOutputTokens: int(request.MaxTokens),
-		UserId:          request.User,
+		Temperature:    request.Temperature,
+		TopP:           request.TopP,
+		PenaltyScore:   request.FrequencyPenalty,
+		Stream:         request.Stream,
+		DisableSearch:  false,
+		EnableCitation: false,
+		UserId:         request.User,
+	}
+	if request.MaxTokens != 0 {
+		maxTokens := int(request.MaxTokens)
+		baiduRequest.MaxOutputTokens = &maxTokens
 	}
 	}
 	for _, message := range request.Messages {
 	for _, message := range request.Messages {
 		if message.Role == "system" {
 		if message.Role == "system" {