Quellcode durchsuchen

prebuild group weights

Lengyue vor 1 Jahr
Ursprung
Commit
5b8a5f50b1
1 geänderte Dateien mit 2 neuen und 3 gelöschten Zeilen
  1. 2 3
      fish_speech/datasets/text.py

+ 2 - 3
fish_speech/datasets/text.py

@@ -250,6 +250,7 @@ class AutoAugTextDataset(IterableDataset):
 
         # Shuffle the lines
         Random(self.seed).shuffle(self.groups)
+        self.group_weights = [len(i.sentences) for i in self.groups]
 
     def __iter__(self):
         while True:
@@ -273,9 +274,7 @@ class AutoAugTextDataset(IterableDataset):
         num_samples = self.max_length // 20
 
         # choice group based on their number of samples
-        group = random.choices(
-            self.groups, weights=[len(i.sentences) for i in self.groups], k=1
-        )[0]
+        group = random.choices(self.groups, weights=self.group_weights, k=1)[0]
 
         if self.causual:
             # Sample in order