concat_repeat.py 1.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152
  1. import bisect
  2. from typing import Iterable
  3. from torch.utils.data import Dataset, IterableDataset
  4. class ConcatRepeatDataset(Dataset):
  5. datasets: list[Dataset]
  6. cumulative_sizes: list[int]
  7. repeats: list[int]
  8. @staticmethod
  9. def cumsum(sequence, repeats):
  10. r, s = [], 0
  11. for dataset, repeat in zip(sequence, repeats):
  12. l = len(dataset) * repeat
  13. r.append(l + s)
  14. s += l
  15. return r
  16. def __init__(self, datasets: Iterable[Dataset], repeats: list[int]):
  17. super().__init__()
  18. self.datasets = list(datasets)
  19. self.repeats = repeats
  20. assert len(self.datasets) > 0, "datasets should not be an empty iterable"
  21. assert len(self.datasets) == len(
  22. repeats
  23. ), "datasets and repeats should have the same length"
  24. for d in self.datasets:
  25. assert not isinstance(
  26. d, IterableDataset
  27. ), "ConcatDataset does not support IterableDataset"
  28. self.cumulative_sizes = self.cumsum(self.datasets, self.repeats)
  29. def __len__(self):
  30. return self.cumulative_sizes[-1]
  31. def __getitem__(self, idx):
  32. dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
  33. if dataset_idx == 0:
  34. sample_idx = idx
  35. else:
  36. sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
  37. dataset = self.datasets[dataset_idx]
  38. return dataset[sample_idx % len(dataset)]