Просмотр исходного кода

feat: support max_tokens now (#52)

JustSong 2 лет назад
Родитель
Сommit
c9ac5e391f
1 измененных файлов с 15 добавлено и 9 удалено
  1. 15 9
      controller/relay.go

+ 15 - 9
controller/relay.go

@@ -26,9 +26,10 @@ type ChatRequest struct {
 }
 }
 
 
 type TextRequest struct {
 type TextRequest struct {
-	Model    string    `json:"model"`
-	Messages []Message `json:"messages"`
-	Prompt   string    `json:"prompt"`
+	Model     string    `json:"model"`
+	Messages  []Message `json:"messages"`
+	Prompt    string    `json:"prompt"`
+	MaxTokens int       `json:"max_tokens"`
 	//Stream   bool      `json:"stream"`
 	//Stream   bool      `json:"stream"`
 }
 }
 
 
@@ -128,8 +129,17 @@ func relayHelper(c *gin.Context) error {
 		model_ = strings.TrimSuffix(model_, "-0314")
 		model_ = strings.TrimSuffix(model_, "-0314")
 		fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/%s", baseURL, model_, task)
 		fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/%s", baseURL, model_, task)
 	}
 	}
+	var promptText string
+	for _, message := range textRequest.Messages {
+		promptText += fmt.Sprintf("%s: %s\n", message.Role, message.Content)
+	}
+	promptTokens := countToken(promptText) + 3
+	preConsumedTokens := common.PreConsumedQuota
+	if textRequest.MaxTokens != 0 {
+		preConsumedTokens = promptTokens + textRequest.MaxTokens
+	}
 	ratio := common.GetModelRatio(textRequest.Model)
 	ratio := common.GetModelRatio(textRequest.Model)
-	preConsumedQuota := int(float64(common.PreConsumedQuota) * ratio)
+	preConsumedQuota := int(float64(preConsumedTokens) * ratio)
 	if consumeQuota {
 	if consumeQuota {
 		err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota)
 		err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota)
 		if err != nil {
 		if err != nil {
@@ -176,12 +186,8 @@ func relayHelper(c *gin.Context) error {
 				completionRatio = 2
 				completionRatio = 2
 			}
 			}
 			if isStream {
 			if isStream {
-				var promptText string
-				for _, message := range textRequest.Messages {
-					promptText += fmt.Sprintf("%s: %s\n", message.Role, message.Content)
-				}
 				completionText := fmt.Sprintf("%s: %s\n", "assistant", streamResponseText)
 				completionText := fmt.Sprintf("%s: %s\n", "assistant", streamResponseText)
-				quota = countToken(promptText) + countToken(completionText)*completionRatio + 3
+				quota = promptTokens + countToken(completionText)*completionRatio
 			} else {
 			} else {
 				quota = textResponse.Usage.PromptTokens + textResponse.Usage.CompletionTokens*completionRatio
 				quota = textResponse.Usage.PromptTokens + textResponse.Usage.CompletionTokens*completionRatio
 			}
 			}