content_sequence.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403
  1. from dataclasses import dataclass, field
  2. from typing import List, Literal, Union
  3. import numpy as np
  4. import torch
  5. from fish_speech.tokenizer import (
  6. IM_END_TOKEN,
  7. MODALITY_TOKENS,
  8. FishTokenizer,
  9. )
  10. def restore_ndarray(obj, to_tensor: bool = False):
  11. if isinstance(obj, dict) and "__ndarray__" in obj:
  12. obj = np.frombuffer(obj["data"], dtype=obj["dtype"]).reshape(obj["shape"])
  13. if to_tensor and isinstance(obj, np.ndarray):
  14. obj = torch.from_numpy(obj.copy())
  15. return obj
  16. @dataclass
  17. class BasePart:
  18. type: Literal["text", "vq", "audio"] | None = None
  19. cal_loss: bool = False
  20. @dataclass(kw_only=True)
  21. class VQPart(BasePart):
  22. type = "vq"
  23. codes: torch.Tensor
  24. def __post_init__(self: "VQPart"):
  25. self.type = "vq"
  26. self.codes = restore_ndarray(self.codes, to_tensor=True)
  27. @dataclass(kw_only=True)
  28. class TextPart(BasePart):
  29. type = "text"
  30. text: str | None = None
  31. tokens: list[int] | None = None
  32. def __post_init__(self: "TextPart"):
  33. self.type = "text"
  34. if self.text is None and self.tokens is None:
  35. raise ValueError("Either text or tokens must be provided")
  36. @dataclass(kw_only=True)
  37. class AudioPart(BasePart):
  38. type = "audio"
  39. features: torch.Tensor
  40. def __post_init__(self: "AudioPart"):
  41. self.type = "audio"
  42. self.features = restore_ndarray(self.features, to_tensor=True)
  43. @dataclass(kw_only=True)
  44. class EncodedMessage:
  45. tokens: torch.Tensor
  46. labels: torch.Tensor
  47. vq_mask_tokens: torch.Tensor | None = None
  48. vq_mask_labels: torch.Tensor | None = None
  49. vq_parts: list[torch.Tensor]
  50. vq_require_losses: torch.Tensor | None = None
  51. audio_parts: list[torch.Tensor]
  52. audio_masks: torch.Tensor | None = None
  53. metadata: dict | None = None
  54. @dataclass
  55. class ContentSequence:
  56. """
  57. Flexible sequence of content parts that supports interleaved multimodal format.
  58. Example format: <|interleave|><|speaker:1|> TEXT AUDIO <|im_end|><|speaker:2|> TEXT AUDIO <|im_end|>
  59. """
  60. parts: list[BasePart] = field(default_factory=list)
  61. modality: Literal["text", "voice", "interleave"] | None = None
  62. metadata: dict | None = None
  63. def __init__(
  64. self: "ContentSequence",
  65. parts: list[BasePart | dict] | None = None,
  66. modality: Literal["text", "voice", "interleave"] | None = None,
  67. metadata: dict | None = None,
  68. ):
  69. self.modality = modality
  70. self.metadata = metadata or {}
  71. fixed_parts = []
  72. for part in parts or []:
  73. if isinstance(part, dict):
  74. if part["type"] == "vq":
  75. part = VQPart(**part)
  76. elif part["type"] == "audio":
  77. part = AudioPart(**part)
  78. elif part["type"] == "text":
  79. part = TextPart(**part)
  80. else:
  81. raise ValueError(f"Unsupported part type: {part['type']}")
  82. fixed_parts.append(part)
  83. self.parts = fixed_parts
  84. # If modality is specified, add it at the beginning if it's not already there
  85. if self.modality and not (
  86. len(self.parts) > 0
  87. and isinstance(self.parts[0], dict) is False
  88. and isinstance(self.parts[0], TextPart)
  89. and self.parts[0].text is not None
  90. and self.parts[0].text.startswith(MODALITY_TOKENS[self.modality])
  91. ):
  92. modality_token = MODALITY_TOKENS[self.modality]
  93. self.parts.insert(0, TextPart(text=modality_token))
  94. def append(
  95. self: "ContentSequence",
  96. part_or_parts: Union[BasePart, List[BasePart]],
  97. add_end: bool = False,
  98. speaker: Union[str, int] | None = None,
  99. ):
  100. """
  101. Append a part or list of parts to the sequence.
  102. Args:
  103. part_or_parts: A single part or list of parts to add
  104. add_end: Whether to add the IM_END_TOKEN after these parts
  105. speaker: Optional speaker identifier (name or ID) to add before the parts
  106. """
  107. # Convert single part to list
  108. parts_to_add = (
  109. [part_or_parts] if not isinstance(part_or_parts, list) else part_or_parts
  110. )
  111. # Add speaker token if specified
  112. if speaker is not None:
  113. speaker_token = f"<|speaker:{speaker}|>"
  114. self.parts.append(TextPart(text=speaker_token))
  115. # Add all the parts
  116. self.parts.extend(parts_to_add)
  117. # Add end token if requested
  118. if add_end:
  119. self.parts.append(
  120. TextPart(text=IM_END_TOKEN, cal_loss=self.parts[-1].cal_loss)
  121. )
  122. def encode(
  123. self: "ContentSequence",
  124. tokenizer: FishTokenizer,
  125. add_shift: bool = True,
  126. ignore_loss_tokens: list[str] = [],
  127. ) -> EncodedMessage:
  128. """
  129. Encode the sequence parts into tokens for the model.
  130. Args:
  131. tokenizer: The tokenizer to use
  132. add_shift: Whether to shift tokens for next-token prediction
  133. ignore_loss_tokens: List of token strings to ignore when calculating loss
  134. Returns:
  135. EncodedMessage with tensors ready for the model
  136. """
  137. all_tokens = []
  138. all_labels = []
  139. # Multi-modal elements
  140. vq_parts = []
  141. vq_masks = []
  142. vq_require_losses = []
  143. audio_parts = []
  144. audio_masks = []
  145. # Optimization: Batch conversion for ignore tokens
  146. ignore_loss_token_ids = []
  147. if ignore_loss_tokens:
  148. # Use the wrapper method which uses convert_tokens_to_ids
  149. ignore_loss_token_ids = [
  150. tokenizer.get_token_id(i) for i in ignore_loss_tokens
  151. ]
  152. for part in self.parts:
  153. if isinstance(part, TextPart):
  154. if part.tokens is None:
  155. assert part.text is not None
  156. # Optimization: Explicitly disable special tokens (BOS/EOS)
  157. # because we are constructing the sequence manually
  158. tokens = tokenizer.encode(part.text, add_special_tokens=False)
  159. else:
  160. tokens = part.tokens
  161. tokens = torch.tensor(tokens, dtype=torch.long)
  162. elif isinstance(part, VQPart):
  163. # Critical Optimization: Vectorized mapping
  164. # Instead of loop lookup: [tokenizer.semantic_id_to_token_id[i] for i in codes]
  165. # We use arithmetic offset: code + semantic_begin_id
  166. # This assumes semantic tokens are contiguous in the vocab (DualAR requirement)
  167. curr_codes = part.codes.clone().to(torch.int)
  168. # Use int64 (long) for token IDs to avoid overflow or type mismatch in embedding
  169. tokens = (curr_codes[0] + tokenizer.semantic_begin_id).to(torch.long)
  170. vq_parts.append(curr_codes)
  171. vq_require_losses.append(part.cal_loss)
  172. else:
  173. raise ValueError(f"Unsupported part type: {type(part)}")
  174. all_tokens.append(tokens)
  175. # Set masks for different part types
  176. if isinstance(part, VQPart):
  177. vq_masks.append(torch.ones_like(tokens, dtype=torch.bool))
  178. audio_masks.append(torch.zeros_like(tokens, dtype=torch.bool))
  179. elif isinstance(part, AudioPart):
  180. vq_masks.append(torch.zeros_like(tokens, dtype=torch.bool))
  181. audio_mask = torch.ones_like(tokens, dtype=torch.bool)
  182. audio_mask[0] = False # Skip start token
  183. audio_mask[-1] = False # Skip end token
  184. audio_masks.append(audio_mask)
  185. else:
  186. vq_masks.append(torch.zeros_like(tokens, dtype=torch.bool))
  187. audio_masks.append(torch.zeros_like(tokens, dtype=torch.bool))
  188. # Set labels based on whether we want to calculate loss for this part
  189. if part.cal_loss and not isinstance(part, AudioPart):
  190. all_labels.append(tokens.clone())
  191. else:
  192. all_labels.append(torch.full_like(tokens, -100))
  193. # Concatenate all tensors
  194. if not all_tokens:
  195. # Handle empty case safely
  196. tokens = torch.empty(0, dtype=torch.long)
  197. labels = torch.empty(0, dtype=torch.long)
  198. vq_masks = torch.empty(0, dtype=torch.bool)
  199. audio_masks = torch.empty(0, dtype=torch.bool)
  200. else:
  201. tokens = torch.cat(all_tokens, dim=0)
  202. labels = torch.cat(all_labels, dim=0)
  203. vq_masks = torch.cat(vq_masks, dim=0)
  204. audio_masks = torch.cat(audio_masks, dim=0)
  205. vq_require_losses = torch.tensor(vq_require_losses, dtype=torch.bool)
  206. # Apply shift if needed for next-token prediction
  207. vq_mask_tokens = vq_masks
  208. vq_mask_labels = vq_masks
  209. if add_shift and len(tokens) > 0:
  210. tokens = tokens[:-1]
  211. labels = labels[1:]
  212. vq_masks = vq_masks[:-1]
  213. vq_mask_tokens = vq_mask_tokens[:-1]
  214. vq_mask_labels = vq_mask_labels[1:]
  215. audio_masks = audio_masks[:-1]
  216. # Ignore specified tokens
  217. for i in ignore_loss_token_ids:
  218. if i is not None:
  219. labels[labels == i] = -100
  220. return EncodedMessage(
  221. tokens=tokens,
  222. labels=labels,
  223. vq_parts=vq_parts,
  224. vq_mask_tokens=vq_mask_tokens,
  225. vq_mask_labels=vq_mask_labels,
  226. vq_require_losses=vq_require_losses,
  227. audio_parts=audio_parts,
  228. audio_masks=audio_masks,
  229. metadata=self.metadata,
  230. )
  231. def encode_for_inference(
  232. self: "ContentSequence",
  233. tokenizer: FishTokenizer,
  234. num_codebooks: int,
  235. ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
  236. encoded = self.encode(tokenizer, add_shift=False)
  237. tokens = encoded.tokens
  238. # Use int32 for prompt cache to save memory, convert to model dtype later if needed
  239. # Or keep as input_ids (long)
  240. values = torch.zeros((num_codebooks + 1, len(tokens)), dtype=torch.long)
  241. values[0] = tokens
  242. if (encoded.vq_parts is None or len(encoded.vq_parts) == 0) and (
  243. encoded.audio_parts is None or len(encoded.audio_parts) == 0
  244. ):
  245. return values, None, None
  246. audio_parts = None
  247. audio_masks = None
  248. if encoded.vq_parts is not None and len(encoded.vq_parts) > 0:
  249. vq_parts = encoded.vq_parts
  250. # List[Tensor(1, T)] -> Tensor(1, Total_T) -> Tensor(1, Total_T)
  251. # Ensure we are handling the list concatenation correctly
  252. if len(vq_parts) > 1:
  253. # We need to be careful here: vq_parts is a list of tensors from different VQPart segments
  254. # They correspond to encoded.vq_mask_tokens
  255. # Since we just want to fill the 'values' tensor at the right positions:
  256. all_vq_codes = torch.cat(
  257. vq_parts, dim=1
  258. ) # Shape: (C, Total_Semantic_Tokens)
  259. else:
  260. all_vq_codes = vq_parts[0]
  261. # Values[0] is already the Main Token ID (Semantic Begin + Code)
  262. # Values[1:] should be the codes themselves
  263. values[1:, encoded.vq_mask_tokens] = all_vq_codes.to(dtype=torch.long)
  264. if encoded.audio_parts is not None and len(encoded.audio_parts) > 0:
  265. audio_parts = torch.cat(encoded.audio_parts, dim=0)
  266. audio_masks = encoded.audio_masks[None, :]
  267. return values, audio_masks, audio_parts
  268. def visualize(
  269. self: "ContentSequence",
  270. tokenizer: FishTokenizer,
  271. ignore_loss_tokens: list[str] = [],
  272. merge_semantic_tokens: bool = False,
  273. ):
  274. """
  275. Visualize the encoded sequence with color-coded tokens.
  276. Blue/cyan tokens contribute to loss, green tokens do not.
  277. """
  278. encoded = self.encode(
  279. tokenizer, add_shift=False, ignore_loss_tokens=ignore_loss_tokens
  280. )
  281. # Colors for alternating tokens
  282. colors = {
  283. "blue": "\033[94m", # Light blue
  284. "cyan": "\033[96m", # Cyan
  285. "green": "\033[92m", # Light green
  286. "dark_green": "\033[32m", # Dark green
  287. }
  288. blue_idx = 0
  289. green_idx = 0
  290. def print_in_blue(x):
  291. nonlocal blue_idx
  292. color = colors["blue"] if blue_idx % 2 == 0 else colors["cyan"]
  293. print(f"{color}{x}\033[0m", end="")
  294. blue_idx += 1
  295. def print_in_green(x):
  296. nonlocal green_idx
  297. color = colors["green"] if green_idx % 2 == 0 else colors["dark_green"]
  298. print(f"{color}{x}\033[0m", end="")
  299. green_idx += 1
  300. def print_semantic_token(x, count):
  301. val = f"[<|semantic|>x{count}]"
  302. if x == -100:
  303. print_in_green(val)
  304. else:
  305. print_in_blue(val)
  306. count_semantic_tokens = 0
  307. semantic_label = None
  308. for tok, lab in zip(encoded.tokens, encoded.labels):
  309. token_id = int(tok.item())
  310. if merge_semantic_tokens:
  311. if (
  312. tokenizer.semantic_begin_id <= token_id <= tokenizer.semantic_end_id
  313. and (semantic_label is None or semantic_label == lab)
  314. ):
  315. count_semantic_tokens += 1
  316. semantic_label = lab
  317. continue
  318. elif count_semantic_tokens > 0:
  319. print_semantic_token(semantic_label, count_semantic_tokens)
  320. count_semantic_tokens = 0
  321. semantic_label = None
  322. # Use HF decode
  323. val = tokenizer.decode([token_id])
  324. # Simple fallback for visualization if decode returns empty or weird stuff for special tokens
  325. if not val:
  326. val = f"<{token_id}>"
  327. if lab == -100:
  328. print_in_green(val)
  329. else:
  330. print_in_blue(val)
  331. if merge_semantic_tokens and count_semantic_tokens > 0:
  332. print_semantic_token(semantic_label, count_semantic_tokens)
  333. print()