tokenizer.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179
  1. import base64
  2. import json
  3. import logging
  4. import re
  5. from pathlib import Path
  6. import tiktoken
  7. logger = logging.getLogger(__name__)
  8. # This is a modified version of the default pattern from GPT-4o, that better handles punctuations.
  9. FISH_TIKTOKEN_PATTERN = "|".join(
  10. [
  11. r"(?i:'s|'t|'re|'ve|'m|'ll|'d)",
  12. r"\p{P}",
  13. r"[^\r\n\p{L}\p{N}]?\p{L}+",
  14. r"\p{N}",
  15. r" ?[^\s\p{L}\p{N}]+[\r\n]*",
  16. r"\s*[\r\n]+",
  17. r"\s+(\?!\S)",
  18. r"\s+",
  19. ]
  20. )
  21. TIKTOKEN_MAX_ENCODE_CHARS = 400_000
  22. BOS_TOKEN = "<|begin_of_text|>"
  23. EOS_TOKEN = "<|end_of_text|>"
  24. PAD_TOKEN = "<|pad|>"
  25. IM_START_TOKEN = "<|im_start|>"
  26. IM_END_TOKEN = "<|im_end|>"
  27. PHONEME_START_TOKEN = "<|phoneme_start|>"
  28. PHONEME_END_TOKEN = "<|phoneme_end|>"
  29. TOOL_CALL_START_TOKEN = "<|tool_call_start|>"
  30. TOOL_CALL_END_TOKEN = "<|tool_call_end|>"
  31. MODALITY_TEXT_TOKEN = "<|text|>"
  32. MODALITY_VOICE_TOKEN = "<|voice|>"
  33. MODALITY_INTERLEAVE_TOKEN = "<|interleave|>"
  34. AUDIO_START_TOKEN = "<|audio_start|>"
  35. AUDIO_END_TOKEN = "<|audio_end|>"
  36. AUDIO_EMBED_TOKEN = "<|audio|>"
  37. MODALITY_TOKENS = {
  38. "text": MODALITY_TEXT_TOKEN,
  39. "voice": MODALITY_VOICE_TOKEN,
  40. "interleave": MODALITY_INTERLEAVE_TOKEN,
  41. }
  42. SEMANTIC_TOKEN_TEMPLATE = "<|semantic:{i}|>"
  43. SEMANTIC_TOKENS = [SEMANTIC_TOKEN_TEMPLATE.format(i=i) for i in range(4096)]
  44. # Warning: when you add a new special token, you should only add it to the end of the list.
  45. ALL_SPECIAL_TOKENS = [
  46. BOS_TOKEN,
  47. EOS_TOKEN,
  48. PAD_TOKEN,
  49. IM_START_TOKEN,
  50. IM_END_TOKEN,
  51. PHONEME_START_TOKEN,
  52. PHONEME_END_TOKEN,
  53. TOOL_CALL_START_TOKEN,
  54. TOOL_CALL_END_TOKEN,
  55. MODALITY_TEXT_TOKEN,
  56. MODALITY_VOICE_TOKEN,
  57. MODALITY_INTERLEAVE_TOKEN,
  58. AUDIO_START_TOKEN,
  59. AUDIO_END_TOKEN,
  60. AUDIO_EMBED_TOKEN,
  61. *SEMANTIC_TOKENS,
  62. ]
  63. class FishTokenizer:
  64. def __init__(
  65. self, model_path: str, special_tokens: list[str] = ALL_SPECIAL_TOKENS
  66. ) -> None:
  67. mergeable_ranks = self.load_tiktoken_bpe(model_path)
  68. special_token_begin = len(mergeable_ranks)
  69. self.all_special_tokens_with_ids = {
  70. token: special_token_begin + i for i, token in enumerate(special_tokens)
  71. }
  72. self.semantic_id_to_token_id = {}
  73. end_idx = 0
  74. for token in special_tokens:
  75. if token.startswith("<|semantic:"):
  76. idx = int(re.match(r"<\|semantic:(\d+)\|>", token).group(1))
  77. self.semantic_id_to_token_id[idx] = self.all_special_tokens_with_ids[
  78. token
  79. ]
  80. if idx > end_idx:
  81. end_idx = idx
  82. self.semantic_begin_id = self.semantic_id_to_token_id[0]
  83. self.semantic_end_id = self.semantic_id_to_token_id[end_idx]
  84. self.tkt_model = tiktoken.core.Encoding(
  85. name=Path(model_path).stem,
  86. pat_str=FISH_TIKTOKEN_PATTERN,
  87. mergeable_ranks=mergeable_ranks,
  88. special_tokens=self.all_special_tokens_with_ids,
  89. )
  90. @property
  91. def vocab_size(self):
  92. return len(self.tkt_model._mergeable_ranks)
  93. @property
  94. def num_special_tokens(self):
  95. return len(self.all_special_tokens_with_ids)
  96. @staticmethod
  97. def load_tiktoken_bpe(tiktoken_bpe_file: str) -> dict[bytes, int]:
  98. data = {}
  99. for line in open(tiktoken_bpe_file).read().splitlines():
  100. if not line:
  101. continue
  102. token, rank = line.split()
  103. if token == "=":
  104. continue
  105. data[base64.b64decode(token)] = int(rank)
  106. return data
  107. def get_token_id(self, token: str) -> int:
  108. return self.all_special_tokens_with_ids[token]
  109. def encode(self, s: str, allowed_special: bool | set[str] = True) -> list[int]:
  110. assert isinstance(s, str)
  111. subs = []
  112. for i in range(0, len(s), TIKTOKEN_MAX_ENCODE_CHARS):
  113. subs.append(s[i : i + TIKTOKEN_MAX_ENCODE_CHARS])
  114. if allowed_special is True:
  115. allowed_special = self.tkt_model.special_tokens_set
  116. elif allowed_special is False:
  117. allowed_special = set()
  118. return sum(
  119. self.tkt_model.encode_batch(
  120. subs, allowed_special=allowed_special, disallowed_special=set()
  121. ),
  122. start=[],
  123. )
  124. def decode(self, tokens: list[int]) -> str:
  125. return self.tkt_model.decode(tokens)
  126. def save_pretrained(self, path: str):
  127. path = Path(path)
  128. path.mkdir(parents=True, exist_ok=True)
  129. with open(path / "tokenizer.tiktoken", "w") as f:
  130. for token, rank in self.tkt_model._mergeable_ranks.items():
  131. a = base64.b64encode(token).decode()
  132. if a == "":
  133. a = "="
  134. f.write(f"{a} {rank}\n")
  135. with open(path / "special_tokens.json", "w") as f:
  136. json.dump(
  137. self.all_special_tokens_with_ids,
  138. f,
  139. indent=2,
  140. ensure_ascii=False,
  141. )
  142. @staticmethod
  143. def from_pretrained(path: str):
  144. special_tokens_path = Path(path) / "special_tokens.json"
  145. if special_tokens_path.exists():
  146. with open(special_tokens_path) as f:
  147. all_special_tokens_with_ids = json.load(f)
  148. else:
  149. all_special_tokens_with_ids = ALL_SPECIAL_TOKENS
  150. return FishTokenizer(
  151. Path(path) / "tokenizer.tiktoken", all_special_tokens_with_ids
  152. )