Lengyue 2 лет назад
Родитель
Сommit
9920873cf2
4 измененных файлов с 16 добавлено и 35 удалено
  1. 1 1
      dockerfile
  2. 10 31
      speech_lm/configs/pretrain.yaml
  3. 5 0
      speech_lm/train.py
  4. 0 3
      train.sh

+ 1 - 1
dockerfile

@@ -7,7 +7,7 @@ RUN apt-get update && apt-get install -y git curl build-essential ffmpeg libsm6
     apt-get clean && rm -rf /var/lib/apt/lists/*
 
 # Install s5cmd
-RUN curl -L https://github.com/peak/s5cmd/releases/download/v2.1.0-beta.1/s5cmd_2.1.0-beta.1_Linux-64bit.tar.gz | tar xvz -C /tmp && \
+RUN curl -L https://github.com/peak/s5cmd/releases/download/v2.2.0/s5cmd_2.2.0_Linux-64bit.tar.gz | tar xvz -C /tmp && \
     mv /tmp/s5cmd /usr/local/bin/s5cmd && s5cmd --help
 
 # Install code server and zsh

+ 10 - 31
speech_lm/configs/pretrain.yaml

@@ -10,32 +10,9 @@ trainer:
   _target_: lightning.fabric.Fabric
   accelerator: gpu
   strategy:
-    _target_: lightning.fabric.strategies.FSDPStrategy
-    sync_module_states: true
-    use_orig_params: true
-    cpu_offload: false
-    mixed_precision:
-      _target_: torch.distributed.fsdp.MixedPrecision
-      param_dtype: 
-        _target_: hydra.utils.get_object
-        path: torch.bfloat16
-      reduce_dtype:
-        _target_: hydra.utils.get_object
-        path: torch.bfloat16
-      buffer_dtype:
-        _target_: hydra.utils.get_object
-        path: torch.bfloat16
-      cast_forward_inputs: true
-    sharding_strategy: SHARD_GRAD_OP
-    auto_wrap_policy:
-      _target_: torch.distributed.fsdp.wrap.transformer_auto_wrap_policy
-      _partial_: true
-      transformer_layer_cls:
-        - _target_: hydra.utils.get_class
-          path: transformers.models.llama.modeling_llama.LlamaDecoderLayer
-    activation_checkpointing_policy: ${trainer.strategy.auto_wrap_policy}
-    state_dict_type: full
-  num_nodes: 1
+    _target_: lightning.fabric.strategies.DDPStrategy
+    static_graph: true
+  num_nodes: 4
   devices: 8
   precision: bf16-mixed
   loggers:
@@ -56,12 +33,14 @@ tokenizer:
 
 # Say we want a 3 trillion seen token schedule
 # 3e12 / 1024 / 512 / 8 = 715255
+# But we use a 100k steps schedule here to save time
+# This is a 300 billion seen token schedule
 schedule:
   max_length: 1024
-  batch_size: 512  # 128 * 4 = 512
-  micro_batch_size: 64
-  max_steps: 715255
-  save_interval: 2000
+  batch_size: 64  # 128 * 4 = 512
+  micro_batch_size: 8
+  max_steps: 100000
+  save_interval: 5000
   log_interval: 10
   gradient_accumulation_steps: "${eval: ${schedule.batch_size} // ${schedule.micro_batch_size}}"
   clip_grad_norm: 1.0
@@ -82,7 +61,7 @@ train_dataloader:
   _target_: torch.utils.data.DataLoader
   dataset: ${dataset}
   batch_size: ${schedule.micro_batch_size}
-  num_workers: 4
+  num_workers: 8
   collate_fn:
     _target_: speech_lm.datasets.cultura_x.CulutreXCollator
     tokenizer: ${tokenizer}

+ 5 - 0
speech_lm/train.py

@@ -138,8 +138,13 @@ def train(
                 optimizer,
                 max_norm=cfg.schedule.clip_grad_norm,
                 norm_type=2.0,
+                error_if_nonfinite=True,
             )
 
+            if torch.isnan(grad_norm) or torch.isinf(grad_norm):
+                log.warning(f"Gradient norm is {grad_norm}, skipping update")
+                optimizer.zero_grad()
+
             # We can't average gradients across multiple steps
             trackers["grad_norm"].append(float(grad_norm))
 

+ 0 - 3
train.sh

@@ -1,3 +0,0 @@
-docker run --rm -it --gpus all \
-    --ipc=host --shm-size=16g --ulimit memlock=-1 --ulimit stack=67108864 \
-    -v $(pwd):/exp speech-llm-train