Ver Fonte

optimize InterleaveDataset

Lengyue há 2 anos atrás
pai
commit
e01dd330a1
2 ficheiros alterados com 42 adições e 5 exclusões
  1. 1 1
      speech_lm/configs/pretrain.yaml
  2. 41 4
      speech_lm/datasets/cultura_x.py

+ 1 - 1
speech_lm/configs/pretrain.yaml

@@ -40,7 +40,7 @@ schedule:
   clip_grad_norm: 1.0
 
 dataset:
-  _target_: datasets.interleave_datasets
+  _target_: speech_lm.datasets.cultura_x.InterleaveDataset
   datasets:
     - _target_: speech_lm.datasets.cultura_x.CulturaXDataset
       lang: 'en'

+ 41 - 4
speech_lm/datasets/cultura_x.py

@@ -8,6 +8,8 @@ from random import Random
 from logging import getLogger
 from huggingface_hub import hf_hub_download
 from transformers import AutoTokenizer
+import numpy as np
+
 
 SUBSETS = {
     "en": "en_part_{00000..03071}",
@@ -36,11 +38,11 @@ class CulturaXDataset(IterableDataset):
         total_devices = 1
         if is_initialized():
             total_devices = get_world_size()
-        
+
         worker_info = get_worker_info()
         if worker_info is not None:
             total_devices *= worker_info.num_workers
-        
+
         if len(files) < total_devices:
             # Repeat the files N times to match the number of devices
             files = files * (total_devices // len(files) + 1)
@@ -60,7 +62,10 @@ class CulturaXDataset(IterableDataset):
         random.shuffle(files)
 
         for filename in files:
-            yield from self.parse_data(filename)
+            try:
+                yield from self.parse_data(filename)
+            except:
+                log.exception(f"Failed to parse {filename}")
 
     def parse_data(self, filename: str):
         fname = hf_hub_download(
@@ -114,10 +119,42 @@ class CulutreXCollator:
         return data
 
 
+class InterleaveDataset(IterableDataset):
+    def __init__(
+        self,
+        datasets: list[IterableDataset],
+        probabilities: list[float],
+        seed: int = 42,
+    ):
+        super().__init__()
+
+        self.datasets = datasets
+        self.probabilities = probabilities
+        self.seed = seed
+
+    def __iter__(self):
+        rng = np.random.default_rng(self.seed)
+        dataset_iterators = [iter(dataset) for dataset in self.datasets]
+
+        while True:
+            # Random choice one
+            dataset_idx = rng.choice(len(self.datasets), p=self.probabilities)
+            dataset_iterator = dataset_iterators[dataset_idx]
+
+            try:
+                yield next(dataset_iterator)
+            except StopIteration:
+                # Exhausted, create a new iterator
+                dataset_iterators[dataset_idx] = iter(self.datasets[dataset_idx])
+                yield next(dataset_iterators[dataset_idx])
+
+
 if __name__ == "__main__":
     from torch.utils.data import DataLoader
 
-    dataset = CulturaXDataset("en")
+    dataset_en = CulturaXDataset("en")
+    dataset_ja = CulturaXDataset("ja")
+    dataset = InterleaveDataset([dataset_en, dataset_ja], [0.5, 0.5])
     collator = CulutreXCollator(AutoTokenizer.from_pretrained("gpt2"))
 
     for batch in DataLoader(dataset, batch_size=4, collate_fn=collator, num_workers=4):