|
|
@@ -250,6 +250,7 @@ class AutoAugTextDataset(IterableDataset):
|
|
|
|
|
|
# Shuffle the lines
|
|
|
Random(self.seed).shuffle(self.groups)
|
|
|
+ self.group_weights = [len(i.sentences) for i in self.groups]
|
|
|
|
|
|
def __iter__(self):
|
|
|
while True:
|
|
|
@@ -273,9 +274,7 @@ class AutoAugTextDataset(IterableDataset):
|
|
|
num_samples = self.max_length // 20
|
|
|
|
|
|
# choice group based on their number of samples
|
|
|
- group = random.choices(
|
|
|
- self.groups, weights=[len(i.sentences) for i in self.groups], k=1
|
|
|
- )[0]
|
|
|
+ group = random.choices(self.groups, weights=self.group_weights, k=1)[0]
|
|
|
|
|
|
if self.causual:
|
|
|
# Sample in order
|