Explorar el Código

Better train loop tracking

Lengyue hace 2 años
padre
commit
fb23d5df2d
Se han modificado 1 ficheros con 108 adiciones y 14 borrados
  1. 108 14
      speech_lm/train.py

+ 108 - 14
speech_lm/train.py

@@ -1,6 +1,8 @@
+from collections import defaultdict
 import time
 from datetime import timedelta
 from pathlib import Path
+from typing import Optional
 
 import hydra
 import torch
@@ -23,25 +25,81 @@ OmegaConf.register_new_resolver("eval", eval)
 log = RankedLogger(__name__, rank_zero_only=True)
 
 
+def valid(
+    model: LlamaForCausalLM,
+    valid_dataloader: Optional[torch.utils.data.DataLoader],
+    global_step: int,
+    fabric: Fabric,
+    cfg: DictConfig,
+):
+    model.eval()
+    log.info(f"Evaluating at step {global_step}")
+
+    accumulate_infos = None
+
+    for idx, batch in tqdm(enumerate(valid_dataloader), desc="Evaluating"):
+        outputs = model(**batch)
+        loss = outputs.loss
+        metrics = getattr(outputs, "metrics", {})
+        log_info = {
+            "valid/loss": float(loss),
+            **{f"valid/{k}": float(v) for k, v in metrics.items()},
+        }
+
+        fabric.log_dict(
+            log_info,
+            step=global_step + idx,
+        )
+
+        # Update log info
+        if accumulate_infos is None:
+            accumulate_infos = log_info
+        else:
+            assert set(accumulate_infos.keys()) == set(
+                log_info.keys()
+            ), "Log keys changed during evaluation"
+            for k in accumulate_infos.keys():
+                accumulate_infos[k] += log_info[k]
+
+        if idx == getattr(cfg.schedule, "eval_max_batches", None):
+            break
+
+    # Log average
+    items = []
+    for k in accumulate_infos.keys():
+        items.append(f"{k}: {accumulate_infos[k] / (idx + 1):.4f}")
+    log.info(f"Average: {' | '.join(items)}")
+
+
 def train(
     model: LlamaForCausalLM,
     optimizer: torch.optim.Optimizer,
     scheduler: torch.optim.lr_scheduler._LRScheduler,
-    dataloader: torch.utils.data.DataLoader,
+    train_dataloader: torch.utils.data.DataLoader,
+    valid_dataloader: Optional[torch.utils.data.DataLoader],
     global_step: int,
     fabric: Fabric,
     cfg: DictConfig,
 ):
-    model.train()
-
     bar = tqdm(total=cfg.schedule.max_steps, desc="Training")
     bar.update(global_step)
     accumulate_steps = 0
     optimizer.zero_grad()
+
+    # Start time is ~model forward time + data loading time
     start_time = time.time()
+    trackers = defaultdict(list)
 
     while global_step < cfg.schedule.max_steps:
-        for batch in dataloader:
+        last_batch_time = time.time()
+        for batch in train_dataloader:
+            # Measure time used by data loading
+            trackers["data_time"].append(time.time() - last_batch_time)
+
+            # Measure time used by model forward
+            model_begin_time = time.time()
+            model.train()
+
             # Accumulate gradients
             is_accumulating = (
                 accumulate_steps % cfg.schedule.gradient_accumulation_steps != 0
@@ -55,9 +113,25 @@ def train(
                 metrics = getattr(outputs, "metrics", {})
                 fabric.backward(loss)
 
+                # Update trackers
+                trackers["loss"].append(float(loss))
+                trackers["lr"].append(float(optimizer.param_groups[0]["lr"]))
+                trackers["grad_norm"].append(
+                    trackers.get("grad_norm", 0) + float(grad_norm)
+                )
+                for k, v in metrics.items():
+                    trackers[f"metrics/{k}"].append(float(v))
+
+            trackers["model_time"].append(time.time() - model_begin_time)
+
             if is_accumulating:
                 continue
 
+            # Check all trackers has the same length
+            assert (
+                len(set(len(v) for v in trackers.values())) == 1
+            ), "Trackers has ambiguous length"
+
             # Perform gradient clipping
             grad_norm = fabric.clip_gradients(
                 model, optimizer, max_norm=cfg.schedule.clip_grad_norm, norm_type=2.0
@@ -70,10 +144,9 @@ def train(
 
             fabric.log_dict(
                 {
-                    "train/loss": loss,
-                    "train/lr": optimizer.param_groups[0]["lr"],
-                    "train/grad_norm": grad_norm,
-                    **{f"train/{k}": v for k, v in metrics.items()},
+                    f"train/{k}": sum(v[-accumulate_steps:])
+                    / len(v[-accumulate_steps:])
+                    for k, v in trackers.items()
                 },
                 step=global_step,
             )
@@ -84,12 +157,18 @@ def train(
             if global_step % cfg.schedule.log_interval == 0:
                 step_time = (time.time() - start_time) / cfg.schedule.log_interval
                 eta = step_time * (cfg.schedule.max_steps - global_step)
+                additional_info = [
+                    f"{k}: {sum(v[-cfg.schedule.log_interval:]) / len(v[-cfg.schedule.log_interval:]):.4f}"
+                    for k, v in trackers.items()
+                    if k != "lr"  # lr use .2e format
+                ]
+
                 log.info(
-                    f"[{global_step}/{cfg.schedule.max_steps}] loss: {loss:.4f} "
+                    f"[{global_step}/{cfg.schedule.max_steps}] "
                     + f"step time: {step_time:.2f}s "
+                    + f"ETA: {timedelta(round(eta))}s "
                     f"lr: {optimizer.param_groups[0]['lr']:.2e} "
-                    + f"grad_norm: {grad_norm:.2f} "
-                    + f"ETA: {timedelta(round(eta))}s"
+                    + " ".join(additional_info)
                 )
 
                 start_time = time.time()
@@ -105,6 +184,12 @@ def train(
                     },
                 )
 
+            if (
+                global_step % cfg.schedule.eval_interval == 0
+                and valid_dataloader is not None
+            ):
+                valid(model, valid_dataloader, fabric, global_step, cfg)
+
             if global_step >= cfg.schedule.max_steps:
                 break
 
@@ -161,17 +246,26 @@ def main(cfg: DictConfig):
         global_step = remainder["global_step"]
         log.info(f"Restored global step: {global_step}")
 
-    train_dataloader = hydra.utils.instantiate(cfg.dataloader)
-    log.info(f"Dataloader: {train_dataloader}")
+    train_dataloader = hydra.utils.instantiate(cfg.train_dataloader)
+    log.info(f"Train Dataloader: {train_dataloader}")
+
+    valid_dataloader = None
+    if getattr(train_dataloader, "valid_dataloader", None) is not None:
+        valid_dataloader = hydra.utils.instantiate(train_dataloader.valid_dataloader)
+        log.info(f"Valid Dataloader: {valid_dataloader}")
 
     train_dataloader = fabric.setup_dataloaders(train_dataloader)
+    if valid_dataloader is not None:
+        valid_dataloader = fabric.setup_dataloaders(valid_dataloader)
+
     log.info(f"Begin training")
 
     train(
         model=model,
         optimizer=optimizer,
         scheduler=scheduler,
-        dataloader=train_dataloader,
+        train_dataloader=train_dataloader,
+        valid_dataloader=valid_dataloader,
         global_step=global_step,
         fabric=fabric,
         cfg=cfg,