semantic.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496
  1. import random
  2. from dataclasses import dataclass
  3. from itertools import chain
  4. from pathlib import Path
  5. from random import Random
  6. from typing import Optional, Union
  7. import numpy as np
  8. import pyarrow.parquet as pq
  9. import torch
  10. import torch.nn.functional as F
  11. from datasets.download.streaming_download_manager import xopen
  12. from huggingface_hub import HfApi
  13. from lightning import LightningDataModule
  14. from torch.distributed import get_rank, get_world_size, is_initialized
  15. from torch.utils.data import DataLoader, IterableDataset, get_worker_info
  16. from transformers import AutoTokenizer
  17. from fish_speech.conversation import CODEBOOK_PAD_TOKEN_ID
  18. from fish_speech.datasets.protos.text_data_pb2 import SampledData
  19. from fish_speech.datasets.protos.text_data_stream import read_pb_stream
  20. from fish_speech.text.clean import clean_text
  21. from fish_speech.utils import RankedLogger
  22. from fish_speech.utils.braceexpand import braceexpand
  23. log = RankedLogger(__name__, rank_zero_only=True)
  24. def split_by_rank_worker(files):
  25. # We need to know the total number of devices
  26. # to split the data properly
  27. total_devices = 1
  28. if is_initialized():
  29. total_devices = get_world_size()
  30. worker_info = get_worker_info()
  31. if worker_info is not None:
  32. total_devices *= worker_info.num_workers
  33. if len(files) < total_devices:
  34. # Repeat the files N times to match the number of devices
  35. files = files * (total_devices // len(files) + 1)
  36. # DDP
  37. if is_initialized():
  38. files = files[get_rank() :: get_world_size()]
  39. # Split by worker
  40. if worker_info is not None:
  41. files = files[worker_info.id :: worker_info.num_workers]
  42. return files
  43. class AutoTextSemanticInstructionDataset(IterableDataset):
  44. """
  45. Auto Augment Dataset by Speaker
  46. 1. Random concatenate multiple sentences from the same speaker to form a longer sentence
  47. 2. Automatically normalize the text
  48. For interactive mode, we use the following format (multiple sequences):
  49. <s> [INST] [SPK: speaker] text [/INST] ... [INST] text [/INST] </s>
  50. For non-interactive mode, we use the following format (one long sequence):
  51. <s> [INST] text [/INST] ... </s>
  52. """
  53. def __init__(
  54. self,
  55. proto_files: list[str],
  56. seed: int = 42,
  57. interactive_prob: float = 0.5,
  58. max_length: int = 1024,
  59. tokenizer: AutoTokenizer = None,
  60. use_speaker: bool | float = True,
  61. causal: bool = True,
  62. num_codebooks: Optional[int] = None,
  63. skip_text_prob: float = 0.0,
  64. ):
  65. """
  66. Args:
  67. proto_files: proto buf files if using local data
  68. seed: random seed
  69. interactive_prob: probability to use interactive mode
  70. max_length: max length of the text
  71. tokenizer: tokenizer
  72. use_speaker: include speaker information in the prompt
  73. causal: use causal sampling when using local data, disable will lead to random sampling
  74. num_codebooks: number of codebooks, if None, it will be automatically detected
  75. skip_text_prob: probability to skip the text (audio only), this only applies to interactive mode
  76. """
  77. super().__init__()
  78. assert 0 <= interactive_prob <= 1, "interactive_prob must be in [0, 1]"
  79. self.seed = seed
  80. self.max_length = max_length
  81. self.tokenizer = tokenizer
  82. self.interactive_prob = interactive_prob
  83. self.use_speaker = use_speaker
  84. self.proto_files = proto_files
  85. self.causal = causal
  86. self.num_codebooks = num_codebooks
  87. self.skip_text_prob = skip_text_prob
  88. self.semantic_token_id = self.tokenizer.convert_tokens_to_ids("<|semantic|>")
  89. self.groups = None
  90. def init_mock_data_server(self):
  91. if self.groups is not None:
  92. return
  93. # Expand the proto files
  94. expanded_proto_files = []
  95. for filename in self.proto_files:
  96. for i in braceexpand(filename):
  97. i = Path(i)
  98. if i.is_file():
  99. expanded_proto_files.append(i)
  100. elif i.is_dir():
  101. expanded_proto_files.extend(i.rglob("*.proto"))
  102. expanded_proto_files.extend(i.rglob("*.protos"))
  103. else:
  104. raise ValueError(f"{i} is not a file or directory")
  105. expanded_proto_files = sorted(expanded_proto_files)
  106. Random(self.seed).shuffle(expanded_proto_files)
  107. self.groups = []
  108. shard_proto_files = split_by_rank_worker(expanded_proto_files)
  109. log.info(
  110. f"Reading {len(shard_proto_files)} / {len(expanded_proto_files)} files"
  111. )
  112. count = 0
  113. for filename in shard_proto_files:
  114. with open(filename, "rb") as f:
  115. for text_data in read_pb_stream(f):
  116. self.groups.append(text_data)
  117. count += 1
  118. log.info(f"Read total {count} groups of data")
  119. # Shuffle the lines
  120. Random(self.seed).shuffle(self.groups)
  121. self.group_weights = [len(i.sentences) for i in self.groups]
  122. def __iter__(self):
  123. while True:
  124. yield self.augment()
  125. def tokenize_sentence(self, sentence: str):
  126. sentence = clean_text(sentence)
  127. tokens = self.tokenizer.encode(
  128. f"{sentence}",
  129. max_length=10**6,
  130. add_special_tokens=False,
  131. truncation=False,
  132. )
  133. return sentence, len(tokens)
  134. def sample_data(self):
  135. if self.groups is None:
  136. self.init_mock_data_server()
  137. # Shuffle unique lines, estimate that each sample is at least 20 tokens
  138. num_samples = self.max_length // 20
  139. # choice group based on their number of samples
  140. group = random.choices(self.groups, weights=self.group_weights, k=1)[0]
  141. if self.causal:
  142. # Sample in order
  143. if num_samples >= len(group.sentences):
  144. samples = group.sentences
  145. else:
  146. begin = random.randint(0, len(group.sentences) - num_samples)
  147. samples = group.sentences[begin : begin + num_samples]
  148. else:
  149. samples = random.choices(
  150. group.sentences, k=min(num_samples, len(group.sentences))
  151. )
  152. return SampledData(
  153. source=group.source,
  154. name=group.name,
  155. samples=samples,
  156. )
  157. def augment(self):
  158. final_text, final_semantic = [], []
  159. response = self.sample_data()
  160. if len(response.samples) == 0:
  161. # Invalid group
  162. return None
  163. samples = list(response.samples)
  164. idx = 0
  165. use_interactive = random.random() < self.interactive_prob
  166. if use_interactive is False:
  167. # Random sample based on speaker using a truncated normal distribution
  168. a = torch.tensor([0], dtype=torch.float32)
  169. torch.nn.init.trunc_normal_(
  170. a,
  171. mean=self.max_length // 2,
  172. std=self.max_length // 4,
  173. a=10,
  174. b=self.max_length,
  175. )
  176. remaining_tokens = a.long().item() - 4
  177. else:
  178. remaining_tokens = self.max_length
  179. # Use speaker
  180. if isinstance(self.use_speaker, float):
  181. use_speaker = random.random() < self.use_speaker
  182. else:
  183. use_speaker = self.use_speaker
  184. all_tokens, all_labels = [], []
  185. while remaining_tokens > 0 and len(samples) > 0:
  186. sentence = samples.pop(0)
  187. text = random.choice(sentence.texts)
  188. text, length = self.tokenize_sentence(text)
  189. remaining_tokens -= length + len(sentence.semantics[0].values)
  190. if use_interactive is False:
  191. final_text.append(text)
  192. final_semantic.append(sentence.semantics)
  193. else:
  194. # For interactive mode, we only apply speaker for the first sentence
  195. # [INST] [SPK: speaker] text [/INST] ... [INST] text [/INST]
  196. tokens, labels = self.pack_sentences(
  197. sentences=[text],
  198. semantics=[sentence.semantics],
  199. speaker=response.name if use_speaker else None,
  200. skip_text=random.random() < self.skip_text_prob,
  201. )
  202. all_tokens.append(tokens)
  203. all_labels.append(labels)
  204. idx += 1
  205. if use_interactive is False:
  206. tokens, labels = self.pack_sentences(
  207. final_text,
  208. semantics=final_semantic,
  209. speaker=response.name if use_speaker else None,
  210. )
  211. all_tokens.append(tokens)
  212. all_labels.append(labels)
  213. tokens = torch.cat(all_tokens, dim=1)
  214. labels = torch.cat(all_labels, dim=1)
  215. # Verify that the length is correct
  216. assert tokens.size(1) == labels.size(1), f"{tokens.size(1)} != {labels.size(1)}"
  217. data = {"tokens": tokens, "labels": labels}
  218. return data
  219. def pack_sentences(
  220. self,
  221. sentences: list[str],
  222. semantics: list,
  223. speaker: Optional[str] = None,
  224. skip_text: bool = False,
  225. ):
  226. if speaker is None:
  227. speaker = "assistant"
  228. cated_sentences = " ".join(sentences)
  229. if skip_text:
  230. cated_sentences = "<|skip_text|>"
  231. final_text = "<|im_start|>user\n" + cated_sentences + "<|im_end|>"
  232. final_text = final_text + f"<|im_start|>{speaker}\n"
  233. encoded = self.tokenizer.encode(
  234. final_text,
  235. add_special_tokens=False,
  236. truncation=False,
  237. max_length=10**6,
  238. )
  239. semantic_length = sum([len(i[0].values) for i in semantics])
  240. prompt_length = len(encoded)
  241. num_codebooks = (
  242. len(semantics[0]) if self.num_codebooks is None else self.num_codebooks
  243. )
  244. # Pack the tokens and semantics (add <s> and </s> to semantic tokens)
  245. tokens = (
  246. encoded
  247. + [self.semantic_token_id] * semantic_length
  248. + self.tokenizer.convert_tokens_to_ids(["<|im_end|>"])
  249. )
  250. # Codebook bos/padding: 0, eos: 1
  251. codes = [[CODEBOOK_PAD_TOKEN_ID] * prompt_length for _ in range(num_codebooks)]
  252. for segment in semantics:
  253. for book_idx, book in zip(range(num_codebooks), segment):
  254. for j in book.values:
  255. codes[book_idx].append(int(j) + 1)
  256. for book in codes:
  257. book.extend([CODEBOOK_PAD_TOKEN_ID] * 1)
  258. tokens = [tokens] + codes
  259. tokens = torch.tensor(tokens, dtype=torch.long)
  260. labels = tokens.clone()
  261. if skip_text:
  262. # If text is not provided, the sentence is used for condition only, all labels are -100
  263. torch.fill_(labels, -100)
  264. return tokens, labels
  265. # Mask out the <s> tokens for semantic, predict semantic tokens only
  266. # Since we don't mask out the input tokens, the language modeling still works
  267. labels[1:, :prompt_length] = -100
  268. tokens = tokens[:, :-1]
  269. labels = labels[:, 1:]
  270. # Verify the padding is correct, and the last token is eos
  271. assert (tokens[1:, :prompt_length] == CODEBOOK_PAD_TOKEN_ID).all()
  272. assert (labels[1:, -1:] == CODEBOOK_PAD_TOKEN_ID).all()
  273. return tokens, labels
  274. @dataclass
  275. class TextDataCollator:
  276. tokenizer: AutoTokenizer
  277. max_length: int = 1024
  278. def __call__(self, examples):
  279. if "negative_tokens" in examples:
  280. positive_examples = []
  281. negative_examples = []
  282. for i in examples:
  283. positive_examples.append(
  284. {
  285. "tokens": i["tokens"],
  286. "labels": i["labels"],
  287. }
  288. )
  289. negative_examples.append(
  290. {
  291. "tokens": i["negative_tokens"],
  292. "labels": i["negative_labels"],
  293. }
  294. )
  295. examples = positive_examples + negative_examples
  296. return self.batchify(examples)
  297. def batchify(self, examples, tokens_key="tokens", labels_key="labels"):
  298. tokens, attention_masks, labels = [], [], []
  299. # Calculate the max length
  300. max_tokens_length = 0
  301. for example in examples:
  302. max_tokens_length = max(max_tokens_length, example[tokens_key].size(1))
  303. max_tokens_length = min(max_tokens_length, self.max_length)
  304. for example in examples:
  305. _tokens = example[tokens_key][:, :max_tokens_length]
  306. _labels = example[labels_key][:, :max_tokens_length]
  307. _attention_mask = torch.ones((max_tokens_length,), dtype=torch.bool)
  308. tokens_length = _tokens.size(1)
  309. _attention_mask[:tokens_length] = False
  310. assert tokens_length == _labels.size(
  311. 1
  312. ), f"{tokens_length} != {_labels.size(1)}"
  313. if tokens_length < max_tokens_length:
  314. _tokens = F.pad(
  315. _tokens,
  316. (0, max_tokens_length - tokens_length),
  317. value=self.tokenizer.eos_token_id,
  318. )
  319. _tokens[1:, tokens_length:] = CODEBOOK_PAD_TOKEN_ID
  320. _labels = F.pad(
  321. _labels, (0, max_tokens_length - _labels.size(1)), value=-100
  322. )
  323. tokens.append(_tokens)
  324. attention_masks.append(_attention_mask)
  325. labels.append(_labels)
  326. tokens = torch.stack(tokens, dim=0)
  327. attention_masks = torch.stack(attention_masks, dim=0)
  328. labels = torch.stack(labels, dim=0)
  329. return {
  330. "inputs": tokens,
  331. "attention_masks": attention_masks,
  332. "labels": labels,
  333. }
  334. class InterleaveDataset(IterableDataset):
  335. def __init__(
  336. self,
  337. datasets: list[IterableDataset],
  338. probabilities: list[float],
  339. seed: int = 42,
  340. ):
  341. super().__init__()
  342. self.datasets = datasets
  343. self.probabilities = probabilities
  344. self.seed = seed
  345. def __iter__(self):
  346. rng = np.random.default_rng(self.seed)
  347. dataset_iterators = [iter(dataset) for dataset in self.datasets]
  348. while True:
  349. # Random choice one
  350. dataset_idx = rng.choice(len(self.datasets), p=self.probabilities)
  351. dataset_iterator = dataset_iterators[dataset_idx]
  352. try:
  353. yield next(dataset_iterator)
  354. except StopIteration:
  355. # Exhausted, create a new iterator
  356. dataset_iterators[dataset_idx] = iter(self.datasets[dataset_idx])
  357. yield next(dataset_iterators[dataset_idx])
  358. class SemanticDataModule(LightningDataModule):
  359. def __init__(
  360. self,
  361. train_dataset: Union[AutoTextSemanticInstructionDataset, InterleaveDataset],
  362. val_dataset: Union[AutoTextSemanticInstructionDataset, InterleaveDataset],
  363. batch_size: int = 32,
  364. tokenizer: AutoTokenizer = None,
  365. max_length: int = 1024,
  366. num_workers: int = 4,
  367. ):
  368. super().__init__()
  369. self.train_dataset = train_dataset
  370. self.val_dataset = val_dataset
  371. self.batch_size = batch_size
  372. self.tokenizer = tokenizer
  373. self.max_length = max_length
  374. self.num_workers = num_workers
  375. def train_dataloader(self):
  376. return DataLoader(
  377. self.train_dataset,
  378. batch_size=self.batch_size,
  379. collate_fn=TextDataCollator(self.tokenizer, self.max_length),
  380. num_workers=self.num_workers,
  381. persistent_workers=True,
  382. )
  383. def val_dataloader(self):
  384. return DataLoader(
  385. self.val_dataset,
  386. batch_size=self.batch_size,
  387. collate_fn=TextDataCollator(self.tokenizer, self.max_length),
  388. num_workers=self.num_workers,
  389. persistent_workers=True,
  390. )
  391. if __name__ == "__main__":
  392. from tqdm import tqdm
  393. ds = AutoTextSemanticInstructionDataset(
  394. ["data/protos"],
  395. tokenizer=AutoTokenizer.from_pretrained("fishaudio/fish-speech-1"),
  396. use_speaker=False,
  397. interactive_prob=1.0,
  398. skip_text_prob=0.5,
  399. )
  400. for i in ds:
  401. print(ds.tokenizer.decode(i["tokens"][0], skip_special_tokens=False))
  402. # i["labels"][0][i["labels"][0] == -100] = 0
  403. # print(ds.tokenizer.decode(i["labels"][0], skip_special_tokens=False))
  404. break