|
|
@@ -295,17 +295,6 @@ class AutoAugTextDataset(IterableDataset):
|
|
|
)
|
|
|
|
|
|
def augment(self):
|
|
|
- # Random sample based on speaker using a truncated normal distribution
|
|
|
- a = torch.tensor([0], dtype=torch.float32)
|
|
|
- torch.nn.init.trunc_normal_(
|
|
|
- a,
|
|
|
- mean=self.max_length // 2,
|
|
|
- std=self.max_length // 4,
|
|
|
- a=10,
|
|
|
- b=self.max_length,
|
|
|
- )
|
|
|
- remaining_tokens = a.long().item() - 4
|
|
|
-
|
|
|
final_text, final_semantic = [], []
|
|
|
response = self.sample_data()
|
|
|
if len(response.samples) == 0:
|
|
|
@@ -316,6 +305,20 @@ class AutoAugTextDataset(IterableDataset):
|
|
|
idx = 0
|
|
|
use_interactive = random.random() < self.interactive_prob
|
|
|
|
|
|
+ if use_interactive is False:
|
|
|
+ # Random sample based on speaker using a truncated normal distribution
|
|
|
+ a = torch.tensor([0], dtype=torch.float32)
|
|
|
+ torch.nn.init.trunc_normal_(
|
|
|
+ a,
|
|
|
+ mean=self.max_length // 2,
|
|
|
+ std=self.max_length // 4,
|
|
|
+ a=10,
|
|
|
+ b=self.max_length,
|
|
|
+ )
|
|
|
+ remaining_tokens = a.long().item() - 4
|
|
|
+ else:
|
|
|
+ remaining_tokens = self.max_length
|
|
|
+
|
|
|
all_tokens, all_labels = [], []
|
|
|
while remaining_tokens > 0 and len(samples) > 0:
|
|
|
sentence = samples.pop(0)
|