concat_repeat.py 1.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253
  1. import bisect
  2. import random
  3. from typing import Iterable
  4. from torch.utils.data import Dataset, IterableDataset
  5. class ConcatRepeatDataset(Dataset):
  6. datasets: list[Dataset]
  7. cumulative_sizes: list[int]
  8. repeats: list[int]
  9. @staticmethod
  10. def cumsum(sequence, repeats):
  11. r, s = [], 0
  12. for dataset, repeat in zip(sequence, repeats):
  13. l = len(dataset) * repeat
  14. r.append(l + s)
  15. s += l
  16. return r
  17. def __init__(self, datasets: Iterable[Dataset], repeats: list[int]):
  18. super().__init__()
  19. self.datasets = list(datasets)
  20. self.repeats = repeats
  21. assert len(self.datasets) > 0, "datasets should not be an empty iterable"
  22. assert len(self.datasets) == len(
  23. repeats
  24. ), "datasets and repeats should have the same length"
  25. for d in self.datasets:
  26. assert not isinstance(
  27. d, IterableDataset
  28. ), "ConcatRepeatDataset does not support IterableDataset"
  29. self.cumulative_sizes = self.cumsum(self.datasets, self.repeats)
  30. def __len__(self):
  31. return self.cumulative_sizes[-1]
  32. def __getitem__(self, idx):
  33. dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
  34. if dataset_idx == 0:
  35. sample_idx = idx
  36. else:
  37. sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
  38. dataset = self.datasets[dataset_idx]
  39. return dataset[sample_idx % len(dataset)]