Lengyue пре 2 година
родитељ
комит
23b3c98c67
1 измењених фајлова са 23 додато и 10 уклоњено
  1. 23 10
      fish_speech/datasets/text.py

+ 23 - 10
fish_speech/datasets/text.py

@@ -196,6 +196,7 @@ class AutoAugTextDataset(IterableDataset):
         use_data_server: bool = True,
         proto_files: str = "data",
         causual: bool = True,
+        mix_text_phone_prob: float = 0.5,
     ):
         """
         Args:
@@ -210,10 +211,16 @@ class AutoAugTextDataset(IterableDataset):
             use_data_server: use data server or local data
             proto_files: proto buf files if using local data
             causual: use causual sampling when using local data, disable will lead to random sampling
+            mix_text_phone_prob: probability to mix text and phones, if this is 0, then it will be pure text or pure phones
         """
 
         super().__init__()
 
+        assert 0 <= phones_prob <= 1, "phones_prob must be in [0, 1]"
+        assert 0 <= repetition_prob <= 1, "repetition_prob must be in [0, 1]"
+        assert 0 <= interactive_prob <= 1, "interactive_prob must be in [0, 1]"
+        assert 0 <= mix_text_phone_prob <= 1, "mix_text_phone_prob must be in [0, 1]"
+
         self.seed = seed
         self.phones_prob = phones_prob
         self.max_length = max_length
@@ -224,6 +231,7 @@ class AutoAugTextDataset(IterableDataset):
         self.use_data_server = use_data_server
         self.proto_files = proto_files
         self.causual = causual
+        self.mix_text_phone_prob = mix_text_phone_prob
 
         if use_data_server is True:
             self.channel = grpc.insecure_channel(server)
@@ -307,8 +315,12 @@ class AutoAugTextDataset(IterableDataset):
     def augment(self):
         # 50% to pure text or pure phones
         mode = "sample"
-        if random.random() < 0.5:
-            mode = random.choice(["text", "phones"])
+        if random.random() > self.mix_text_phone_prob:
+            mode = random.choices(
+                ["text", "phones"],
+                weights=[1 - self.phones_prob, self.phones_prob],
+                k=1,
+            )[0]
 
         # Random sample based on speaker using a truncated normal distribution
         a = torch.tensor([0], dtype=torch.float32)
@@ -558,20 +570,21 @@ class TextDataModule(LightningDataModule):
 if __name__ == "__main__":
     from tqdm import tqdm
 
-    # ds = AutoAugTextDataset(
-    #     tokenizer=AutoTokenizer.from_pretrained("fishaudio/speech-lm-v1"),
-    #     use_speaker=True,
-    #     interactive_prob=1.0,
-    # )
-
     ds = AutoAugTextDataset(
         tokenizer=AutoTokenizer.from_pretrained("fishaudio/speech-lm-v1"),
         use_speaker=True,
         interactive_prob=1.0,
-        use_data_server=False,
-        proto_files=["data/wenet-speech.protos"],
+        phones_prob=1.0,
     )
 
+    # ds = AutoAugTextDataset(
+    #     tokenizer=AutoTokenizer.from_pretrained("fishaudio/speech-lm-v1"),
+    #     use_speaker=True,
+    #     interactive_prob=1.0,
+    #     use_data_server=False,
+    #     proto_files=["data/wenet-speech.protos"],
+    # )
+
     dm = TextDataModule(
         train_dataset=ds,
         val_dataset=ds,