Просмотр исходного кода

feat: Add CountToken configuration and update token counting logic

CaIon 3 месяцев назад
Родитель
Сommit
0952973887
3 измененных файлов с 48 добавлено и 35 удалено
  1. 1 0
      common/init.go
  2. 1 0
      constant/env.go
  3. 46 35
      service/token_counter.go

+ 1 - 0
common/init.go

@@ -111,6 +111,7 @@ func initConstantEnv() {
 	constant.MaxFileDownloadMB = GetEnvOrDefault("MAX_FILE_DOWNLOAD_MB", 20)
 	// ForceStreamOption 覆盖请求参数,强制返回usage信息
 	constant.ForceStreamOption = GetEnvOrDefaultBool("FORCE_STREAM_OPTION", true)
+	constant.CountToken = GetEnvOrDefaultBool("CountToken", true)
 	constant.GetMediaToken = GetEnvOrDefaultBool("GET_MEDIA_TOKEN", true)
 	constant.GetMediaTokenNotStream = GetEnvOrDefaultBool("GET_MEDIA_TOKEN_NOT_STREAM", false)
 	constant.UpdateTask = GetEnvOrDefaultBool("UPDATE_TASK", true)

+ 1 - 0
constant/env.go

@@ -4,6 +4,7 @@ var StreamingTimeout int
 var DifyDebug bool
 var MaxFileDownloadMB int
 var ForceStreamOption bool
+var CountToken bool
 var GetMediaToken bool
 var GetMediaTokenNotStream bool
 var UpdateTask bool

+ 46 - 35
service/token_counter.go

@@ -143,6 +143,12 @@ func getImageToken(fileMeta *types.FileMeta, model string, stream bool) (int, er
 	if fileMeta.Detail == "low" && !isPatchBased {
 		return baseTokens, nil
 	}
+
+	// Whether to count image tokens at all
+	if !constant.GetMediaToken {
+		return 3 * baseTokens, nil
+	}
+
 	if !constant.GetMediaTokenNotStream && !stream {
 		return 3 * baseTokens, nil
 	}
@@ -150,10 +156,6 @@ func getImageToken(fileMeta *types.FileMeta, model string, stream bool) (int, er
 	if fileMeta.Detail == "auto" || fileMeta.Detail == "" {
 		fileMeta.Detail = "high"
 	}
-	// Whether to count image tokens at all
-	if !constant.GetMediaToken {
-		return 3 * baseTokens, nil
-	}
 
 	// Decode image to get dimensions
 	var config image.Config
@@ -256,16 +258,15 @@ func getImageToken(fileMeta *types.FileMeta, model string, stream bool) (int, er
 }
 
 func CountRequestToken(c *gin.Context, meta *types.TokenCountMeta, info *relaycommon.RelayInfo) (int, error) {
+	// 是否统计token
+	if !constant.CountToken {
+		return 0, nil
+	}
+
 	if meta == nil {
 		return 0, errors.New("token count meta is nil")
 	}
 
-	if !constant.GetMediaToken {
-		return 0, nil
-	}
-	if !constant.GetMediaTokenNotStream && !info.IsStream {
-		return 0, nil
-	}
 	if info.RelayFormat == types.RelayFormatOpenAIRealtime {
 		return 0, nil
 	}
@@ -316,9 +317,19 @@ func CountRequestToken(c *gin.Context, meta *types.TokenCountMeta, info *relayco
 		shouldFetchFiles = false
 	}
 
-	if shouldFetchFiles {
-		for _, file := range meta.Files {
-			if strings.HasPrefix(file.OriginData, "http") {
+	// 是否本地计算媒体token数量
+	if !constant.GetMediaToken {
+		shouldFetchFiles = false
+	}
+
+	// 是否在非流模式下本地计算媒体token数量
+	if !constant.GetMediaTokenNotStream && !info.IsStream {
+		shouldFetchFiles = false
+	}
+
+	for _, file := range meta.Files {
+		if strings.HasPrefix(file.OriginData, "http") {
+			if shouldFetchFiles {
 				mineType, err := GetFileTypeFromUrl(c, file.OriginData, "token_counter")
 				if err != nil {
 					return 0, fmt.Errorf("error getting file base64 from url: %v", err)
@@ -333,28 +344,28 @@ func CountRequestToken(c *gin.Context, meta *types.TokenCountMeta, info *relayco
 					file.FileType = types.FileTypeFile
 				}
 				file.MimeType = mineType
-			} else if strings.HasPrefix(file.OriginData, "data:") {
-				// get mime type from base64 header
-				parts := strings.SplitN(file.OriginData, ",", 2)
-				if len(parts) >= 1 {
-					header := parts[0]
-					// Extract mime type from "data:mime/type;base64" format
-					if strings.Contains(header, ":") && strings.Contains(header, ";") {
-						mimeStart := strings.Index(header, ":") + 1
-						mimeEnd := strings.Index(header, ";")
-						if mimeStart < mimeEnd {
-							mineType := header[mimeStart:mimeEnd]
-							if strings.HasPrefix(mineType, "image/") {
-								file.FileType = types.FileTypeImage
-							} else if strings.HasPrefix(mineType, "video/") {
-								file.FileType = types.FileTypeVideo
-							} else if strings.HasPrefix(mineType, "audio/") {
-								file.FileType = types.FileTypeAudio
-							} else {
-								file.FileType = types.FileTypeFile
-							}
-							file.MimeType = mineType
+			}
+		} else if strings.HasPrefix(file.OriginData, "data:") {
+			// get mime type from base64 header
+			parts := strings.SplitN(file.OriginData, ",", 2)
+			if len(parts) >= 1 {
+				header := parts[0]
+				// Extract mime type from "data:mime/type;base64" format
+				if strings.Contains(header, ":") && strings.Contains(header, ";") {
+					mimeStart := strings.Index(header, ":") + 1
+					mimeEnd := strings.Index(header, ";")
+					if mimeStart < mimeEnd {
+						mineType := header[mimeStart:mimeEnd]
+						if strings.HasPrefix(mineType, "image/") {
+							file.FileType = types.FileTypeImage
+						} else if strings.HasPrefix(mineType, "video/") {
+							file.FileType = types.FileTypeVideo
+						} else if strings.HasPrefix(mineType, "audio/") {
+							file.FileType = types.FileTypeAudio
+						} else {
+							file.FileType = types.FileTypeFile
 						}
+						file.MimeType = mineType
 					}
 				}
 			}
@@ -365,7 +376,7 @@ func CountRequestToken(c *gin.Context, meta *types.TokenCountMeta, info *relayco
 		switch file.FileType {
 		case types.FileTypeImage:
 			if info.RelayFormat == types.RelayFormatGemini {
-				tkm += 256
+				tkm += 520 // gemini per input image tokens
 			} else {
 				token, err := getImageToken(file, model, info.IsStream)
 				if err != nil {