Procházet zdrojové kódy

Optimize lora & add auto dpo training

Lengyue před 2 roky
rodič
revize
8093258065

+ 1 - 0
fish_speech/configs/text2semantic_finetune_lora.yaml

@@ -61,6 +61,7 @@ model:
       codebook_size: 168 # codebook size 160 + 2 special tokens
       dropout: 0.1 # For small dataset, dropout helps to prevent overfitting
 
+  save_lora_only: true
   lora_config:
     _target_: fish_speech.models.text2semantic.lit_module.LoraConfig
     r: 8

+ 4 - 1
fish_speech/configs/text2semantic_sft.yaml

@@ -27,6 +27,7 @@ train_dataset:
   use_speaker: true
   phones_prob: 0.5
   interactive_prob: 0.5
+  use_negative_samples: true
 
 val_dataset:
   _target_: fish_speech.datasets.text.AutoAugTextDataset
@@ -35,19 +36,21 @@ val_dataset:
   use_speaker: true
   phones_prob: 0.5
   interactive_prob: 0.5
+  use_negative_samples: true
 
 data:
   _target_: fish_speech.datasets.text.TextDataModule
   train_dataset: ${train_dataset}
   val_dataset: ${val_dataset}
   num_workers: 4
-  batch_size: 32
+  batch_size: 4
   tokenizer: ${tokenizer}
   max_length: ${max_length}
 
 # Model Configuration
 model:
   _target_: fish_speech.models.text2semantic.TextToSemantic
+  use_dpo: true
 
   model:
     # ~ 130M parameters, for debug purpose

+ 102 - 6
fish_speech/datasets/text.py

@@ -197,6 +197,7 @@ class AutoAugTextDataset(IterableDataset):
         proto_files: str = "data",
         causual: bool = True,
         mix_text_phone_prob: float = 0.5,
+        use_negative_samples: bool = False,
     ):
         """
         Args:
@@ -212,6 +213,7 @@ class AutoAugTextDataset(IterableDataset):
             proto_files: proto buf files if using local data
             causual: use causual sampling when using local data, disable will lead to random sampling
             mix_text_phone_prob: probability to mix text and phones, if this is 0, then it will be pure text or pure phones
+            use_negative_samples: generate negative samples
         """
 
         super().__init__()
@@ -232,6 +234,7 @@ class AutoAugTextDataset(IterableDataset):
         self.proto_files = proto_files
         self.causual = causual
         self.mix_text_phone_prob = mix_text_phone_prob
+        self.use_negative_samples = use_negative_samples
 
         if use_data_server is True:
             self.channel = grpc.insecure_channel(server)
@@ -381,9 +384,11 @@ class AutoAugTextDataset(IterableDataset):
                 speaker=None if self.use_speaker else response.name,
                 add_bos=True,
             )
-        else:
-            tokens = torch.cat(all_tokens, dim=1)
-            labels = torch.cat(all_labels, dim=1)
+            all_tokens.append(tokens)
+            all_labels.append(labels)
+
+        tokens = torch.cat(all_tokens, dim=1)
+        labels = torch.cat(all_labels, dim=1)
 
         # Verify that the length is correct
         assert tokens.size(1) == labels.size(1), f"{tokens.size(1)} != {labels.size(1)}"
@@ -391,7 +396,74 @@ class AutoAugTextDataset(IterableDataset):
         # Verify only one <s> token
         assert (tokens[:, 0] == self.tokenizer.bos_token_id).sum() == 1
 
