From 3dea0fa586fee33edf885857bad33c288924701b Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 16 Feb 2022 22:35:36 +0100 Subject: [PATCH 1/5] Describe behavior with --- .../loops/epoch/prediction_epoch_loop.py | 6 ++- pytorch_lightning/loops/fit_loop.py | 1 + pytorch_lightning/trainer/trainer.py | 42 ++++++++++++------- tests/trainer/flags/test_limit_batches.py | 21 ++++++++++ 4 files changed, 53 insertions(+), 17 deletions(-) diff --git a/pytorch_lightning/loops/epoch/prediction_epoch_loop.py b/pytorch_lightning/loops/epoch/prediction_epoch_loop.py index 14ec50c489a86..277f5a6ad9f00 100644 --- a/pytorch_lightning/loops/epoch/prediction_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/prediction_epoch_loop.py @@ -164,7 +164,11 @@ def _get_batch_indices(self, dataloader_idx: int) -> List[List[int]]: """Returns a reference to the seen batch indices if the dataloader has a batch sampler wrapped by our :class:`~pytorch_lightning.overrides.distributed.IndexBatchSamplerWrapper`.""" # the batch_sampler is not be defined in case of CombinedDataLoaders - batch_sampler = getattr(self.trainer.predict_dataloaders[dataloader_idx], "batch_sampler", None) + batch_sampler = getattr( + self.trainer.predict_dataloaders[dataloader_idx], # type: ignore[has-type] + "batch_sampler", + None, + ) if isinstance(batch_sampler, IndexBatchSamplerWrapper) and self.should_store_predictions: return batch_sampler.seen_batch_indices diff --git a/pytorch_lightning/loops/fit_loop.py b/pytorch_lightning/loops/fit_loop.py index 8cbe4c167a29d..86c5d9ab83f0d 100644 --- a/pytorch_lightning/loops/fit_loop.py +++ b/pytorch_lightning/loops/fit_loop.py @@ -263,6 +263,7 @@ def advance(self) -> None: # type: ignore[override] log.detail(f"{self.__class__.__name__}: advancing loop") assert self.trainer.train_dataloader is not None dataloader = self.trainer.strategy.process_dataloader(self.trainer.train_dataloader) + assert self._data_fetcher is not None self._data_fetcher.setup( dataloader, batch_to_device=partial(self.trainer._call_strategy_hook, "batch_to_device", dataloader_idx=0) ) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 4b68adbdf31b7..6a65e5b1ae6ea 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -147,7 +147,7 @@ def __init__( log_gpu_memory: Optional[str] = None, # TODO: Remove in 1.7 progress_bar_refresh_rate: Optional[int] = None, # TODO: remove in v1.7 enable_progress_bar: bool = True, - overfit_batches: Union[int, float] = 0.0, + overfit_batches: Optional[Union[int, float]] = None, track_grad_norm: Union[int, float, str] = -1, check_val_every_n_epoch: int = 1, fast_dev_run: Union[int, bool] = False, @@ -157,11 +157,11 @@ def __init__( max_steps: int = -1, min_steps: Optional[int] = None, max_time: Optional[Union[str, timedelta, Dict[str, int]]] = None, - limit_train_batches: Union[int, float] = 1.0, - limit_val_batches: Union[int, float] = 1.0, - limit_test_batches: Union[int, float] = 1.0, - limit_predict_batches: Union[int, float] = 1.0, - val_check_interval: Union[int, float] = 1.0, + limit_train_batches: Optional[Union[int, float]] = None, + limit_val_batches: Optional[Union[int, float]] = None, + limit_test_batches: Optional[Union[int, float]] = None, + limit_predict_batches: Optional[Union[int, float]] = None, + val_check_interval: Optional[Union[int, float]] = None, flush_logs_every_n_steps: Optional[int] = None, log_every_n_steps: int = 50, accelerator: Optional[Union[str, Accelerator]] = None, @@ -512,6 +512,7 @@ def __init__( self._call_callback_hooks("on_init_start") # init data flags + self.check_val_every_n_epoch: int self._data_connector.on_trainer_init( check_val_every_n_epoch, reload_dataloaders_every_n_epochs, @@ -569,6 +570,7 @@ def __init__( self.logger_connector.on_trainer_init(logger, flush_logs_every_n_steps, log_every_n_steps, move_metrics_to_cpu) # init debugging flags + self.val_check_interval: Union[int, float] self._init_debugging_flags( limit_train_batches, limit_val_batches, @@ -584,14 +586,14 @@ def __init__( def _init_debugging_flags( self, - limit_train_batches, - limit_val_batches, - limit_test_batches, - limit_predict_batches, - val_check_interval, - overfit_batches, - fast_dev_run, - ): + limit_train_batches: Optional[Union[int, float]], + limit_val_batches: Optional[Union[int, float]], + limit_test_batches: Optional[Union[int, float]], + limit_predict_batches: Optional[Union[int, float]], + val_check_interval: Optional[Union[int, float]], + overfit_batches: Optional[Union[int, float]], + fast_dev_run: Union[int, bool], + ) -> None: if isinstance(fast_dev_run, int) and (fast_dev_run < 0): raise MisconfigurationException( f"fast_dev_run={fast_dev_run} is not a valid configuration. It should be >= 0." @@ -2640,10 +2642,18 @@ def terminate_on_nan(self, val: bool) -> None: self._terminate_on_nan = val # : 212 -def _determine_batch_limits(batches: Union[int, float], name: str) -> Union[int, float]: +def _determine_batch_limits(batches: Optional[Union[int, float]], name: str) -> Union[int, float]: + if isinstance(batches, int) and batches == 1: + log.info(f"`Trainer({name}=1)` was configured so 1 batch per epoch will be used.") + elif isinstance(batches, float) and batches == 1.0: + log.info(f"`Trainer({name}=1.0)` was configured so 100% of the batches per epoch will be used.") + elif batches is None: + # batches is optional to know if the user passed a value so that we can show the above info messages only to the + # users that set a value explicitly + return 1.0 if 0 <= batches <= 1: return batches - if batches > 1 and batches % 1.0 == 0: + if batches > 1: return int(batches) raise MisconfigurationException( f"You have passed invalid value {batches} for {name}, it has to be in [0.0, 1.0] or an int." diff --git a/tests/trainer/flags/test_limit_batches.py b/tests/trainer/flags/test_limit_batches.py index 92febaeb7f83b..3b11dc3ee7ada 100644 --- a/tests/trainer/flags/test_limit_batches.py +++ b/tests/trainer/flags/test_limit_batches.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import logging import pytest @@ -66,3 +67,23 @@ def test_eval_limit_batches(stage, mode, limit_batches): expected_batches = int(limit_batches * len(eval_loader)) if isinstance(limit_batches, float) else limit_batches assert loader_num_batches[0] == expected_batches assert len(dataloaders[0]) == len(eval_loader) + + +@pytest.mark.parametrize( + "flag", + ("limit_train_batches", "limit_test_batches", "limit_predict_batches", "overfit_batches", "val_check_interval"), +) +@pytest.mark.parametrize("value", (1, 1.0)) +def test_limit_batches_info_message(caplog, flag, value): + with caplog.at_level(logging.INFO, logger="pytorch_lightning.trainer.trainer"): + Trainer(**{flag: value}) + assert f"`Trainer({flag}={value})` was configured" in caplog.text + message = f"configured so {'1' if isinstance(value, int) else '100%'}" + assert message in caplog.text + + caplog.clear() + + # the message should not appear by default + with caplog.at_level(logging.INFO, logger="pytorch_lightning.trainer.trainer"): + Trainer() + assert message not in caplog.text From 0b4ef025ac46ac1ecf0ad71a7317b87af91b6efe Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 16 Feb 2022 22:41:58 +0100 Subject: [PATCH 2/5] CHANGELOG --- CHANGELOG.md | 2 ++ pytorch_lightning/trainer/trainer.py | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 3d183b0918e3f..d05e7431a1991 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -66,6 +66,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added a `LOGGER_REGISTRY` instance to register custom loggers to the `LightningCLI` ([#11533](https://github.com/PyTorchLightning/pytorch-lightning/pull/11533)) +- Added info message when the `Trainer` arguments `limit_*_batches`, `overfit_batches`, or `val_check_interval` are set to `1` or `1.0` ([#11950](https://github.com/PyTorchLightning/pytorch-lightning/pull/11950)) + - Added a `PrecisionPlugin.teardown` method ([#10990](https://github.com/PyTorchLightning/pytorch-lightning/pull/10990)) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 6a65e5b1ae6ea..04eb02006db44 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -2653,7 +2653,7 @@ def _determine_batch_limits(batches: Optional[Union[int, float]], name: str) -> return 1.0 if 0 <= batches <= 1: return batches - if batches > 1: + if batches > 1 and batches % 1.0 == 0: return int(batches) raise MisconfigurationException( f"You have passed invalid value {batches} for {name}, it has to be in [0.0, 1.0] or an int." From 179a2237642d4724d4d93058b3ae47d44914c6ea Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 17 Feb 2022 01:26:34 +0100 Subject: [PATCH 3/5] Undo overfit_batches change --- pytorch_lightning/trainer/trainer.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 04eb02006db44..303435a96b0c0 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -147,7 +147,7 @@ def __init__( log_gpu_memory: Optional[str] = None, # TODO: Remove in 1.7 progress_bar_refresh_rate: Optional[int] = None, # TODO: remove in v1.7 enable_progress_bar: bool = True, - overfit_batches: Optional[Union[int, float]] = None, + overfit_batches: Union[int, float] = 0.0, track_grad_norm: Union[int, float, str] = -1, check_val_every_n_epoch: int = 1, fast_dev_run: Union[int, bool] = False, @@ -591,7 +591,7 @@ def _init_debugging_flags( limit_test_batches: Optional[Union[int, float]], limit_predict_batches: Optional[Union[int, float]], val_check_interval: Optional[Union[int, float]], - overfit_batches: Optional[Union[int, float]], + overfit_batches: Union[int, float], fast_dev_run: Union[int, bool], ) -> None: if isinstance(fast_dev_run, int) and (fast_dev_run < 0): @@ -2643,14 +2643,14 @@ def terminate_on_nan(self, val: bool) -> None: def _determine_batch_limits(batches: Optional[Union[int, float]], name: str) -> Union[int, float]: + if batches is None: + # batches is optional to know if the user passed a value so that we can show the above info messages only to the + # users that set a value explicitly + return 1.0 if isinstance(batches, int) and batches == 1: log.info(f"`Trainer({name}=1)` was configured so 1 batch per epoch will be used.") elif isinstance(batches, float) and batches == 1.0: log.info(f"`Trainer({name}=1.0)` was configured so 100% of the batches per epoch will be used.") - elif batches is None: - # batches is optional to know if the user passed a value so that we can show the above info messages only to the - # users that set a value explicitly - return 1.0 if 0 <= batches <= 1: return batches if batches > 1 and batches % 1.0 == 0: From e4f47cf10db9e5f2f19b503a23beae8d51ce44d7 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 17 Feb 2022 17:29:21 +0100 Subject: [PATCH 4/5] Rohit's review --- pytorch_lightning/trainer/trainer.py | 19 +++++++++++++++++-- tests/trainer/flags/test_limit_batches.py | 10 +++++----- .../trainer/flags/test_val_check_interval.py | 18 ++++++++++++++++++ 3 files changed, 40 insertions(+), 7 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 303435a96b0c0..7f7e10f9385eb 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -2647,10 +2647,25 @@ def _determine_batch_limits(batches: Optional[Union[int, float]], name: str) -> # batches is optional to know if the user passed a value so that we can show the above info messages only to the # users that set a value explicitly return 1.0 + + # differentiating based on the type can be error-prone for users. show a message describing the chosen behaviour if isinstance(batches, int) and batches == 1: - log.info(f"`Trainer({name}=1)` was configured so 1 batch per epoch will be used.") + if name == "limit_train_batches": + message = "1 batch per epoch will be used." + elif name == "val_check_interval": + message = "validation will run after every batch." + else: + message = "1 batch will be used." + log.info(f"`Trainer({name}=1)` was configured so {message}") elif isinstance(batches, float) and batches == 1.0: - log.info(f"`Trainer({name}=1.0)` was configured so 100% of the batches per epoch will be used.") + if name == "limit_train_batches": + message = "100% of the batches per epoch will be used." + elif name == "val_check_interval": + message = "validation will run at the end of the training epoch." + else: + message = "100% of the batches will be used." + log.info(f"`Trainer({name}=1.0)` was configured so {message}.") + if 0 <= batches <= 1: return batches if batches > 1 and batches % 1.0 == 0: diff --git a/tests/trainer/flags/test_limit_batches.py b/tests/trainer/flags/test_limit_batches.py index 3b11dc3ee7ada..2c9eb1814a280 100644 --- a/tests/trainer/flags/test_limit_batches.py +++ b/tests/trainer/flags/test_limit_batches.py @@ -70,14 +70,14 @@ def test_eval_limit_batches(stage, mode, limit_batches): @pytest.mark.parametrize( - "flag", - ("limit_train_batches", "limit_test_batches", "limit_predict_batches", "overfit_batches", "val_check_interval"), + "argument", + ("limit_train_batches", "limit_val_batches", "limit_test_batches", "limit_predict_batches", "overfit_batches"), ) @pytest.mark.parametrize("value", (1, 1.0)) -def test_limit_batches_info_message(caplog, flag, value): +def test_limit_batches_info_message(caplog, argument, value): with caplog.at_level(logging.INFO, logger="pytorch_lightning.trainer.trainer"): - Trainer(**{flag: value}) - assert f"`Trainer({flag}={value})` was configured" in caplog.text + Trainer(**{argument: value}) + assert f"`Trainer({argument}={value})` was configured" in caplog.text message = f"configured so {'1' if isinstance(value, int) else '100%'}" assert message in caplog.text diff --git a/tests/trainer/flags/test_val_check_interval.py b/tests/trainer/flags/test_val_check_interval.py index 8bc6ef774f4a1..219a6cf65ab1e 100644 --- a/tests/trainer/flags/test_val_check_interval.py +++ b/tests/trainer/flags/test_val_check_interval.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import logging + import pytest from pytorch_lightning.trainer import Trainer @@ -39,3 +41,19 @@ def on_validation_epoch_start(self) -> None: assert model.train_epoch_calls == max_epochs assert model.val_epoch_calls == max_epochs * denominator + + +@pytest.mark.parametrize("value", (1, 1.0)) +def test_val_check_interval_info_message(caplog, value): + with caplog.at_level(logging.INFO, logger="pytorch_lightning.trainer.trainer"): + Trainer(val_check_interval=value) + assert f"`Trainer(val_check_interval={value})` was configured" in caplog.text + message = "configured so validation will run" + assert message in caplog.text + + caplog.clear() + + # the message should not appear by default + with caplog.at_level(logging.INFO, logger="pytorch_lightning.trainer.trainer"): + Trainer() + assert message not in caplog.text From 0a7ce3c94bffd4ad4d0f550c4b013a70e49afe2d Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 17 Feb 2022 23:09:12 +0100 Subject: [PATCH 5/5] Rank zero info --- pytorch_lightning/trainer/trainer.py | 4 ++-- tests/trainer/flags/test_limit_batches.py | 4 ++-- tests/trainer/flags/test_val_check_interval.py | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 7f7e10f9385eb..87dc8b2875f20 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -2656,7 +2656,7 @@ def _determine_batch_limits(batches: Optional[Union[int, float]], name: str) -> message = "validation will run after every batch." else: message = "1 batch will be used." - log.info(f"`Trainer({name}=1)` was configured so {message}") + rank_zero_info(f"`Trainer({name}=1)` was configured so {message}") elif isinstance(batches, float) and batches == 1.0: if name == "limit_train_batches": message = "100% of the batches per epoch will be used." @@ -2664,7 +2664,7 @@ def _determine_batch_limits(batches: Optional[Union[int, float]], name: str) -> message = "validation will run at the end of the training epoch." else: message = "100% of the batches will be used." - log.info(f"`Trainer({name}=1.0)` was configured so {message}.") + rank_zero_info(f"`Trainer({name}=1.0)` was configured so {message}.") if 0 <= batches <= 1: return batches diff --git a/tests/trainer/flags/test_limit_batches.py b/tests/trainer/flags/test_limit_batches.py index 2c9eb1814a280..f46eb3acb10ff 100644 --- a/tests/trainer/flags/test_limit_batches.py +++ b/tests/trainer/flags/test_limit_batches.py @@ -75,7 +75,7 @@ def test_eval_limit_batches(stage, mode, limit_batches): ) @pytest.mark.parametrize("value", (1, 1.0)) def test_limit_batches_info_message(caplog, argument, value): - with caplog.at_level(logging.INFO, logger="pytorch_lightning.trainer.trainer"): + with caplog.at_level(logging.INFO): Trainer(**{argument: value}) assert f"`Trainer({argument}={value})` was configured" in caplog.text message = f"configured so {'1' if isinstance(value, int) else '100%'}" @@ -84,6 +84,6 @@ def test_limit_batches_info_message(caplog, argument, value): caplog.clear() # the message should not appear by default - with caplog.at_level(logging.INFO, logger="pytorch_lightning.trainer.trainer"): + with caplog.at_level(logging.INFO): Trainer() assert message not in caplog.text diff --git a/tests/trainer/flags/test_val_check_interval.py b/tests/trainer/flags/test_val_check_interval.py index 219a6cf65ab1e..685e104805daa 100644 --- a/tests/trainer/flags/test_val_check_interval.py +++ b/tests/trainer/flags/test_val_check_interval.py @@ -45,7 +45,7 @@ def on_validation_epoch_start(self) -> None: @pytest.mark.parametrize("value", (1, 1.0)) def test_val_check_interval_info_message(caplog, value): - with caplog.at_level(logging.INFO, logger="pytorch_lightning.trainer.trainer"): + with caplog.at_level(logging.INFO): Trainer(val_check_interval=value) assert f"`Trainer(val_check_interval={value})` was configured" in caplog.text message = "configured so validation will run" @@ -54,6 +54,6 @@ def test_val_check_interval_info_message(caplog, value): caplog.clear() # the message should not appear by default - with caplog.at_level(logging.INFO, logger="pytorch_lightning.trainer.trainer"): + with caplog.at_level(logging.INFO): Trainer() assert message not in caplog.text