|
@@ -196,6 +196,7 @@ class AutoAugTextDataset(IterableDataset):
|
|
|
use_data_server: bool = True,
|
|
use_data_server: bool = True,
|
|
|
proto_files: str = "data",
|
|
proto_files: str = "data",
|
|
|
causual: bool = True,
|
|
causual: bool = True,
|
|
|
|
|
+ mix_text_phone_prob: float = 0.5,
|
|
|
):
|
|
):
|
|
|
"""
|
|
"""
|
|
|
Args:
|
|
Args:
|
|
@@ -210,10 +211,16 @@ class AutoAugTextDataset(IterableDataset):
|
|
|
use_data_server: use data server or local data
|
|
use_data_server: use data server or local data
|
|
|
proto_files: proto buf files if using local data
|
|
proto_files: proto buf files if using local data
|
|
|
causual: use causual sampling when using local data, disable will lead to random sampling
|
|
causual: use causual sampling when using local data, disable will lead to random sampling
|
|
|
|
|
+ mix_text_phone_prob: probability to mix text and phones, if this is 0, then it will be pure text or pure phones
|
|
|
"""
|
|
"""
|
|
|
|
|
|
|
|
super().__init__()
|
|
super().__init__()
|
|
|
|
|
|
|
|
|
|
+ assert 0 <= phones_prob <= 1, "phones_prob must be in [0, 1]"
|
|
|
|
|
+ assert 0 <= repetition_prob <= 1, "repetition_prob must be in [0, 1]"
|
|
|
|
|
+ assert 0 <= interactive_prob <= 1, "interactive_prob must be in [0, 1]"
|
|
|
|
|
+ assert 0 <= mix_text_phone_prob <= 1, "mix_text_phone_prob must be in [0, 1]"
|
|
|
|
|
+
|
|
|
self.seed = seed
|
|
self.seed = seed
|
|
|
self.phones_prob = phones_prob
|
|
self.phones_prob = phones_prob
|
|
|
self.max_length = max_length
|
|
self.max_length = max_length
|
|
@@ -224,6 +231,7 @@ class AutoAugTextDataset(IterableDataset):
|
|
|
self.use_data_server = use_data_server
|
|
self.use_data_server = use_data_server
|
|
|
self.proto_files = proto_files
|
|
self.proto_files = proto_files
|
|
|
self.causual = causual
|
|
self.causual = causual
|
|
|
|
|
+ self.mix_text_phone_prob = mix_text_phone_prob
|
|
|
|
|
|
|
|
if use_data_server is True:
|
|
if use_data_server is True:
|
|
|
self.channel = grpc.insecure_channel(server)
|
|
self.channel = grpc.insecure_channel(server)
|
|
@@ -307,8 +315,12 @@ class AutoAugTextDataset(IterableDataset):
|
|
|
def augment(self):
|
|
def augment(self):
|
|
|
# 50% to pure text or pure phones
|
|
# 50% to pure text or pure phones
|
|
|
mode = "sample"
|
|
mode = "sample"
|
|
|
- if random.random() < 0.5:
|
|
|
|
|
- mode = random.choice(["text", "phones"])
|
|
|
|
|
|
|
+ if random.random() > self.mix_text_phone_prob:
|
|
|
|
|
+ mode = random.choices(
|
|
|
|
|
+ ["text", "phones"],
|
|
|
|
|
+ weights=[1 - self.phones_prob, self.phones_prob],
|
|
|
|
|
+ k=1,
|
|
|
|
|
+ )[0]
|
|
|
|
|
|
|
|
# Random sample based on speaker using a truncated normal distribution
|
|
# Random sample based on speaker using a truncated normal distribution
|
|
|
a = torch.tensor([0], dtype=torch.float32)
|
|
a = torch.tensor([0], dtype=torch.float32)
|
|
@@ -558,20 +570,21 @@ class TextDataModule(LightningDataModule):
|
|
|
if __name__ == "__main__":
|
|
if __name__ == "__main__":
|
|
|
from tqdm import tqdm
|
|
from tqdm import tqdm
|
|
|
|
|
|
|
|
- # ds = AutoAugTextDataset(
|
|
|
|
|
- # tokenizer=AutoTokenizer.from_pretrained("fishaudio/speech-lm-v1"),
|
|
|
|
|
- # use_speaker=True,
|
|
|
|
|
- # interactive_prob=1.0,
|
|
|
|
|
- # )
|
|
|
|
|
-
|
|
|
|
|
ds = AutoAugTextDataset(
|
|
ds = AutoAugTextDataset(
|
|
|
tokenizer=AutoTokenizer.from_pretrained("fishaudio/speech-lm-v1"),
|
|
tokenizer=AutoTokenizer.from_pretrained("fishaudio/speech-lm-v1"),
|
|
|
use_speaker=True,
|
|
use_speaker=True,
|
|
|
interactive_prob=1.0,
|
|
interactive_prob=1.0,
|
|
|
- use_data_server=False,
|
|
|
|
|
- proto_files=["data/wenet-speech.protos"],
|
|
|
|
|
|
|
+ phones_prob=1.0,
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
|
|
+ # ds = AutoAugTextDataset(
|
|
|
|
|
+ # tokenizer=AutoTokenizer.from_pretrained("fishaudio/speech-lm-v1"),
|
|
|
|
|
+ # use_speaker=True,
|
|
|
|
|
+ # interactive_prob=1.0,
|
|
|
|
|
+ # use_data_server=False,
|
|
|
|
|
+ # proto_files=["data/wenet-speech.protos"],
|
|
|
|
|
+ # )
|
|
|
|
|
+
|
|
|
dm = TextDataModule(
|
|
dm = TextDataModule(
|
|
|
train_dataset=ds,
|
|
train_dataset=ds,
|
|
|
val_dataset=ds,
|
|
val_dataset=ds,
|