|
|
@@ -111,7 +111,9 @@ def train(
|
|
|
outputs = model(**batch)
|
|
|
loss = outputs.loss
|
|
|
metrics = getattr(outputs, "metrics", {})
|
|
|
- fabric.backward(loss)
|
|
|
+
|
|
|
+ # Need to divide loss by accumulation steps
|
|
|
+ fabric.backward(loss / cfg.schedule.gradient_accumulation_steps)
|
|
|
|
|
|
# Update trackers
|
|
|
trackers["loss"].append(float(loss))
|