tokenizer.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152
  1. import base64
  2. import json
  3. import logging
  4. from pathlib import Path
  5. import tiktoken
  6. logger = logging.getLogger(__name__)
  7. # This is a modified version of the default pattern from GPT-4o, that better handles punctuations.
  8. FISH_TIKTOKEN_PATTERN = "|".join(
  9. [
  10. r"(?i:'s|'t|'re|'ve|'m|'ll|'d)",
  11. r"\p{P}",
  12. r"[^\r\n\p{L}\p{N}]?\p{L}+",
  13. r"\p{N}",
  14. r" ?[^\s\p{L}\p{N}]+[\r\n]*",
  15. r"\s*[\r\n]+",
  16. r"\s+(\?!\S)",
  17. r"\s+",
  18. ]
  19. )
  20. TIKTOKEN_MAX_ENCODE_CHARS = 400_000
  21. BOS_TOKEN = "<|begin_of_text|>"
  22. EOS_TOKEN = "<|end_of_text|>"
  23. PAD_TOKEN = "<|pad|>"
  24. IM_START_TOKEN = "<|im_start|>"
  25. IM_END_TOKEN = "<|im_end|>"
  26. MODALITY_TEXT_TOKEN = "<|text|>"
  27. MODALITY_VOICE_TOKEN = "<|voice|>"
  28. MODALITY_INTERLEAVE_TOKEN = "<|interleave|>"
  29. MODALITY_TOKENS = {
  30. "text": MODALITY_TEXT_TOKEN,
  31. "voice": MODALITY_VOICE_TOKEN,
  32. "interleave": MODALITY_INTERLEAVE_TOKEN,
  33. }
  34. PLACEHOLDER_TOKEN = [""] * 4
  35. for i in range(4):
  36. PLACEHOLDER_TOKEN[i] = f"<|placeholder:{i}|>"
  37. SEMANTIC_TOKEN_TEMPLATE = "<|semantic:{i}|>"
  38. SEMANTIC_TOKENS = [SEMANTIC_TOKEN_TEMPLATE.format(i=i) for i in range(1024)]
  39. # Warning: when you add a new special token, you should only add it to the end of the list.
  40. ALL_SPECIAL_TOKENS = [
  41. BOS_TOKEN,
  42. EOS_TOKEN,
  43. PAD_TOKEN,
  44. IM_START_TOKEN,
  45. IM_END_TOKEN,
  46. PLACEHOLDER_TOKEN[0],
  47. PLACEHOLDER_TOKEN[1],
  48. PLACEHOLDER_TOKEN[2],
  49. PLACEHOLDER_TOKEN[3],
  50. MODALITY_TEXT_TOKEN,
  51. MODALITY_VOICE_TOKEN,
  52. MODALITY_INTERLEAVE_TOKEN,
  53. *SEMANTIC_TOKENS,
  54. ]
  55. class FishTokenizer:
  56. def __init__(self, model_path: str) -> None:
  57. mergeable_ranks = self.load_tiktoken_bpe(model_path)
  58. special_token_begin = len(mergeable_ranks)
  59. self.all_special_tokens_with_ids = {
  60. token: special_token_begin + i for i, token in enumerate(ALL_SPECIAL_TOKENS)
  61. }
  62. self.semantic_id_to_token_id = {
  63. i: self.all_special_tokens_with_ids[token]
  64. for i, token in enumerate(SEMANTIC_TOKENS)
  65. }
  66. self.semantic_begin_id = self.all_special_tokens_with_ids[SEMANTIC_TOKENS[0]]
  67. self.semantic_end_id = self.all_special_tokens_with_ids[SEMANTIC_TOKENS[-1]]
  68. self.tkt_model = tiktoken.core.Encoding(
  69. name=Path(model_path).stem,
  70. pat_str=FISH_TIKTOKEN_PATTERN,
  71. mergeable_ranks=mergeable_ranks,
  72. special_tokens=self.all_special_tokens_with_ids,
  73. )
  74. @staticmethod
  75. def load_tiktoken_bpe(tiktoken_bpe_file: str) -> dict[bytes, int]:
  76. data = {}
  77. for line in open(tiktoken_bpe_file).read().splitlines():
  78. if not line:
  79. continue
  80. token, rank = line.split()
  81. data[base64.b64decode(token)] = int(rank)
  82. return data
  83. def get_token_id(self, token: str) -> int:
  84. return self.all_special_tokens_with_ids[token]
  85. def encode(self, s: str, allowed_special: bool | set[str] = True) -> list[int]:
  86. assert isinstance(s, str)
  87. subs = []
  88. for i in range(0, len(s), TIKTOKEN_MAX_ENCODE_CHARS):
  89. subs.append(s[i : i + TIKTOKEN_MAX_ENCODE_CHARS])
  90. if allowed_special is True:
  91. allowed_special = self.tkt_model.special_tokens_set
  92. elif allowed_special is False:
  93. allowed_special = set()
  94. return sum(
  95. self.tkt_model.encode_batch(
  96. subs, allowed_special=allowed_special, disallowed_special=set()
  97. ),
  98. start=[],
  99. )
  100. def decode(self, tokens: list[int]) -> str:
  101. return self.tkt_model.decode(tokens)
  102. def save_pretrained(self, path: str):
  103. path = Path(path)
  104. path.mkdir(parents=True, exist_ok=True)
  105. with open(path / "tokenizer.tiktoken", "w") as f:
  106. for token, rank in self.tkt_model._mergeable_ranks.items():
  107. f.write(f"{base64.b64encode(token).decode()} {rank}\n")
  108. with open(path / "special_tokens.json", "w") as f:
  109. json.dump(
  110. self.all_special_tokens_with_ids,
  111. f,
  112. indent=2,
  113. ensure_ascii=False,
  114. )
  115. @staticmethod
  116. def from_pretrained(path: str):
  117. return FishTokenizer(Path(path) / "tokenizer.tiktoken")
  118. if __name__ == "__main__":
  119. tokenizer = FishTokenizer("data/mpacks/v1.4-pretrain/tokenizer.all.tiktoken")
  120. tokenizer.save_pretrained("checkpoints/fish-speech-0.5B")
  121. tokenizer = FishTokenizer.from_pretrained("checkpoints/fish-speech-0.5B")
  122. print(
  123. [
  124. tokenizer.decode([i])
  125. for i in tokenizer.encode(f"{BOS_TOKEN}你好,世界!{EOS_TOKEN}")
  126. ]
  127. )