Explorar el Código

Update dataloader, loss tracker, and config

Lengyue hace 2 años
padre
commit
83582d1e89

+ 8 - 4
fish_speech/configs/text2semantic_pretrain.yaml

@@ -3,11 +3,11 @@ defaults:
   - _self_
 
 project: text2semantic_400m_pretrain_1.0
-max_length: 2048
+max_length: 1024
 
 # Lightning Trainer
 trainer:
-  accumulate_grad_batches: 2
+  accumulate_grad_batches: 1
   gradient_clip_val: 1.0
   gradient_clip_algorithm: 'norm'
   max_steps: 1_000_000
@@ -25,19 +25,23 @@ train_dataset:
   tokenizer: ${tokenizer}
   max_length: ${max_length}
   use_speaker: false
+  phones_prob: 1.0
+  interactive_prob: 0.5
 
 val_dataset:
   _target_: fish_speech.datasets.text.AutoAugTextDataset
   tokenizer: ${tokenizer}
   max_length: ${max_length}
   use_speaker: false
+  phones_prob: 1.0
+  interactive_prob: 0.5
 
 data:
   _target_: fish_speech.datasets.text.TextDataModule
   train_dataset: ${train_dataset}
   val_dataset: ${val_dataset}
   num_workers: 4
-  batch_size: 16
+  batch_size: 32
   tokenizer: ${tokenizer}
   max_length: ${max_length}
 
@@ -57,7 +61,7 @@ model:
       dim: 1024
       rope_base: 10000
       norm_eps: 1e-5
-      num_codebooks: 8  # single codebook
+      num_codebooks: 4  # single codebook
       codebook_size: 264 # codebook size 256 + 2 special tokens
 
   optimizer:

+ 11 - 4
fish_speech/datasets/text.py

@@ -404,6 +404,11 @@ class AutoAugTextDataset(IterableDataset):
             sentences = [f"[SPK: {speaker}]"] + sentences
 
         final_text = "[INST] " + " ".join(sentences) + " [/INST]"
+
+        for segment in semantics:
+            for j in segment[0].values:
+                final_text += f" <s:{int(j)}>"
+
         encoded = self.tokenizer.encode(
             final_text,
             add_special_tokens=False,
@@ -411,12 +416,14 @@ class AutoAugTextDataset(IterableDataset):
             max_length=10**6,
         )
         semantic_length = sum([len(i[0].values) for i in semantics])
+        prompt_length = len(encoded) - semantic_length
+
         bos_bias = 1 if add_bos else 0
 
         # Pack the tokens and semantics (add <s> and </s> to semantic tokens)
         tokens = (
             encoded
-            + [self.tokenizer.pad_token_id] * semantic_length
+            # + [self.tokenizer.pad_token_id] * semantic_length
             + [self.tokenizer.eos_token_id]
         )
 
@@ -425,7 +432,7 @@ class AutoAugTextDataset(IterableDataset):
 
         # Codebook bos/padding: 0, eos: 1
         codes = [
-            [CODEBOOK_BOS_TOKEN_ID] * (len(encoded) + bos_bias)
+            [CODEBOOK_BOS_TOKEN_ID] * (prompt_length + bos_bias)
             for _ in range(len(semantics[0]))
         ]
         for segment in semantics:
@@ -443,14 +450,14 @@ class AutoAugTextDataset(IterableDataset):
 
         # Mask out the <s> tokens for semantic, predict semantic tokens only
         # Since we don't mask out the input tokens, the language modeling still works
-        labels[1:, : (len(encoded) + bos_bias)] = -100
+        labels[1:, : (prompt_length + bos_bias)] = -100
 
         tokens = tokens[:, :-1]
         labels = labels[:, 1:]
 
         # Verify the padding is correct, and the last token is eos
         assert add_bos is False or tokens[0, 0] == self.tokenizer.bos_token_id
-        assert (tokens[1:, : len(encoded) + bos_bias] == CODEBOOK_BOS_TOKEN_ID).all()
+        assert (tokens[1:, : prompt_length + bos_bias] == CODEBOOK_BOS_TOKEN_ID).all()
         assert labels[0, -1] == self.tokenizer.eos_token_id
         assert (labels[1:, -1] == CODEBOOK_EOS_TOKEN_ID).all()
 

+ 26 - 3
fish_speech/models/text2semantic/lit_module.py

@@ -100,7 +100,7 @@ class TextToSemantic(L.LightningModule):
 
         # Generate labels
         labels = batch["labels"]
-        loss = F.cross_entropy(
+        base_loss = F.cross_entropy(
             outputs.token_logits.reshape(-1, outputs.token_logits.size(-1)),
             labels[:, 0].reshape(-1),
             ignore_index=-100,
@@ -108,14 +108,16 @@ class TextToSemantic(L.LightningModule):
 
         # If we have a codebook, add the loss
         if self.model.config.num_codebooks != 0:
-            codebook_labels = labels[:, 1:].mT
+            codebook_labels = labels[:, 1 : 1 + self.model.config.num_codebooks].mT
             semantic_loss = F.cross_entropy(
                 outputs.codebook_logits.reshape(-1, outputs.codebook_logits.size(-1)),
                 codebook_labels.reshape(-1),
                 ignore_index=-100,
             )
 
-            loss = loss + semantic_loss
+            loss = base_loss + semantic_loss
+        else:
+            loss = base_loss
 
         self.log(
             f"{stage}/loss",
@@ -126,6 +128,25 @@ class TextToSemantic(L.LightningModule):
             logger=True,
         )
 
+        if self.model.config.num_codebooks != 0:
+            self.log(
+                f"{stage}/base_loss",
+                base_loss,
+                on_step=True,
+                on_epoch=False,
+                prog_bar=False,
+                logger=True,
+            )
+
+            self.log(
+                f"{stage}/semantic_loss",
+                semantic_loss,
+                on_step=True,
+                on_epoch=False,
+                prog_bar=False,
+                logger=True,
+            )
+
         # Top-5 accuracy
         if self.model.config.num_codebooks == 0:
             _, indices = outputs.token_logits.topk(5, dim=-1)
@@ -135,6 +156,8 @@ class TextToSemantic(L.LightningModule):
             accuracy = correct / (labels[:, 0] != -100).sum()
         else:
             _, indices = outputs.codebook_logits.topk(5, dim=-1)
+            # print(codebook_labels[0, :10], torch.argmax(outputs.codebook_logits[0, :10], dim=-1))
+            # print(codebook_labels[codebook_labels != -100][:10], indices[codebook_labels != -100][:10])
             correct = indices.eq(codebook_labels.unsqueeze(-1))
             correct[codebook_labels == -100] = 0
             correct = correct.sum()

+ 1 - 1
tools/llama/generate.py

@@ -298,7 +298,7 @@ def encode_tokens(
     tokens = torch.tensor([tokens], dtype=torch.int, device=device)
 
     # Codebooks
-    zeros = torch.zeros((8, tokens.size(1)), dtype=torch.int, device=device)
+    zeros = torch.zeros((4, tokens.size(1)), dtype=torch.int, device=device)
     prompt = torch.cat((tokens, zeros), dim=0)
 
     if prompt_tokens is None: