concat_repeat.py 2.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. import bisect
  2. import random
  3. from typing import Iterable
  4. from torch.utils.data import Dataset, IterableDataset
  5. class ConcatRepeatDataset(Dataset):
  6. datasets: list[Dataset]
  7. cumulative_sizes: list[int]
  8. repeats: list[int]
  9. @staticmethod
  10. def cumsum(sequence, repeats):
  11. r, s = [], 0
  12. for dataset, repeat in zip(sequence, repeats):
  13. l = len(dataset) * repeat
  14. r.append(l + s)
  15. s += l
  16. return r
  17. def __init__(self, datasets: Iterable[Dataset], repeats: list[int]):
  18. super().__init__()
  19. self.datasets = list(datasets)
  20. self.repeats = repeats
  21. assert len(self.datasets) > 0, "datasets should not be an empty iterable"
  22. assert len(self.datasets) == len(
  23. repeats
  24. ), "datasets and repeats should have the same length"
  25. for d in self.datasets:
  26. assert not isinstance(
  27. d, IterableDataset
  28. ), "ConcatRepeatDataset does not support IterableDataset"
  29. self.cumulative_sizes = self.cumsum(self.datasets, self.repeats)
  30. def __len__(self):
  31. return self.cumulative_sizes[-1]
  32. def __getitem__(self, idx):
  33. dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
  34. if dataset_idx == 0:
  35. sample_idx = idx
  36. else:
  37. sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
  38. dataset = self.datasets[dataset_idx]
  39. return dataset[sample_idx % len(dataset)]
  40. class ConcatWeightedIterableDataset(IterableDataset):
  41. datasets: list[IterableDataset]
  42. weights: list[float]
  43. def __init__(self, datasets: Iterable[IterableDataset], weights: list[float]):
  44. super().__init__()
  45. total_weight = sum(weights)
  46. self.weights = [w / total_weight for w in weights]
  47. self.datasets = list(datasets)
  48. assert len(self.datasets) > 0, "datasets should not be an empty iterable"
  49. assert len(self.datasets) == len(
  50. weights
  51. ), "datasets and repeats should have the same length"
  52. for d in self.datasets:
  53. assert isinstance(
  54. d, IterableDataset
  55. ), "ConcatRepeatIterableDataset only supports IterableDataset"
  56. def __iter__(self):
  57. all_datasets = [iter(dataset) for dataset in self.datasets]
  58. ids = list(range(len(self.datasets)))
  59. while True:
  60. chosen_dataset = random.choices(ids, self.weights)[0]
  61. try:
  62. yield next(all_datasets[chosen_dataset])
  63. except StopIteration:
  64. all_datasets[chosen_dataset] = iter(self.datasets[chosen_dataset])
  65. yield next(all_datasets[chosen_dataset])