text.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520
  1. import random
  2. from dataclasses import dataclass
  3. from itertools import chain
  4. from random import Random
  5. from typing import Optional, Union
  6. import grpc
  7. import numpy as np
  8. import pyarrow.parquet as pq
  9. import torch
  10. import torch.nn.functional as F
  11. from datasets.download.streaming_download_manager import xopen
  12. from huggingface_hub import HfApi
  13. from lightning import LightningDataModule
  14. from torch.distributed import get_rank, get_world_size, is_initialized
  15. from torch.utils.data import DataLoader, IterableDataset, get_worker_info
  16. from transformers import AutoTokenizer
  17. from fish_speech.datasets.protos.text_data_pb2 import SampleDataRequest
  18. from fish_speech.datasets.protos.text_data_pb2_grpc import DataServiceStub
  19. from fish_speech.text.parser import clean_text
  20. from fish_speech.text.symbols import pad as pad_symbol
  21. from fish_speech.text.symbols import pu_symbols
  22. from fish_speech.utils import RankedLogger
  23. from fish_speech.utils.braceexpand import braceexpand
  24. log = RankedLogger(__name__, rank_zero_only=True)
  25. CODEBOOK_BOS_TOKEN_ID = 0
  26. CODEBOOK_EOS_TOKEN_ID = 1
  27. def split_by_rank_worker(files):
  28. # We need to know the total number of devices
  29. # to split the data properly
  30. total_devices = 1
  31. if is_initialized():
  32. total_devices = get_world_size()
  33. worker_info = get_worker_info()
  34. if worker_info is not None:
  35. total_devices *= worker_info.num_workers
  36. if len(files) < total_devices:
  37. # Repeat the files N times to match the number of devices
  38. files = files * (total_devices // len(files) + 1)
  39. # DDP
  40. if is_initialized():
  41. files = files[get_rank() :: get_world_size()]
  42. # Split by worker
  43. if worker_info is not None:
  44. files = files[worker_info.id :: worker_info.num_workers]
  45. return files
  46. class StreamTextDataset(IterableDataset):
  47. def __init__(
  48. self,
  49. files: Optional[Union[list[str], str]] = None,
  50. prefix: Optional[str] = None,
  51. seed: int = 42,
  52. parquet_batch_size: int = 10000,
  53. repo: str = "uonlp/CulturaX",
  54. max_length: int = 1024,
  55. tokenizer: AutoTokenizer = None,
  56. ):
  57. super().__init__()
  58. self.seed = seed
  59. self.parquet_batch_size = parquet_batch_size
  60. self.repo = repo
  61. self.max_length = max_length
  62. self.tokenizer = tokenizer
  63. if files is None and prefix is None:
  64. raise ValueError("Either files or prefix must be specified")
  65. if prefix is not None:
  66. files = HfApi().list_repo_files(repo, repo_type="dataset")
  67. files = [
  68. f for f in files if f.startswith(prefix) and f.endswith(".parquet")
  69. ]
  70. log.info(f"Found {len(files)} files in {repo} with prefix {prefix}")
  71. else:
  72. if isinstance(files, str):
  73. files = [files]
  74. files = list(chain.from_iterable(map(braceexpand, files)))
  75. log.info(f"Expanded {len(files)} files in {repo}")
  76. # Get sharded files
  77. self.files = sorted(files)
  78. Random(seed).shuffle(self.files)
  79. def __iter__(self):
  80. files = split_by_rank_worker(self.files)
  81. random.shuffle(files)
  82. for filename in files:
  83. try:
  84. yield from self.parse_data(filename)
  85. except Exception as e:
  86. log.exception(f"Failed to parse {filename}: {e}")
  87. def parse_data(self, filename: str):
  88. for data in self.parse_data_internal(filename):
  89. text = data["text"]
  90. # 30% modeling phones
  91. if random.random() < 0.3:
  92. text = " ".join(
  93. [
  94. (f"<p:{i}>" if i not in pu_symbols and i != pad_symbol else i)
  95. for i in text
  96. ]
  97. )
  98. # encode
  99. tokens = self.tokenizer.encode(
  100. text,
  101. add_special_tokens=False,
  102. truncation=False,
  103. max_length=10**6,
  104. )
  105. # Random choice self.max_length
  106. if len(tokens) > self.max_length:
  107. start = random.randint(0, len(tokens) - self.max_length)
  108. tokens = tokens[start : start + self.max_length - 1]
  109. tokens = (
  110. [self.tokenizer.bos_token_id] + tokens + [self.tokenizer.eos_token_id]
  111. )
  112. # Pad dims
  113. placeholder_multi_codebook = torch.zeros((4, len(tokens)), dtype=torch.long)
  114. tokens = torch.concat(
  115. [
  116. torch.tensor([tokens], dtype=torch.long),
  117. placeholder_multi_codebook,
  118. ],
  119. dim=0,
  120. )
  121. labels = tokens.clone()
  122. tokens = tokens[:, :-1]
  123. labels = labels[:, 1:]
  124. labels[1:] = -100 # remove all placeholders
  125. yield {"tokens": tokens, "labels": labels}
  126. def parse_data_internal(self, filename: str):
  127. url = f"https://huggingface.co/datasets/{self.repo}/resolve/main/{filename}"
  128. with xopen(url, mode="rb") as stream:
  129. parquet_file = pq.ParquetFile(stream)
  130. for batch in parquet_file.iter_batches(
  131. batch_size=self.parquet_batch_size, columns=["text"]
  132. ):
  133. # In-batch shuffling
  134. texts = [{"text": text.as_py()} for text in batch["text"]]
  135. random.shuffle(texts)
  136. yield from texts
  137. class AutoAugTextDataset(IterableDataset):
  138. """
  139. Auto Augment Dataset by Speaker
  140. 1. Random concatenate multiple sentences from the same speaker to form a longer sentence
  141. 2. Automatically normalize the text
  142. 3. Mix text and phones
  143. For interactive mode, we use the following format (multiple sequences):
  144. <s> [INST] [SPK: speaker] text [/INST] ... [INST] text [/INST] </s>
  145. For non-interactive mode, we use the following format (one long sequence):
  146. <s> [INST] text [/INST] ... </s>
  147. """
  148. def __init__(
  149. self,
  150. server: str = "localhost:50051",
  151. seed: int = 42,
  152. phones_prob: float = 0.3,
  153. repetition_prob: float = 0.0,
  154. interactive_prob: float = 0.5,
  155. max_length: int = 1024,
  156. tokenizer: AutoTokenizer = None,
  157. use_speaker: bool = True,
  158. ):
  159. """
  160. Args:
  161. server: gRPC server address
  162. seed: random seed
  163. phones_prob: probability to use phones
  164. repetition_prob: probability to repeat the same sentence
  165. interactive_prob: probability to use interactive mode
  166. max_length: max length of the text
  167. tokenizer: tokenizer
  168. """
  169. super().__init__()
  170. self.seed = seed
  171. self.phones_prob = phones_prob
  172. self.max_length = max_length
  173. self.tokenizer = tokenizer
  174. self.repetition_prob = repetition_prob
  175. self.interactive_prob = interactive_prob
  176. self.use_speaker = use_speaker
  177. # Read all lines, and group by speaker
  178. self.channel = grpc.insecure_channel(server)
  179. self.stub = DataServiceStub(self.channel)
  180. def __iter__(self):
  181. while True:
  182. yield self.augment()
  183. def tokenize_sentence(self, sentence: str, phones: list[str], mode: str = "sample"):
  184. if (
  185. mode == "sample" and (random.random() < self.phones_prob)
  186. ) or mode == "phones":
  187. sentence = " ".join(
  188. [
  189. (f"<p:{i}>" if i not in pu_symbols and i != pad_symbol else i)
  190. for i in phones
  191. ]
  192. )
  193. else:
  194. sentence = clean_text(sentence)
  195. tokens = self.tokenizer.encode(
  196. f"{sentence}",
  197. max_length=10**6,
  198. add_special_tokens=False,
  199. truncation=False,
  200. )
  201. return sentence, len(tokens)
  202. def augment(self):
  203. # 50% to pure text or pure phones
  204. mode = "sample"
  205. if random.random() < 0.5:
  206. mode = random.choice(["text", "phones"])
  207. # Random sample based on speaker using a truncated normal distribution
  208. a = torch.tensor([0], dtype=torch.float32)
  209. torch.nn.init.trunc_normal_(
  210. a,
  211. mean=self.max_length // 2,
  212. std=self.max_length // 4,
  213. a=10,
  214. b=self.max_length,
  215. )
  216. remaining_tokens = a.long().item() - 4
  217. final_text, final_semantic = [], []
  218. # Shuffle unique lines, estimate that each sample is at least 20 tokens
  219. request = SampleDataRequest(num_samples=self.max_length // 20)
  220. response = self.stub.SampleData(request)
  221. if len(response.samples) == 0:
  222. # Invalid group
  223. return None
  224. samples = list(response.samples)
  225. idx = 0
  226. use_interactive = random.random() < self.interactive_prob
  227. all_tokens, all_labels = [], []
  228. while remaining_tokens > 0 and len(samples) > 0:
  229. if random.random() < self.repetition_prob:
  230. # Repeat the same sentence
  231. sentence = samples[-1]
  232. else:
  233. sentence = samples.pop()
  234. text, length = self.tokenize_sentence(
  235. sentence.text, sentence.phones, mode=mode
  236. )
  237. remaining_tokens -= length + len(sentence.semantics[0].values)
  238. if use_interactive is False:
  239. final_text.append(text)
  240. final_semantic.append(sentence.semantics)
  241. else:
  242. # For interactive mode, we only apply speaker for the first sentence
  243. # [INST] [SPK: speaker] text [/INST] ... [INST] text [/INST]
  244. tokens, labels = self.pack_sentences(
  245. sentences=[text],
  246. semantics=[sentence.semantics],
  247. speaker=response.name if (self.use_speaker and idx == 0) else None,
  248. add_bos=idx == 0,
  249. )
  250. all_tokens.append(tokens)
  251. all_labels.append(labels)
  252. idx += 1
  253. if use_interactive is False:
  254. tokens, labels = self.pack_sentences(
  255. final_text,
  256. semantics=final_semantic,
  257. speaker=None if self.use_speaker else response.name,
  258. add_bos=True,
  259. )
  260. else:
  261. tokens = torch.cat(all_tokens, dim=1)
  262. labels = torch.cat(all_labels, dim=1)
  263. # Verify that the length is correct
  264. assert tokens.size(1) == labels.size(1), f"{tokens.size(1)} != {labels.size(1)}"
  265. # Verify only one <s> token
  266. assert (tokens[:, 0] == self.tokenizer.bos_token_id).sum() == 1
  267. return {"tokens": tokens, "labels": labels}
  268. def pack_sentences(
  269. self,
  270. sentences: list[str],
  271. semantics=list,
  272. speaker: Optional[str] = None,
  273. add_bos: bool = True,
  274. ):
  275. if speaker is not None:
  276. sentences = [f"[SPK: {speaker}]"] + sentences
  277. final_text = "[INST] " + " ".join(sentences) + " [/INST]"
  278. encoded = self.tokenizer.encode(
  279. final_text,
  280. add_special_tokens=False,
  281. truncation=False,
  282. max_length=10**6,
  283. )
  284. semantic_length = sum([len(i[0].values) for i in semantics])
  285. bos_bias = 1 if add_bos else 0
  286. # Pack the tokens and semantics (add <s> and </s> to semantic tokens)
  287. tokens = (
  288. encoded
  289. + [self.tokenizer.pad_token_id] * semantic_length
  290. + [self.tokenizer.eos_token_id]
  291. )
  292. if add_bos:
  293. tokens = [self.tokenizer.bos_token_id] + tokens
  294. # Codebook bos/padding: 0, eos: 1
  295. codes = [
  296. [CODEBOOK_BOS_TOKEN_ID] * (len(encoded) + bos_bias)
  297. for _ in range(len(semantics[0]))
  298. ]
  299. for segment in semantics:
  300. for book_idx, book in enumerate(segment):
  301. for j in book.values:
  302. codes[book_idx].append(int(j) + 2)
  303. for book in codes:
  304. book.append(CODEBOOK_EOS_TOKEN_ID)
  305. tokens = [tokens] + codes
  306. tokens = torch.tensor(tokens, dtype=torch.long)
  307. labels = tokens.clone()
  308. # Mask out the <s> tokens for semantic, predict semantic tokens only
  309. # Since we don't mask out the input tokens, the language modeling still works
  310. labels[1:, : (len(encoded) + bos_bias)] = -100
  311. tokens = tokens[:, :-1]
  312. labels = labels[:, 1:]
  313. # Verify the padding is correct, and the last token is eos
  314. assert add_bos is False or tokens[0, 0] == self.tokenizer.bos_token_id
  315. assert (tokens[1:, : len(encoded) + bos_bias] == CODEBOOK_BOS_TOKEN_ID).all()
  316. assert labels[0, -1] == self.tokenizer.eos_token_id
  317. assert (labels[1:, -1] == CODEBOOK_EOS_TOKEN_ID).all()
  318. return tokens, labels
  319. @dataclass
  320. class TextDataCollator:
  321. tokenizer: AutoTokenizer
  322. max_length: int = 1024
  323. def __call__(self, examples):
  324. tokens, attention_masks, labels = [], [], []
  325. for example in examples:
  326. _tokens = example["tokens"][:, : self.max_length]
  327. _labels = example["labels"][:, : self.max_length]
  328. _attention_mask = torch.ones((self.max_length,), dtype=torch.bool)
  329. tokens_length = _tokens.size(1)
  330. _attention_mask[:tokens_length] = False
  331. assert tokens_length == _labels.size(
  332. 1
  333. ), f"{tokens_length} != {_labels.size(1)}"
  334. if tokens_length < self.max_length:
  335. _tokens = F.pad(
  336. _tokens,
  337. (0, self.max_length - tokens_length),
  338. value=self.tokenizer.eos_token_id,
  339. )
  340. _tokens[1:, tokens_length:] = CODEBOOK_EOS_TOKEN_ID
  341. _labels = F.pad(
  342. _labels, (0, self.max_length - _labels.size(1)), value=-100
  343. )
  344. tokens.append(_tokens)
  345. attention_masks.append(_attention_mask)
  346. labels.append(_labels)
  347. tokens = torch.stack(tokens, dim=0)
  348. attention_masks = torch.stack(attention_masks, dim=0)
  349. labels = torch.stack(labels, dim=0)
  350. return {
  351. "inputs": tokens,
  352. "attention_masks": attention_masks,
  353. "labels": labels,
  354. }
  355. class InterleaveDataset(IterableDataset):
  356. def __init__(
  357. self,
  358. datasets: list[IterableDataset],
  359. probabilities: list[float],
  360. seed: int = 42,
  361. ):
  362. super().__init__()
  363. self.datasets = datasets
  364. self.probabilities = probabilities
  365. self.seed = seed
  366. def __iter__(self):
  367. rng = np.random.default_rng(self.seed)
  368. dataset_iterators = [iter(dataset) for dataset in self.datasets]
  369. while True:
  370. # Random choice one
  371. dataset_idx = rng.choice(len(self.datasets), p=self.probabilities)
  372. dataset_iterator = dataset_iterators[dataset_idx]
  373. try:
  374. yield next(dataset_iterator)
  375. except StopIteration:
  376. # Exhausted, create a new iterator
  377. dataset_iterators[dataset_idx] = iter(self.datasets[dataset_idx])
  378. yield next(dataset_iterators[dataset_idx])
  379. class TextDataModule(LightningDataModule):
  380. def __init__(
  381. self,
  382. train_dataset: Union[StreamTextDataset, AutoAugTextDataset, InterleaveDataset],
  383. val_dataset: Union[StreamTextDataset, AutoAugTextDataset, InterleaveDataset],
  384. batch_size: int = 32,
  385. tokenizer: AutoTokenizer = None,
  386. max_length: int = 1024,
  387. num_workers: int = 4,
  388. ):
  389. super().__init__()
  390. self.train_dataset = train_dataset
  391. self.val_dataset = val_dataset
  392. self.batch_size = batch_size
  393. self.tokenizer = tokenizer
  394. self.max_length = max_length
  395. self.num_workers = num_workers
  396. def train_dataloader(self):
  397. return DataLoader(
  398. self.train_dataset,
  399. batch_size=self.batch_size,
  400. collate_fn=TextDataCollator(self.tokenizer, self.max_length),
  401. num_workers=self.num_workers,
  402. )
  403. def val_dataloader(self):
  404. return DataLoader(
  405. self.val_dataset,
  406. batch_size=self.batch_size,
  407. collate_fn=TextDataCollator(self.tokenizer, self.max_length),
  408. num_workers=self.num_workers,
  409. )
  410. if __name__ == "__main__":
  411. from tqdm import tqdm
  412. ds = AutoAugTextDataset(
  413. tokenizer=AutoTokenizer.from_pretrained("fishaudio/speech-lm-v1"),
  414. use_speaker=True,
  415. interactive_prob=1.0,
  416. )
  417. dm = TextDataModule(
  418. train_dataset=ds,
  419. val_dataset=ds,
  420. tokenizer=ds.tokenizer,
  421. batch_size=2,
  422. max_length=1024,
  423. num_workers=0,
  424. )
  425. for batch in tqdm(dm.train_dataloader()):
  426. print(batch)
  427. break