|
|
@@ -35,25 +35,47 @@ def train(
|
|
|
):
|
|
|
bar = tqdm(total=cfg.schedule.max_steps, desc="Training")
|
|
|
bar.update(global_step)
|
|
|
+ accumulate_steps = 0
|
|
|
+ optimizer.zero_grad()
|
|
|
|
|
|
while global_step < cfg.schedule.max_steps:
|
|
|
for batch in dataloader:
|
|
|
- # Train loop
|
|
|
- optimizer.zero_grad()
|
|
|
- loss = model(**batch).loss
|
|
|
- fabric.backward(loss)
|
|
|
+ is_accumulating = (
|
|
|
+ accumulate_steps % cfg.schedule.gradient_accumulation_steps != 0
|
|
|
+ )
|
|
|
+
|
|
|
+ # Train one step
|
|
|
+ with fabric.no_backward_sync(model, enabled=is_accumulating):
|
|
|
+ loss = model(**batch).loss
|
|
|
+ fabric.backward(loss)
|
|
|
+
|
|
|
+ if is_accumulating:
|
|
|
+ accumulate_steps += 1
|
|
|
+ continue
|
|
|
+
|
|
|
+ # Perform gradient clipping
|
|
|
+ grad_norm = fabric.clip_gradients(
|
|
|
+ model, optimizer, max_norm=cfg.schedule.clip_grad_norm, norm_type=2.0
|
|
|
+ )
|
|
|
+
|
|
|
+ # Update
|
|
|
optimizer.step()
|
|
|
+ optimizer.zero_grad()
|
|
|
scheduler.step()
|
|
|
|
|
|
- fabric.log_dict({
|
|
|
- "train/loss": loss,
|
|
|
- "train/lr": optimizer.param_groups[0]["lr"],
|
|
|
- }, step=global_step)
|
|
|
+ fabric.log_dict(
|
|
|
+ {
|
|
|
+ "train/loss": loss,
|
|
|
+ "train/lr": optimizer.param_groups[0]["lr"],
|
|
|
+ "train/grad_norm": grad_norm,
|
|
|
+ },
|
|
|
+ step=global_step,
|
|
|
+ )
|
|
|
|
|
|
global_step += 1
|
|
|
bar.update(1)
|
|
|
|
|
|
- if global_step % cfg.schedule.save_steps == 0:
|
|
|
+ if global_step % cfg.schedule.save_interval == 0:
|
|
|
fabric.save(
|
|
|
Path(cfg.paths.checkpoint_dir) / f"step_{global_step}.ckpt",
|
|
|
{
|