Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ability for TQDMProgressBar to retain prior epoch training bars (… #19578

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
8 changes: 8 additions & 0 deletions docs/source-pytorch/common/progress_bar.rst
Expand Up @@ -36,6 +36,14 @@ You can update ``refresh_rate`` (rate (number of batches) at which the progress

trainer = Trainer(callbacks=[TQDMProgressBar(refresh_rate=10)])

By default the training progress bar is reset (overwritten) at each new epoch.
If you wish for a new progress bar to be displayed at the end of every epoch, set
:paramref:`TQDMProgressBar.leave <lightning.pytorch.callbacks.TQDMProgressBar.leave>` to ``True``.

.. code-block:: python

trainer = Trainer(callbacks=[TQDMProgressBar(leave=True)])

If you want to customize the default :class:`~lightning.pytorch.callbacks.TQDMProgressBar` used by Lightning, you can override
specific methods of the callback class and pass your custom implementation to the :class:`~lightning.pytorch.trainer.trainer.Trainer`.

Expand Down
2 changes: 1 addition & 1 deletion src/lightning/pytorch/CHANGELOG.md
Expand Up @@ -14,7 +14,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added robust timer duration parsing with an informative error message when parsing fails ([#19513](https://github.com/Lightning-AI/pytorch-lightning/pull/19513))

-
- The `TQDMProgressBar` now provides an option to retain prior training epoch bars. ([#19578])(https://github.com/Lightning-AI/pytorch-lightning/pull/19578)

### Changed

Expand Down
8 changes: 7 additions & 1 deletion src/lightning/pytorch/callbacks/progress/tqdm_progress.py
Expand Up @@ -99,12 +99,13 @@ class TQDMProgressBar(ProgressBar):
together. This corresponds to
:paramref:`~lightning.pytorch.trainer.trainer.Trainer.process_position` in the
:class:`~lightning.pytorch.trainer.trainer.Trainer`.
leave: Leaves the finished progress bar in the terminal at the end of the epoch. Default: False

"""

BAR_FORMAT = "{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_noinv_fmt}{postfix}]"

def __init__(self, refresh_rate: int = 1, process_position: int = 0):
def __init__(self, refresh_rate: int = 1, process_position: int = 0, leave: bool = False):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds like a good idea!

super().__init__()
self._refresh_rate = self._resolve_refresh_rate(refresh_rate)
self._process_position = process_position
Expand All @@ -113,6 +114,7 @@ def __init__(self, refresh_rate: int = 1, process_position: int = 0):
self._val_progress_bar: Optional[_tqdm] = None
self._test_progress_bar: Optional[_tqdm] = None
self._predict_progress_bar: Optional[_tqdm] = None
self._leave = leave

def __getstate__(self) -> Dict:
# can't pickle the tqdm objects
Expand Down Expand Up @@ -262,6 +264,8 @@ def on_train_start(self, *_: Any) -> None:

@override
def on_train_epoch_start(self, trainer: "pl.Trainer", *_: Any) -> None:
if self._leave:
self.train_progress_bar = self.init_train_tqdm()
Comment on lines +267 to +268
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of doing this, I believe it would be better to pass leave=self.leave to the tqdm bar directly (see init_train_tqdm(), init_validation_tqdm() etc. above.

Copy link
Author

@jojje jojje Mar 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You mean creating yet another constructor wrapper, like in the rich progress bar, and what is already present in this file?

E.g.

    def reinit_train_tqdm(self) -> Tqdm:
        """Override this to customize the tqdm bar for training."""
        return Tqdm(
            desc=self.train_description,
            position=(2 * self.process_position),
            disable=self.is_disabled,
            leave=self._leave,
            dynamic_ncols=True,
            file=sys.stdout,
            smoothing=0,
            bar_format=self.BAR_FORMAT,
        )

It'd be identical to init_validation_tqdm, so I don't really see the point of that, unless the intended introduction of init_validation_tqdm had some undocumented purpose that you're planning to start taking advantage of (and perhaps need to change the implementation of the constructor function)

Or did you perhaps mean changing the hard coded "leave=True" to "leave=self._leave" in the existing "init_train_tqdm" function. Or perhaps a third variant, where the init_train_tqdm function is parameterized to take the leave value?

PS. Thanks for the review.

Copy link
Member

@awaelchli awaelchli Mar 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, I'm proposing to simply pass leave=self.leave in init_train_tqdm() where we hard coded it to True so far. Line 205

self.train_progress_bar.reset(convert_inf(self.total_train_batches))
self.train_progress_bar.initial = 0
self.train_progress_bar.set_description(f"Epoch {trainer.current_epoch}")
Expand All @@ -279,6 +283,8 @@ def on_train_batch_end(
def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
if not self.train_progress_bar.disable:
self.train_progress_bar.set_postfix(self.get_metrics(trainer, pl_module))
if self._leave:
self.train_progress_bar.close()

@override
def on_train_end(self, *_: Any) -> None:
Expand Down