import bisect import random from typing import Iterable from torch.utils.data import Dataset, IterableDataset class ConcatRepeatDataset(Dataset): datasets: list[Dataset] cumulative_sizes: list[int] repeats: list[int] @staticmethod def cumsum(sequence, repeats): r, s = [], 0 for dataset, repeat in zip(sequence, repeats): l = len(dataset) * repeat r.append(l + s) s += l return r def __init__(self, datasets: Iterable[Dataset], repeats: list[int]): super().__init__() self.datasets = list(datasets) self.repeats = repeats assert len(self.datasets) > 0, "datasets should not be an empty iterable" assert len(self.datasets) == len( repeats ), "datasets and repeats should have the same length" for d in self.datasets: assert not isinstance( d, IterableDataset ), "ConcatRepeatDataset does not support IterableDataset" self.cumulative_sizes = self.cumsum(self.datasets, self.repeats) def __len__(self): return self.cumulative_sizes[-1] def __getitem__(self, idx): dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) if dataset_idx == 0: sample_idx = idx else: sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] dataset = self.datasets[dataset_idx] return dataset[sample_idx % len(dataset)] class ConcatWeightedIterableDataset(IterableDataset): datasets: list[IterableDataset] weights: list[float] def __init__(self, datasets: Iterable[IterableDataset], weights: list[float]): super().__init__() total_weight = sum(weights) self.weights = [w / total_weight for w in weights] self.datasets = list(datasets) assert len(self.datasets) > 0, "datasets should not be an empty iterable" assert len(self.datasets) == len( weights ), "datasets and repeats should have the same length" for d in self.datasets: assert isinstance( d, IterableDataset ), "ConcatRepeatIterableDataset only supports IterableDataset" def __iter__(self): all_datasets = [iter(dataset) for dataset in self.datasets] ids = list(range(len(self.datasets))) while True: chosen_dataset = random.choices(ids, self.weights)[0] try: yield next(all_datasets[chosen_dataset]) except StopIteration: all_datasets[chosen_dataset] = iter(self.datasets[chosen_dataset]) yield next(all_datasets[chosen_dataset])