Просмотр исходного кода

Optimize dataset & multi codebook training

Lengyue 2 лет назад
Родитель
Сommit
55ed0f91af

+ 1 - 1
fish_speech/configs/base.yaml

@@ -14,7 +14,7 @@ trainer:
   default_root_dir: ${paths.run_dir}
   accelerator: gpu
   num_nodes: 1
-  devices: 8
+  devices: auto
   strategy:
     _target_: lightning.pytorch.strategies.DDPStrategy
 

+ 5 - 2
fish_speech/configs/text2semantic.yaml

@@ -3,6 +3,7 @@ defaults:
   - _self_
 
 project: text2semantic_400m
+max_length: 1024
 
 # Lightning Trainer
 trainer:
@@ -22,19 +23,21 @@ tokenizer:
 train_dataset:
   _target_: fish_speech.datasets.text.AutoAugTextDataset
   tokenizer: ${tokenizer}
+  max_length: ${max_length}
 
 val_dataset:
   _target_: fish_speech.datasets.text.AutoAugTextDataset
   tokenizer: ${tokenizer}
+  max_length: ${max_length}
 
 data:
   _target_: fish_speech.datasets.text.TextDataModule
   train_dataset: ${train_dataset}
   val_dataset: ${val_dataset}
   num_workers: 4
-  batch_size: 16
+  batch_size: 8
   tokenizer: ${tokenizer}
-  max_length: 1024
+  max_length: ${max_length}
 
 # Model Configuration
 model:

+ 77 - 0
fish_speech/configs/text2semantic_multi.yaml

@@ -0,0 +1,77 @@
+defaults:
+  - base
+  - _self_
+
+project: text2semantic_400m_multi
+max_length: 1024
+
+# Lightning Trainer
+trainer:
+  accumulate_grad_batches: 2
+  gradient_clip_val: 1.0
+  gradient_clip_algorithm: 'norm'
+  max_steps: 1_000_000
+  precision: bf16-true
+  limit_val_batches: 10
+
+# Dataset Configuration
+tokenizer:
+  _target_: transformers.AutoTokenizer.from_pretrained
+  pretrained_model_name_or_path: fishaudio/speech-lm-v1
+
+# Dataset Configuration
+train_dataset:
+  _target_: fish_speech.datasets.text.AutoAugTextDataset
+  tokenizer: ${tokenizer}
+  max_length: ${max_length}
+
+val_dataset:
+  _target_: fish_speech.datasets.text.AutoAugTextDataset
+  tokenizer: ${tokenizer}
+  max_length: ${max_length}
+
+data:
+  _target_: fish_speech.datasets.text.TextDataModule
+  train_dataset: ${train_dataset}
+  val_dataset: ${val_dataset}
+  num_workers: 4
+  batch_size: 16
+  tokenizer: ${tokenizer}
+  max_length: ${max_length}
+
+# Model Configuration
+model:
+  _target_: fish_speech.models.text2semantic.TextToSemantic
+
+  model:
+    # ~ 130M parameters, for debug purpose
+    _target_: fish_speech.models.text2semantic.llama.Transformer
+    config:
+      _target_: fish_speech.models.text2semantic.llama.ModelArgs
+      max_seq_len: 4096
+      vocab_size: 36408
+      n_layer: 24
+      n_head: 16
+      dim: 1024
+      rope_base: 10000
+      norm_eps: 1e-5
+      num_codebooks: 4  # single codebook
+      codebook_size: 168 # codebook size 160 + 2 special tokens
+
+  optimizer:
+    _target_: torch.optim.AdamW
+    _partial_: true
+    lr: 3e-4
+    weight_decay: 0.1
+    betas: [0.9, 0.95]
+    eps: 1e-5
+
+  lr_scheduler:
+    _target_: torch.optim.lr_scheduler.LambdaLR
+    _partial_: true
+    lr_lambda:
+      _target_: fish_speech.scheduler.get_cosine_schedule_with_warmup_lr_lambda
+      _partial_: true
+      num_warmup_steps: 2000
+      num_training_steps: ${trainer.max_steps}
+      final_lr_ratio: 0.1

+ 18 - 1
fish_speech/datasets/text.py

@@ -149,15 +149,27 @@ class AutoAugTextDataset(IterableDataset):
         server: str = "localhost:50051",
         seed: int = 42,
         phones_prob: float = 0.3,
+        repetition_prob: float = 0.1,
         max_length: int = 1024,
         tokenizer: AutoTokenizer = None,
     ):
+        """
+        Args:
+            server: gRPC server address
+            seed: random seed
+            phones_prob: probability to use phones
+            repetition_prob: probability to repeat the same sentence
+            max_length: max length of the text
+            tokenizer: tokenizer
+        """
+
         super().__init__()
 
         self.seed = seed
         self.phones_prob = phones_prob
         self.max_length = max_length
         self.tokenizer = tokenizer
+        self.repetition_prob = repetition_prob
 
         # Read all lines, and group by speaker
         self.channel = grpc.insecure_channel(server)
@@ -215,7 +227,12 @@ class AutoAugTextDataset(IterableDataset):
 
         samples = list(response.samples)
         while remaining_tokens > 0 and len(samples) > 0:
-            sentence = samples.pop()
+            if random.random() < self.repetition_prob:
+                # Repeat the same sentence
+                sentence = samples[-1]
+            else:
+                sentence = samples.pop()
+
             text, length = self.tokenize_sentence(
                 sentence.text, sentence.phones, mode=mode
             )