|
@@ -1,4 +1,5 @@
|
|
|
import bisect
|
|
import bisect
|
|
|
|
|
+import random
|
|
|
from typing import Iterable
|
|
from typing import Iterable
|
|
|
|
|
|
|
|
from torch.utils.data import Dataset, IterableDataset
|
|
from torch.utils.data import Dataset, IterableDataset
|
|
@@ -32,7 +33,7 @@ class ConcatRepeatDataset(Dataset):
|
|
|
for d in self.datasets:
|
|
for d in self.datasets:
|
|
|
assert not isinstance(
|
|
assert not isinstance(
|
|
|
d, IterableDataset
|
|
d, IterableDataset
|
|
|
- ), "ConcatDataset does not support IterableDataset"
|
|
|
|
|
|
|
+ ), "ConcatRepeatDataset does not support IterableDataset"
|
|
|
|
|
|
|
|
self.cumulative_sizes = self.cumsum(self.datasets, self.repeats)
|
|
self.cumulative_sizes = self.cumsum(self.datasets, self.repeats)
|
|
|
|
|
|
|
@@ -50,3 +51,38 @@ class ConcatRepeatDataset(Dataset):
|
|
|
dataset = self.datasets[dataset_idx]
|
|
dataset = self.datasets[dataset_idx]
|
|
|
|
|
|
|
|
return dataset[sample_idx % len(dataset)]
|
|
return dataset[sample_idx % len(dataset)]
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+class ConcatWeightedIterableDataset(IterableDataset):
|
|
|
|
|
+ datasets: list[IterableDataset]
|
|
|
|
|
+ weights: list[float]
|
|
|
|
|
+
|
|
|
|
|
+ def __init__(self, datasets: Iterable[IterableDataset], weights: list[float]):
|
|
|
|
|
+ super().__init__()
|
|
|
|
|
+
|
|
|
|
|
+ total_weight = sum(weights)
|
|
|
|
|
+ self.weights = [w / total_weight for w in weights]
|
|
|
|
|
+ self.datasets = list(datasets)
|
|
|
|
|
+
|
|
|
|
|
+ assert len(self.datasets) > 0, "datasets should not be an empty iterable"
|
|
|
|
|
+ assert len(self.datasets) == len(
|
|
|
|
|
+ weights
|
|
|
|
|
+ ), "datasets and repeats should have the same length"
|
|
|
|
|
+
|
|
|
|
|
+ for d in self.datasets:
|
|
|
|
|
+ assert isinstance(
|
|
|
|
|
+ d, IterableDataset
|
|
|
|
|
+ ), "ConcatRepeatIterableDataset only supports IterableDataset"
|
|
|
|
|
+
|
|
|
|
|
+ def __iter__(self):
|
|
|
|
|
+ all_datasets = [iter(dataset) for dataset in self.datasets]
|
|
|
|
|
+ ids = list(range(len(self.datasets)))
|
|
|
|
|
+
|
|
|
|
|
+ while True:
|
|
|
|
|
+ chosen_dataset = random.choices(ids, self.weights)[0]
|
|
|
|
|
+
|
|
|
|
|
+ try:
|
|
|
|
|
+ yield next(all_datasets[chosen_dataset])
|
|
|
|
|
+ except StopIteration:
|
|
|
|
|
+ all_datasets[chosen_dataset] = iter(self.datasets[chosen_dataset])
|
|
|
|
|
+ yield next(all_datasets[chosen_dataset])
|