Просмотр исходного кода

fix: now the input field can be array type now (close #149)

JustSong 2 лет назад
Родитель
Сommit
7c7eb6b7ec
2 измененных файлов с 16 добавлено и 2 удалено
  1. 14 0
      controller/relay-utils.go
  2. 2 2
      controller/relay.go

+ 14 - 0
controller/relay-utils.go

@@ -58,6 +58,20 @@ func countTokenMessages(messages []Message, model string) int {
 	return tokenNum
 	return tokenNum
 }
 }
 
 
+func countTokenInput(input any, model string) int {
+	switch input.(type) {
+	case string:
+		return countTokenText(input.(string), model)
+	case []string:
+		text := ""
+		for _, s := range input.([]string) {
+			text += s
+		}
+		return countTokenText(text, model)
+	}
+	return 0
+}
+
 func countTokenText(text string, model string) int {
 func countTokenText(text string, model string) int {
 	tokenEncoder := getTokenEncoder(model)
 	tokenEncoder := getTokenEncoder(model)
 	token := tokenEncoder.Encode(text, nil, nil)
 	token := tokenEncoder.Encode(text, nil, nil)

+ 2 - 2
controller/relay.go

@@ -38,7 +38,7 @@ type GeneralOpenAIRequest struct {
 	Temperature float64   `json:"temperature"`
 	Temperature float64   `json:"temperature"`
 	TopP        float64   `json:"top_p"`
 	TopP        float64   `json:"top_p"`
 	N           int       `json:"n"`
 	N           int       `json:"n"`
-	Input       string    `json:"input"`
+	Input       any       `json:"input"`
 }
 }
 
 
 type ChatRequest struct {
 type ChatRequest struct {
@@ -189,7 +189,7 @@ func relayHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 	case RelayModeCompletions:
 	case RelayModeCompletions:
 		promptTokens = countTokenText(textRequest.Prompt, textRequest.Model)
 		promptTokens = countTokenText(textRequest.Prompt, textRequest.Model)
 	case RelayModeModeration:
 	case RelayModeModeration:
-		promptTokens = countTokenText(textRequest.Input, textRequest.Model)
+		promptTokens = countTokenInput(textRequest.Input, textRequest.Model)
 	}
 	}
 	preConsumedTokens := common.PreConsumedQuota
 	preConsumedTokens := common.PreConsumedQuota
 	if textRequest.MaxTokens != 0 {
 	if textRequest.MaxTokens != 0 {