فهرست منبع

Update pretrain config

Lengyue 2 سال پیش
والد
کامیت
a86e61a7cd

+ 7 - 7
fish_speech/configs/text2semantic_finetune.yaml

@@ -2,17 +2,17 @@ defaults:
   - base
   - _self_
 
-project: text2semantic_400m_finetune
+project: text2semantic_400m_finetune_spk
 max_length: 4096
-# ckpt_path: results/text2semantic_400m_pretrain/checkpoints/step_000065000.ckpt
-# resume_weights_only: true
+ckpt_path: checkpoints/text2semantic-400m-v0.2-4k.pth
+resume_weights_only: true
 
 # Lightning Trainer
 trainer:
   accumulate_grad_batches: 2
   gradient_clip_val: 1.0
   gradient_clip_algorithm: 'norm'
-  max_steps: 10000
+  max_steps: 1000
   precision: bf16-true
   limit_val_batches: 10
 
@@ -63,7 +63,7 @@ model:
   optimizer:
     _target_: torch.optim.AdamW
     _partial_: true
-    lr: 1e-4
+    lr: 1e-5
     weight_decay: 0.1
     betas: [0.9, 0.95]
     eps: 1e-5
@@ -74,11 +74,11 @@ model:
     lr_lambda:
       _target_: fish_speech.scheduler.get_cosine_schedule_with_warmup_lr_lambda
       _partial_: true
-      num_warmup_steps: 2000
+      num_warmup_steps: 100
       num_training_steps: ${trainer.max_steps}
       final_lr_ratio: 0.1
 
 # Callbacks
 callbacks:
   model_checkpoint:
-    every_n_train_steps: 1000
+    every_n_train_steps: 200

+ 0 - 84
fish_speech/configs/text2semantic_finetune_spk.yaml

@@ -1,84 +0,0 @@
-defaults:
-  - base
-  - _self_
-
-project: text2semantic_400m_finetune_spk
-max_length: 4096
-ckpt_path: checkpoints/text2semantic-400m-v0.2-4k.pth
-resume_weights_only: true
-
-# Lightning Trainer
-trainer:
-  accumulate_grad_batches: 2
-  gradient_clip_val: 1.0
-  gradient_clip_algorithm: 'norm'
-  max_steps: 1000
-  precision: bf16-true
-  limit_val_batches: 10
-
-# Dataset Configuration
-tokenizer:
-  _target_: transformers.AutoTokenizer.from_pretrained
-  pretrained_model_name_or_path: fishaudio/speech-lm-v1
-
-# Dataset Configuration
-train_dataset:
-  _target_: fish_speech.datasets.text.AutoAugTextDataset
-  tokenizer: ${tokenizer}
-  max_length: ${max_length}
-
-val_dataset:
-  _target_: fish_speech.datasets.text.AutoAugTextDataset
-  tokenizer: ${tokenizer}
-  max_length: ${max_length}
-
-data:
-  _target_: fish_speech.datasets.text.TextDataModule
-  train_dataset: ${train_dataset}
-  val_dataset: ${val_dataset}
-  num_workers: 4
-  batch_size: 8
-  tokenizer: ${tokenizer}
-  max_length: ${max_length}
-
-# Model Configuration
-model:
-  _target_: fish_speech.models.text2semantic.TextToSemantic
-
-  model:
-    # ~ 130M parameters, for debug purpose
-    _target_: fish_speech.models.text2semantic.llama.Transformer
-    config:
-      _target_: fish_speech.models.text2semantic.llama.ModelArgs
-      max_seq_len: 4096
-      vocab_size: 36408
-      n_layer: 24
-      n_head: 16
-      dim: 1024
-      rope_base: 10000
-      norm_eps: 1e-5
-      num_codebooks: 4  # single codebook
-      codebook_size: 168 # codebook size 160 + 2 special tokens
-
-  optimizer:
-    _target_: torch.optim.AdamW
-    _partial_: true
-    lr: 1e-5
-    weight_decay: 0.1
-    betas: [0.9, 0.95]
-    eps: 1e-5
-
-  lr_scheduler:
-    _target_: torch.optim.lr_scheduler.LambdaLR
-    _partial_: true
-    lr_lambda:
-      _target_: fish_speech.scheduler.get_cosine_schedule_with_warmup_lr_lambda
-      _partial_: true
-      num_warmup_steps: 100
-      num_training_steps: ${trainer.max_steps}
-      final_lr_ratio: 0.1
-
-# Callbacks
-callbacks:
-  model_checkpoint:
-    every_n_train_steps: 200

+ 6 - 2
fish_speech/configs/text2semantic_pretrain.yaml

@@ -2,8 +2,10 @@ defaults:
   - base
   - _self_
 
-project: text2semantic_400m_pretrain
-max_length: 1024
+project: text2semantic_400m_pretrain_0.3
+max_length: 4096
+# ckpt_path: checkpoints/text2semantic-400m-v0.2-4k.pth
+# resume_weights_only: true
 
 # Lightning Trainer
 trainer:
@@ -24,11 +26,13 @@ train_dataset:
   _target_: fish_speech.datasets.text.AutoAugTextDataset
   tokenizer: ${tokenizer}
   max_length: ${max_length}
+  use_speaker: false
 
 val_dataset:
   _target_: fish_speech.datasets.text.AutoAugTextDataset
   tokenizer: ${tokenizer}
   max_length: ${max_length}
+  use_speaker: false
 
 data:
   _target_: fish_speech.datasets.text.TextDataModule

+ 1 - 1
fish_speech/datasets/text.py

@@ -272,7 +272,7 @@ class AutoAugTextDataset(IterableDataset):
             final_text.append(text)
             final_semantic.append(sentence.semantics)
 
-        if self.use_speaker is not None:
+        if self.use_speaker:
             final_text = [f"[SPK: {response.name}]"] + final_text
 
         final_text = "[INST] " + " ".join(final_text) + " [/INST]"