|
@@ -1,5 +1,5 @@
|
|
|
paths:
|
|
paths:
|
|
|
- run_dir: results/pretrain
|
|
|
|
|
|
|
+ run_dir: results/hubert-vq
|
|
|
checkpoint_dir: ${paths.run_dir}/checkpoints
|
|
checkpoint_dir: ${paths.run_dir}/checkpoints
|
|
|
|
|
|
|
|
hydra:
|
|
hydra:
|
|
@@ -9,7 +9,11 @@ hydra:
|
|
|
trainer:
|
|
trainer:
|
|
|
_target_: lightning.fabric.Fabric
|
|
_target_: lightning.fabric.Fabric
|
|
|
accelerator: gpu
|
|
accelerator: gpu
|
|
|
- strategy: ddp
|
|
|
|
|
|
|
+ strategy:
|
|
|
|
|
+ _target_: lightning.fabric.strategies.DDPStrategy
|
|
|
|
|
+ find_unused_parameters: true
|
|
|
|
|
+ static_graph: true
|
|
|
|
|
+
|
|
|
devices: auto
|
|
devices: auto
|
|
|
precision: bf16-mixed
|
|
precision: bf16-mixed
|
|
|
loggers:
|
|
loggers:
|
|
@@ -28,8 +32,8 @@ model:
|
|
|
vq_loss_weight: 1.0
|
|
vq_loss_weight: 1.0
|
|
|
|
|
|
|
|
schedule:
|
|
schedule:
|
|
|
- batch_size: 128
|
|
|
|
|
- micro_batch_size: 128
|
|
|
|
|
|
|
+ batch_size: 32
|
|
|
|
|
+ micro_batch_size: 32
|
|
|
max_steps: 100000
|
|
max_steps: 100000
|
|
|
save_interval: 2000
|
|
save_interval: 2000
|
|
|
gradient_accumulation_steps: "${eval: ${schedule.batch_size} // ${schedule.micro_batch_size}}"
|
|
gradient_accumulation_steps: "${eval: ${schedule.batch_size} // ${schedule.micro_batch_size}}"
|