Lengyue vor 2 Jahren
Ursprung
Commit
d211a7eac7
1 geänderte Dateien mit 4 neuen und 2 gelöschten Zeilen
  1. 4 2
      speech_lm/train.py

+ 4 - 2
speech_lm/train.py

@@ -1,5 +1,6 @@
-from pathlib import Path
 import time
+from datetime import timedelta
+from pathlib import Path
 
 import hydra
 import torch
@@ -82,12 +83,13 @@ def train(
 
             if global_step % cfg.schedule.log_interval == 0:
                 step_time = (time.time() - start_time) / cfg.schedule.log_interval
+                eta = step_time * (cfg.schedule.max_steps - global_step)
                 log.info(
                     f"[{global_step}/{cfg.schedule.max_steps}] loss: {loss:.4f} "
                     + f"step time: {step_time:.2f}s "
                     f"lr: {optimizer.param_groups[0]['lr']:.2e} "
                     + f"grad_norm: {grad_norm:.2f} "
-                    + f"ETA: {step_time * (cfg.schedule.max_steps - global_step):.2f}s"
+                    + f"ETA: {timedelta(round(eta))}s"
                 )
 
                 start_time = time.time()