|
|
@@ -1,5 +1,6 @@
|
|
|
-from pathlib import Path
|
|
|
import time
|
|
|
+from datetime import timedelta
|
|
|
+from pathlib import Path
|
|
|
|
|
|
import hydra
|
|
|
import torch
|
|
|
@@ -82,12 +83,13 @@ def train(
|
|
|
|
|
|
if global_step % cfg.schedule.log_interval == 0:
|
|
|
step_time = (time.time() - start_time) / cfg.schedule.log_interval
|
|
|
+ eta = step_time * (cfg.schedule.max_steps - global_step)
|
|
|
log.info(
|
|
|
f"[{global_step}/{cfg.schedule.max_steps}] loss: {loss:.4f} "
|
|
|
+ f"step time: {step_time:.2f}s "
|
|
|
f"lr: {optimizer.param_groups[0]['lr']:.2e} "
|
|
|
+ f"grad_norm: {grad_norm:.2f} "
|
|
|
- + f"ETA: {step_time * (cfg.schedule.max_steps - global_step):.2f}s"
|
|
|
+ + f"ETA: {timedelta(round(eta))}s"
|
|
|
)
|
|
|
|
|
|
start_time = time.time()
|