Przeglądaj źródła

Fix grad acc loss

Lengyue 2 lat temu
rodzic
commit
2a19adb045
1 zmienionych plików z 3 dodań i 1 usunięć
  1. 3 1
      speech_lm/train.py

+ 3 - 1
speech_lm/train.py

@@ -111,7 +111,9 @@ def train(
                 outputs = model(**batch)
                 loss = outputs.loss
                 metrics = getattr(outputs, "metrics", {})
-                fabric.backward(loss)
+
+                # Need to divide loss by accumulation steps
+                fabric.backward(loss / cfg.schedule.gradient_accumulation_steps)
 
                 # Update trackers
                 trackers["loss"].append(float(loss))