text.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472
  1. import random
  2. from dataclasses import dataclass
  3. from itertools import chain
  4. from random import Random
  5. from typing import Optional, Union
  6. import grpc
  7. import numpy as np
  8. import pyarrow.parquet as pq
  9. import torch
  10. import torch.nn.functional as F
  11. from datasets.download.streaming_download_manager import xopen
  12. from huggingface_hub import HfApi
  13. from lightning import LightningDataModule
  14. from torch.distributed import get_rank, get_world_size, is_initialized
  15. from torch.utils.data import DataLoader, IterableDataset, get_worker_info
  16. from transformers import AutoTokenizer
  17. from fish_speech.datasets.protos.text_data_pb2 import SampleDataRequest
  18. from fish_speech.datasets.protos.text_data_pb2_grpc import DataServiceStub
  19. from fish_speech.text.parser import clean_text
  20. from fish_speech.text.symbols import pad as pad_symbol
  21. from fish_speech.text.symbols import pu_symbols
  22. from fish_speech.utils import RankedLogger
  23. from fish_speech.utils.braceexpand import braceexpand
  24. log = RankedLogger(__name__, rank_zero_only=True)
  25. def split_by_rank_worker(files):
  26. # We need to know the total number of devices
  27. # to split the data properly
  28. total_devices = 1
  29. if is_initialized():
  30. total_devices = get_world_size()
  31. worker_info = get_worker_info()
  32. if worker_info is not None:
  33. total_devices *= worker_info.num_workers
  34. if len(files) < total_devices:
  35. # Repeat the files N times to match the number of devices
  36. files = files * (total_devices // len(files) + 1)
  37. # DDP
  38. if is_initialized():
  39. files = files[get_rank() :: get_world_size()]
  40. # Split by worker
  41. if worker_info is not None:
  42. files = files[worker_info.id :: worker_info.num_workers]
  43. return files
  44. class StreamTextDataset(IterableDataset):
  45. def __init__(
  46. self,
  47. files: Optional[Union[list[str], str]] = None,
  48. prefix: Optional[str] = None,
  49. seed: int = 42,
  50. parquet_batch_size: int = 10000,
  51. repo: str = "uonlp/CulturaX",
  52. max_length: int = 1024,
  53. tokenizer: AutoTokenizer = None,
  54. ):
  55. super().__init__()
  56. self.seed = seed
  57. self.parquet_batch_size = parquet_batch_size
  58. self.repo = repo
  59. self.max_length = max_length
  60. self.tokenizer = tokenizer
  61. if files is None and prefix is None:
  62. raise ValueError("Either files or prefix must be specified")
  63. if prefix is not None:
  64. files = HfApi().list_repo_files(repo, repo_type="dataset")
  65. files = [
  66. f for f in files if f.startswith(prefix) and f.endswith(".parquet")
  67. ]
  68. log.info(f"Found {len(files)} files in {repo} with prefix {prefix}")
  69. else:
  70. if isinstance(files, str):
  71. files = [files]
  72. files = list(chain.from_iterable(map(braceexpand, files)))
  73. log.info(f"Expanded {len(files)} files in {repo}")
  74. # Get sharded files
  75. self.files = sorted(files)
  76. Random(seed).shuffle(self.files)
  77. def __iter__(self):
  78. files = split_by_rank_worker(self.files)
  79. random.shuffle(files)
  80. for filename in files:
  81. try:
  82. yield from self.parse_data(filename)
  83. except Exception as e:
  84. log.exception(f"Failed to parse {filename}: {e}")
  85. def parse_data(self, filename: str):
  86. for data in self.parse_data_internal(filename):
  87. text = data["text"]
  88. # 30% modeling phones
  89. if random.random() < 0.3:
  90. text = " ".join(
  91. [
  92. (f"<p:{i}>" if i not in pu_symbols and i != pad_symbol else i)
  93. for i in text
  94. ]
  95. )
  96. # encode
  97. tokens = self.tokenizer.encode(
  98. text,
  99. add_special_tokens=False,
  100. truncation=False,
  101. max_length=10**6,
  102. )
  103. # Random choice self.max_length
  104. if len(tokens) > self.max_length:
  105. start = random.randint(0, len(tokens) - self.max_length)
  106. tokens = tokens[start : start + self.max_length - 1]
  107. tokens = (
  108. [self.tokenizer.bos_token_id] + tokens + [self.tokenizer.eos_token_id]
  109. )
  110. # Pad dims
  111. placeholder_multi_codebook = torch.zeros((4, len(tokens)), dtype=torch.long)
  112. tokens = torch.concat(
  113. [
  114. torch.tensor([tokens], dtype=torch.long),
  115. placeholder_multi_codebook,
  116. ],
  117. dim=0,
  118. )
  119. labels = tokens.clone()
  120. tokens = tokens[:, :-1]
  121. labels = labels[:, 1:]
  122. labels[1:] = -100 # remove all placeholders
  123. yield {"tokens": tokens, "labels": labels}
  124. def parse_data_internal(self, filename: str):
  125. url = f"https://huggingface.co/datasets/{self.repo}/resolve/main/{filename}"
  126. with xopen(url, mode="rb") as stream:
  127. parquet_file = pq.ParquetFile(stream)
  128. for batch in parquet_file.iter_batches(
  129. batch_size=self.parquet_batch_size, columns=["text"]
  130. ):
  131. # In-batch shuffling
  132. texts = [{"text": text.as_py()} for text in batch["text"]]
  133. random.shuffle(texts)
  134. yield from texts
  135. class AutoAugTextDataset(IterableDataset):
  136. """
  137. Auto Augment Dataset by Speaker
  138. 1. Random concatenate multiple sentences from the same speaker to form a longer sentence
  139. 2. Automatically normalize the text
  140. 3. Mix text and phones
  141. """
  142. def __init__(
  143. self,
  144. server: str = "localhost:50051",
  145. seed: int = 42,
  146. phones_prob: float = 0.3,
  147. repetition_prob: float = 0.0,
  148. max_length: int = 1024,
  149. tokenizer: AutoTokenizer = None,
  150. use_speaker: bool = True,
  151. ):
  152. """
  153. Args:
  154. server: gRPC server address
  155. seed: random seed
  156. phones_prob: probability to use phones
  157. repetition_prob: probability to repeat the same sentence
  158. max_length: max length of the text
  159. tokenizer: tokenizer
  160. """
  161. super().__init__()
  162. self.seed = seed
  163. self.phones_prob = phones_prob
  164. self.max_length = max_length
  165. self.tokenizer = tokenizer
  166. self.repetition_prob = repetition_prob
  167. self.use_speaker = use_speaker
  168. # Read all lines, and group by speaker
  169. self.channel = grpc.insecure_channel(server)
  170. self.stub = DataServiceStub(self.channel)
  171. def __iter__(self):
  172. while True:
  173. yield self.augment()
  174. def tokenize_sentence(self, sentence: str, phones: list[str], mode: str = "sample"):
  175. if (
  176. mode == "sample" and (random.random() < self.phones_prob)
  177. ) or mode == "phones":
  178. sentence = " ".join(
  179. [
  180. (f"<p:{i}>" if i not in pu_symbols and i != pad_symbol else i)
  181. for i in phones
  182. ]
  183. )
  184. else:
  185. sentence = clean_text(sentence)
  186. tokens = self.tokenizer.encode(
  187. f"{sentence}",
  188. max_length=10**6,
  189. add_special_tokens=False,
  190. truncation=False,
  191. )
  192. return sentence, len(tokens)
  193. def augment(self):
  194. # 50% to pure text or pure phones
  195. mode = "sample"
  196. if random.random() < 0.5:
  197. mode = random.choice(["text", "phones"])
  198. # Random sample based on speaker using a truncated normal distribution
  199. a = torch.tensor([0], dtype=torch.float32)
  200. torch.nn.init.trunc_normal_(
  201. a,
  202. mean=self.max_length // 2,
  203. std=self.max_length // 4,
  204. a=10,
  205. b=self.max_length,
  206. )
  207. remaining_tokens = a.long().item() - 4
  208. final_text, final_semantic = [], []
  209. # Shuffle unique lines, estimate that each sample is at least 20 tokens
  210. request = SampleDataRequest(num_samples=self.max_length // 20)
  211. response = self.stub.SampleData(request)
  212. if len(response.samples) == 0:
  213. # Invalid group
  214. return None
  215. samples = list(response.samples)
  216. while remaining_tokens > 0 and len(samples) > 0:
  217. if random.random() < self.repetition_prob:
  218. # Repeat the same sentence
  219. sentence = samples[-1]
  220. else:
  221. sentence = samples.pop()
  222. text, length = self.tokenize_sentence(
  223. sentence.text, sentence.phones, mode=mode
  224. )
  225. remaining_tokens -= length + len(sentence.semantics[0].values)
  226. final_text.append(text)
  227. final_semantic.append(sentence.semantics)
  228. if self.use_speaker is not None:
  229. final_text = [f"[SPK: {response.name}]"] + final_text
  230. final_text = "[INST] " + " ".join(final_text) + " [/INST]"
  231. encoded = self.tokenizer.encode(
  232. final_text,
  233. add_special_tokens=False,
  234. truncation=False,
  235. max_length=10**6,
  236. )
  237. semantic_length = sum([len(i[0].values) for i in final_semantic])
  238. # Single codebook
  239. if len(final_semantic[0]) == 1:
  240. semantic_tokens = [f"<s:{j}>" for i in final_semantic for j in i[0].values]
  241. tokenized = self.tokenizer.encode(
  242. f" ".join(semantic_tokens),
  243. add_special_tokens=False,
  244. truncation=False,
  245. max_length=10**6,
  246. )
  247. # Pack the tokens and semantics (add <s> and </s> to semantic tokens)
  248. tokens = (
  249. [self.tokenizer.bos_token_id]
  250. + encoded
  251. + tokenized
  252. + [self.tokenizer.eos_token_id]
  253. )
  254. tokens = torch.tensor([tokens], dtype=torch.long)
  255. labels = tokens.clone()
  256. labels[0, : len(encoded) + 1] = -100 # Mask out the <s> and query tokens
  257. else:
  258. # Pack the tokens and semantics (add <s> and </s> to semantic tokens)
  259. tokens = (
  260. [self.tokenizer.bos_token_id]
  261. + encoded
  262. + [self.tokenizer.pad_token_id] * semantic_length
  263. + [self.tokenizer.eos_token_id]
  264. )
  265. codes = [[0] * (len(encoded) + 1) for _ in range(len(final_semantic[0]))]
  266. for segment in final_semantic:
  267. for book_idx, book in enumerate(segment):
  268. for j in book.values:
  269. codes[book_idx].append(int(j) + 2)
  270. for book in codes:
  271. book.append(1)
  272. tokens = [tokens] + codes
  273. tokens = torch.tensor(tokens, dtype=torch.long)
  274. labels = tokens.clone()
  275. labels[
  276. 1:, : len(encoded) + 1
  277. ] = -100 # Mask out the <s> tokens for semantic
  278. return {
  279. "tokens": tokens[:, :-1],
  280. "labels": labels[:, 1:],
  281. }
  282. @dataclass
  283. class TextDataCollator:
  284. tokenizer: AutoTokenizer
  285. max_length: int = 1024
  286. def __call__(self, examples):
  287. tokens, attention_masks, labels = [], [], []
  288. for example in examples:
  289. _tokens = example["tokens"][:, : self.max_length]
  290. _labels = example["labels"][:, : self.max_length]
  291. _attention_mask = torch.ones((self.max_length,), dtype=torch.bool)
  292. _attention_mask[: _tokens.size(1)] = False
  293. assert _tokens.size(1) == _labels.size(
  294. 1
  295. ), f"{_tokens.size(1)} != {_labels.size(1)}"
  296. if _tokens.size(1) < self.max_length:
  297. _tokens = F.pad(
  298. _tokens,
  299. (0, self.max_length - _tokens.size(1)),
  300. value=self.tokenizer.eos_token_id,
  301. )
  302. _labels = F.pad(
  303. _labels, (0, self.max_length - _labels.size(1)), value=-100
  304. )
  305. tokens.append(_tokens)
  306. attention_masks.append(_attention_mask)
  307. labels.append(_labels)
  308. tokens = torch.stack(tokens, dim=0)
  309. attention_masks = torch.stack(attention_masks, dim=0)
  310. labels = torch.stack(labels, dim=0)
  311. return {
  312. "inputs": tokens,
  313. "attention_masks": attention_masks,
  314. "labels": labels,
  315. }
  316. class InterleaveDataset(IterableDataset):
  317. def __init__(
  318. self,
  319. datasets: list[IterableDataset],
  320. probabilities: list[float],
  321. seed: int = 42,
  322. ):
  323. super().__init__()
  324. self.datasets = datasets
  325. self.probabilities = probabilities
  326. self.seed = seed
  327. def __iter__(self):
  328. rng = np.random.default_rng(self.seed)
  329. dataset_iterators = [iter(dataset) for dataset in self.datasets]
  330. while True:
  331. # Random choice one
  332. dataset_idx = rng.choice(len(self.datasets), p=self.probabilities)
  333. dataset_iterator = dataset_iterators[dataset_idx]
  334. try:
  335. yield next(dataset_iterator)
  336. except StopIteration:
  337. # Exhausted, create a new iterator
  338. dataset_iterators[dataset_idx] = iter(self.datasets[dataset_idx])
  339. yield next(dataset_iterators[dataset_idx])
  340. class TextDataModule(LightningDataModule):
  341. def __init__(
  342. self,
  343. train_dataset: Union[StreamTextDataset, AutoAugTextDataset, InterleaveDataset],
  344. val_dataset: Union[StreamTextDataset, AutoAugTextDataset, InterleaveDataset],
  345. batch_size: int = 32,
  346. tokenizer: AutoTokenizer = None,
  347. max_length: int = 1024,
  348. num_workers: int = 4,
  349. ):
  350. super().__init__()
  351. self.train_dataset = train_dataset
  352. self.val_dataset = val_dataset
  353. self.batch_size = batch_size
  354. self.tokenizer = tokenizer
  355. self.max_length = max_length
  356. self.num_workers = num_workers
  357. def train_dataloader(self):
  358. return DataLoader(
  359. self.train_dataset,
  360. batch_size=self.batch_size,
  361. collate_fn=TextDataCollator(self.tokenizer, self.max_length),
  362. num_workers=self.num_workers,
  363. )
  364. def val_dataloader(self):
  365. return DataLoader(
  366. self.val_dataset,
  367. batch_size=self.batch_size,
  368. collate_fn=TextDataCollator(self.tokenizer, self.max_length),
  369. num_workers=self.num_workers,
  370. )
  371. if __name__ == "__main__":
  372. import json
  373. from tqdm import tqdm
  374. ds = AutoAugTextDataset(
  375. tokenizer=AutoTokenizer.from_pretrained("fishaudio/speech-lm-v1"),
  376. use_speaker=True,
  377. )
  378. # ds = StreamTextDataset(
  379. # prefix="en/",
  380. # tokenizer=AutoTokenizer.from_pretrained("fishaudio/speech-lm-v1"),
  381. # )
  382. dm = TextDataModule(
  383. train_dataset=ds,
  384. val_dataset=ds,
  385. tokenizer=ds.tokenizer,
  386. batch_size=2,
  387. max_length=1024,
  388. num_workers=0,
  389. )
  390. for batch in tqdm(dm.train_dataloader()):
  391. print(batch)
  392. break