|
|
@@ -10,32 +10,9 @@ trainer:
|
|
|
_target_: lightning.fabric.Fabric
|
|
|
accelerator: gpu
|
|
|
strategy:
|
|
|
- _target_: lightning.fabric.strategies.FSDPStrategy
|
|
|
- sync_module_states: true
|
|
|
- use_orig_params: true
|
|
|
- cpu_offload: false
|
|
|
- mixed_precision:
|
|
|
- _target_: torch.distributed.fsdp.MixedPrecision
|
|
|
- param_dtype:
|
|
|
- _target_: hydra.utils.get_object
|
|
|
- path: torch.bfloat16
|
|
|
- reduce_dtype:
|
|
|
- _target_: hydra.utils.get_object
|
|
|
- path: torch.bfloat16
|
|
|
- buffer_dtype:
|
|
|
- _target_: hydra.utils.get_object
|
|
|
- path: torch.bfloat16
|
|
|
- cast_forward_inputs: true
|
|
|
- sharding_strategy: SHARD_GRAD_OP
|
|
|
- auto_wrap_policy:
|
|
|
- _target_: torch.distributed.fsdp.wrap.transformer_auto_wrap_policy
|
|
|
- _partial_: true
|
|
|
- transformer_layer_cls:
|
|
|
- - _target_: hydra.utils.get_class
|
|
|
- path: transformers.models.llama.modeling_llama.LlamaDecoderLayer
|
|
|
- activation_checkpointing_policy: ${trainer.strategy.auto_wrap_policy}
|
|
|
- state_dict_type: full
|
|
|
- num_nodes: 1
|
|
|
+ _target_: lightning.fabric.strategies.DDPStrategy
|
|
|
+ static_graph: true
|
|
|
+ num_nodes: 4
|
|
|
devices: 8
|
|
|
precision: bf16-mixed
|
|
|
loggers:
|
|
|
@@ -56,12 +33,14 @@ tokenizer:
|
|
|
|
|
|
# Say we want a 3 trillion seen token schedule
|
|
|
# 3e12 / 1024 / 512 / 8 = 715255
|
|
|
+# But we use a 100k steps schedule here to save time
|
|
|
+# This is a 300 billion seen token schedule
|
|
|
schedule:
|
|
|
max_length: 1024
|
|
|
- batch_size: 512 # 128 * 4 = 512
|
|
|
- micro_batch_size: 64
|
|
|
- max_steps: 715255
|
|
|
- save_interval: 2000
|
|
|
+ batch_size: 64 # 128 * 4 = 512
|
|
|
+ micro_batch_size: 8
|
|
|
+ max_steps: 100000
|
|
|
+ save_interval: 5000
|
|
|
log_interval: 10
|
|
|
gradient_accumulation_steps: "${eval: ${schedule.batch_size} // ${schedule.micro_batch_size}}"
|
|
|
clip_grad_norm: 1.0
|
|
|
@@ -82,7 +61,7 @@ train_dataloader:
|
|
|
_target_: torch.utils.data.DataLoader
|
|
|
dataset: ${dataset}
|
|
|
batch_size: ${schedule.micro_batch_size}
|
|
|
- num_workers: 4
|
|
|
+ num_workers: 8
|
|
|
collate_fn:
|
|
|
_target_: speech_lm.datasets.cultura_x.CulutreXCollator
|
|
|
tokenizer: ${tokenizer}
|