-        return {"tokens": tokens, "labels": labels}
+        data = {"tokens": tokens, "labels": labels}
+
+        if self.use_negative_samples:
+            negative_samples = self.generate_negative_samples(all_tokens, all_labels)
+            data.update(negative_samples)
+
+        return data
+
+    def generate_negative_samples(self, all_tokens, all_labels):
+        new_tokens, new_labels = [], []
+
+        for tokens, labels in zip(all_tokens, all_labels):
+            # If all codebooks are not -100, we find where it starts
+            start = torch.where(labels[1:].sum(0) != -100 * (labels.size(0) - 1))[0][0]
+            assert (labels[1:, start:] != -100).all()  # This shouldn't happen
+
+            mode = random.choice(["repeat", "lost", "noise"])
+            begin = random.randint(start, labels.size(1) - 1)
+            end = random.randint(begin, labels.size(1) - 1)
+
+            if mode == "repeat":
+                tokens = torch.cat(
+                    [
+                        tokens[:, :begin],
+                        tokens[:, begin:end],
+                        tokens[:, begin:end],
+                        tokens[:, end:],
+                    ],
+                    dim=1,
+                )
+                labels = torch.cat(
+                    [
+                        labels[:, :begin],
+                        labels[:, begin:end],
+                        labels[:, begin:end],
+                        labels[:, end:],
+                    ],
+                    dim=1,
+                )
+            elif mode == "lost":
+                tokens = torch.cat([tokens[:, :begin], tokens[:, end:]], dim=1)
+                labels = torch.cat([labels[:, :begin], labels[:, end:]], dim=1)
+            elif mode == "noise":
+                middle_tokens, middle_labels = (
+                    tokens[:, begin:end],
+                    labels[:, begin:end],
+                )
+                random_order0 = torch.randperm(middle_tokens.size(1))
+                random_order1 = torch.randperm(middle_tokens.size(1))
+                middle_tokens = middle_tokens[:, random_order0]
+                middle_labels = middle_labels[:, random_order1]
+                tokens = torch.cat(
+                    [tokens[:, :begin], middle_tokens, tokens[:, end:]], dim=1
+                )
+                labels = torch.cat(
+                    [labels[:, :begin], middle_labels, labels[:, end:]], dim=1
+                )
+
+            new_tokens.append(tokens)
+            new_labels.append(labels)
+
+        tokens = torch.cat(new_tokens, dim=1)
+        labels = torch.cat(new_labels, dim=1)
+
+        # Verify that the length is correct
+        assert tokens.size(1) == labels.size(1), f"{tokens.size(1)} != {labels.size(1)}"
+
+        return {"negative_tokens": tokens, "negative_labels": labels}
 
     def pack_sentences(
         self,
@@ -470,10 +542,33 @@ class TextDataCollator:
     max_length: int = 1024
 
     def __call__(self, examples):
+        if "negative_tokens" in examples:
+            positive_examples = []
+            negative_examples = []
+
+            for i in examples:
+                positive_examples.append(
+                    {
+                        "tokens": i["tokens"],
+                        "labels": i["labels"],
+                    }
+                )
+                negative_examples.append(
+                    {
+                        "tokens": i["negative_tokens"],
+                        "labels": i["negative_labels"],
+                    }
+                )
+
+            examples = positive_examples + negative_examples
+
+        return self.batchify(examples)
+
+    def batchify(self, examples, tokens_key="tokens", labels_key="labels"):
         tokens, attention_masks, labels = [], [], []
         for example in examples:
-            _tokens = example["tokens"][:, : self.max_length]
-            _labels = example["labels"][:, : self.max_length]
+            _tokens = example[tokens_key][:, : self.max_length]
+            _labels = example[labels_key][:, : self.max_length]
             _attention_mask = torch.ones((self.max_length,), dtype=torch.bool)
             tokens_length = _tokens.size(1)
             _attention_mask[:tokens_length] = False
@@ -582,6 +677,7 @@ if __name__ == "__main__":
         use_speaker=True,
         interactive_prob=1.0,
         phones_prob=1.0,
+        use_negative_samples=True,
     )
 
     # ds = AutoAugTextDataset(

+ 123 - 11
fish_speech/models/text2semantic/lit_module.py

@@ -1,6 +1,5 @@
-import platform
 from dataclasses import dataclass
-from typing import Any, Dict, Optional
+from typing import Any, Optional
 
 import lightning as L
 import loralib as lora
@@ -28,6 +27,9 @@ class TextToSemantic(L.LightningModule):
         optimizer: Any,
         lr_scheduler: Any,
         lora_config: Optional[LoraConfig] = None,
+        save_lora_only: bool = False,
+        use_dpo: bool = False,
+        dpo_beta: float = 0.2,
     ):
         super().__init__()
 
@@ -35,6 +37,9 @@ class TextToSemantic(L.LightningModule):
         self.optimizer_builder = optimizer
         self.lr_scheduler_builder = lr_scheduler
         self.lora_config = lora_config
+        self.save_lora_only = save_lora_only
+        self.use_dpo = use_dpo  # We don't support reference model yet
+        self.dpo_beta = dpo_beta
 
         if self.lora_config is not None:
             self.setup_lora()
@@ -81,10 +86,10 @@ class TextToSemantic(L.LightningModule):
         return self.model(x)
 
     def on_save_checkpoint(self, checkpoint):
-        if self.lora_config is None:
+        if self.lora_config is None or self.save_lora_only is False:
             return
 
-        # Save the LoRA parameters
+        # Save only LoRA parameters
         state_dict = checkpoint["state_dict"]
         for name in list(state_dict.keys()):
             if "lora" not in name:
@@ -102,16 +107,59 @@ class TextToSemantic(L.LightningModule):
             },
         }
 
+    # Copied from https://github.com/eric-mitchell/direct-preference-optimization/blob/main/trainers.py#L90
+    def get_batch_logps(
+        self,
+        logits: torch.FloatTensor,
+        labels: torch.LongTensor,
+        average_log_prob: bool = False,
+    ) -> torch.FloatTensor:
+        """Compute the log probabilities of the given labels under the given logits.
+
+        Args:
+            logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, codebook_size, vocab_size)
+            labels: Labels for which to compute the log probabilities. Label tokens with a value of -100 are ignored. Shape: (batch_size, sequence_length, codebook_size)
+            average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens.
+
+        Returns:
+            A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits.
+        """
+        assert logits.shape[:-1] == labels.shape
+
+        labels = labels.clone()
+        loss_mask = labels != -100
+
+        # dummy token; we'll ignore the losses on these tokens later
+        labels[labels == -100] = 0
+
+        per_token_logps = torch.gather(
+            logits.log_softmax(-1), dim=-1, index=labels.unsqueeze(-1)
+        ).squeeze(-1)
+
+        if average_log_prob:
+            return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
+        else:
+            return (per_token_logps * loss_mask).sum(-1)
+
     def _step(self, batch, batch_idx, stage: str):
