|
|
@@ -99,10 +99,8 @@ def train(
|
|
|
model.train()
|
|
|
|
|
|
# Accumulate gradients
|
|
|
- is_accumulating = (
|
|
|
- accumulate_steps % cfg.schedule.gradient_accumulation_steps != 0
|
|
|
- )
|
|
|
accumulate_steps += 1
|
|
|
+ is_accumulating = accumulate_steps < cfg.schedule.gradient_accumulation_steps
|
|
|
|
|
|
# Train one step
|
|
|
with fabric.no_backward_sync(model, enabled=is_accumulating):
|
|
|
@@ -153,12 +151,14 @@ def train(
|
|
|
|
|
|
fabric.log_dict(
|
|
|
{
|
|
|
- f"train/{k}": float(v[-1])
|
|
|
+ f"train/{k}": sum(v[-accumulate_steps:])
|
|
|
+ / len(v[-accumulate_steps:])
|
|
|
for k, v in trackers.items()
|
|
|
},
|
|
|
step=global_step,
|
|
|
)
|
|
|
|
|
|
+ accumulate_steps = 0
|
|
|
global_step += 1
|
|
|
|
|
|
if global_step % cfg.schedule.log_interval == 0:
|