tokenizer.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129
  1. import json
  2. import logging
  3. from pathlib import Path
  4. from typing import TYPE_CHECKING, List, Union
  5. import torch
  6. from transformers import AutoTokenizer
  7. if TYPE_CHECKING:
  8. from transformers import PreTrainedTokenizerFast
  9. logger = logging.getLogger(__name__)
  10. # Constants definitions
  11. EOS_TOKEN = "<|endoftext|>"
  12. PAD_TOKEN = "<|pad|>"
  13. IM_START_TOKEN = "<|im_start|>"
  14. IM_END_TOKEN = "<|im_end|>"
  15. PHONEME_START_TOKEN = "<|phoneme_start|>"
  16. PHONEME_END_TOKEN = "<|phoneme_end|>"
  17. MODALITY_TEXT_TOKEN = "<|text|>"
  18. MODALITY_VOICE_TOKEN = "<|voice|>"
  19. MODALITY_INTERLEAVE_TOKEN = "<|interleave|>"
  20. AUDIO_START_TOKEN = "<|audio_start|>"
  21. AUDIO_END_TOKEN = "<|audio_end|>"
  22. AUDIO_EMBED_TOKEN = "<|audio_pad|>"
  23. MODALITY_TOKENS = {
  24. "text": MODALITY_TEXT_TOKEN,
  25. "voice": MODALITY_VOICE_TOKEN,
  26. "interleave": MODALITY_INTERLEAVE_TOKEN,
  27. }
  28. SEMANTIC_TOKEN_TEMPLATE = "<|semantic:{i}|>"
  29. SEMANTIC_TOKENS = [SEMANTIC_TOKEN_TEMPLATE.format(i=i) for i in range(4096)]
  30. ALL_SPECIAL_TOKENS = [
  31. EOS_TOKEN,
  32. PAD_TOKEN,
  33. IM_START_TOKEN,
  34. IM_END_TOKEN,
  35. PHONEME_START_TOKEN,
  36. PHONEME_END_TOKEN,
  37. MODALITY_TEXT_TOKEN,
  38. MODALITY_VOICE_TOKEN,
  39. MODALITY_INTERLEAVE_TOKEN,
  40. AUDIO_START_TOKEN,
  41. AUDIO_END_TOKEN,
  42. AUDIO_EMBED_TOKEN,
  43. *SEMANTIC_TOKENS,
  44. ]
  45. class FishTokenizer:
  46. def __init__(self, model_path: str):
  47. self._tokenizer = AutoTokenizer.from_pretrained(model_path)
  48. self.semantic_id_to_token_id = {}
  49. vocab = self._tokenizer.get_vocab()
  50. valid_ids = []
  51. for code_idx in range(4096):
  52. token = SEMANTIC_TOKEN_TEMPLATE.format(i=code_idx)
  53. if token in vocab:
  54. token_id = vocab[token]
  55. self.semantic_id_to_token_id[code_idx] = token_id
  56. valid_ids.append(token_id)
  57. if not valid_ids:
  58. logger.error(
  59. "CRITICAL ERROR: No semantic tokens found in vocab! Audio cannot be synthesized."
  60. )
  61. self.semantic_begin_id = 0
  62. self.semantic_end_id = 0
  63. # Dummy tensor to prevent crash, though generation will fail
  64. self.semantic_map_tensor = torch.zeros(4096, dtype=torch.long)
  65. else:
  66. self.semantic_begin_id = min(valid_ids)
  67. self.semantic_end_id = max(valid_ids)
  68. # Create a lookup tensor to handle potential gaps in token IDs safely
  69. self.semantic_map_tensor = torch.zeros(4096, dtype=torch.long)
  70. for k, v in self.semantic_id_to_token_id.items():
  71. self.semantic_map_tensor[k] = v
  72. logger.info(
  73. f"Loaded Tokenizer. Semantic Range: {self.semantic_begin_id} -> {self.semantic_end_id}"
  74. )
  75. @property
  76. def vocab_size(self):
  77. return self._tokenizer.vocab_size
  78. @property
  79. def pad_token_id(self):
  80. return self._tokenizer.pad_token_id
  81. @property
  82. def eos_token_id(self):
  83. return self._tokenizer.eos_token_id
  84. def get_token_id(self, token: str) -> int:
  85. return self._tokenizer.convert_tokens_to_ids(token)
  86. def encode(
  87. self, text: str, add_special_tokens: bool = False, **kwargs
  88. ) -> List[int]:
  89. # [FIX] Force Qwen/Tiktoken backends to parse special tokens inline
  90. import inspect
  91. sig = inspect.signature(self._tokenizer.encode)
  92. if "allowed_special" in sig.parameters and "allowed_special" not in kwargs:
  93. kwargs["allowed_special"] = "all"
  94. return self._tokenizer.encode(
  95. text, add_special_tokens=add_special_tokens, **kwargs
  96. )
  97. def decode(self, tokens: Union[List[int], int], **kwargs) -> str:
  98. return self._tokenizer.decode(tokens, **kwargs)
  99. def save_pretrained(self, path: str):
  100. self._tokenizer.save_pretrained(path)
  101. @classmethod
  102. def from_pretrained(cls, path: str):
  103. return cls(path)
  104. def __getattr__(self, name):
  105. return getattr(self._tokenizer, name)