@@ -111,7 +111,9 @@ def train(
outputs = model(**batch)
loss = outputs.loss
metrics = getattr(outputs, "metrics", {})
- fabric.backward(loss)
+
+ # Need to divide loss by accumulation steps
+ fabric.backward(loss / cfg.schedule.gradient_accumulation_steps)
# Update trackers
trackers["loss"].append(float(loss))