Ver Fonte

Better ETA

Lengyue há 2 anos atrás
pai
commit
d211a7eac7
1 ficheiros alterados com 4 adições e 2 exclusões
  1. 4 2
      speech_lm/train.py

+ 4 - 2
speech_lm/train.py

@@ -1,5 +1,6 @@
-from pathlib import Path
 import time
+from datetime import timedelta
+from pathlib import Path
 
 import hydra
 import torch
@@ -82,12 +83,13 @@ def train(
 
             if global_step % cfg.schedule.log_interval == 0:
                 step_time = (time.time() - start_time) / cfg.schedule.log_interval
+                eta = step_time * (cfg.schedule.max_steps - global_step)
                 log.info(
                     f"[{global_step}/{cfg.schedule.max_steps}] loss: {loss:.4f} "
                     + f"step time: {step_time:.2f}s "
                     f"lr: {optimizer.param_groups[0]['lr']:.2e} "
                     + f"grad_norm: {grad_norm:.2f} "
-                    + f"ETA: {step_time * (cfg.schedule.max_steps - global_step):.2f}s"
+                    + f"ETA: {timedelta(round(eta))}s"
                 )
 
                 start_time = time.time()