|
|
@@ -99,10 +99,9 @@ def train(
|
|
|
model.train()
|
|
|
|
|
|
# Accumulate gradients
|
|
|
+ gradient_accumulation_steps = cfg.schedule.gradient_accumulation_steps
|
|
|
+ is_accumulating = accumulate_steps % gradient_accumulation_steps != 0
|
|
|
accumulate_steps += 1
|
|
|
- is_accumulating = (
|
|
|
- accumulate_steps < cfg.schedule.gradient_accumulation_steps
|
|
|
- )
|
|
|
|
|
|
# Train one step
|
|
|
with fabric.no_backward_sync(model, enabled=is_accumulating):
|
|
|
@@ -111,7 +110,7 @@ def train(
|
|
|
metrics = getattr(outputs, "metrics", {})
|
|
|
|
|
|
# Need to divide loss by accumulation steps
|
|
|
- fabric.backward(loss / cfg.schedule.gradient_accumulation_steps)
|
|
|
+ fabric.backward(loss / gradient_accumulation_steps)
|
|
|
|
|
|
# Update trackers
|
|
|
trackers["loss"].append(float(loss))
|
|
|
@@ -153,14 +152,14 @@ def train(
|
|
|
|
|
|
fabric.log_dict(
|
|
|
{
|
|
|
- f"train/{k}": sum(v[-accumulate_steps:])
|
|
|
- / len(v[-accumulate_steps:])
|
|
|
+ f"train/{k}": sum(v[-gradient_accumulation_steps:])
|
|
|
+ / len(v[-gradient_accumulation_steps:])
|
|
|
for k, v in trackers.items()
|
|
|
},
|
|
|
step=global_step,
|
|
|
)
|
|
|
|
|
|
- accumulate_steps = 0
|
|
|
+ # accumulate_steps = 0
|
|
|
global_step += 1
|
|
|
|
|
|
if global_step % cfg.schedule.log_interval == 0:
|