progress_bar.py 518 B

12345678910111213141516
  1. from lightning.pytorch.callbacks import TQDMProgressBar
  2. class GradAccumProgressBar(TQDMProgressBar):
  3. """
  4. Progress bar that accounts for gradient accumulation so the total
  5. reflects actual forward passes rather than optimizer steps.
  6. """
  7. @property
  8. def total_train_batches(self):
  9. total = super().total_train_batches
  10. accumulate = self.trainer.accumulate_grad_batches
  11. if isinstance(total, int) and accumulate > 1:
  12. return total * accumulate
  13. return total