Lengyue 2 лет назад
Родитель
Сommit
5f07060c6a
1 измененных файлов с 16 добавлено и 8 удалено
  1. 16 8
      speech_lm/train.py

+ 16 - 8
speech_lm/train.py

@@ -116,7 +116,6 @@ def train(
                 # Update trackers
                 # Update trackers
                 trackers["loss"].append(float(loss))
                 trackers["loss"].append(float(loss))
                 trackers["lr"].append(float(optimizer.param_groups[0]["lr"]))
                 trackers["lr"].append(float(optimizer.param_groups[0]["lr"]))
-                trackers["grad_norm"].append(float(grad_norm))
                 for k, v in metrics.items():
                 for k, v in metrics.items():
                     trackers[f"metrics/{k}"].append(float(v))
                     trackers[f"metrics/{k}"].append(float(v))
 
 
@@ -127,7 +126,7 @@ def train(
 
 
             # Check all trackers has the same length
             # Check all trackers has the same length
             assert (
             assert (
-                len(set(len(v) for v in trackers.values())) == 1
+                len(set(len(v) for k, v in trackers.items() if k != "grad_norm")) == 1
             ), "Trackers has ambiguous length"
             ), "Trackers has ambiguous length"
 
 
             # Perform gradient clipping
             # Perform gradient clipping
@@ -135,6 +134,9 @@ def train(
                 model, optimizer, max_norm=cfg.schedule.clip_grad_norm, norm_type=2.0
                 model, optimizer, max_norm=cfg.schedule.clip_grad_norm, norm_type=2.0
             )
             )
 
 
+            # We can't average gradients across multiple steps
+            trackers["grad_norm"].append(float(grad_norm))
+
             # Update
             # Update
             optimizer.step()
             optimizer.step()
             optimizer.zero_grad()
             optimizer.zero_grad()
@@ -163,12 +165,15 @@ def train(
 
 
                 log.info(
                 log.info(
                     f"[{global_step}/{cfg.schedule.max_steps}] "
                     f"[{global_step}/{cfg.schedule.max_steps}] "
-                    + f"step time: {step_time:.2f}s "
-                    + f"ETA: {timedelta(round(eta))}s "
+                    + f"step_time: {step_time:.2f}s "
+                    + f"ETA: {timedelta(seconds=round(eta))}s "
                     f"lr: {optimizer.param_groups[0]['lr']:.2e} "
                     f"lr: {optimizer.param_groups[0]['lr']:.2e} "
                     + " ".join(additional_info)
                     + " ".join(additional_info)
                 )
                 )
 
 
+                # Reset trackers
+                trackers = defaultdict(list)
+
                 start_time = time.time()
                 start_time = time.time()
 
 
             if global_step % cfg.schedule.save_interval == 0:
             if global_step % cfg.schedule.save_interval == 0:
@@ -183,14 +188,17 @@ def train(
                 )
                 )
 
 
             if (
             if (
-                global_step % cfg.schedule.eval_interval == 0
+                getattr(cfg.schedule, "eval_interval", None) is not None
+                and global_step % cfg.schedule.eval_interval == 0
                 and valid_dataloader is not None
                 and valid_dataloader is not None
             ):
             ):
-                valid(model, valid_dataloader, fabric, global_step, cfg)
+                valid(model, valid_dataloader, global_step, fabric, cfg)
 
 
             if global_step >= cfg.schedule.max_steps:
             if global_step >= cfg.schedule.max_steps:
                 break
                 break
 
 
+            last_batch_time = time.time()
+
 
 
 @hydra.main(version_base="1.3", config_path="./configs", config_name="pretrain.yaml")
 @hydra.main(version_base="1.3", config_path="./configs", config_name="pretrain.yaml")
 def main(cfg: DictConfig):
 def main(cfg: DictConfig):
@@ -248,8 +256,8 @@ def main(cfg: DictConfig):
     log.info(f"Train Dataloader: {train_dataloader}")
     log.info(f"Train Dataloader: {train_dataloader}")
 
 
     valid_dataloader = None
     valid_dataloader = None
-    if getattr(train_dataloader, "valid_dataloader", None) is not None:
-        valid_dataloader = hydra.utils.instantiate(train_dataloader.valid_dataloader)
+    if getattr(cfg, "valid_dataloader", None) is not None:
+        valid_dataloader = hydra.utils.instantiate(cfg.valid_dataloader)
         log.info(f"Valid Dataloader: {valid_dataloader}")
         log.info(f"Valid Dataloader: {valid_dataloader}")
 
 
     train_dataloader = fabric.setup_dataloaders(train_dataloader)
     train_dataloader = fabric.setup_dataloaders(train_dataloader)