Просмотр исходного кода

fix: correct progress bar total when using gradient accumulation with max_steps (#1227)

When training with accumulate_grad_batches > 1 and max_steps, the
default TQDMProgressBar overflows past 100% because its total is
computed in optimizer steps while on_train_batch_end fires on every
forward pass.

GradAccumProgressBar multiplies total_train_batches by
accumulate_grad_batches so the total matches the actual number of
forward passes, keeping the bar accurate throughout training.

Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
Matteo 2 недель назад
Родитель
Сommit
48b19f66b1

+ 2 - 1
fish_speech/callbacks/__init__.py

@@ -1,3 +1,4 @@
 from .grad_norm import GradNormMonitor
+from .progress_bar import GradAccumProgressBar
 
-__all__ = ["GradNormMonitor"]
+__all__ = ["GradNormMonitor", "GradAccumProgressBar"]

+ 16 - 0
fish_speech/callbacks/progress_bar.py

@@ -0,0 +1,16 @@
+from lightning.pytorch.callbacks import TQDMProgressBar
+
+
+class GradAccumProgressBar(TQDMProgressBar):
+    """
+    Progress bar that accounts for gradient accumulation so the total
+    reflects actual forward passes rather than optimizer steps.
+    """
+
+    @property
+    def total_train_batches(self):
+        total = super().total_train_batches
+        accumulate = self.trainer.accumulate_grad_batches
+        if isinstance(total, int) and accumulate > 1:
+            return total * accumulate
+        return total

+ 3 - 0
fish_speech/configs/base.yaml

@@ -57,6 +57,9 @@ callbacks:
     norm_type: 2
     logging_interval: step
 
+  progress_bar:
+    _target_: fish_speech.callbacks.GradAccumProgressBar
+
 # Logger
 logger:
   tensorboard: