|
|
@@ -1,4 +1,5 @@
|
|
|
from pathlib import Path
|
|
|
+import time
|
|
|
|
|
|
import hydra
|
|
|
import torch
|
|
|
@@ -36,6 +37,7 @@ def train(
|
|
|
bar.update(global_step)
|
|
|
accumulate_steps = 0
|
|
|
optimizer.zero_grad()
|
|
|
+ start_time = time.time()
|
|
|
|
|
|
while global_step < cfg.schedule.max_steps:
|
|
|
for batch in dataloader:
|
|
|
@@ -78,6 +80,18 @@ def train(
|
|
|
global_step += 1
|
|
|
bar.update(1)
|
|
|
|
|
|
+ if global_step % cfg.schedule.log_interval == 0:
|
|
|
+ step_time = (time.time() - start_time) / cfg.schedule.log_interval
|
|
|
+ log.info(
|
|
|
+ f"[{global_step}/{cfg.schedule.max_steps}] loss: {loss:.4f} "
|
|
|
+ + f"step time: {step_time:.2f}s "
|
|
|
+ f"lr: {optimizer.param_groups[0]['lr']:.2e} "
|
|
|
+ + f"grad_norm: {grad_norm:.2f} "
|
|
|
+ + f"ETA: {step_time * (cfg.schedule.max_steps - global_step):.2f}s"
|
|
|
+ )
|
|
|
+
|
|
|
+ start_time = time.time()
|
|
|
+
|
|
|
if global_step % cfg.schedule.save_interval == 0:
|
|
|
fabric.save(
|
|
|
Path(cfg.paths.checkpoint_dir) / f"step_{global_step}.ckpt",
|
|
|
@@ -118,7 +132,8 @@ def main(cfg: DictConfig):
|
|
|
log.info(f"Scheduler: {scheduler}")
|
|
|
|
|
|
log.info(f"Setup fabric model & dataset")
|
|
|
- model, optimizer, scheduler = fabric.setup(model, optimizer, scheduler)
|
|
|
+ model = fabric.setup_module(model)
|
|
|
+ optimizer = fabric.setup_optimizers(optimizer)
|
|
|
|
|
|
# Build state
|
|
|
global_step = 0
|