token_counter.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409
  1. package service
  2. import (
  3. "errors"
  4. "fmt"
  5. "log"
  6. "math"
  7. "path/filepath"
  8. "strings"
  9. "unicode/utf8"
  10. "github.com/QuantumNous/new-api/common"
  11. "github.com/QuantumNous/new-api/constant"
  12. "github.com/QuantumNous/new-api/dto"
  13. relaycommon "github.com/QuantumNous/new-api/relay/common"
  14. constant2 "github.com/QuantumNous/new-api/relay/constant"
  15. "github.com/QuantumNous/new-api/types"
  16. "github.com/gin-gonic/gin"
  17. )
  18. func getImageToken(c *gin.Context, fileMeta *types.FileMeta, model string, stream bool) (int, error) {
  19. if fileMeta == nil || fileMeta.Source == nil {
  20. return 0, fmt.Errorf("image_url_is_nil")
  21. }
  22. // Defaults for 4o/4.1/4.5 family unless overridden below
  23. baseTokens := 85
  24. tileTokens := 170
  25. // Model classification
  26. lowerModel := strings.ToLower(model)
  27. // Special cases from existing behavior
  28. if strings.HasPrefix(lowerModel, "glm-4") {
  29. return 1047, nil
  30. }
  31. // Patch-based models (32x32 patches, capped at 1536, with multiplier)
  32. isPatchBased := false
  33. multiplier := 1.0
  34. switch {
  35. case strings.Contains(lowerModel, "gpt-4.1-mini"):
  36. isPatchBased = true
  37. multiplier = 1.62
  38. case strings.Contains(lowerModel, "gpt-4.1-nano"):
  39. isPatchBased = true
  40. multiplier = 2.46
  41. case strings.HasPrefix(lowerModel, "o4-mini"):
  42. isPatchBased = true
  43. multiplier = 1.72
  44. case strings.HasPrefix(lowerModel, "gpt-5-mini"):
  45. isPatchBased = true
  46. multiplier = 1.62
  47. case strings.HasPrefix(lowerModel, "gpt-5-nano"):
  48. isPatchBased = true
  49. multiplier = 2.46
  50. }
  51. // Tile-based model tokens and bases per doc
  52. if !isPatchBased {
  53. if strings.HasPrefix(lowerModel, "gpt-4o-mini") {
  54. baseTokens = 2833
  55. tileTokens = 5667
  56. } else if strings.HasPrefix(lowerModel, "gpt-5-chat-latest") || (strings.HasPrefix(lowerModel, "gpt-5") && !strings.Contains(lowerModel, "mini") && !strings.Contains(lowerModel, "nano")) {
  57. baseTokens = 70
  58. tileTokens = 140
  59. } else if strings.HasPrefix(lowerModel, "o1") || strings.HasPrefix(lowerModel, "o3") || strings.HasPrefix(lowerModel, "o1-pro") {
  60. baseTokens = 75
  61. tileTokens = 150
  62. } else if strings.Contains(lowerModel, "computer-use-preview") {
  63. baseTokens = 65
  64. tileTokens = 129
  65. } else if strings.Contains(lowerModel, "4.1") || strings.Contains(lowerModel, "4o") || strings.Contains(lowerModel, "4.5") {
  66. baseTokens = 85
  67. tileTokens = 170
  68. }
  69. }
  70. // Respect existing feature flags/short-circuits
  71. if fileMeta.Detail == "low" && !isPatchBased {
  72. return baseTokens, nil
  73. }
  74. // Whether to count image tokens at all
  75. if !constant.GetMediaToken {
  76. return 3 * baseTokens, nil
  77. }
  78. if !constant.GetMediaTokenNotStream && !stream {
  79. return 3 * baseTokens, nil
  80. }
  81. // Normalize detail
  82. if fileMeta.Detail == "auto" || fileMeta.Detail == "" {
  83. fileMeta.Detail = "high"
  84. }
  85. // 使用统一的文件服务获取图片配置
  86. config, format, err := GetImageConfig(c, fileMeta.Source)
  87. if err != nil {
  88. return 0, err
  89. }
  90. fileMeta.MimeType = format
  91. if config.Width == 0 || config.Height == 0 {
  92. // not an image, but might be a valid file
  93. if format != "" {
  94. // file type
  95. return 3 * baseTokens, nil
  96. }
  97. return 0, errors.New(fmt.Sprintf("fail to decode image config: %s", fileMeta.GetIdentifier()))
  98. }
  99. width := config.Width
  100. height := config.Height
  101. log.Printf("format: %s, width: %d, height: %d", format, width, height)
  102. if isPatchBased {
  103. // 32x32 patch-based calculation with 1536 cap and model multiplier
  104. ceilDiv := func(a, b int) int { return (a + b - 1) / b }
  105. rawPatchesW := ceilDiv(width, 32)
  106. rawPatchesH := ceilDiv(height, 32)
  107. rawPatches := rawPatchesW * rawPatchesH
  108. if rawPatches > 1536 {
  109. // scale down
  110. area := float64(width * height)
  111. r := math.Sqrt(float64(32*32*1536) / area)
  112. wScaled := float64(width) * r
  113. hScaled := float64(height) * r
  114. // adjust to fit whole number of patches after scaling
  115. adjW := math.Floor(wScaled/32.0) / (wScaled / 32.0)
  116. adjH := math.Floor(hScaled/32.0) / (hScaled / 32.0)
  117. adj := math.Min(adjW, adjH)
  118. if !math.IsNaN(adj) && adj > 0 {
  119. r = r * adj
  120. }
  121. wScaled = float64(width) * r
  122. hScaled = float64(height) * r
  123. patchesW := math.Ceil(wScaled / 32.0)
  124. patchesH := math.Ceil(hScaled / 32.0)
  125. imageTokens := int(patchesW * patchesH)
  126. if imageTokens > 1536 {
  127. imageTokens = 1536
  128. }
  129. return int(math.Round(float64(imageTokens) * multiplier)), nil
  130. }
  131. // below cap
  132. imageTokens := rawPatches
  133. return int(math.Round(float64(imageTokens) * multiplier)), nil
  134. }
  135. // Tile-based calculation for 4o/4.1/4.5/o1/o3/etc.
  136. // Step 1: fit within 2048x2048 square
  137. maxSide := math.Max(float64(width), float64(height))
  138. fitScale := 1.0
  139. if maxSide > 2048 {
  140. fitScale = maxSide / 2048.0
  141. }
  142. fitW := int(math.Round(float64(width) / fitScale))
  143. fitH := int(math.Round(float64(height) / fitScale))
  144. // Step 2: scale so that shortest side is exactly 768
  145. minSide := math.Min(float64(fitW), float64(fitH))
  146. if minSide == 0 {
  147. return baseTokens, nil
  148. }
  149. shortScale := 768.0 / minSide
  150. finalW := int(math.Round(float64(fitW) * shortScale))
  151. finalH := int(math.Round(float64(fitH) * shortScale))
  152. // Count 512px tiles
  153. tilesW := (finalW + 512 - 1) / 512
  154. tilesH := (finalH + 512 - 1) / 512
  155. tiles := tilesW * tilesH
  156. if common.DebugEnabled {
  157. log.Printf("scaled to: %dx%d, tiles: %d", finalW, finalH, tiles)
  158. }
  159. return tiles*tileTokens + baseTokens, nil
  160. }
  161. func EstimateRequestToken(c *gin.Context, meta *types.TokenCountMeta, info *relaycommon.RelayInfo) (int, error) {
  162. // 是否统计token
  163. if !constant.CountToken {
  164. return 0, nil
  165. }
  166. if meta == nil {
  167. return 0, errors.New("token count meta is nil")
  168. }
  169. if info.RelayFormat == types.RelayFormatOpenAIRealtime {
  170. return 0, nil
  171. }
  172. if info.RelayMode == constant2.RelayModeAudioTranscription || info.RelayMode == constant2.RelayModeAudioTranslation {
  173. multiForm, err := common.ParseMultipartFormReusable(c)
  174. if err != nil {
  175. return 0, fmt.Errorf("error parsing multipart form: %v", err)
  176. }
  177. fileHeaders := multiForm.File["file"]
  178. totalAudioToken := 0
  179. for _, fileHeader := range fileHeaders {
  180. file, err := fileHeader.Open()
  181. if err != nil {
  182. return 0, fmt.Errorf("error opening audio file: %v", err)
  183. }
  184. defer file.Close()
  185. // get ext and io.seeker
  186. ext := filepath.Ext(fileHeader.Filename)
  187. duration, err := common.GetAudioDuration(c.Request.Context(), file, ext)
  188. if err != nil {
  189. return 0, fmt.Errorf("error getting audio duration: %v", err)
  190. }
  191. // 一分钟 1000 token,与 $price / minute 对齐
  192. totalAudioToken += int(math.Round(math.Ceil(duration) / 60.0 * 1000))
  193. }
  194. return totalAudioToken, nil
  195. }
  196. model := common.GetContextKeyString(c, constant.ContextKeyOriginalModel)
  197. tkm := 0
  198. if meta.TokenType == types.TokenTypeTextNumber {
  199. tkm += utf8.RuneCountInString(meta.CombineText)
  200. } else {
  201. tkm += CountTextToken(meta.CombineText, model)
  202. }
  203. if info.RelayFormat == types.RelayFormatOpenAI {
  204. tkm += meta.ToolsCount * 8
  205. tkm += meta.MessagesCount * 3 // 每条消息的格式化token数量
  206. tkm += meta.NameCount * 3
  207. tkm += 3
  208. }
  209. shouldFetchFiles := true
  210. if info.RelayFormat == types.RelayFormatGemini {
  211. shouldFetchFiles = false
  212. }
  213. // 是否本地计算媒体token数量
  214. if !constant.GetMediaToken {
  215. shouldFetchFiles = false
  216. }
  217. // 是否在非流模式下本地计算媒体token数量
  218. if !constant.GetMediaTokenNotStream && !info.IsStream {
  219. shouldFetchFiles = false
  220. }
  221. // 使用统一的文件服务获取文件类型
  222. for _, file := range meta.Files {
  223. if file.Source == nil {
  224. continue
  225. }
  226. // 如果文件类型未知且需要获取,通过 MIME 类型检测
  227. if file.FileType == "" || (file.Source.IsURL() && shouldFetchFiles) {
  228. mimeType, err := GetMimeType(c, file.Source)
  229. if err != nil {
  230. if shouldFetchFiles {
  231. return 0, fmt.Errorf("error getting file type: %v", err)
  232. }
  233. // 如果不需要获取,使用默认类型
  234. continue
  235. }
  236. file.MimeType = mimeType
  237. file.FileType = DetectFileType(mimeType)
  238. }
  239. }
  240. for i, file := range meta.Files {
  241. switch file.FileType {
  242. case types.FileTypeImage:
  243. if common.IsOpenAITextModel(model) {
  244. token, err := getImageToken(c, file, model, info.IsStream)
  245. if err != nil {
  246. return 0, fmt.Errorf("error counting image token, media index[%d], identifier[%s], err: %v", i, file.GetIdentifier(), err)
  247. }
  248. tkm += token
  249. } else {
  250. tkm += 520
  251. }
  252. case types.FileTypeAudio:
  253. tkm += 256
  254. case types.FileTypeVideo:
  255. tkm += 4096 * 2
  256. case types.FileTypeFile:
  257. tkm += 4096
  258. default:
  259. tkm += 4096 // Default case for unknown file types
  260. }
  261. }
  262. common.SetContextKey(c, constant.ContextKeyPromptTokens, tkm)
  263. return tkm, nil
  264. }
  265. func CountTokenRealtime(info *relaycommon.RelayInfo, request dto.RealtimeEvent, model string) (int, int, error) {
  266. audioToken := 0
  267. textToken := 0
  268. switch request.Type {
  269. case dto.RealtimeEventTypeSessionUpdate:
  270. if request.Session != nil {
  271. msgTokens := CountTextToken(request.Session.Instructions, model)
  272. textToken += msgTokens
  273. }
  274. case dto.RealtimeEventResponseAudioDelta:
  275. // count audio token
  276. atk, err := CountAudioTokenOutput(request.Delta, info.OutputAudioFormat)
  277. if err != nil {
  278. return 0, 0, fmt.Errorf("error counting audio token: %v", err)
  279. }
  280. audioToken += atk
  281. case dto.RealtimeEventResponseAudioTranscriptionDelta, dto.RealtimeEventResponseFunctionCallArgumentsDelta:
  282. // count text token
  283. tkm := CountTextToken(request.Delta, model)
  284. textToken += tkm
  285. case dto.RealtimeEventInputAudioBufferAppend:
  286. // count audio token
  287. atk, err := CountAudioTokenInput(request.Audio, info.InputAudioFormat)
  288. if err != nil {
  289. return 0, 0, fmt.Errorf("error counting audio token: %v", err)
  290. }
  291. audioToken += atk
  292. case dto.RealtimeEventConversationItemCreated:
  293. if request.Item != nil {
  294. switch request.Item.Type {
  295. case "message":
  296. for _, content := range request.Item.Content {
  297. if content.Type == "input_text" {
  298. tokens := CountTextToken(content.Text, model)
  299. textToken += tokens
  300. }
  301. }
  302. }
  303. }
  304. case dto.RealtimeEventTypeResponseDone:
  305. // count tools token
  306. if !info.IsFirstRequest {
  307. if info.RealtimeTools != nil && len(info.RealtimeTools) > 0 {
  308. for _, tool := range info.RealtimeTools {
  309. toolTokens := CountTokenInput(tool, model)
  310. textToken += 8
  311. textToken += toolTokens
  312. }
  313. }
  314. }
  315. }
  316. return textToken, audioToken, nil
  317. }
  318. func CountTokenInput(input any, model string) int {
  319. switch v := input.(type) {
  320. case string:
  321. return CountTextToken(v, model)
  322. case []string:
  323. text := ""
  324. for _, s := range v {
  325. text += s
  326. }
  327. return CountTextToken(text, model)
  328. case []interface{}:
  329. text := ""
  330. for _, item := range v {
  331. text += fmt.Sprintf("%v", item)
  332. }
  333. return CountTextToken(text, model)
  334. }
  335. return CountTokenInput(fmt.Sprintf("%v", input), model)
  336. }
  337. func CountAudioTokenInput(audioBase64 string, audioFormat string) (int, error) {
  338. if audioBase64 == "" {
  339. return 0, nil
  340. }
  341. duration, err := parseAudio(audioBase64, audioFormat)
  342. if err != nil {
  343. return 0, err
  344. }
  345. return int(duration / 60 * 100 / 0.06), nil
  346. }
  347. func CountAudioTokenOutput(audioBase64 string, audioFormat string) (int, error) {
  348. if audioBase64 == "" {
  349. return 0, nil
  350. }
  351. duration, err := parseAudio(audioBase64, audioFormat)
  352. if err != nil {
  353. return 0, err
  354. }
  355. return int(duration / 60 * 200 / 0.24), nil
  356. }
  357. // CountTextToken 统计文本的token数量,仅OpenAI模型使用tokenizer,其余模型使用估算
  358. func CountTextToken(text string, model string) int {
  359. if text == "" {
  360. return 0
  361. }
  362. if common.IsOpenAITextModel(model) {
  363. tokenEncoder := getTokenEncoder(model)
  364. return getTokenNum(tokenEncoder, text)
  365. } else {
  366. // 非openai模型,使用tiktoken-go计算没有意义,使用估算节省资源
  367. return EstimateTokenByModel(model, text)
  368. }
  369. }