Skip to content

Commit

Permalink
Add ability for TQDMProgressBar to retain prior epoch training bars (L…
Browse files Browse the repository at this point in the history
  • Loading branch information
jojje committed Mar 5, 2024
1 parent b871f7a commit 1fa6195
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 2 deletions.
8 changes: 8 additions & 0 deletions docs/source-pytorch/common/progress_bar.rst
Expand Up @@ -36,6 +36,14 @@ You can update ``refresh_rate`` (rate (number of batches) at which the progress
trainer = Trainer(callbacks=[TQDMProgressBar(refresh_rate=10)])
By default the training progress bar is reset (overwritten) at each new epoch.
If you wish for a new progress bar to be displayed at the end of every epoch, set
:paramref:`TQDMProgressBar.leave <lightning.pytorch.callbacks.TQDMProgressBar.leave>` to ``True``.

.. code-block:: python
trainer = Trainer(callbacks=[TQDMProgressBar(leave=True)])
If you want to customize the default :class:`~lightning.pytorch.callbacks.TQDMProgressBar` used by Lightning, you can override
specific methods of the callback class and pass your custom implementation to the :class:`~lightning.pytorch.trainer.trainer.Trainer`.

Expand Down
9 changes: 7 additions & 2 deletions src/lightning/pytorch/callbacks/progress/tqdm_progress.py
Expand Up @@ -99,12 +99,13 @@ class TQDMProgressBar(ProgressBar):
together. This corresponds to
:paramref:`~lightning.pytorch.trainer.trainer.Trainer.process_position` in the
:class:`~lightning.pytorch.trainer.trainer.Trainer`.
leave: Leaves the finished progress bar in the terminal at the end of the epoch. Default: False
"""

BAR_FORMAT = "{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_noinv_fmt}{postfix}]"

def __init__(self, refresh_rate: int = 1, process_position: int = 0):
def __init__(self, refresh_rate: int = 1, process_position: int = 0, leave: bool = False):
super().__init__()
self._refresh_rate = self._resolve_refresh_rate(refresh_rate)
self._process_position = process_position
Expand All @@ -113,6 +114,7 @@ def __init__(self, refresh_rate: int = 1, process_position: int = 0):
self._val_progress_bar: Optional[_tqdm] = None
self._test_progress_bar: Optional[_tqdm] = None
self._predict_progress_bar: Optional[_tqdm] = None
self._leave = leave

def __getstate__(self) -> Dict:
# can't pickle the tqdm objects
Expand Down Expand Up @@ -262,6 +264,8 @@ def on_train_start(self, *_: Any) -> None:

@override
def on_train_epoch_start(self, trainer: "pl.Trainer", *_: Any) -> None:
if self._leave:
self.train_progress_bar = self.init_train_tqdm()
self.train_progress_bar.reset(convert_inf(self.total_train_batches))
self.train_progress_bar.initial = 0
self.train_progress_bar.set_description(f"Epoch {trainer.current_epoch}")
Expand All @@ -282,7 +286,8 @@ def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModu

@override
def on_train_end(self, *_: Any) -> None:
self.train_progress_bar.close()
if not self._leave:
self.train_progress_bar.close()

@override
def on_validation_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
Expand Down

0 comments on commit 1fa6195

Please sign in to comment.