| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788 |
- import bisect
- import random
- from typing import Iterable
- from torch.utils.data import Dataset, IterableDataset
- class ConcatRepeatDataset(Dataset):
- datasets: list[Dataset]
- cumulative_sizes: list[int]
- repeats: list[int]
- @staticmethod
- def cumsum(sequence, repeats):
- r, s = [], 0
- for dataset, repeat in zip(sequence, repeats):
- l = len(dataset) * repeat
- r.append(l + s)
- s += l
- return r
- def __init__(self, datasets: Iterable[Dataset], repeats: list[int]):
- super().__init__()
- self.datasets = list(datasets)
- self.repeats = repeats
- assert len(self.datasets) > 0, "datasets should not be an empty iterable"
- assert len(self.datasets) == len(
- repeats
- ), "datasets and repeats should have the same length"
- for d in self.datasets:
- assert not isinstance(
- d, IterableDataset
- ), "ConcatRepeatDataset does not support IterableDataset"
- self.cumulative_sizes = self.cumsum(self.datasets, self.repeats)
- def __len__(self):
- return self.cumulative_sizes[-1]
- def __getitem__(self, idx):
- dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
- if dataset_idx == 0:
- sample_idx = idx
- else:
- sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
- dataset = self.datasets[dataset_idx]
- return dataset[sample_idx % len(dataset)]
- class ConcatWeightedIterableDataset(IterableDataset):
- datasets: list[IterableDataset]
- weights: list[float]
- def __init__(self, datasets: Iterable[IterableDataset], weights: list[float]):
- super().__init__()
- total_weight = sum(weights)
- self.weights = [w / total_weight for w in weights]
- self.datasets = list(datasets)
- assert len(self.datasets) > 0, "datasets should not be an empty iterable"
- assert len(self.datasets) == len(
- weights
- ), "datasets and repeats should have the same length"
- for d in self.datasets:
- assert isinstance(
- d, IterableDataset
- ), "ConcatRepeatIterableDataset only supports IterableDataset"
- def __iter__(self):
- all_datasets = [iter(dataset) for dataset in self.datasets]
- ids = list(range(len(self.datasets)))
- while True:
- chosen_dataset = random.choices(ids, self.weights)[0]
- try:
- yield next(all_datasets[chosen_dataset])
- except StopIteration:
- all_datasets[chosen_dataset] = iter(self.datasets[chosen_dataset])
- yield next(all_datasets[chosen_dataset])
|