conversation.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256
  1. from dataclasses import dataclass, field
  2. from typing import Literal
  3. import torch
  4. from transformers import AutoTokenizer, PretrainedConfig, PreTrainedTokenizerFast
  5. IM_START_TOKEN = "<|im_start|>"
  6. IM_END_TOKEN = "<|im_end|>"
  7. SEMANTIC_TOKEN = "<|semantic|>"
  8. MEL_TOKEN = "<|mel|>"
  9. PHONEME_START_TOKEN = "<|phoneme_start|>"
  10. PHONEME_END_TOKEN = "<|phoneme_end|>"
  11. ALL_SPECIAL_TOKENS = [
  12. IM_START_TOKEN,
  13. IM_END_TOKEN,
  14. SEMANTIC_TOKEN,
  15. MEL_TOKEN,
  16. PHONEME_START_TOKEN,
  17. PHONEME_END_TOKEN,
  18. ]
  19. CODEBOOK_PAD_TOKEN_ID = 0
  20. class FishTokenizerConfig(PretrainedConfig):
  21. share_codebook_embeddings: bool = True
  22. codebook_size: int = 1024
  23. num_codebooks: int = 8
  24. class FishTokenizerFast(PreTrainedTokenizerFast):
  25. def __init__(self, *args, **kwargs):
  26. super().__init__(*args, **kwargs)
  27. self.share_codebook_embeddings = kwargs.pop("share_codebook_embeddings", True)
  28. self.codebook_size = kwargs.pop("codebook_size", 1024)
  29. self.num_codebooks = kwargs.pop("num_codebooks", 8)
  30. AutoTokenizer.register(FishTokenizerConfig, fast_tokenizer_class=FishTokenizerFast)
  31. @dataclass(kw_only=True)
  32. class BasePart:
  33. pass
  34. @dataclass(kw_only=True)
  35. class VQPart(BasePart):
  36. codes: torch.Tensor
  37. @dataclass(kw_only=True)
  38. class TextPart(BasePart):
  39. text: str
  40. @dataclass(kw_only=True)
  41. class MelPart(BasePart):
  42. mels: torch.Tensor
  43. @dataclass(kw_only=True)
  44. class EncodedMessage:
  45. tokens: torch.Tensor
  46. labels: torch.Tensor
  47. vq_parts: list[torch.Tensor]
  48. mel_parts: list[torch.Tensor]
  49. vq_require_losses: torch.Tensor | None = None
  50. @dataclass(kw_only=True)
  51. class Message:
  52. role: Literal["system", "user", "assistant"]
  53. parts: list[VQPart | TextPart | MelPart] = field(default_factory=list)
  54. add_im_start: bool = True
  55. add_im_end: bool = True
  56. cal_loss: bool = False
  57. # By default, ignore the loss of the auto-generated im_start token
  58. ignore_im_start_loss: bool = True
  59. def encode(
  60. self: "Message",
  61. tokenizer: AutoTokenizer,
  62. ) -> EncodedMessage:
  63. all_tokens = []
  64. all_labels = []
  65. # Multi-modal tokens
  66. vq_parts = []
  67. mel_parts = []
  68. semantic_id, mel_id = tokenizer.convert_tokens_to_ids(
  69. [SEMANTIC_TOKEN, MEL_TOKEN]
  70. )
  71. parts = self.parts.copy()
  72. if self.add_im_start:
  73. parts.insert(0, TextPart(text=f"<|im_start|>{self.role}\n"))
  74. if self.add_im_end:
  75. parts.append(TextPart(text="<|im_end|>"))
  76. for part in parts:
  77. if isinstance(part, TextPart):
  78. tokens = tokenizer.encode(
  79. part.text,
  80. add_special_tokens=False,
  81. truncation=False,
  82. return_tensors="pt",
  83. ).int()[0]
  84. elif isinstance(part, VQPart):
  85. tokens = torch.zeros(part.codes.shape[1], dtype=torch.int) + semantic_id
  86. codes = part.codes.clone() + 1
  87. if getattr(tokenizer, "share_codebook_embeddings", True) is False:
  88. for i in range(len(codes)):
  89. codes[i] += tokenizer.codebook_size * i
  90. vq_parts.append(codes)
  91. elif isinstance(part, MelPart):
  92. tokens = torch.zeros(part.mels.shape[1], dtype=torch.int) + mel_id
  93. mel_parts.append(part.mels)
  94. else:
  95. raise ValueError(f"Unsupported part type: {type(part)}")
  96. all_tokens.append(tokens)
  97. if self.cal_loss:
  98. all_labels.append(tokens.clone())
  99. else:
  100. all_labels.append(torch.full_like(tokens, -100))
  101. tokens = torch.cat(all_tokens, dim=0)
  102. labels = torch.cat(all_labels, dim=0)
  103. assert tokens.shape == labels.shape
  104. if self.ignore_im_start_loss and self.add_im_start:
  105. labels[: len(all_tokens[0])] = -100
  106. return EncodedMessage(
  107. tokens=tokens,
  108. labels=labels,
  109. vq_parts=vq_parts,
  110. mel_parts=mel_parts,
  111. )
  112. @dataclass
  113. class Conversation:
  114. messages: list[Message]
  115. def encode(
  116. self: "Conversation",
  117. tokenizer: AutoTokenizer,
  118. add_shift: bool = True,
  119. ) -> EncodedMessage:
  120. # Build the input_ids and labels
  121. tokens = []
  122. labels = []
  123. vq_parts = []
  124. mel_parts = []
  125. vq_require_losses = []
  126. for message in self.messages:
  127. encoded = message.encode(
  128. tokenizer,
  129. )
  130. tokens.append(encoded.tokens)
  131. labels.append(encoded.labels)
  132. vq_parts.extend(encoded.vq_parts)
  133. mel_parts.extend(encoded.mel_parts)
  134. vq_require_losses.extend([message.cal_loss] * len(encoded.vq_parts))
  135. tokens = torch.cat(tokens, dim=0)
  136. labels = torch.cat(labels, dim=0)
  137. vq_require_losses = torch.tensor(vq_require_losses, dtype=torch.bool)
  138. if add_shift:
  139. tokens = tokens[:-1]
  140. labels = labels[1:]
  141. assert tokens.dtype in [
  142. torch.int,
  143. torch.long,
  144. ], f"Invalid dtype: {tokens.dtype}, conv: {conversation}"
  145. return EncodedMessage(
  146. tokens=tokens,
  147. labels=labels,
  148. vq_parts=vq_parts,
  149. mel_parts=mel_parts,
  150. vq_require_losses=vq_require_losses,
  151. )
  152. def encode_for_inference(
  153. self: "Conversation",
  154. tokenizer: AutoTokenizer,
  155. num_codebooks: int,
  156. ) -> EncodedMessage:
  157. encoded = self.encode(tokenizer, add_shift=False)
  158. tokens = encoded.tokens
  159. values = torch.zeros((num_codebooks + 1, len(tokens)), dtype=torch.int)
  160. values[0] = tokens
  161. if encoded.vq_parts is None or len(encoded.vq_parts) == 0:
  162. return values
  163. semantic_id, mel_id = tokenizer.convert_tokens_to_ids(
  164. [SEMANTIC_TOKEN, MEL_TOKEN]
  165. )
  166. vq_parts = encoded.vq_parts
  167. vq_parts = torch.cat(vq_parts, dim=1)
  168. values[1:, tokens == semantic_id] = vq_parts
  169. return values
  170. def visualize(self: "Conversation", tokenizer: AutoTokenizer):
  171. encoded = self.encode(tokenizer, add_shift=False)
  172. print_in_blue = lambda x: print("\033[94m" + x + "\033[0m", end="")
  173. print_in_green = lambda x: print("\033[92m" + x + "\033[0m", end="")
  174. for tok, lab in zip(encoded.tokens, encoded.labels):
  175. val = tokenizer.decode(tok, skip_special_tokens=False)
  176. if val == "\n":
  177. val = "\\n\n"
  178. if lab == -100:
  179. print_in_green(val)
  180. else:
  181. print_in_blue(val)
  182. print()
  183. if __name__ == "__main__":
  184. message0 = Message(
  185. role="user",
  186. parts=[
  187. TextPart(text="Hello, how are you?"),
  188. VQPart(codes=torch.zeros((4, 10))),
  189. ],
  190. cal_loss=False,
  191. )
  192. message1 = Message(
  193. role="assistant",
  194. parts=[TextPart(text="I'm fine, thank you.")],
  195. cal_loss=True,
  196. )
  197. conversation = Conversation([message0, message1])
  198. tokenizer = AutoTokenizer.from_pretrained("checkpoints/Qwen2-1.5B-Instruct")
  199. conversation.visualize(tokenizer)
  200. encoded = conversation.encode(tokenizer)
  201. print(encoded)
  202. print(tokenizer.batch_decode(encoded.tokens))