+        # Do positive and negative samples in the same batch to speed up training
         outputs = self.model(
             x=batch["inputs"],
             key_padding_mask=batch["attention_masks"],
         )
+        labels = batch["labels"]
+        token_logits = outputs.token_logits
+        codebook_logits = outputs.codebook_logits
+
+        if self.use_dpo:
+            # Firtst half is positive, second half is negative
+            token_logits, negative_token_logits = token_logits.chunk(2)
+            codebook_logits, negative_codebook_logits = codebook_logits.chunk(2)
+            labels, negative_labels = labels.chunk(2)
 
         # Generate labels
-        labels = batch["labels"]
         base_loss = F.cross_entropy(
-            outputs.token_logits.reshape(-1, outputs.token_logits.size(-1)),
+            token_logits.reshape(-1, token_logits.size(-1)),
             labels[:, 0].reshape(-1),
             ignore_index=-100,
         )
@@ -120,7 +168,7 @@ class TextToSemantic(L.LightningModule):
         if self.model.config.num_codebooks != 0:
             codebook_labels = labels[:, 1 : 1 + self.model.config.num_codebooks].mT
             semantic_loss = F.cross_entropy(
-                outputs.codebook_logits.reshape(-1, outputs.codebook_logits.size(-1)),
+                codebook_logits.reshape(-1, codebook_logits.size(-1)),
                 codebook_labels.reshape(-1),
                 ignore_index=-100,
             )
@@ -129,6 +177,72 @@ class TextToSemantic(L.LightningModule):
         else:
             loss = base_loss
 
+        # If we use dpo
+        if self.use_dpo:
+            negative_codebook_labels = negative_labels[
+                :, 1 : 1 + self.model.config.num_codebooks
+            ].mT
+
+            positive_codebook_logps = self.get_batch_logps(
+                codebook_logits, codebook_labels
+            )
+            negative_codebook_logps = self.get_batch_logps(
+                negative_codebook_logits, negative_codebook_labels
+            )
+
+            # TODO: implement the reference model, avoid screwing up the gradients
+            dpo_loss = -F.logsigmoid(
+                (positive_codebook_logps - negative_codebook_logps) * self.dpo_beta
+            ).mean()
+
+            chosen_rewards = self.dpo_beta * positive_codebook_logps.detach()
+            rejected_rewards = self.dpo_beta * negative_codebook_logps.detach()
+            reward_accuracy = (
+                (positive_codebook_logps > negative_codebook_logps).float().mean()
+            )
+            chosen_rewards, rejected_rewards = (
+                chosen_rewards.mean(),
+                rejected_rewards.mean(),
+            )
+
+            loss = loss + dpo_loss
+
+            self.log(
+                f"{stage}/dpo_loss",
+                dpo_loss,
+                on_step=True,
+                on_epoch=False,
+                prog_bar=False,
+                logger=True,
+            )
+
+            self.log(
+                f"{stage}/chosen_rewards",
+                chosen_rewards,
+                on_step=True,
+                on_epoch=False,
+                prog_bar=False,
+                logger=True,
+            )
+
+            self.log(
+                f"{stage}/rejected_rewards",
+                rejected_rewards,
+                on_step=True,
+                on_epoch=False,
+                prog_bar=False,
+                logger=True,
+            )
+
+            self.log(
+                f"{stage}/reward_accuracy",
+                reward_accuracy,
+                on_step=True,
+                on_epoch=False,
+                prog_bar=False,
+                logger=True,
+            )
+
         self.log(
             f"{stage}/loss",
             loss,
@@ -159,15 +273,13 @@ class TextToSemantic(L.LightningModule):
 
         # Top-5 accuracy
         if self.model.config.num_codebooks == 0:
-            _, indices = outputs.token_logits.topk(5, dim=-1)
+            _, indices = token_logits.topk(5, dim=-1)
             correct = indices.eq(labels[:, 0].unsqueeze(-1))
             correct[labels[:, 0] == -100] = 0
             correct = correct.sum()
             accuracy = correct / (labels[:, 0] != -100).sum()
         else:
-            _, indices = outputs.codebook_logits.topk(5, dim=-1)
-            # print(codebook_labels[0, :10], torch.argmax(outputs.codebook_logits[0, :10], dim=-1))
-            # print(codebook_labels[codebook_labels != -100][:10], indices[codebook_labels != -100][:10])
+            _, indices = codebook_logits.topk(5, dim=-1)
             correct = indices.eq(codebook_labels.unsqueeze(-1))
             correct[codebook_labels == -100] = 0
             correct = correct.sum()

+ 2 - 0
fish_speech/models/text2semantic/llama.py

@@ -169,6 +169,8 @@ class Transformer(nn.Module):
             )
             x += torch.rand_like(x) * scaled_alpha
 
+        return x
+
     def compute(
         self,
         x: Tensor,