|
|
@@ -0,0 +1,52 @@
|
|
|
+import bisect
|
|
|
+from typing import Iterable
|
|
|
+
|
|
|
+from torch.utils.data import Dataset, IterableDataset
|
|
|
+
|
|
|
+
|
|
|
+class ConcatRepeatDataset(Dataset):
|
|
|
+ datasets: list[Dataset]
|
|
|
+ cumulative_sizes: list[int]
|
|
|
+ repeats: list[int]
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def cumsum(sequence, repeats):
|
|
|
+ r, s = [], 0
|
|
|
+ for dataset, repeat in zip(sequence, repeats):
|
|
|
+ l = len(dataset) * repeat
|
|
|
+ r.append(l + s)
|
|
|
+ s += l
|
|
|
+ return r
|
|
|
+
|
|
|
+ def __init__(self, datasets: Iterable[Dataset], repeats: list[int]):
|
|
|
+ super().__init__()
|
|
|
+
|
|
|
+ self.datasets = list(datasets)
|
|
|
+ self.repeats = repeats
|
|
|
+
|
|
|
+ assert len(self.datasets) > 0, "datasets should not be an empty iterable"
|
|
|
+ assert len(self.datasets) == len(
|
|
|
+ repeats
|
|
|
+ ), "datasets and repeats should have the same length"
|
|
|
+
|
|
|
+ for d in self.datasets:
|
|
|
+ assert not isinstance(
|
|
|
+ d, IterableDataset
|
|
|
+ ), "ConcatDataset does not support IterableDataset"
|
|
|
+
|
|
|
+ self.cumulative_sizes = self.cumsum(self.datasets, self.repeats)
|
|
|
+
|
|
|
+ def __len__(self):
|
|
|
+ return self.cumulative_sizes[-1]
|
|
|
+
|
|
|
+ def __getitem__(self, idx):
|
|
|
+ dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
|
|
|
+
|
|
|
+ if dataset_idx == 0:
|
|
|
+ sample_idx = idx
|
|
|
+ else:
|
|
|
+ sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
|
|
|
+
|
|
|
+ dataset = self.datasets[dataset_idx]
|
|
|
+
|
|
|
+ return dataset[sample_idx % len(dataset)]
|