Lengyue 2 лет назад
Родитель
Сommit
5f01798ab4
1 измененных файлов с 6 добавлено и 2 удалено
  1. 6 2
      speech_lm/train.py

+ 6 - 2
speech_lm/train.py

@@ -1,5 +1,5 @@
-from collections import defaultdict
 import time
+from collections import defaultdict
 from datetime import timedelta
 from pathlib import Path
 from typing import Optional
@@ -134,7 +134,11 @@ def train(
 
             # Perform gradient clipping
             grad_norm = fabric.clip_gradients(
-                model, optimizer, max_norm=cfg.schedule.clip_grad_norm, norm_type=2.0
+                model,
+                optimizer,
+                max_norm=cfg.schedule.clip_grad_norm,
+                norm_type=2.0,
+                error_if_nonfinite=False,
             )
 
             # We can't average gradients across multiple steps