|
@@ -8,6 +8,8 @@ from random import Random
|
|
|
from logging import getLogger
|
|
from logging import getLogger
|
|
|
from huggingface_hub import hf_hub_download
|
|
from huggingface_hub import hf_hub_download
|
|
|
from transformers import AutoTokenizer
|
|
from transformers import AutoTokenizer
|
|
|
|
|
+import numpy as np
|
|
|
|
|
+
|
|
|
|
|
|
|
|
SUBSETS = {
|
|
SUBSETS = {
|
|
|
"en": "en_part_{00000..03071}",
|
|
"en": "en_part_{00000..03071}",
|
|
@@ -36,11 +38,11 @@ class CulturaXDataset(IterableDataset):
|
|
|
total_devices = 1
|
|
total_devices = 1
|
|
|
if is_initialized():
|
|
if is_initialized():
|
|
|
total_devices = get_world_size()
|
|
total_devices = get_world_size()
|
|
|
-
|
|
|
|
|
|
|
+
|
|
|
worker_info = get_worker_info()
|
|
worker_info = get_worker_info()
|
|
|
if worker_info is not None:
|
|
if worker_info is not None:
|
|
|
total_devices *= worker_info.num_workers
|
|
total_devices *= worker_info.num_workers
|
|
|
-
|
|
|
|
|
|
|
+
|
|
|
if len(files) < total_devices:
|
|
if len(files) < total_devices:
|
|
|
# Repeat the files N times to match the number of devices
|
|
# Repeat the files N times to match the number of devices
|
|
|
files = files * (total_devices // len(files) + 1)
|
|
files = files * (total_devices // len(files) + 1)
|
|
@@ -60,7 +62,10 @@ class CulturaXDataset(IterableDataset):
|
|
|
random.shuffle(files)
|
|
random.shuffle(files)
|
|
|
|
|
|
|
|
for filename in files:
|
|
for filename in files:
|
|
|
- yield from self.parse_data(filename)
|
|
|
|
|
|
|
+ try:
|
|
|
|
|
+ yield from self.parse_data(filename)
|
|
|
|
|
+ except:
|
|
|
|
|
+ log.exception(f"Failed to parse {filename}")
|
|
|
|
|
|
|
|
def parse_data(self, filename: str):
|
|
def parse_data(self, filename: str):
|
|
|
fname = hf_hub_download(
|
|
fname = hf_hub_download(
|
|
@@ -114,10 +119,42 @@ class CulutreXCollator:
|
|
|
return data
|
|
return data
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
+class InterleaveDataset(IterableDataset):
|
|
|
|
|
+ def __init__(
|
|
|
|
|
+ self,
|
|
|
|
|
+ datasets: list[IterableDataset],
|
|
|
|
|
+ probabilities: list[float],
|
|
|
|
|
+ seed: int = 42,
|
|
|
|
|
+ ):
|
|
|
|
|
+ super().__init__()
|
|
|
|
|
+
|
|
|
|
|
+ self.datasets = datasets
|
|
|
|
|
+ self.probabilities = probabilities
|
|
|
|
|
+ self.seed = seed
|
|
|
|
|
+
|
|
|
|
|
+ def __iter__(self):
|
|
|
|
|
+ rng = np.random.default_rng(self.seed)
|
|
|
|
|
+ dataset_iterators = [iter(dataset) for dataset in self.datasets]
|
|
|
|
|
+
|
|
|
|
|
+ while True:
|
|
|
|
|
+ # Random choice one
|
|
|
|
|
+ dataset_idx = rng.choice(len(self.datasets), p=self.probabilities)
|
|
|
|
|
+ dataset_iterator = dataset_iterators[dataset_idx]
|
|
|
|
|
+
|
|
|
|
|
+ try:
|
|
|
|
|
+ yield next(dataset_iterator)
|
|
|
|
|
+ except StopIteration:
|
|
|
|
|
+ # Exhausted, create a new iterator
|
|
|
|
|
+ dataset_iterators[dataset_idx] = iter(self.datasets[dataset_idx])
|
|
|
|
|
+ yield next(dataset_iterators[dataset_idx])
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
if __name__ == "__main__":
|
|
if __name__ == "__main__":
|
|
|
from torch.utils.data import DataLoader
|
|
from torch.utils.data import DataLoader
|
|
|
|
|
|
|
|
- dataset = CulturaXDataset("en")
|
|
|
|
|
|
|
+ dataset_en = CulturaXDataset("en")
|
|
|
|
|
+ dataset_ja = CulturaXDataset("ja")
|
|
|
|
|
+ dataset = InterleaveDataset([dataset_en, dataset_ja], [0.5, 0.5])
|
|
|
collator = CulutreXCollator(AutoTokenizer.from_pretrained("gpt2"))
|
|
collator = CulutreXCollator(AutoTokenizer.from_pretrained("gpt2"))
|
|
|
|
|
|
|
|
for batch in DataLoader(dataset, batch_size=4, collate_fn=collator, num_workers=4):
|
|
for batch in DataLoader(dataset, batch_size=4, collate_fn=collator, num_workers=4):
|