Преглед изворни кода

Handle adaptive number of codebooks

Lengyue пре 2 година
родитељ
комит
5707699dfd

+ 0 - 0
.project-root


+ 2 - 0
fish_speech/configs/text2semantic_finetune.yaml

@@ -26,11 +26,13 @@ train_dataset:
   _target_: fish_speech.datasets.text.AutoAugTextDataset
   tokenizer: ${tokenizer}
   max_length: ${max_length}
+  num_codebooks: ${model.model.config.num_codebooks}
 
 val_dataset:
   _target_: fish_speech.datasets.text.AutoAugTextDataset
   tokenizer: ${tokenizer}
   max_length: ${max_length}
+  num_codebooks: ${model.model.config.num_codebooks}
 
 data:
   _target_: fish_speech.datasets.text.TextDataModule

+ 2 - 0
fish_speech/configs/text2semantic_finetune_lora.yaml

@@ -27,11 +27,13 @@ train_dataset:
   _target_: fish_speech.datasets.text.AutoAugTextDataset
   tokenizer: ${tokenizer}
   max_length: ${max_length}
+  num_codebooks: ${model.model.config.num_codebooks}
 
 val_dataset:
   _target_: fish_speech.datasets.text.AutoAugTextDataset
   tokenizer: ${tokenizer}
   max_length: ${max_length}
+  num_codebooks: ${model.model.config.num_codebooks}
 
 data:
   _target_: fish_speech.datasets.text.TextDataModule

+ 4 - 2
fish_speech/configs/text2semantic_pretrain.yaml

@@ -2,7 +2,7 @@ defaults:
   - base
   - _self_
 
-project: text2semantic_pretrain_400m_8_codebooks
+project: text2semantic_pretrain_400m_4_codebooks
 max_length: 2048
 
 # Lightning Trainer
@@ -24,6 +24,7 @@ train_dataset:
   _target_: fish_speech.datasets.text.AutoAugTextDataset
   tokenizer: ${tokenizer}
   max_length: ${max_length}
+  num_codebooks: ${model.model.config.num_codebooks}
   use_speaker: false
   phones_prob: 0.5
   interactive_prob: 0.5
@@ -32,6 +33,7 @@ val_dataset:
   _target_: fish_speech.datasets.text.AutoAugTextDataset
   tokenizer: ${tokenizer}
   max_length: ${max_length}
+  num_codebooks: ${model.model.config.num_codebooks}
   use_speaker: false
   phones_prob: 0.5
   interactive_prob: 0.5
@@ -61,7 +63,7 @@ model:
       dim: 1024
       rope_base: 10000
       norm_eps: 1e-5
-      num_codebooks: 8  # single codebook
+      num_codebooks: 4  # single codebook
       codebook_size: 264 # codebook size 256 + 2 special tokens
       dropout: 0.1
       neft_alpha: 10

+ 2 - 1
fish_speech/configs/text2semantic_sft.yaml

@@ -24,6 +24,7 @@ train_dataset:
   _target_: fish_speech.datasets.text.AutoAugTextDataset
   tokenizer: ${tokenizer}
   max_length: ${max_length}
+  num_codebooks: ${model.model.config.num_codebooks}
   use_speaker: true
   phones_prob: 0.5
   interactive_prob: 0.5
@@ -33,6 +34,7 @@ val_dataset:
   _target_: fish_speech.datasets.text.AutoAugTextDataset
   tokenizer: ${tokenizer}
   max_length: ${max_length}
+  num_codebooks: ${model.model.config.num_codebooks}
   use_speaker: true
   phones_prob: 0.5
   interactive_prob: 0.5
@@ -50,7 +52,6 @@ data:
 # Model Configuration
 model:
   _target_: fish_speech.models.text2semantic.TextToSemantic
-  use_dpo: true
 
   model:
     # ~ 130M parameters, for debug purpose

+ 9 - 4
fish_speech/datasets/text.py

@@ -198,6 +198,7 @@ class AutoAugTextDataset(IterableDataset):
         causual: bool = True,
         mix_text_phone_prob: float = 0.5,
         use_negative_samples: bool = False,
+        num_codebooks: Optional[int] = None,
     ):
         """
         Args:
@@ -214,6 +215,7 @@ class AutoAugTextDataset(IterableDataset):
             causual: use causual sampling when using local data, disable will lead to random sampling
             mix_text_phone_prob: probability to mix text and phones, if this is 0, then it will be pure text or pure phones
             use_negative_samples: generate negative samples
+            num_codebooks: number of codebooks, if None, it will be automatically detected
         """
 
         super().__init__()
@@ -235,6 +237,7 @@ class AutoAugTextDataset(IterableDataset):
         self.causual = causual
         self.mix_text_phone_prob = mix_text_phone_prob
         self.use_negative_samples = use_negative_samples
+        self.num_codebooks = num_codebooks
 
         if use_data_server is True:
             self.channel = grpc.insecure_channel(server)
@@ -484,7 +487,9 @@ class AutoAugTextDataset(IterableDataset):
         )
         semantic_length = sum([len(i[0].values) for i in semantics])
         prompt_length = len(encoded)
-        num_codebooks = len(semantics[0])
+        num_codebooks = (
+            len(semantics[0]) if self.num_codebooks is None else self.num_codebooks
+        )
 
         bos_bias = 1 if add_bos else 0
 
@@ -505,7 +510,7 @@ class AutoAugTextDataset(IterableDataset):
             for i in range(num_codebooks)
         ]
         for segment in semantics:
-            for book_idx, book in enumerate(segment):
+            for book_idx, book in zip(range(num_codebooks), segment):
                 for j in book.values:
                     codes[book_idx].append(int(j) + 2)
 
@@ -520,8 +525,7 @@ class AutoAugTextDataset(IterableDataset):
 
         # Mask out the <s> tokens for semantic, predict semantic tokens only
         # Since we don't mask out the input tokens, the language modeling still works
-        # labels[1:, : (prompt_length + bos_bias)] = -100
-        labels[:, : (prompt_length + bos_bias)] = -100
+        labels[1:, : (prompt_length + bos_bias)] = -100
 
         tokens = tokens[:, :-1]
         labels = labels[:, 1:]
@@ -677,6 +681,7 @@ if __name__ == "__main__":
         interactive_prob=1.0,
         phones_prob=1.0,
         use_negative_samples=False,
+        num_codebooks=4,
     )
 
     # ds = AutoAugTextDataset(

+ 9 - 0
fish_speech/train.py

@@ -1,7 +1,9 @@
+import os
 from typing import Optional
 
 import hydra
 import lightning as L
+import pyrootutils
 import torch
 from lightning import Callback, LightningDataModule, LightningModule, Trainer
 from lightning.pytorch.loggers import Logger
@@ -9,6 +11,13 @@ from omegaconf import DictConfig, OmegaConf
 
 import fish_speech.utils as utils
 
+os.environ.pop("SLURM_NTASKS", None)
+os.environ.pop("SLURM_JOB_NAME", None)
+os.environ.pop("SLURM_NTASKS_PER_NODE", None)
+
+# register eval resolver and root
+pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
+
 # Allow TF32 on Ampere GPUs
 torch.set_float32_matmul_precision("high")
 torch.backends.cudnn.allow_tf32 = True