Explorar el Código

Rename config & smaller schedule

Lengyue hace 2 años
padre
commit
938eca21a6
Se han modificado 3 ficheros con 12 adiciones y 15 borrados
  1. 0 0
      speech_lm/configs/llama_pretrain.yaml
  2. 11 14
      speech_lm/configs/whisper_vq.yaml
  3. 1 1
      speech_lm/train.py

+ 0 - 0
speech_lm/configs/pretrain.yaml → speech_lm/configs/llama_pretrain.yaml


+ 11 - 14
speech_lm/configs/whisper_vq.yaml

@@ -11,7 +11,6 @@ trainer:
   accelerator: gpu
   strategy: 
     _target_: lightning.fabric.strategies.DDPStrategy
-    find_unused_parameters: true
     static_graph: true
 
   devices: auto
@@ -40,31 +39,29 @@ model:
 schedule:
   batch_size: 64
   micro_batch_size: 64
-  max_steps: 1000000
+  max_steps: 10000
   save_interval: 2000
   gradient_accumulation_steps: "${eval: ${schedule.batch_size} // ${schedule.micro_batch_size}}"
   clip_grad_norm: 2.0
-
-train_dataset:
-  _target_: speech_lm.datasets.whisper_vq.WhisperVQDataset
-  filelist: filelists/whisper-vq.train.train.filelist
-
-valid_dataset:
-  _target_: speech_lm.datasets.whisper_vq.WhisperVQDataset
-  filelist: filelists/whisper-vq.train.test.filelist
+  log_interval: 10
+  eval_interval: 2000
 
 train_dataloader:
   _target_: torch.utils.data.DataLoader
-  dataset: ${dataset}
+  dataset:
+    _target_: speech_lm.datasets.whisper_vq.WhisperVQDataset
+    filelist: filelists/whisper-vq.train.train.filelist
   batch_size: ${schedule.micro_batch_size}
-  num_workers: 4
+  num_workers: 8
   collate_fn:
     _target_: speech_lm.datasets.whisper_vq.WhisperVQCollator
 
 valid_dataloader:
   _target_: torch.utils.data.DataLoader
-  dataset: ${dataset}
-  batch_size: ${schedule.micro_batch_size}
+  dataset:
+    _target_: speech_lm.datasets.whisper_vq.WhisperVQDataset
+    filelist: filelists/whisper-vq.train.test.filelist
+  batch_size: 32
   num_workers: 4
   collate_fn:
     _target_: speech_lm.datasets.whisper_vq.WhisperVQCollator

+ 1 - 1
speech_lm/train.py

@@ -211,7 +211,7 @@ def train(
             last_batch_time = time.time()
 
 
-@hydra.main(version_base="1.3", config_path="./configs", config_name="pretrain.yaml")
+@hydra.main(version_base="1.3", config_path="./configs", config_name="llama_pretrain.yaml")
 def main(cfg: DictConfig):
     log.info(f"Config: \n{OmegaConf.to_yaml(cfg)}")