|
|
@@ -1,5 +1,5 @@
|
|
|
-from collections import defaultdict
|
|
|
import time
|
|
|
+from collections import defaultdict
|
|
|
from datetime import timedelta
|
|
|
from pathlib import Path
|
|
|
from typing import Optional
|
|
|
@@ -134,7 +134,11 @@ def train(
|
|
|
|
|
|
# Perform gradient clipping
|
|
|
grad_norm = fabric.clip_gradients(
|
|
|
- model, optimizer, max_norm=cfg.schedule.clip_grad_norm, norm_type=2.0
|
|
|
+ model,
|
|
|
+ optimizer,
|
|
|
+ max_norm=cfg.schedule.clip_grad_norm,
|
|
|
+ norm_type=2.0,
|
|
|
+ error_if_nonfinite=False,
|
|
|
)
|
|
|
|
|
|
# We can't average gradients across multiple steps
|