|
|
@@ -11,7 +11,6 @@ trainer:
|
|
|
accelerator: gpu
|
|
|
strategy:
|
|
|
_target_: lightning.fabric.strategies.DDPStrategy
|
|
|
- find_unused_parameters: true
|
|
|
static_graph: true
|
|
|
|
|
|
devices: auto
|
|
|
@@ -40,31 +39,29 @@ model:
|
|
|
schedule:
|
|
|
batch_size: 64
|
|
|
micro_batch_size: 64
|
|
|
- max_steps: 1000000
|
|
|
+ max_steps: 10000
|
|
|
save_interval: 2000
|
|
|
gradient_accumulation_steps: "${eval: ${schedule.batch_size} // ${schedule.micro_batch_size}}"
|
|
|
clip_grad_norm: 2.0
|
|
|
-
|
|
|
-train_dataset:
|
|
|
- _target_: speech_lm.datasets.whisper_vq.WhisperVQDataset
|
|
|
- filelist: filelists/whisper-vq.train.train.filelist
|
|
|
-
|
|
|
-valid_dataset:
|
|
|
- _target_: speech_lm.datasets.whisper_vq.WhisperVQDataset
|
|
|
- filelist: filelists/whisper-vq.train.test.filelist
|
|
|
+ log_interval: 10
|
|
|
+ eval_interval: 2000
|
|
|
|
|
|
train_dataloader:
|
|
|
_target_: torch.utils.data.DataLoader
|
|
|
- dataset: ${dataset}
|
|
|
+ dataset:
|
|
|
+ _target_: speech_lm.datasets.whisper_vq.WhisperVQDataset
|
|
|
+ filelist: filelists/whisper-vq.train.train.filelist
|
|
|
batch_size: ${schedule.micro_batch_size}
|
|
|
- num_workers: 4
|
|
|
+ num_workers: 8
|
|
|
collate_fn:
|
|
|
_target_: speech_lm.datasets.whisper_vq.WhisperVQCollator
|
|
|
|
|
|
valid_dataloader:
|
|
|
_target_: torch.utils.data.DataLoader
|
|
|
- dataset: ${dataset}
|
|
|
- batch_size: ${schedule.micro_batch_size}
|
|
|
+ dataset:
|
|
|
+ _target_: speech_lm.datasets.whisper_vq.WhisperVQDataset
|
|
|
+ filelist: filelists/whisper-vq.train.test.filelist
|
|
|
+ batch_size: 32
|
|
|
num_workers: 4
|
|
|
collate_fn:
|
|
|
_target_: speech_lm.datasets.whisper_vq.WhisperVQCollator
|