|
|
@@ -24,6 +24,13 @@ func getTokenEncoder(model string) *tiktoken.Tiktoken {
|
|
|
return tokenEncoder
|
|
|
}
|
|
|
|
|
|
+func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int {
|
|
|
+ if common.ApproximateTokenEnabled {
|
|
|
+ return int(float64(len(text)) * 0.38)
|
|
|
+ }
|
|
|
+ return len(tokenEncoder.Encode(text, nil, nil))
|
|
|
+}
|
|
|
+
|
|
|
func countTokenMessages(messages []Message, model string) int {
|
|
|
tokenEncoder := getTokenEncoder(model)
|
|
|
// Reference:
|
|
|
@@ -43,11 +50,11 @@ func countTokenMessages(messages []Message, model string) int {
|
|
|
tokenNum := 0
|
|
|
for _, message := range messages {
|
|
|
tokenNum += tokensPerMessage
|
|
|
- tokenNum += len(tokenEncoder.Encode(message.Content, nil, nil))
|
|
|
- tokenNum += len(tokenEncoder.Encode(message.Role, nil, nil))
|
|
|
+ tokenNum += getTokenNum(tokenEncoder, message.Content)
|
|
|
+ tokenNum += getTokenNum(tokenEncoder, message.Role)
|
|
|
if message.Name != nil {
|
|
|
tokenNum += tokensPerName
|
|
|
- tokenNum += len(tokenEncoder.Encode(*message.Name, nil, nil))
|
|
|
+ tokenNum += getTokenNum(tokenEncoder, *message.Name)
|
|
|
}
|
|
|
}
|
|
|
tokenNum += 3 // Every reply is primed with <|start|>assistant<|message|>
|
|
|
@@ -70,8 +77,7 @@ func countTokenInput(input any, model string) int {
|
|
|
|
|
|
func countTokenText(text string, model string) int {
|
|
|
tokenEncoder := getTokenEncoder(model)
|
|
|
- token := tokenEncoder.Encode(text, nil, nil)
|
|
|
- return len(token)
|
|
|
+ return getTokenNum(tokenEncoder, text)
|
|
|
}
|
|
|
|
|
|
func errorWrapper(err error, code string, statusCode int) *OpenAIErrorWithStatusCode {
|