Selaa lähdekoodia

Fix logger in config

Lengyue 2 vuotta sitten
vanhempi
commit
026d419e9a
2 muutettua tiedostoa jossa 41 lisäystä ja 16 poistoa
  1. 10 7
      speech_lm/configs/pretrain.yaml
  2. 31 9
      speech_lm/train.py

+ 10 - 7
speech_lm/configs/pretrain.yaml

@@ -13,10 +13,10 @@ trainer:
   devices: auto
   precision: bf16-mixed
   loggers:
-    - _target_: pytorch_lightning.loggers.TensorBoardLogger
-      save_dir: ${paths.run_dir}
-      name: tensorboard
-      version: null
+    _target_: pytorch_lightning.loggers.TensorBoardLogger
+    save_dir: ${paths.run_dir}
+    name: tensorboard
+    version: null
 
 model:
   _target_: transformers.AutoModelForCausalLM.from_pretrained
@@ -32,9 +32,12 @@ tokenizer:
 # 3e12 / 1024 / 512 / 8 = 715255
 schedule:
   max_length: 1024
-  batch_size: 16
+  batch_size: 512
+  micro_batch_size: 2
   max_steps: 715255
-  save_every: 2000
+  save_interval: 2000
+  gradient_accumulation_steps: "${eval: ${schedule.batch_size} // ${schedule.micro_batch_size}}"
+  clip_grad_norm: 1.0
 
 dataloader:
   _target_: torch.utils.data.DataLoader
@@ -42,7 +45,7 @@ dataloader:
     _target_: speech_lm.dataset.build_dataset
     tokenizer: ${tokenizer}
     max_length: ${schedule.max_length}
-  batch_size: ${schedule.batch_size}
+  batch_size: ${schedule.micro_batch_size}
   num_workers: 4
   collate_fn:
     _target_: transformers.DataCollatorWithPadding

+ 31 - 9
speech_lm/train.py

@@ -35,25 +35,47 @@ def train(
 ):
     bar = tqdm(total=cfg.schedule.max_steps, desc="Training")
     bar.update(global_step)
+    accumulate_steps = 0
+    optimizer.zero_grad()
 
     while global_step < cfg.schedule.max_steps:
         for batch in dataloader:
-            # Train loop
-            optimizer.zero_grad()
-            loss = model(**batch).loss
-            fabric.backward(loss)
+            is_accumulating = (
+                accumulate_steps % cfg.schedule.gradient_accumulation_steps != 0
+            )
+
+            # Train one step
+            with fabric.no_backward_sync(model, enabled=is_accumulating):
+                loss = model(**batch).loss
+                fabric.backward(loss)
+
+            if is_accumulating:
+                accumulate_steps += 1
+                continue
+
+            # Perform gradient clipping
+            grad_norm = fabric.clip_gradients(
+                model, optimizer, max_norm=cfg.schedule.clip_grad_norm, norm_type=2.0
+            )
+
+            # Update
             optimizer.step()
+            optimizer.zero_grad()
             scheduler.step()
 
-            fabric.log_dict({
-                "train/loss": loss,
-                "train/lr": optimizer.param_groups[0]["lr"],
-            }, step=global_step)
+            fabric.log_dict(
+                {
+                    "train/loss": loss,
+                    "train/lr": optimizer.param_groups[0]["lr"],
+                    "train/grad_norm": grad_norm,
+                },
+                step=global_step,
+            )
 
             global_step += 1
             bar.update(1)
 
-            if global_step % cfg.schedule.save_steps == 0:
+            if global_step % cfg.schedule.save_interval == 0:
                 fabric.save(
                     Path(cfg.paths.checkpoint_dir) / f"step_{global_step}.ckpt",
                     {