| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253 |
- import bisect
- import random
- from typing import Iterable
- from torch.utils.data import Dataset, IterableDataset
- class ConcatRepeatDataset(Dataset):
- datasets: list[Dataset]
- cumulative_sizes: list[int]
- repeats: list[int]
- @staticmethod
- def cumsum(sequence, repeats):
- r, s = [], 0
- for dataset, repeat in zip(sequence, repeats):
- l = len(dataset) * repeat
- r.append(l + s)
- s += l
- return r
- def __init__(self, datasets: Iterable[Dataset], repeats: list[int]):
- super().__init__()
- self.datasets = list(datasets)
- self.repeats = repeats
- assert len(self.datasets) > 0, "datasets should not be an empty iterable"
- assert len(self.datasets) == len(
- repeats
- ), "datasets and repeats should have the same length"
- for d in self.datasets:
- assert not isinstance(
- d, IterableDataset
- ), "ConcatRepeatDataset does not support IterableDataset"
- self.cumulative_sizes = self.cumsum(self.datasets, self.repeats)
- def __len__(self):
- return self.cumulative_sizes[-1]
- def __getitem__(self, idx):
- dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
- if dataset_idx == 0:
- sample_idx = idx
- else:
- sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
- dataset = self.datasets[dataset_idx]
- return dataset[sample_idx % len(dataset)]
|