|
|
@@ -4,6 +4,8 @@ import (
|
|
|
"encoding/json"
|
|
|
"errors"
|
|
|
"fmt"
|
|
|
+ "github.com/tiktoken-go/tokenizer"
|
|
|
+ "github.com/tiktoken-go/tokenizer/codec"
|
|
|
"image"
|
|
|
"log"
|
|
|
"math"
|
|
|
@@ -11,78 +13,63 @@ import (
|
|
|
"one-api/constant"
|
|
|
"one-api/dto"
|
|
|
relaycommon "one-api/relay/common"
|
|
|
- "one-api/setting/operation_setting"
|
|
|
"strings"
|
|
|
+ "sync"
|
|
|
"unicode/utf8"
|
|
|
-
|
|
|
- "github.com/pkoukk/tiktoken-go"
|
|
|
)
|
|
|
|
|
|
// tokenEncoderMap won't grow after initialization
|
|
|
-var tokenEncoderMap = map[string]*tiktoken.Tiktoken{}
|
|
|
-var defaultTokenEncoder *tiktoken.Tiktoken
|
|
|
-var o200kTokenEncoder *tiktoken.Tiktoken
|
|
|
+var defaultTokenEncoder tokenizer.Codec
|
|
|
+
|
|
|
+// tokenEncoderMap is used to store token encoders for different models
|
|
|
+var tokenEncoderMap = make(map[string]tokenizer.Codec)
|
|
|
+
|
|
|
+// tokenEncoderMutex protects tokenEncoderMap for concurrent access
|
|
|
+var tokenEncoderMutex sync.RWMutex
|
|
|
|
|
|
func InitTokenEncoders() {
|
|
|
common.SysLog("initializing token encoders")
|
|
|
- cl100TokenEncoder, err := tiktoken.GetEncoding(tiktoken.MODEL_CL100K_BASE)
|
|
|
- if err != nil {
|
|
|
- common.FatalLog(fmt.Sprintf("failed to get gpt-3.5-turbo token encoder: %s", err.Error()))
|
|
|
- }
|
|
|
- defaultTokenEncoder = cl100TokenEncoder
|
|
|
- o200kTokenEncoder, err = tiktoken.GetEncoding(tiktoken.MODEL_O200K_BASE)
|
|
|
- if err != nil {
|
|
|
- common.FatalLog(fmt.Sprintf("failed to get gpt-4o token encoder: %s", err.Error()))
|
|
|
- }
|
|
|
- for model, _ := range operation_setting.GetDefaultModelRatioMap() {
|
|
|
- if strings.HasPrefix(model, "gpt-3.5") {
|
|
|
- tokenEncoderMap[model] = cl100TokenEncoder
|
|
|
- } else if strings.HasPrefix(model, "gpt-4") {
|
|
|
- if strings.HasPrefix(model, "gpt-4o") {
|
|
|
- tokenEncoderMap[model] = o200kTokenEncoder
|
|
|
- } else {
|
|
|
- tokenEncoderMap[model] = defaultTokenEncoder
|
|
|
- }
|
|
|
- } else if strings.HasPrefix(model, "o") {
|
|
|
- tokenEncoderMap[model] = o200kTokenEncoder
|
|
|
- } else {
|
|
|
- tokenEncoderMap[model] = defaultTokenEncoder
|
|
|
- }
|
|
|
- }
|
|
|
+ defaultTokenEncoder = codec.NewCl100kBase()
|
|
|
common.SysLog("token encoders initialized")
|
|
|
}
|
|
|
|
|
|
-func getModelDefaultTokenEncoder(model string) *tiktoken.Tiktoken {
|
|
|
- if strings.HasPrefix(model, "gpt-4o") || strings.HasPrefix(model, "chatgpt-4o") || strings.HasPrefix(model, "o1") {
|
|
|
- return o200kTokenEncoder
|
|
|
+func getTokenEncoder(model string) tokenizer.Codec {
|
|
|
+ // First, try to get the encoder from cache with read lock
|
|
|
+ tokenEncoderMutex.RLock()
|
|
|
+ if encoder, exists := tokenEncoderMap[model]; exists {
|
|
|
+ tokenEncoderMutex.RUnlock()
|
|
|
+ return encoder
|
|
|
}
|
|
|
- return defaultTokenEncoder
|
|
|
-}
|
|
|
+ tokenEncoderMutex.RUnlock()
|
|
|
|
|
|
-func getTokenEncoder(model string) *tiktoken.Tiktoken {
|
|
|
- tokenEncoder, ok := tokenEncoderMap[model]
|
|
|
- if ok && tokenEncoder != nil {
|
|
|
- return tokenEncoder
|
|
|
+ // If not in cache, create new encoder with write lock
|
|
|
+ tokenEncoderMutex.Lock()
|
|
|
+ defer tokenEncoderMutex.Unlock()
|
|
|
+
|
|
|
+ // Double-check if another goroutine already created the encoder
|
|
|
+ if encoder, exists := tokenEncoderMap[model]; exists {
|
|
|
+ return encoder
|
|
|
}
|
|
|
- // 如果ok(即model在tokenEncoderMap中),但是tokenEncoder为nil,说明可能是自定义模型
|
|
|
- if ok {
|
|
|
- tokenEncoder, err := tiktoken.EncodingForModel(model)
|
|
|
- if err != nil {
|
|
|
- common.SysError(fmt.Sprintf("failed to get token encoder for model %s: %s, using encoder for gpt-3.5-turbo", model, err.Error()))
|
|
|
- tokenEncoder = getModelDefaultTokenEncoder(model)
|
|
|
- }
|
|
|
- tokenEncoderMap[model] = tokenEncoder
|
|
|
- return tokenEncoder
|
|
|
+
|
|
|
+ // Create new encoder
|
|
|
+ modelCodec, err := tokenizer.ForModel(tokenizer.Model(model))
|
|
|
+ if err != nil {
|
|
|
+ // Cache the default encoder for this model to avoid repeated failures
|
|
|
+ tokenEncoderMap[model] = defaultTokenEncoder
|
|
|
+ return defaultTokenEncoder
|
|
|
}
|
|
|
- // 如果model不在tokenEncoderMap中,直接返回默认的tokenEncoder
|
|
|
- return getModelDefaultTokenEncoder(model)
|
|
|
+
|
|
|
+ // Cache the new encoder
|
|
|
+ tokenEncoderMap[model] = modelCodec
|
|
|
+ return modelCodec
|
|
|
}
|
|
|
|
|
|
-func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int {
|
|
|
+func getTokenNum(tokenEncoder tokenizer.Codec, text string) int {
|
|
|
if text == "" {
|
|
|
return 0
|
|
|
}
|
|
|
- return len(tokenEncoder.Encode(text, nil, nil))
|
|
|
+ ids, _, _ := tokenEncoder.Encode(text)
|
|
|
+ return len(ids)
|
|
|
}
|
|
|
|
|
|
func getImageToken(info *relaycommon.RelayInfo, imageUrl *dto.MessageImageUrl, model string, stream bool) (int, error) {
|