tokenizer.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. import json
  2. import logging
  3. from pathlib import Path
  4. from typing import TYPE_CHECKING, List, Union
  5. import torch
  6. from transformers import AutoTokenizer
  7. if TYPE_CHECKING:
  8. from transformers import PreTrainedTokenizerFast
  9. logger = logging.getLogger(__name__)
  10. # Constants definitions
  11. EOS_TOKEN = "<|endoftext|>"
  12. PAD_TOKEN = "<|pad|>"
  13. IM_START_TOKEN = "<|im_start|>"
  14. IM_END_TOKEN = "<|im_end|>"
  15. PHONEME_START_TOKEN = "<|phoneme_start|>"
  16. PHONEME_END_TOKEN = "<|phoneme_end|>"
  17. MODALITY_TEXT_TOKEN = "<|text|>"
  18. MODALITY_VOICE_TOKEN = "<|voice|>"
  19. MODALITY_INTERLEAVE_TOKEN = "<|interleave|>"
  20. AUDIO_START_TOKEN = "<|audio_start|>"
  21. AUDIO_END_TOKEN = "<|audio_end|>"
  22. AUDIO_EMBED_TOKEN = "<|audio_pad|>"
  23. MODALITY_TOKENS = {
  24. "text": MODALITY_TEXT_TOKEN,
  25. "voice": MODALITY_VOICE_TOKEN,
  26. "interleave": MODALITY_INTERLEAVE_TOKEN,
  27. }
  28. SEMANTIC_TOKEN_TEMPLATE = "<|semantic:{i}|>"
  29. SEMANTIC_TOKENS =[SEMANTIC_TOKEN_TEMPLATE.format(i=i) for i in range(4096)]
  30. ALL_SPECIAL_TOKENS =[
  31. EOS_TOKEN, PAD_TOKEN, IM_START_TOKEN, IM_END_TOKEN,
  32. PHONEME_START_TOKEN, PHONEME_END_TOKEN, MODALITY_TEXT_TOKEN,
  33. MODALITY_VOICE_TOKEN, MODALITY_INTERLEAVE_TOKEN, AUDIO_START_TOKEN,
  34. AUDIO_END_TOKEN, AUDIO_EMBED_TOKEN, *SEMANTIC_TOKENS,
  35. ]
  36. class FishTokenizer:
  37. def __init__(self, model_path: str):
  38. self._tokenizer = AutoTokenizer.from_pretrained(model_path)
  39. self.semantic_id_to_token_id = {}
  40. vocab = self._tokenizer.get_vocab()
  41. valid_ids =[]
  42. for code_idx in range(4096):
  43. token = SEMANTIC_TOKEN_TEMPLATE.format(i=code_idx)
  44. if token in vocab:
  45. token_id = vocab[token]
  46. self.semantic_id_to_token_id[code_idx] = token_id
  47. valid_ids.append(token_id)
  48. if not valid_ids:
  49. logger.error("CRITICAL ERROR: No semantic tokens found in vocab! Audio cannot be synthesized.")
  50. self.semantic_begin_id = 0
  51. self.semantic_end_id = 0
  52. # Dummy tensor to prevent crash, though generation will fail
  53. self.semantic_map_tensor = torch.zeros(4096, dtype=torch.long)
  54. else:
  55. self.semantic_begin_id = min(valid_ids)
  56. self.semantic_end_id = max(valid_ids)
  57. # Create a lookup tensor to handle potential gaps in token IDs safely
  58. self.semantic_map_tensor = torch.zeros(4096, dtype=torch.long)
  59. for k, v in self.semantic_id_to_token_id.items():
  60. self.semantic_map_tensor[k] = v
  61. logger.info(f"Loaded Tokenizer. Semantic Range: {self.semantic_begin_id} -> {self.semantic_end_id}")
  62. @property
  63. def vocab_size(self):
  64. return self._tokenizer.vocab_size
  65. @property
  66. def pad_token_id(self):
  67. return self._tokenizer.pad_token_id
  68. @property
  69. def eos_token_id(self):
  70. return self._tokenizer.eos_token_id
  71. def get_token_id(self, token: str) -> int:
  72. return self._tokenizer.convert_tokens_to_ids(token)
  73. def encode(self, text: str, add_special_tokens: bool = False, **kwargs) -> List[int]:
  74. # [FIX] Force Qwen/Tiktoken backends to parse special tokens inline
  75. import inspect
  76. sig = inspect.signature(self._tokenizer.encode)
  77. if "allowed_special" in sig.parameters and "allowed_special" not in kwargs:
  78. kwargs["allowed_special"] = "all"
  79. return self._tokenizer.encode(text, add_special_tokens=add_special_tokens, **kwargs)
  80. def decode(self, tokens: Union[List[int], int], **kwargs) -> str:
  81. return self._tokenizer.decode(tokens, **kwargs)
  82. def save_pretrained(self, path: str):
  83. self._tokenizer.save_pretrained(path)
  84. @classmethod
  85. def from_pretrained(cls, path: str):
  86. return cls(path)
  87. def __getattr__(self, name):
  88. return getattr(self._tokenizer, name)