Przeglądaj źródła

Support skip prompt text training

Lengyue 1 rok temu
rodzic
commit
a13ac22b2d

+ 2 - 0
fish_speech/configs/text2semantic_pretrain.yaml

@@ -30,6 +30,7 @@ train_dataset:
   num_codebooks: ${model.model.config.num_codebooks}
   use_speaker: false
   interactive_prob: 0.5
+  skip_text_prob: 0.1
 
 val_dataset:
   _target_: fish_speech.datasets.text.AutoAugTextDataset
@@ -40,6 +41,7 @@ val_dataset:
   num_codebooks: ${model.model.config.num_codebooks}
   use_speaker: false
   interactive_prob: 0.5
+  skip_text_prob: 0.1
 
 data:
   _target_: fish_speech.datasets.text.TextDataModule

+ 18 - 3
fish_speech/datasets/text.py

@@ -5,7 +5,6 @@ from pathlib import Path
 from random import Random
 from typing import Optional, Union
 
-import grpc
 import numpy as np
 import pyarrow.parquet as pq
 import torch
@@ -27,6 +26,7 @@ log = RankedLogger(__name__, rank_zero_only=True)
 
 CODEBOOK_PAD_TOKEN_ID = 0
 CODEBOOK_EOS_TOKEN_ID = 1
+SKIP_TEXT_STRING = "<|skip_text|>"
 
 
 def split_by_rank_worker(files):
@@ -182,6 +182,7 @@ class AutoAugTextDataset(IterableDataset):
         causual: bool = True,
         use_negative_samples: bool = False,
         num_codebooks: Optional[int] = None,
+        skip_text_prob: float = 0.0,
     ):
         """
         Args:
@@ -194,6 +195,7 @@ class AutoAugTextDataset(IterableDataset):
             causual: use causual sampling when using local data, disable will lead to random sampling
             use_negative_samples: generate negative samples
             num_codebooks: number of codebooks, if None, it will be automatically detected
+            skip_text_prob: probability to skip the text (audio only), this only applies to interactive mode
         """
 
         super().__init__()
@@ -209,6 +211,7 @@ class AutoAugTextDataset(IterableDataset):
         self.causual = causual
         self.use_negative_samples = use_negative_samples
         self.num_codebooks = num_codebooks
+        self.skip_text_prob = skip_text_prob
 
         self.semantic_token_id = self.tokenizer.convert_tokens_to_ids("<|semantic|>")
         self.groups = None
@@ -344,6 +347,7 @@ class AutoAugTextDataset(IterableDataset):
                     semantics=[sentence.semantics],
                     speaker=response.name if use_speaker else None,
                     add_bos=idx == 0,
+                    skip_text=random.random() < self.skip_text_prob,
                 )
 
                 all_tokens.append(tokens)
@@ -442,14 +446,19 @@ class AutoAugTextDataset(IterableDataset):
     def pack_sentences(
         self,
         sentences: list[str],
-        semantics=list,
+        semantics: list,
         speaker: Optional[str] = None,
         add_bos: bool = True,
+        skip_text: bool = False,
     ):
         if speaker is None:
             speaker = "assistant"
 
-        final_text = "<|im_start|>user<|im_sep|>" + " ".join(sentences) + "<|im_end|>"
+        cated_sentences = " ".join(sentences)
+        if skip_text:
+            cated_sentences = SKIP_TEXT_STRING
+
+        final_text = "<|im_start|>user<|im_sep|>" + cated_sentences + "<|im_end|>"
         final_text = final_text + f"<|im_start|>{speaker}<|im_sep|>"
 
         encoded = self.tokenizer.encode(
@@ -496,6 +505,11 @@ class AutoAugTextDataset(IterableDataset):
         tokens = torch.tensor(tokens, dtype=torch.long)
         labels = tokens.clone()
 
+        if skip_text:
+            # If text is not provided, the sentence is used for condition only, all labels are -100
+            torch.fill_(labels, -100)
+            return tokens, labels
+
         # Mask out the <s> tokens for semantic, predict semantic tokens only
         # Since we don't mask out the input tokens, the language modeling still works
         labels[1:, : (prompt_length + bos_bias)] = -100
@@ -663,6 +677,7 @@ if __name__ == "__main__":
         use_speaker=False,
         interactive_prob=1.0,
         use_negative_samples=False,
+        skip_text_prob=0.5,
     )
 
     for i in ds: