conversation.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267
  1. from dataclasses import dataclass, field
  2. from typing import Literal
  3. import torch
  4. from .tokenizer import MODALITY_TOKENS, FishTokenizer
  5. CODEBOOK_PAD_TOKEN_ID = 0
  6. @dataclass(kw_only=True)
  7. class BasePart:
  8. pass
  9. @dataclass(kw_only=True)
  10. class VQPart(BasePart):
  11. codes: torch.Tensor
  12. @dataclass(kw_only=True)
  13. class TextPart(BasePart):
  14. text: str
  15. @dataclass(kw_only=True)
  16. class EncodedMessage:
  17. tokens: torch.Tensor
  18. labels: torch.Tensor
  19. vq_mask_tokens: torch.Tensor | None = None
  20. vq_mask_labels: torch.Tensor | None = None
  21. vq_parts: list[torch.Tensor]
  22. vq_require_losses: torch.Tensor | None = None
  23. @dataclass(kw_only=True)
  24. class Message:
  25. role: Literal["system", "user", "assistant"]
  26. parts: list[VQPart | TextPart] = field(default_factory=list)
  27. add_im_start: bool = True
  28. add_im_end: bool = True
  29. cal_loss: bool = False
  30. modality: Literal["text", "voice", "interleave"] | None = None
  31. # By default, ignore the loss of the auto-generated im_start token
  32. ignore_im_start_loss: bool = True
  33. def encode(
  34. self: "Message",
  35. tokenizer: FishTokenizer,
  36. ) -> EncodedMessage:
  37. all_tokens = []
  38. all_labels = []
  39. # Multi-modal tokens
  40. vq_parts = []
  41. vq_masks = []
  42. parts = self.parts.copy()
  43. if self.add_im_start:
  44. modality_token = MODALITY_TOKENS[self.modality] if self.modality else ""
  45. parts.insert(0, TextPart(text=f"<|im_start|>{self.role}\n{modality_token}"))
  46. if self.add_im_end:
  47. parts.append(TextPart(text="<|im_end|>"))
  48. for part in parts:
  49. if isinstance(part, TextPart):
  50. tokens = torch.tensor(
  51. tokenizer.encode(part.text),
  52. dtype=torch.int,
  53. )
  54. elif isinstance(part, VQPart):
  55. curr_codes = part.codes.clone()
  56. tokens = torch.tensor(
  57. [
  58. tokenizer.semantic_id_to_token_id[i.item()]
  59. for i in curr_codes[0].int()
  60. ],
  61. dtype=torch.int,
  62. )
  63. vq_parts.append(curr_codes)
  64. else:
  65. raise ValueError(f"Unsupported part type: {type(part)}")
  66. all_tokens.append(tokens)
  67. if isinstance(part, VQPart):
  68. vq_masks.append(torch.ones_like(tokens, dtype=torch.bool))
  69. else:
  70. vq_masks.append(torch.zeros_like(tokens, dtype=torch.bool))
  71. if self.cal_loss:
  72. all_labels.append(tokens.clone())
  73. else:
  74. all_labels.append(torch.full_like(tokens, -100))
  75. tokens = torch.cat(all_tokens, dim=0)
  76. labels = torch.cat(all_labels, dim=0)
  77. vq_masks = torch.cat(vq_masks, dim=0)
  78. assert tokens.shape == labels.shape == vq_masks.shape
  79. if self.ignore_im_start_loss and self.add_im_start:
  80. labels[: len(all_tokens[0])] = -100
  81. return EncodedMessage(
  82. tokens=tokens,
  83. labels=labels,
  84. vq_parts=vq_parts,
  85. vq_mask_tokens=vq_masks,
  86. vq_mask_labels=vq_masks,
  87. )
  88. @dataclass
  89. class Conversation:
  90. messages: list[Message]
  91. def __init__(self: "Conversation", messages: list[Message] | None = None):
  92. self.messages = messages or []
  93. def encode(
  94. self: "Conversation",
  95. tokenizer: FishTokenizer,
  96. add_shift: bool = True,
  97. ignore_loss_tokens: list[str] = [],
  98. ) -> EncodedMessage:
  99. # Build the input_ids and labels
  100. tokens = []
  101. labels = []
  102. vq_parts = []
  103. vq_mask_tokens = []
  104. vq_mask_labels = []
  105. vq_require_losses = []
  106. ignore_loss_token_ids = [tokenizer.get_token_id(i) for i in ignore_loss_tokens]
  107. for message in self.messages:
  108. encoded = message.encode(
  109. tokenizer,
  110. )
  111. tokens.append(encoded.tokens)
  112. labels.append(encoded.labels)
  113. vq_parts.extend(encoded.vq_parts)
  114. vq_mask_tokens.append(encoded.vq_mask_tokens)
  115. vq_mask_labels.append(encoded.vq_mask_labels)
  116. vq_require_losses.extend([message.cal_loss] * len(encoded.vq_parts))
  117. tokens = torch.cat(tokens, dim=0)
  118. labels = torch.cat(labels, dim=0)
  119. vq_mask_tokens = torch.cat(vq_mask_tokens, dim=0)
  120. vq_mask_labels = torch.cat(vq_mask_labels, dim=0)
  121. vq_require_losses = torch.tensor(vq_require_losses, dtype=torch.bool)
  122. if add_shift:
  123. tokens = tokens[:-1]
  124. labels = labels[1:]
  125. vq_mask_tokens = vq_mask_tokens[:-1]
  126. vq_mask_labels = vq_mask_labels[1:]
  127. for i in ignore_loss_token_ids:
  128. assert i != -100 and i is not None
  129. labels[labels == i] = -100
  130. assert tokens.dtype in [
  131. torch.int,
  132. torch.long,
  133. ], f"Invalid dtype: {tokens.dtype}, conv: {conversation}"
  134. return EncodedMessage(
  135. tokens=tokens,
  136. labels=labels,
  137. vq_parts=vq_parts,
  138. vq_mask_tokens=vq_mask_tokens,
  139. vq_mask_labels=vq_mask_labels,
  140. vq_require_losses=vq_require_losses,
  141. )
  142. def encode_for_inference(
  143. self: "Conversation",
  144. tokenizer: FishTokenizer,
  145. num_codebooks: int,
  146. ) -> EncodedMessage:
  147. # self.visualize(tokenizer)
  148. encoded = self.encode(tokenizer, add_shift=False)
  149. tokens = encoded.tokens
  150. values = torch.zeros((num_codebooks + 1, len(tokens)), dtype=torch.int)
  151. values[0] = tokens
  152. if encoded.vq_parts is None or len(encoded.vq_parts) == 0:
  153. return values
  154. vq_parts = encoded.vq_parts
  155. vq_parts = [part.to(values.device) for part in vq_parts]
  156. vq_parts = torch.cat(vq_parts, dim=1)
  157. values[0, encoded.vq_mask_tokens] = vq_parts[0] + tokenizer.semantic_begin_id
  158. values[1:, encoded.vq_mask_tokens] = vq_parts
  159. return values
  160. def visualize(
  161. self: "Conversation",
  162. tokenizer: FishTokenizer,
  163. ignore_loss_tokens: list[str] = [],
  164. ):
  165. encoded = self.encode(
  166. tokenizer, add_shift=False, ignore_loss_tokens=ignore_loss_tokens
  167. )
  168. # Colors for alternating tokens
  169. colors = {
  170. "blue": "\033[94m", # Light blue
  171. "cyan": "\033[96m", # Cyan
  172. "green": "\033[92m", # Light green
  173. "dark_green": "\033[32m", # Dark green
  174. }
  175. blue_idx = 0
  176. green_idx = 0
  177. def print_in_blue(x):
  178. nonlocal blue_idx
  179. color = colors["blue"] if blue_idx % 2 == 0 else colors["cyan"]
  180. print(f"{color}{x}\033[0m", end="")
  181. blue_idx += 1
  182. def print_in_green(x):
  183. nonlocal green_idx
  184. color = colors["green"] if green_idx % 2 == 0 else colors["dark_green"]
  185. print(f"{color}{x}\033[0m", end="")
  186. green_idx += 1
  187. for tok, lab in zip(encoded.tokens, encoded.labels):
  188. val = tokenizer.decode([tok])
  189. if lab == -100:
  190. print_in_green(val)
  191. else:
  192. print_in_blue(val)
  193. print()
  194. def append(self: "Conversation", message: Message):
  195. self.messages.append(message)
  196. if __name__ == "__main__":
  197. message0 = Message(
  198. role="user",
  199. parts=[
  200. TextPart(text="Hello, how are you?"),
  201. VQPart(codes=torch.zeros((4, 10))),
  202. ],
  203. cal_loss=False,
  204. )
  205. message1 = Message(
  206. role="assistant",
  207. parts=[TextPart(text="I'm fine, thank you.")],
  208. cal_loss=True,
  209. )
  210. conversation = Conversation([message0, message1])
  211. tokenizer = FishTokenizer.from_pretrained("checkpoints/Qwen2-1.5B-Instruct")
  212. conversation.visualize(tokenizer)
  213. encoded = conversation.encode(tokenizer)
  214. print(encoded)
  215. print(tokenizer.batch_decode(encoded.tokens))