Преглед изворни кода

fix: make the token number calculation more accurate (#101)

* Make token calculation more accurate.

* fix: make the token number calculation more accurate

---------

Co-authored-by: JustSong <songquanpeng@foxmail.com>
quzard пре 2 година
родитељ
комит
7c6bf3e97b
2 измењених фајлова са 66 додато и 15 уклоњено
  1. 61 0
      controller/relay-utils.go
  2. 5 15
      controller/relay.go

+ 61 - 0
controller/relay-utils.go

@@ -0,0 +1,61 @@
+package controller
+
+import (
+	"fmt"
+	"github.com/pkoukk/tiktoken-go"
+	"one-api/common"
+	"strings"
+)
+
+var tokenEncoderMap = map[string]*tiktoken.Tiktoken{}
+
+func getTokenEncoder(model string) *tiktoken.Tiktoken {
+	if tokenEncoder, ok := tokenEncoderMap[model]; ok {
+		return tokenEncoder
+	}
+	tokenEncoder, err := tiktoken.EncodingForModel(model)
+	if err != nil {
+		common.FatalLog(fmt.Sprintf("failed to get token encoder for model %s: %s", model, err.Error()))
+	}
+	tokenEncoderMap[model] = tokenEncoder
+	return tokenEncoder
+}
+
+func countTokenMessages(messages []Message, model string) int {
+	tokenEncoder := getTokenEncoder(model)
+	// Reference:
+	// https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
+	// https://github.com/pkoukk/tiktoken-go/issues/6
+	//
+	// Every message follows <|start|>{role/name}\n{content}<|end|>\n
+	var tokensPerMessage int
+	var tokensPerName int
+	if strings.HasPrefix(model, "gpt-3.5") {
+		tokensPerMessage = 4
+		tokensPerName = -1 // If there's a name, the role is omitted
+	} else if strings.HasPrefix(model, "gpt-4") {
+		tokensPerMessage = 3
+		tokensPerName = 1
+	} else {
+		tokensPerMessage = 3
+		tokensPerName = 1
+	}
+	tokenNum := 0
+	for _, message := range messages {
+		tokenNum += tokensPerMessage
+		tokenNum += len(tokenEncoder.Encode(message.Content, nil, nil))
+		tokenNum += len(tokenEncoder.Encode(message.Role, nil, nil))
+		if message.Name != "" {
+			tokenNum += tokensPerName
+			tokenNum += len(tokenEncoder.Encode(message.Name, nil, nil))
+		}
+	}
+	tokenNum += 3 // Every reply is primed with <|start|>assistant<|message|>
+	return tokenNum
+}
+
+func countTokenText(text string, model string) int {
+	tokenEncoder := getTokenEncoder(model)
+	token := tokenEncoder.Encode(text, nil, nil)
+	return len(token)
+}

+ 5 - 15
controller/relay.go

@@ -6,7 +6,6 @@ import (
 	"encoding/json"
 	"encoding/json"
 	"fmt"
 	"fmt"
 	"github.com/gin-gonic/gin"
 	"github.com/gin-gonic/gin"
-	"github.com/pkoukk/tiktoken-go"
 	"io"
 	"io"
 	"net/http"
 	"net/http"
 	"one-api/common"
 	"one-api/common"
@@ -17,6 +16,7 @@ import (
 type Message struct {
 type Message struct {
 	Role    string `json:"role"`
 	Role    string `json:"role"`
 	Content string `json:"content"`
 	Content string `json:"content"`
+	Name    string `json:"name"`
 }
 }
 
 
 type ChatRequest struct {
 type ChatRequest struct {
@@ -65,13 +65,6 @@ type StreamResponse struct {
 	} `json:"choices"`
 	} `json:"choices"`
 }
 }
 
 
-var tokenEncoder, _ = tiktoken.GetEncoding("cl100k_base")
-
-func countToken(text string) int {
-	token := tokenEncoder.Encode(text, nil, nil)
-	return len(token)
-}
-
 func Relay(c *gin.Context) {
 func Relay(c *gin.Context) {
 	err := relayHelper(c)
 	err := relayHelper(c)
 	if err != nil {
 	if err != nil {
@@ -149,11 +142,8 @@ func relayHelper(c *gin.Context) *OpenAIErrorWithStatusCode {
 		model_ = strings.TrimSuffix(model_, "-0314")
 		model_ = strings.TrimSuffix(model_, "-0314")
 		fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/%s", baseURL, model_, task)
 		fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/%s", baseURL, model_, task)
 	}
 	}
-	var promptText string
-	for _, message := range textRequest.Messages {
-		promptText += fmt.Sprintf("%s: %s\n", message.Role, message.Content)
-	}
-	promptTokens := countToken(promptText) + 3
+
+	promptTokens := countTokenMessages(textRequest.Messages, textRequest.Model)
 	preConsumedTokens := common.PreConsumedQuota
 	preConsumedTokens := common.PreConsumedQuota
 	if textRequest.MaxTokens != 0 {
 	if textRequest.MaxTokens != 0 {
 		preConsumedTokens = promptTokens + textRequest.MaxTokens
 		preConsumedTokens = promptTokens + textRequest.MaxTokens
@@ -206,8 +196,8 @@ func relayHelper(c *gin.Context) *OpenAIErrorWithStatusCode {
 				completionRatio = 2
 				completionRatio = 2
 			}
 			}
 			if isStream {
 			if isStream {
-				completionText := fmt.Sprintf("%s: %s\n", "assistant", streamResponseText)
-				quota = promptTokens + countToken(completionText)*completionRatio
+				responseTokens := countTokenText(streamResponseText, textRequest.Model)
+				quota = promptTokens + responseTokens*completionRatio
 			} else {
 			} else {
 				quota = textResponse.Usage.PromptTokens + textResponse.Usage.CompletionTokens*completionRatio
 				quota = textResponse.Usage.PromptTokens + textResponse.Usage.CompletionTokens*completionRatio
 			}
 			}