Explorar el Código

Optimize iterative training efficiency

Lengyue hace 1 año
padre
commit
d9f42acda4
Se han modificado 1 ficheros con 14 adiciones y 11 borrados
  1. 14 11
      fish_speech/datasets/text.py

+ 14 - 11
fish_speech/datasets/text.py

@@ -295,17 +295,6 @@ class AutoAugTextDataset(IterableDataset):
         )
 
     def augment(self):
-        # Random sample based on speaker using a truncated normal distribution
-        a = torch.tensor([0], dtype=torch.float32)
-        torch.nn.init.trunc_normal_(
-            a,
-            mean=self.max_length // 2,
-            std=self.max_length // 4,
-            a=10,
-            b=self.max_length,
-        )
-        remaining_tokens = a.long().item() - 4
-
         final_text, final_semantic = [], []
         response = self.sample_data()
         if len(response.samples) == 0:
@@ -316,6 +305,20 @@ class AutoAugTextDataset(IterableDataset):
         idx = 0
         use_interactive = random.random() < self.interactive_prob
 
+        if use_interactive is False:
+            # Random sample based on speaker using a truncated normal distribution
+            a = torch.tensor([0], dtype=torch.float32)
+            torch.nn.init.trunc_normal_(
+                a,
+                mean=self.max_length // 2,
+                std=self.max_length // 4,
+                a=10,
+                b=self.max_length,
+            )
+            remaining_tokens = a.long().item() - 4
+        else:
+            remaining_tokens = self.max_length
+
         all_tokens, all_labels = [], []
         while remaining_tokens > 0 and len(samples) > 0:
             sentence = samples.pop(0)