Skip to content

Commit

Permalink
Move Colab setup to ProgressBar (#10542)
Browse files Browse the repository at this point in the history
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people committed Nov 23, 2021
1 parent 2036dfb commit 48cf1ad
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 4 deletions.
11 changes: 10 additions & 1 deletion pytorch_lightning/callbacks/progress/tqdm_progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from tqdm import tqdm as _tqdm

from pytorch_lightning.callbacks.progress.base import ProgressBarBase
from pytorch_lightning.utilities.distributed import rank_zero_debug

_PAD_SIZE = 5

Expand Down Expand Up @@ -100,7 +101,7 @@ class TQDMProgressBar(ProgressBarBase):

def __init__(self, refresh_rate: int = 1, process_position: int = 0):
super().__init__()
self._refresh_rate = refresh_rate
self._refresh_rate = self._resolve_refresh_rate(refresh_rate)
self._process_position = process_position
self._enabled = True
self.main_progress_bar = None
Expand Down Expand Up @@ -324,6 +325,14 @@ def _update_bar(self, bar: Optional[Tqdm]) -> None:
if delta > 0:
bar.update(delta)

@staticmethod
def _resolve_refresh_rate(refresh_rate: int) -> int:
if os.getenv("COLAB_GPU") and refresh_rate == 1:
# smaller refresh rate on colab causes crashes, choose a higher value
rank_zero_debug("Using a higher refresh rate on Colab. Setting it to `20`")
refresh_rate = 20
return refresh_rate


def convert_inf(x: Optional[Union[int, float]]) -> Optional[Union[int, float]]:
"""The tqdm doesn't support inf/nan values.
Expand Down
3 changes: 1 addition & 2 deletions pytorch_lightning/trainer/connectors/callback_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,8 +238,7 @@ def _configure_progress_bar(
if refresh_rate == 0 or not enable_progress_bar:
return
if refresh_rate is None:
# smaller refresh rate on colab causes crashes, choose a higher value
refresh_rate = 20 if os.getenv("COLAB_GPU") else 1
refresh_rate = 1

progress_bar_callback = TQDMProgressBar(refresh_rate=refresh_rate, process_position=process_position)
self.trainer.callbacks.append(progress_bar_callback)
Expand Down
2 changes: 1 addition & 1 deletion tests/callbacks/test_tqdm_progress_bar.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ def test_tqdm_progress_bar_value_on_colab(tmpdir):
assert trainer.progress_bar_callback.refresh_rate == 20

trainer = Trainer(default_root_dir=tmpdir, callbacks=TQDMProgressBar())
assert trainer.progress_bar_callback.refresh_rate == 1 # FIXME: should be 20
assert trainer.progress_bar_callback.refresh_rate == 20

trainer = Trainer(default_root_dir=tmpdir, callbacks=TQDMProgressBar(refresh_rate=19))
assert trainer.progress_bar_callback.refresh_rate == 19
Expand Down

0 comments on commit 48cf1ad

Please sign in to comment.