Ver Fonte

Accelerate training with fsdp

Lengyue há 2 anos atrás
pai
commit
2733ed600f
3 ficheiros alterados com 48 adições e 6 exclusões
  1. 1 1
      requirements.txt
  2. 31 4
      speech_lm/configs/pretrain.yaml
  3. 16 1
      speech_lm/train.py

+ 1 - 1
requirements.txt

@@ -2,7 +2,7 @@ transformers>=4.34.0
 datasets>=2.14.5
 bitsandbytes>=0.41.1
 peft>=0.5.0
-lightning>=2.0.9.post0
+lightning>=2.1.0
 hydra-core>=1.3.2
 tensorboard>=2.14.1
 natsort>=8.4.0

+ 31 - 4
speech_lm/configs/pretrain.yaml

@@ -9,8 +9,34 @@ hydra:
 trainer:
   _target_: lightning.fabric.Fabric
   accelerator: gpu
-  strategy: ddp
-  devices: auto
+  strategy:
+    _target_: lightning.fabric.strategies.FSDPStrategy
+    sync_module_states: true
+    use_orig_params: true
+    cpu_offload: false
+    mixed_precision:
+      _target_: torch.distributed.fsdp.MixedPrecision
+      param_dtype: 
+        _target_: hydra.utils.get_object
+        path: torch.bfloat16
+      reduce_dtype:
+        _target_: hydra.utils.get_object
+        path: torch.bfloat16
+      buffer_dtype:
+        _target_: hydra.utils.get_object
+        path: torch.bfloat16
+      cast_forward_inputs: true
+    sharding_strategy: SHARD_GRAD_OP
+    auto_wrap_policy:
+      _target_: torch.distributed.fsdp.wrap.transformer_auto_wrap_policy
+      _partial_: true
+      transformer_layer_cls:
+        - _target_: hydra.utils.get_class
+          path: transformers.models.llama.modeling_llama.LlamaDecoderLayer
+    activation_checkpointing_policy: ${trainer.strategy.auto_wrap_policy}
+    state_dict_type: full
+  num_nodes: 1
+  devices: 8
   precision: bf16-mixed
   loggers:
     _target_: pytorch_lightning.loggers.TensorBoardLogger
@@ -32,10 +58,11 @@ tokenizer:
 # 3e12 / 1024 / 512 / 8 = 715255
 schedule:
   max_length: 1024
-  batch_size: 512
-  micro_batch_size: 2
+  batch_size: 512  # 128 * 4 = 512
+  micro_batch_size: 64
   max_steps: 715255
   save_interval: 2000
+  log_interval: 10
   gradient_accumulation_steps: "${eval: ${schedule.batch_size} // ${schedule.micro_batch_size}}"
   clip_grad_norm: 1.0
 

+ 16 - 1
speech_lm/train.py

@@ -1,4 +1,5 @@
 from pathlib import Path
+import time
 
 import hydra
 import torch
@@ -36,6 +37,7 @@ def train(
     bar.update(global_step)
     accumulate_steps = 0
     optimizer.zero_grad()
+    start_time = time.time()
 
     while global_step < cfg.schedule.max_steps:
         for batch in dataloader:
@@ -78,6 +80,18 @@ def train(
             global_step += 1
             bar.update(1)
 
+            if global_step % cfg.schedule.log_interval == 0:
+                step_time = (time.time() - start_time) / cfg.schedule.log_interval
+                log.info(
+                    f"[{global_step}/{cfg.schedule.max_steps}] loss: {loss:.4f} "
+                    + f"step time: {step_time:.2f}s "
+                    f"lr: {optimizer.param_groups[0]['lr']:.2e} "
+                    + f"grad_norm: {grad_norm:.2f} "
+                    + f"ETA: {step_time * (cfg.schedule.max_steps - global_step):.2f}s"
+                )
+
+                start_time = time.time()
+
             if global_step % cfg.schedule.save_interval == 0:
                 fabric.save(
                     Path(cfg.paths.checkpoint_dir) / f"step_{global_step}.ckpt",
@@ -118,7 +132,8 @@ def main(cfg: DictConfig):
     log.info(f"Scheduler: {scheduler}")
 
     log.info(f"Setup fabric model & dataset")
-    model, optimizer, scheduler = fabric.setup(model, optimizer, scheduler)
+    model = fabric.setup_module(model)
+    optimizer = fabric.setup_optimizers(optimizer)
 
     # Build state
     global_step = 0