Преглед изворни кода

Optimize data loading & gradient accumulate

Lengyue пре 2 година
родитељ
комит
900fc9dec9
2 измењених фајлова са 3 додато и 6 уклоњено
  1. 1 4
      speech_lm/configs/pretrain.yaml
  2. 2 2
      speech_lm/train.py

+ 1 - 4
speech_lm/configs/pretrain.yaml

@@ -48,10 +48,7 @@ dataloader:
   batch_size: ${schedule.micro_batch_size}
   num_workers: 4
   collate_fn:
-    _target_: transformers.DataCollatorWithPadding
-    tokenizer: ${tokenizer}
-    max_length: ${schedule.max_length}
-    padding: max_length
+    _target_: transformers.DefaultDataCollator
 
 optimizer:
   _target_: torch.optim.AdamW

+ 2 - 2
speech_lm/train.py

@@ -1,4 +1,3 @@
-import logging
 from pathlib import Path
 
 import hydra
@@ -40,9 +39,11 @@ def train(
 
     while global_step < cfg.schedule.max_steps:
         for batch in dataloader:
+            # Accumulate gradients
             is_accumulating = (
                 accumulate_steps % cfg.schedule.gradient_accumulation_steps != 0
             )
+            accumulate_steps += 1
 
             # Train one step
             with fabric.no_backward_sync(model, enabled=is_accumulating):
@@ -50,7 +51,6 @@ def train(
                 fabric.backward(loss)
 
             if is_accumulating:
-                accumulate_steps += 1
                 continue
 
             # Perform gradient clipping