text.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225
  1. import random
  2. from dataclasses import dataclass
  3. from itertools import chain
  4. from random import Random
  5. from typing import Optional, Union
  6. import numpy as np
  7. import pyarrow.parquet as pq
  8. from datasets.download.streaming_download_manager import xopen
  9. from huggingface_hub import HfApi
  10. from lightning import LightningDataModule
  11. from lightning.pytorch.utilities.exceptions import MisconfigurationException
  12. from torch.distributed import get_rank, get_world_size, is_initialized
  13. from torch.utils.data import DataLoader, IterableDataset, get_worker_info
  14. from transformers import AutoTokenizer
  15. from fish_speech.utils import RankedLogger
  16. from fish_speech.utils.braceexpand import braceexpand
  17. log = RankedLogger(__name__, rank_zero_only=True)
  18. class TextDataset(IterableDataset):
  19. def __init__(
  20. self,
  21. files: Optional[Union[list[str], str]] = None,
  22. prefix: Optional[str] = None,
  23. seed: int = 42,
  24. parquet_batch_size: int = 10000,
  25. repo: str = "uonlp/CulturaX",
  26. ):
  27. super().__init__()
  28. self.seed = seed
  29. self.parquet_batch_size = parquet_batch_size
  30. self.repo = repo
  31. if files is None and prefix is None:
  32. raise ValueError("Either files or prefix must be specified")
  33. if prefix is not None:
  34. files = HfApi().list_repo_files(repo, repo_type="dataset")
  35. files = [
  36. f for f in files if f.startswith(prefix) and f.endswith(".parquet")
  37. ]
  38. log.info(f"Found {len(files)} files in {repo} with prefix {prefix}")
  39. else:
  40. if isinstance(files, str):
  41. files = [files]
  42. files = list(chain.from_iterable(map(braceexpand, files)))
  43. log.info(f"Expanded {len(files)} files in {repo}")
  44. # Get sharded files
  45. self.files = sorted(files)
  46. Random(seed).shuffle(self.files)
  47. def get_data_splits(self, files):
  48. # We need to know the total number of devices
  49. # to split the data properly
  50. total_devices = 1
  51. if is_initialized():
  52. total_devices = get_world_size()
  53. worker_info = get_worker_info()
  54. if worker_info is not None:
  55. total_devices *= worker_info.num_workers
  56. if len(files) < total_devices:
  57. # Repeat the files N times to match the number of devices
  58. files = files * (total_devices // len(files) + 1)
  59. # DDP
  60. if is_initialized():
  61. files = files[get_rank() :: get_world_size()]
  62. # Split by worker
  63. if worker_info is not None:
  64. files = files[worker_info.id :: worker_info.num_workers]
  65. return files
  66. def __iter__(self):
  67. files = self.get_data_splits(self.files)
  68. random.shuffle(files)
  69. for filename in files:
  70. try:
  71. yield from self.parse_data(filename)
  72. except Exception as e:
  73. log.exception(f"Failed to parse {filename}: {e}")
  74. def parse_data(self, filename: str):
  75. url = f"https://huggingface.co/datasets/{self.repo}/resolve/main/{filename}"
  76. with xopen(url, mode="rb") as stream:
  77. parquet_file = pq.ParquetFile(stream)
  78. for batch in parquet_file.iter_batches(
  79. batch_size=self.parquet_batch_size, columns=["text"]
  80. ):
  81. # In-batch shuffling
  82. texts = [{"text": text.as_py()} for text in batch["text"]]
  83. random.shuffle(texts)
  84. yield from texts
  85. @dataclass
  86. class TextDataCollator:
  87. tokenizer: AutoTokenizer
  88. max_length: int = 512
  89. def __call__(self, examples):
  90. texts = [i["text"] for i in examples]
  91. if self.tokenizer.pad_token is None:
  92. self.tokenizer.pad_token = self.tokenizer.eos_token
  93. data = self.tokenizer(
  94. texts,
  95. truncation=True,
  96. padding=True,
  97. max_length=self.max_length,
  98. return_tensors="pt",
  99. )
  100. data["labels"] = data["input_ids"].clone()
  101. data["labels"][data["attention_mask"] == 0] = -100
  102. return data
  103. class InterleaveDataset(IterableDataset):
  104. def __init__(
  105. self,
  106. datasets: list[IterableDataset],
  107. probabilities: list[float],
  108. seed: int = 42,
  109. ):
  110. super().__init__()
  111. self.datasets = datasets
  112. self.probabilities = probabilities
  113. self.seed = seed
  114. def __iter__(self):
  115. rng = np.random.default_rng(self.seed)
  116. dataset_iterators = [iter(dataset) for dataset in self.datasets]
  117. while True:
  118. # Random choice one
  119. dataset_idx = rng.choice(len(self.datasets), p=self.probabilities)
  120. dataset_iterator = dataset_iterators[dataset_idx]
  121. try:
  122. yield next(dataset_iterator)
  123. except StopIteration:
  124. # Exhausted, create a new iterator
  125. dataset_iterators[dataset_idx] = iter(self.datasets[dataset_idx])
  126. yield next(dataset_iterators[dataset_idx])
  127. class TextDataModule(LightningDataModule):
  128. def __init__(
  129. self,
  130. train_dataset: Union[TextDataset, InterleaveDataset],
  131. val_dataset: Union[TextDataset, InterleaveDataset],
  132. batch_size: int = 32,
  133. tokenizer: AutoTokenizer = None,
  134. max_length: int = 1024,
  135. num_workers: int = 4,
  136. ):
  137. super().__init__()
  138. self.train_dataset = train_dataset
  139. self.val_dataset = val_dataset
  140. self.batch_size = batch_size
  141. self.tokenizer = tokenizer
  142. self.max_length = max_length
  143. self.num_workers = num_workers
  144. def train_dataloader(self):
  145. return DataLoader(
  146. self.train_dataset,
  147. batch_size=self.batch_size,
  148. collate_fn=TextDataCollator(self.tokenizer, self.max_length),
  149. num_workers=self.num_workers,
  150. )
  151. def val_dataloader(self):
  152. return DataLoader(
  153. self.val_dataset,
  154. batch_size=self.batch_size,
  155. collate_fn=TextDataCollator(self.tokenizer, self.max_length),
  156. num_workers=self.num_workers,
  157. )
  158. if __name__ == "__main__":
  159. dm = TextDataModule(
  160. InterleaveDataset(
  161. datasets=[
  162. TextDataset(
  163. prefix="en/en_part_",
  164. ),
  165. TextDataset(
  166. prefix="zh/zh_part_",
  167. ),
  168. TextDataset(
  169. prefix="ja/ja_part_",
  170. ),
  171. ],
  172. probabilities=[0.8, 0.1, 0.1],
  173. ),
  174. TextDataset(
  175. files="ja/ja_part_{00000..00159}",
  176. ),
  177. batch_size=2,
  178. tokenizer=AutoTokenizer.from_pretrained("bert-base-multilingual-cased"),
  179. )
  180. for batch in dm.train_dataloader():
  181. print(batch)
  182. break