|
|
@@ -149,15 +149,27 @@ class AutoAugTextDataset(IterableDataset):
|
|
|
server: str = "localhost:50051",
|
|
|
seed: int = 42,
|
|
|
phones_prob: float = 0.3,
|
|
|
+ repetition_prob: float = 0.1,
|
|
|
max_length: int = 1024,
|
|
|
tokenizer: AutoTokenizer = None,
|
|
|
):
|
|
|
+ """
|
|
|
+ Args:
|
|
|
+ server: gRPC server address
|
|
|
+ seed: random seed
|
|
|
+ phones_prob: probability to use phones
|
|
|
+ repetition_prob: probability to repeat the same sentence
|
|
|
+ max_length: max length of the text
|
|
|
+ tokenizer: tokenizer
|
|
|
+ """
|
|
|
+
|
|
|
super().__init__()
|
|
|
|
|
|
self.seed = seed
|
|
|
self.phones_prob = phones_prob
|
|
|
self.max_length = max_length
|
|
|
self.tokenizer = tokenizer
|
|
|
+ self.repetition_prob = repetition_prob
|
|
|
|
|
|
# Read all lines, and group by speaker
|
|
|
self.channel = grpc.insecure_channel(server)
|
|
|
@@ -215,7 +227,12 @@ class AutoAugTextDataset(IterableDataset):
|
|
|
|
|
|
samples = list(response.samples)
|
|
|
while remaining_tokens > 0 and len(samples) > 0:
|
|
|
- sentence = samples.pop()
|
|
|
+ if random.random() < self.repetition_prob:
|
|
|
+ # Repeat the same sentence
|
|
|
+ sentence = samples[-1]
|
|
|
+ else:
|
|
|
+ sentence = samples.pop()
|
|
|
+
|
|
|
text, length = self.tokenize_sentence(
|
|
|
sentence.text, sentence.phones, mode=mode
|
|
|
)
|