@@ -32,6 +32,8 @@ def train(
fabric: Fabric,
cfg: DictConfig,
):
+ model.train()
+
bar = tqdm(total=cfg.schedule.max_steps, desc="Training")
bar.update(global_step)
accumulate_steps = 0