Lengyue 2 лет назад
Родитель
Сommit
7ae872018f
2 измененных файлов с 23 добавлено и 13 удалено
  1. 4 5
      fish_speech/datasets/text.py
  2. 19 8
      tools/llama/generate.py

+ 4 - 5
fish_speech/datasets/text.py

@@ -149,7 +149,7 @@ class AutoAugTextDataset(IterableDataset):
         server: str = "localhost:50051",
         seed: int = 42,
         phones_prob: float = 0.3,
-        repetition_prob: float = 0.1,
+        repetition_prob: float = 0.0,
         max_length: int = 1024,
         tokenizer: AutoTokenizer = None,
     ):
@@ -200,10 +200,9 @@ class AutoAugTextDataset(IterableDataset):
 
     def augment(self):
         # 50% to pure text or pure phones
-        # mode = "sample"
-        # if random.random() < 0.5:
-        #     mode = random.choice(["text", "phones"])
-        mode = "phones"
+        mode = "sample"
+        if random.random() < 0.5:
+            mode = random.choice(["text", "phones"])
 
         # Random sample based on speaker using a truncated normal distribution
         a = torch.tensor([0], dtype=torch.float32)

+ 19 - 8
tools/llama/generate.py

@@ -254,18 +254,26 @@ def generate(
 
 
 def encode_tokens(
-    tokenizer, string, bos=True, device="cuda", prompt_string=None, prompt_tokens=None
+    tokenizer,
+    string,
+    bos=True,
+    device="cuda",
+    prompt_string=None,
+    prompt_tokens=None,
+    use_g2p=False,
 ):
     if prompt_string is not None:
         string = prompt_string + " " + string
 
-    prompt = g2p(string)
-    prompt = [
-        (f"<p:{i}>" if i not in pu_symbols and i != pad_symbol else i)
-        for _, i in prompt
-    ]
-    prompt = " ".join(prompt)
-    string = f"[INST] {prompt} [/INST]"
+    if use_g2p:
+        prompt = g2p(prompt)
+        prompt = [
+            (f"<p:{i}>" if i not in pu_symbols and i != pad_symbol else i)
+            for _, i in prompt
+        ]
+        string = " ".join(prompt)
+
+    string = f"[INST] {string} [/INST]"
 
     tokens = tokenizer.encode(
         string,
@@ -359,6 +367,7 @@ def load_model(config_name, checkpoint_path, device, precision):
 @click.option("--config-name", type=str, default="text2semantic_finetune")
 @click.option("--tokenizer", type=str, default="fishaudio/speech-lm-v1")
 @click.option("--compile/--no-compile", default=False)
+@click.option("--use-g2p/--no-g2p", default=True)
 @click.option("--seed", type=int, default=42)
 def main(
     text: str,
@@ -374,6 +383,7 @@ def main(
     config_name: str,
     tokenizer: str,
     compile: bool,
+    use_g2p: bool,
     seed: int,
 ) -> None:
     device = "cuda"
@@ -400,6 +410,7 @@ def main(
         prompt_tokens=prompt_tokens,
         bos=True,
         device=device,
+        use_g2p=use_g2p,
     )
     prompt_length = encoded.size(1)
     logger.info(f"Encoded prompt shape: {encoded.shape}")