content_sequence.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399
  1. from dataclasses import dataclass, field
  2. from typing import List, Literal, Union
  3. import numpy as np
  4. import torch
  5. from fish_speech.tokenizer import (
  6. IM_END_TOKEN,
  7. MODALITY_TOKENS,
  8. FishTokenizer,
  9. )
  10. def restore_ndarray(obj, to_tensor: bool = False):
  11. if isinstance(obj, dict) and "__ndarray__" in obj:
  12. obj = np.frombuffer(obj["data"], dtype=obj["dtype"]).reshape(obj["shape"])
  13. if to_tensor and isinstance(obj, np.ndarray):
  14. obj = torch.from_numpy(obj.copy())
  15. return obj
  16. @dataclass
  17. class BasePart:
  18. type: Literal["text", "vq", "audio"] | None = None
  19. cal_loss: bool = False
  20. @dataclass(kw_only=True)
  21. class VQPart(BasePart):
  22. type = "vq"
  23. codes: torch.Tensor
  24. def __post_init__(self: "VQPart"):
  25. self.type = "vq"
  26. self.codes = restore_ndarray(self.codes, to_tensor=True)
  27. @dataclass(kw_only=True)
  28. class TextPart(BasePart):
  29. type = "text"
  30. text: str | None = None
  31. tokens: list[int] | None = None
  32. def __post_init__(self: "TextPart"):
  33. self.type = "text"
  34. if self.text is None and self.tokens is None:
  35. raise ValueError("Either text or tokens must be provided")
  36. @dataclass(kw_only=True)
  37. class AudioPart(BasePart):
  38. type = "audio"
  39. features: torch.Tensor
  40. def __post_init__(self: "AudioPart"):
  41. self.type = "audio"
  42. self.features = restore_ndarray(self.features, to_tensor=True)
  43. @dataclass(kw_only=True)
  44. class EncodedMessage:
  45. tokens: torch.Tensor
  46. labels: torch.Tensor
  47. vq_mask_tokens: torch.Tensor | None = None
  48. vq_mask_labels: torch.Tensor | None = None
  49. vq_parts: list[torch.Tensor]
  50. vq_require_losses: torch.Tensor | None = None
  51. audio_parts: list[torch.Tensor]
  52. audio_masks: torch.Tensor | None = None
  53. metadata: dict | None = None
  54. @dataclass
  55. class ContentSequence:
  56. """
  57. Flexible sequence of content parts that supports interleaved multimodal format.
  58. Example format: <|interleave|><|speaker:1|> TEXT AUDIO <|im_end|><|speaker:2|> TEXT AUDIO <|im_end|>
  59. """
  60. parts: list[BasePart] = field(default_factory=list)
  61. modality: Literal["text", "voice", "interleave"] | None = None
  62. metadata: dict | None = None
  63. def __init__(
  64. self: "ContentSequence",
  65. parts: list[BasePart | dict] | None = None,
  66. modality: Literal["text", "voice", "interleave"] | None = None,
  67. metadata: dict | None = None,
  68. ):
  69. self.modality = modality
  70. self.metadata = metadata or {}
  71. fixed_parts = []
  72. for part in parts or []:
  73. if isinstance(part, dict):
  74. if part["type"] == "vq":
  75. part = VQPart(**part)
  76. elif part["type"] == "audio":
  77. part = AudioPart(**part)
  78. elif part["type"] == "text":
  79. part = TextPart(**part)
  80. else:
  81. raise ValueError(f"Unsupported part type: {part['type']}")
  82. fixed_parts.append(part)
  83. self.parts = fixed_parts
  84. # If modality is specified, add it at the beginning if it's not already there
  85. if self.modality and not (
  86. len(self.parts) > 0
  87. and isinstance(self.parts[0], dict) is False
  88. and isinstance(self.parts[0], TextPart)
  89. and self.parts[0].text is not None
  90. and self.parts[0].text.startswith(MODALITY_TOKENS[self.modality])
  91. ):
  92. modality_token = MODALITY_TOKENS[self.modality]
  93. self.parts.insert(0, TextPart(text=modality_token))
  94. def append(
  95. self: "ContentSequence",
  96. part_or_parts: Union[BasePart, List[BasePart]],
  97. add_end: bool = False,
  98. speaker: Union[str, int] | None = None,
  99. ):
  100. """
  101. Append a part or list of parts to the sequence.
  102. Args:
  103. part_or_parts: A single part or list of parts to add
  104. add_end: Whether to add the IM_END_TOKEN after these parts
  105. speaker: Optional speaker identifier (name or ID) to add before the parts
  106. """
  107. # Convert single part to list
  108. parts_to_add = (
  109. [part_or_parts] if not isinstance(part_or_parts, list) else part_or_parts
  110. )
  111. # Add speaker token if specified
  112. if speaker is not None:
  113. speaker_token = f"<|speaker:{speaker}|>"
  114. self.parts.append(TextPart(text=speaker_token))
  115. # Add all the parts
  116. self.parts.extend(parts_to_add)
  117. # Add end token if requested
  118. if add_end:
  119. self.parts.append(
  120. TextPart(text=IM_END_TOKEN, cal_loss=self.parts[-1].cal_loss)
  121. )
  122. def encode(
  123. self: "ContentSequence",
  124. tokenizer: FishTokenizer,
  125. add_shift: bool = True,
  126. ignore_loss_tokens: list[str] = [],
  127. ) -> EncodedMessage:
  128. """
  129. Encode the sequence parts into tokens for the model.
  130. Args:
  131. tokenizer: The tokenizer to use
  132. add_shift: Whether to shift tokens for next-token prediction
  133. ignore_loss_tokens: List of token strings to ignore when calculating loss
  134. Returns:
  135. EncodedMessage with tensors ready for the model
  136. """
  137. all_tokens = []
  138. all_labels = []
  139. # Multi-modal elements
  140. vq_parts = []
  141. vq_masks = []
  142. vq_require_losses = []
  143. audio_parts = []
  144. audio_masks = []
  145. # Optimization: Batch conversion for ignore tokens
  146. ignore_loss_token_ids = []
  147. if ignore_loss_tokens:
  148. # Use the wrapper method which uses convert_tokens_to_ids
  149. ignore_loss_token_ids = [tokenizer.get_token_id(i) for i in ignore_loss_tokens]
  150. for part in self.parts:
  151. if isinstance(part, TextPart):
  152. if part.tokens is None:
  153. assert part.text is not None
  154. # Optimization: Explicitly disable special tokens (BOS/EOS)
  155. # because we are constructing the sequence manually
  156. tokens = tokenizer.encode(part.text, add_special_tokens=False)
  157. else:
  158. tokens = part.tokens
  159. tokens = torch.tensor(tokens, dtype=torch.long)
  160. elif isinstance(part, VQPart):
  161. # Critical Optimization: Vectorized mapping
  162. # Instead of loop lookup: [tokenizer.semantic_id_to_token_id[i] for i in codes]
  163. # We use arithmetic offset: code + semantic_begin_id
  164. # This assumes semantic tokens are contiguous in the vocab (DualAR requirement)
  165. curr_codes = part.codes.clone().to(torch.int)
  166. # Use int64 (long) for token IDs to avoid overflow or type mismatch in embedding
  167. tokens = (curr_codes[0] + tokenizer.semantic_begin_id).to(torch.long)
  168. vq_parts.append(curr_codes)
  169. vq_require_losses.append(part.cal_loss)
  170. else:
  171. raise ValueError(f"Unsupported part type: {type(part)}")
  172. all_tokens.append(tokens)
  173. # Set masks for different part types
  174. if isinstance(part, VQPart):
  175. vq_masks.append(torch.ones_like(tokens, dtype=torch.bool))
  176. audio_masks.append(torch.zeros_like(tokens, dtype=torch.bool))
  177. elif isinstance(part, AudioPart):
  178. vq_masks.append(torch.zeros_like(tokens, dtype=torch.bool))
  179. audio_mask = torch.ones_like(tokens, dtype=torch.bool)
  180. audio_mask[0] = False # Skip start token
  181. audio_mask[-1] = False # Skip end token
  182. audio_masks.append(audio_mask)
  183. else:
  184. vq_masks.append(torch.zeros_like(tokens, dtype=torch.bool))
  185. audio_masks.append(torch.zeros_like(tokens, dtype=torch.bool))
  186. # Set labels based on whether we want to calculate loss for this part
  187. if part.cal_loss and not isinstance(part, AudioPart):
  188. all_labels.append(tokens.clone())
  189. else:
  190. all_labels.append(torch.full_like(tokens, -100))
  191. # Concatenate all tensors
  192. if not all_tokens:
  193. # Handle empty case safely
  194. tokens = torch.empty(0, dtype=torch.long)
  195. labels = torch.empty(0, dtype=torch.long)
  196. vq_masks = torch.empty(0, dtype=torch.bool)
  197. audio_masks = torch.empty(0, dtype=torch.bool)
  198. else:
  199. tokens = torch.cat(all_tokens, dim=0)
  200. labels = torch.cat(all_labels, dim=0)
  201. vq_masks = torch.cat(vq_masks, dim=0)
  202. audio_masks = torch.cat(audio_masks, dim=0)
  203. vq_require_losses = torch.tensor(vq_require_losses, dtype=torch.bool)
  204. # Apply shift if needed for next-token prediction
  205. vq_mask_tokens = vq_masks
  206. vq_mask_labels = vq_masks
  207. if add_shift and len(tokens) > 0:
  208. tokens = tokens[:-1]
  209. labels = labels[1:]
  210. vq_masks = vq_masks[:-1]
  211. vq_mask_tokens = vq_mask_tokens[:-1]
  212. vq_mask_labels = vq_mask_labels[1:]
  213. audio_masks = audio_masks[:-1]
  214. # Ignore specified tokens
  215. for i in ignore_loss_token_ids:
  216. if i is not None:
  217. labels[labels == i] = -100
  218. return EncodedMessage(
  219. tokens=tokens,
  220. labels=labels,
  221. vq_parts=vq_parts,
  222. vq_mask_tokens=vq_mask_tokens,
  223. vq_mask_labels=vq_mask_labels,
  224. vq_require_losses=vq_require_losses,
  225. audio_parts=audio_parts,
  226. audio_masks=audio_masks,
  227. metadata=self.metadata,
  228. )
  229. def encode_for_inference(
  230. self: "ContentSequence",
  231. tokenizer: FishTokenizer,
  232. num_codebooks: int,
  233. ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
  234. encoded = self.encode(tokenizer, add_shift=False)
  235. tokens = encoded.tokens
  236. # Use int32 for prompt cache to save memory, convert to model dtype later if needed
  237. # Or keep as input_ids (long)
  238. values = torch.zeros((num_codebooks + 1, len(tokens)), dtype=torch.long)
  239. values[0] = tokens
  240. if (encoded.vq_parts is None or len(encoded.vq_parts) == 0) and (
  241. encoded.audio_parts is None or len(encoded.audio_parts) == 0
  242. ):
  243. return values, None, None
  244. audio_parts = None
  245. audio_masks = None
  246. if encoded.vq_parts is not None and len(encoded.vq_parts) > 0:
  247. vq_parts = encoded.vq_parts
  248. # List[Tensor(1, T)] -> Tensor(1, Total_T) -> Tensor(1, Total_T)
  249. # Ensure we are handling the list concatenation correctly
  250. if len(vq_parts) > 1:
  251. # We need to be careful here: vq_parts is a list of tensors from different VQPart segments
  252. # They correspond to encoded.vq_mask_tokens
  253. # Since we just want to fill the 'values' tensor at the right positions:
  254. all_vq_codes = torch.cat(vq_parts, dim=1) # Shape: (C, Total_Semantic_Tokens)
  255. else:
  256. all_vq_codes = vq_parts[0]
  257. # Values[0] is already the Main Token ID (Semantic Begin + Code)
  258. # Values[1:] should be the codes themselves
  259. values[1:, encoded.vq_mask_tokens] = all_vq_codes.to(dtype=torch.long)
  260. if encoded.audio_parts is not None and len(encoded.audio_parts) > 0:
  261. audio_parts = torch.cat(encoded.audio_parts, dim=0)
  262. audio_masks = encoded.audio_masks[None, :]
  263. return values, audio_masks, audio_parts
  264. def visualize(
  265. self: "ContentSequence",
  266. tokenizer: FishTokenizer,
  267. ignore_loss_tokens: list[str] = [],
  268. merge_semantic_tokens: bool = False,
  269. ):
  270. """
  271. Visualize the encoded sequence with color-coded tokens.
  272. Blue/cyan tokens contribute to loss, green tokens do not.
  273. """
  274. encoded = self.encode(
  275. tokenizer, add_shift=False, ignore_loss_tokens=ignore_loss_tokens
  276. )
  277. # Colors for alternating tokens
  278. colors = {
  279. "blue": "\033[94m", # Light blue
  280. "cyan": "\033[96m", # Cyan
  281. "green": "\033[92m", # Light green
  282. "dark_green": "\033[32m", # Dark green
  283. }
  284. blue_idx = 0
  285. green_idx = 0
  286. def print_in_blue(x):
  287. nonlocal blue_idx
  288. color = colors["blue"] if blue_idx % 2 == 0 else colors["cyan"]
  289. print(f"{color}{x}\033[0m", end="")
  290. blue_idx += 1
  291. def print_in_green(x):
  292. nonlocal green_idx
  293. color = colors["green"] if green_idx % 2 == 0 else colors["dark_green"]
  294. print(f"{color}{x}\033[0m", end="")
  295. green_idx += 1
  296. def print_semantic_token(x, count):
  297. val = f"[<|semantic|>x{count}]"
  298. if x == -100:
  299. print_in_green(val)
  300. else:
  301. print_in_blue(val)
  302. count_semantic_tokens = 0
  303. semantic_label = None
  304. for tok, lab in zip(encoded.tokens, encoded.labels):
  305. token_id = int(tok.item())
  306. if merge_semantic_tokens:
  307. if (
  308. tokenizer.semantic_begin_id <= token_id <= tokenizer.semantic_end_id
  309. and (semantic_label is None or semantic_label == lab)
  310. ):
  311. count_semantic_tokens += 1
  312. semantic_label = lab
  313. continue
  314. elif count_semantic_tokens > 0:
  315. print_semantic_token(semantic_label, count_semantic_tokens)
  316. count_semantic_tokens = 0
  317. semantic_label = None
  318. # Use HF decode
  319. val = tokenizer.decode([token_id])
  320. # Simple fallback for visualization if decode returns empty or weird stuff for special tokens
  321. if not val:
  322. val = f"<{token_id}>"
  323. if lab == -100:
  324. print_in_green(val)
  325. else:
  326. print_in_blue(val)
  327. if merge_semantic_tokens and count_semantic_tokens > 0:
  328. print_semantic_token(semantic_label, count_semantic_tokens)
  329. print()