|
@@ -1,4 +1,3 @@
|
|
|
-import logging
|
|
|
|
|
from pathlib import Path
|
|
from pathlib import Path
|
|
|
|
|
|
|
|
import hydra
|
|
import hydra
|
|
@@ -40,9 +39,11 @@ def train(
|
|
|
|
|
|
|
|
while global_step < cfg.schedule.max_steps:
|
|
while global_step < cfg.schedule.max_steps:
|
|
|
for batch in dataloader:
|
|
for batch in dataloader:
|
|
|
|
|
+ # Accumulate gradients
|
|
|
is_accumulating = (
|
|
is_accumulating = (
|
|
|
accumulate_steps % cfg.schedule.gradient_accumulation_steps != 0
|
|
accumulate_steps % cfg.schedule.gradient_accumulation_steps != 0
|
|
|
)
|
|
)
|
|
|
|
|
+ accumulate_steps += 1
|
|
|
|
|
|
|
|
# Train one step
|
|
# Train one step
|
|
|
with fabric.no_backward_sync(model, enabled=is_accumulating):
|
|
with fabric.no_backward_sync(model, enabled=is_accumulating):
|
|
@@ -50,7 +51,6 @@ def train(
|
|
|
fabric.backward(loss)
|
|
fabric.backward(loss)
|
|
|
|
|
|
|
|
if is_accumulating:
|
|
if is_accumulating:
|
|
|
- accumulate_steps += 1
|
|
|
|
|
continue
|
|
continue
|
|
|
|
|
|
|
|
# Perform gradient clipping
|
|
# Perform gradient clipping
|