Explorar o código

Fix grad acc loss

Lengyue %!s(int64=2) %!d(string=hai) anos
pai
achega
2a19adb045
Modificáronse 1 ficheiros con 3 adicións e 1 borrados
  1. 3 1
      speech_lm/train.py

+ 3 - 1
speech_lm/train.py

@@ -111,7 +111,9 @@ def train(
                 outputs = model(**batch)
                 loss = outputs.loss
                 metrics = getattr(outputs, "metrics", {})
-                fabric.backward(loss)
+
+                # Need to divide loss by accumulation steps
+                fabric.backward(loss / cfg.schedule.gradient_accumulation_steps)
 
                 # Update trackers
                 trackers["loss"].append(float(loss))