Просмотр исходного кода

Fix gradient accumulate logging

Lengyue 2 лет назад
Родитель
Сommit
4cfb62a08a
1 измененных файлов с 4 добавлено и 4 удалено
  1. 4 4
      speech_lm/train.py

+ 4 - 4
speech_lm/train.py

@@ -99,10 +99,8 @@ def train(
             model.train()
 
             # Accumulate gradients
-            is_accumulating = (
-                accumulate_steps % cfg.schedule.gradient_accumulation_steps != 0
-            )
             accumulate_steps += 1
+            is_accumulating = accumulate_steps < cfg.schedule.gradient_accumulation_steps
 
             # Train one step
             with fabric.no_backward_sync(model, enabled=is_accumulating):
@@ -153,12 +151,14 @@ def train(
 
             fabric.log_dict(
                 {
-                    f"train/{k}": float(v[-1])
+                    f"train/{k}": sum(v[-accumulate_steps:])
+                    / len(v[-accumulate_steps:])
                     for k, v in trackers.items()
                 },
                 step=global_step,
             )
 
+            accumulate_steps = 0
             global_step += 1
 
             if global_step % cfg.schedule.log_interval == 0: