Selaa lähdekoodia

optimize config & improve stability

Lengyue 1 vuosi sitten
vanhempi
commit
ab61b60f42

+ 1 - 2
fish_speech/configs/text2semantic_finetune.yaml

@@ -80,5 +80,4 @@ model:
 # Callbacks
 callbacks:
   model_checkpoint:
-    every_n_train_steps: 10
-    # ${trainer.val_check_interval}
+    every_n_train_steps: ${trainer.val_check_interval}

+ 4 - 88
fish_speech/datasets/semantic.py

@@ -76,7 +76,6 @@ class AutoTextSemanticInstructionDataset(IterableDataset):
         tokenizer: AutoTokenizer = None,
         use_speaker: bool | float = True,
         causal: bool = True,
-        use_negative_samples: bool = False,
         num_codebooks: Optional[int] = None,
         skip_text_prob: float = 0.0,
     ):
@@ -89,7 +88,6 @@ class AutoTextSemanticInstructionDataset(IterableDataset):
             tokenizer: tokenizer
             use_speaker: include speaker information in the prompt
             causal: use causal sampling when using local data, disable will lead to random sampling
-            use_negative_samples: generate negative samples
             num_codebooks: number of codebooks, if None, it will be automatically detected
             skip_text_prob: probability to skip the text (audio only), this only applies to interactive mode
         """
@@ -105,7 +103,6 @@ class AutoTextSemanticInstructionDataset(IterableDataset):
         self.use_speaker = use_speaker
         self.proto_files = proto_files
         self.causal = causal
-        self.use_negative_samples = use_negative_samples
         self.num_codebooks = num_codebooks
         self.skip_text_prob = skip_text_prob
 
@@ -242,7 +239,6 @@ class AutoTextSemanticInstructionDataset(IterableDataset):
                     sentences=[text],
                     semantics=[sentence.semantics],
                     speaker=response.name if use_speaker else None,
-                    add_bos=idx == 0,
                     skip_text=random.random() < self.skip_text_prob,
                 )
 
@@ -256,7 +252,6 @@ class AutoTextSemanticInstructionDataset(IterableDataset):
                 final_text,
                 semantics=final_semantic,
                 speaker=response.name if use_speaker else None,
-                add_bos=True,
             )
             all_tokens.append(tokens)
             all_labels.append(labels)
@@ -267,84 +262,15 @@ class AutoTextSemanticInstructionDataset(IterableDataset):
         # Verify that the length is correct
         assert tokens.size(1) == labels.size(1), f"{tokens.size(1)} != {labels.size(1)}"
 
-        # Verify bos token
-        assert tokens[0, 0] == self.tokenizer.bos_token_id
-
         data = {"tokens": tokens, "labels": labels}
 
-        if self.use_negative_samples:
-            negative_samples = self.generate_negative_samples(all_tokens, all_labels)
-            data.update(negative_samples)
-
         return data
 
-    def generate_negative_samples(self, all_tokens, all_labels):
-        new_tokens, new_labels = [], []
-
-        for tokens, labels in zip(all_tokens, all_labels):
-            # If all codebooks are not -100, we find where it starts
-            start = torch.where(labels[1:].sum(0) != -100 * (labels.size(0) - 1))[0][0]
-            assert (labels[1:, start:] != -100).all()  # This shouldn't happen
-
-            mode = random.choice(["repeat", "lost", "noise"])
-            begin = random.randint(start, labels.size(1) - 1)
-            end = random.randint(begin, labels.size(1) - 1)
-
-            if mode == "repeat":
-                tokens = torch.cat(
-                    [
-                        tokens[:, :begin],
-                        tokens[:, begin:end],
-                        tokens[:, begin:end],
-                        tokens[:, end:],
-                    ],
-                    dim=1,
-                )
-                labels = torch.cat(
-                    [
-                        labels[:, :begin],
-                        labels[:, begin:end],
-                        labels[:, begin:end],
-                        labels[:, end:],
-                    ],
-                    dim=1,
-                )
-            elif mode == "lost":
-                tokens = torch.cat([tokens[:, :begin], tokens[:, end:]], dim=1)
-                labels = torch.cat([labels[:, :begin], labels[:, end:]], dim=1)
-            elif mode == "noise":
-                middle_tokens, middle_labels = (
-                    tokens[:, begin:end],
-                    labels[:, begin:end],
-                )
-                random_order0 = torch.randperm(middle_tokens.size(1))
-                random_order1 = torch.randperm(middle_tokens.size(1))
-                middle_tokens = middle_tokens[:, random_order0]
-                middle_labels = middle_labels[:, random_order1]
-                tokens = torch.cat(
-                    [tokens[:, :begin], middle_tokens, tokens[:, end:]], dim=1
-                )
-                labels = torch.cat(
-                    [labels[:, :begin], middle_labels, labels[:, end:]], dim=1
-                )
-
-            new_tokens.append(tokens)
-            new_labels.append(labels)
-
-        tokens = torch.cat(new_tokens, dim=1)
-        labels = torch.cat(new_labels, dim=1)
-
-        # Verify that the length is correct
-        assert tokens.size(1) == labels.size(1), f"{tokens.size(1)} != {labels.size(1)}"
-
-        return {"negative_tokens": tokens, "negative_labels": labels}
-
     def pack_sentences(
         self,
         sentences: list[str],
         semantics: list,
         speaker: Optional[str] = None,
-        add_bos: bool = True,
         skip_text: bool = False,
     ):
         if speaker is None:
@@ -369,8 +295,6 @@ class AutoTextSemanticInstructionDataset(IterableDataset):
             len(semantics[0]) if self.num_codebooks is None else self.num_codebooks
         )
 
-        bos_bias = 1 if add_bos else 0
-
         # Pack the tokens and semantics (add <s> and </s> to semantic tokens)
         tokens = (
             encoded
@@ -378,14 +302,8 @@ class AutoTextSemanticInstructionDataset(IterableDataset):
             + self.tokenizer.convert_tokens_to_ids(["<|im_end|>"])
         )
 
-        if add_bos:
-            tokens = [self.tokenizer.bos_token_id] + tokens
-
         # Codebook bos/padding: 0, eos: 1
-        codes = [
-            [CODEBOOK_PAD_TOKEN_ID] * (prompt_length + bos_bias)
-            for _ in range(num_codebooks)
-        ]
+        codes = [[CODEBOOK_PAD_TOKEN_ID] * prompt_length for _ in range(num_codebooks)]
         for segment in semantics:
             for book_idx, book in zip(range(num_codebooks), segment):
                 for j in book.values:
@@ -406,14 +324,13 @@ class AutoTextSemanticInstructionDataset(IterableDataset):
 
         # Mask out the <s> tokens for semantic, predict semantic tokens only
         # Since we don't mask out the input tokens, the language modeling still works
-        labels[1:, : (prompt_length + bos_bias)] = -100
+        labels[1:, :prompt_length] = -100
 
         tokens = tokens[:, :-1]
         labels = labels[:, 1:]
 
         # Verify the padding is correct, and the last token is eos
-        assert add_bos is False or tokens[0, 0] == self.tokenizer.bos_token_id
-        assert (tokens[1:, : prompt_length + bos_bias] == CODEBOOK_PAD_TOKEN_ID).all()
+        assert (tokens[1:, :prompt_length] == CODEBOOK_PAD_TOKEN_ID).all()
         assert (labels[1:, -1:] == CODEBOOK_PAD_TOKEN_ID).all()
 
         return tokens, labels
@@ -564,12 +481,11 @@ class SemanticDataModule(LightningDataModule):
 if __name__ == "__main__":
     from tqdm import tqdm
 
-    ds = AutoAugTextDataset(
+    ds = AutoTextSemanticInstructionDataset(
         ["data/protos"],
         tokenizer=AutoTokenizer.from_pretrained("fishaudio/fish-speech-1"),
         use_speaker=False,
         interactive_prob=1.0,
-        use_negative_samples=False,
         skip_text_prob=0.5,
     )
 

+ 3 - 1
tools/llama/generate.py

@@ -96,7 +96,9 @@ def decode_one_token_ar(
     codebooks = [
         sample(
             x.logits,
-            previous_tokens=None,  # Disable repetition penalty for the token codebook
+            previous_tokens=(
+                previous_tokens[0] if previous_tokens is not None else None
+            ),  # Disable repetition penalty for the token codebook
             **sampling_kwargs,
         )[0]
     ]