text.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225
  1. import random
  2. from dataclasses import dataclass
  3. from itertools import chain
  4. from random import Random
  5. from typing import Optional, Union
  6. import numpy as np
  7. import pyarrow.parquet as pq
  8. from datasets.download.streaming_download_manager import xopen
  9. from huggingface_hub import HfApi
  10. from lightning import LightningDataModule
  11. from torch.distributed import get_rank, get_world_size, is_initialized
  12. from torch.utils.data import DataLoader, IterableDataset, get_worker_info
  13. from transformers import AutoTokenizer
  14. from fish_speech.utils import RankedLogger
  15. from fish_speech.utils.braceexpand import braceexpand
  16. log = RankedLogger(__name__, rank_zero_only=True)
  17. class TextDataset(IterableDataset):
  18. def __init__(
  19. self,
  20. files: Optional[Union[list[str], str]] = None,
  21. prefix: Optional[str] = None,
  22. seed: int = 42,
  23. parquet_batch_size: int = 10000,
  24. repo: str = "uonlp/CulturaX",
  25. ):
  26. super().__init__()
  27. self.seed = seed
  28. self.parquet_batch_size = parquet_batch_size
  29. self.repo = repo
  30. if files is None and prefix is None:
  31. raise ValueError("Either files or prefix must be specified")
  32. if prefix is not None:
  33. files = HfApi().list_repo_files(repo, repo_type="dataset")
  34. files = [f for f in files if f.startswith(prefix)]
  35. log.info(f"Found {len(files)} files in {repo} with prefix {prefix}")
  36. else:
  37. if isinstance(files, str):
  38. files = [files]
  39. files = list(chain.from_iterable(map(braceexpand, files)))
  40. log.info(f"Expanded {len(files)} files in {repo}")
  41. # Get sharded files
  42. self.files = sorted(files)
  43. Random(seed).shuffle(self.files)
  44. def get_data_splits(self, files):
  45. # We need to know the total number of devices
  46. # to split the data properly
  47. total_devices = 1
  48. if is_initialized():
  49. total_devices = get_world_size()
  50. worker_info = get_worker_info()
  51. if worker_info is not None:
  52. total_devices *= worker_info.num_workers
  53. if len(files) < total_devices:
  54. # Repeat the files N times to match the number of devices
  55. files = files * (total_devices // len(files) + 1)
  56. # DDP
  57. if is_initialized():
  58. files = files[get_rank() :: get_world_size()]
  59. # Split by worker
  60. if worker_info is not None:
  61. files = files[worker_info.id :: worker_info.num_workers]
  62. return files
  63. def __iter__(self):
  64. files = self.get_data_splits(self.files)
  65. random.shuffle(files)
  66. for filename in files:
  67. try:
  68. yield from self.parse_data(filename)
  69. except Exception as e:
  70. log.exception(f"Failed to parse {filename}: {e}")
  71. def parse_data(self, filename: str):
  72. url = f"https://huggingface.co/datasets/{self.repo}/resolve/main/{filename}"
  73. with xopen(url, mode="rb") as stream:
  74. parquet_file = pq.ParquetFile(stream)
  75. for batch in parquet_file.iter_batches(
  76. batch_size=self.parquet_batch_size, columns=["text"]
  77. ):
  78. # In-batch shuffling
  79. texts = [{"text": text.as_py()} for text in batch["text"]]
  80. random.shuffle(texts)
  81. yield from texts
  82. @dataclass
  83. class TextDataCollator:
  84. tokenizer: AutoTokenizer
  85. max_length: int = 512
  86. def __call__(self, examples):
  87. texts = [i["text"] for i in examples]
  88. if self.tokenizer.pad_token is None:
  89. self.tokenizer.pad_token = self.tokenizer.eos_token
  90. data = self.tokenizer(
  91. texts,
  92. truncation=True,
  93. padding=True,
  94. max_length=self.max_length,
  95. return_tensors="pt",
  96. )
  97. data["labels"] = data["input_ids"].clone()
  98. data["labels"][data["attention_mask"] == 0] = -100
  99. return data
  100. class InterleaveDataset(IterableDataset):
  101. def __init__(
  102. self,
  103. datasets: list[IterableDataset],
  104. probabilities: list[float],
  105. seed: int = 42,
  106. ):
  107. super().__init__()
  108. self.datasets = datasets
  109. self.probabilities = probabilities
  110. self.seed = seed
  111. def __iter__(self):
  112. rng = np.random.default_rng(self.seed)
  113. dataset_iterators = [iter(dataset) for dataset in self.datasets]
  114. while True:
  115. # Random choice one
  116. dataset_idx = rng.choice(len(self.datasets), p=self.probabilities)
  117. dataset_iterator = dataset_iterators[dataset_idx]
  118. try:
  119. yield next(dataset_iterator)
  120. except StopIteration:
  121. # Exhausted, create a new iterator
  122. dataset_iterators[dataset_idx] = iter(self.datasets[dataset_idx])
  123. yield next(dataset_iterators[dataset_idx])
  124. class TextDataModule(LightningDataModule):
  125. def __init__(
  126. self,
  127. train_dataset: Union[TextDataset, InterleaveDataset],
  128. val_dataset: Optional[Union[TextDataset, InterleaveDataset]] = None,
  129. batch_size: int = 32,
  130. tokenizer: AutoTokenizer = None,
  131. max_length: int = 1024,
  132. num_workers: int = 4,
  133. ):
  134. super().__init__()
  135. self.train_dataset = train_dataset
  136. self.val_dataset = val_dataset
  137. self.batch_size = batch_size
  138. self.tokenizer = tokenizer
  139. self.max_length = max_length
  140. self.num_workers = num_workers
  141. def train_dataloader(self):
  142. return DataLoader(
  143. self.train_dataset,
  144. batch_size=self.batch_size,
  145. collate_fn=TextDataCollator(self.tokenizer, self.max_length),
  146. num_workers=self.num_workers,
  147. )
  148. def val_dataloader(self):
  149. if self.val_dataset is None:
  150. return None
  151. return DataLoader(
  152. self.val_dataset,
  153. batch_size=self.batch_size,
  154. collate_fn=TextDataCollator(self.tokenizer, self.max_length),
  155. num_workers=self.num_workers,
  156. )
  157. if __name__ == "__main__":
  158. dm = TextDataModule(
  159. InterleaveDataset(
  160. datasets=[
  161. TextDataset(
  162. prefix="en/en_part_",
  163. ),
  164. TextDataset(
  165. prefix="zh/zh_part_",
  166. ),
  167. TextDataset(
  168. prefix="ja/ja_part_",
  169. ),
  170. ],
  171. probabilities=[0.8, 0.1, 0.1],
  172. ),
  173. TextDataset(
  174. files="ja/ja_part_{00000..00159}",
  175. ),
  176. batch_size=2,
  177. tokenizer=AutoTokenizer.from_pretrained("bert-base-multilingual-cased"),
  178. )
  179. for batch in dm.train_dataloader():
  180. print(batch)
  181. break