diff --git a/CHANGELOG.md b/CHANGELOG.md index 3d183b0918e3f..d05e7431a1991 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -66,6 +66,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added a `LOGGER_REGISTRY` instance to register custom loggers to the `LightningCLI` ([#11533](https://github.com/PyTorchLightning/pytorch-lightning/pull/11533)) +- Added info message when the `Trainer` arguments `limit_*_batches`, `overfit_batches`, or `val_check_interval` are set to `1` or `1.0` ([#11950](https://github.com/PyTorchLightning/pytorch-lightning/pull/11950)) + - Added a `PrecisionPlugin.teardown` method ([#10990](https://github.com/PyTorchLightning/pytorch-lightning/pull/10990)) diff --git a/pytorch_lightning/loops/epoch/prediction_epoch_loop.py b/pytorch_lightning/loops/epoch/prediction_epoch_loop.py index 14ec50c489a86..277f5a6ad9f00 100644 --- a/pytorch_lightning/loops/epoch/prediction_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/prediction_epoch_loop.py @@ -164,7 +164,11 @@ def _get_batch_indices(self, dataloader_idx: int) -> List[List[int]]: """Returns a reference to the seen batch indices if the dataloader has a batch sampler wrapped by our :class:`~pytorch_lightning.overrides.distributed.IndexBatchSamplerWrapper`.""" # the batch_sampler is not be defined in case of CombinedDataLoaders - batch_sampler = getattr(self.trainer.predict_dataloaders[dataloader_idx], "batch_sampler", None) + batch_sampler = getattr( + self.trainer.predict_dataloaders[dataloader_idx], # type: ignore[has-type] + "batch_sampler", + None, + ) if isinstance(batch_sampler, IndexBatchSamplerWrapper) and self.should_store_predictions: return batch_sampler.seen_batch_indices diff --git a/pytorch_lightning/loops/fit_loop.py b/pytorch_lightning/loops/fit_loop.py index 8cbe4c167a29d..86c5d9ab83f0d 100644 --- a/pytorch_lightning/loops/fit_loop.py +++ b/pytorch_lightning/loops/fit_loop.py @@ -263,6 +263,7 @@ def advance(self) -> None: # type: ignore[override] log.detail(f"{self.__class__.__name__}: advancing loop") assert self.trainer.train_dataloader is not None dataloader = self.trainer.strategy.process_dataloader(self.trainer.train_dataloader) + assert self._data_fetcher is not None self._data_fetcher.setup( dataloader, batch_to_device=partial(self.trainer._call_strategy_hook, "batch_to_device", dataloader_idx=0) ) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 4b68adbdf31b7..87dc8b2875f20 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -157,11 +157,11 @@ def __init__( max_steps: int = -1, min_steps: Optional[int] = None, max_time: Optional[Union[str, timedelta, Dict[str, int]]] = None, - limit_train_batches: Union[int, float] = 1.0, - limit_val_batches: Union[int, float] = 1.0, - limit_test_batches: Union[int, float] = 1.0, - limit_predict_batches: Union[int, float] = 1.0, - val_check_interval: Union[int, float] = 1.0, + limit_train_batches: Optional[Union[int, float]] = None, + limit_val_batches: Optional[Union[int, float]] = None, + limit_test_batches: Optional[Union[int, float]] = None, + limit_predict_batches: Optional[Union[int, float]] = None, + val_check_interval: Optional[Union[int, float]] = None, flush_logs_every_n_steps: Optional[int] = None, log_every_n_steps: int = 50, accelerator: Optional[Union[str, Accelerator]] = None, @@ -512,6 +512,7 @@ def __init__( self._call_callback_hooks("on_init_start") # init data flags + self.check_val_every_n_epoch: int self._data_connector.on_trainer_init( check_val_every_n_epoch, reload_dataloaders_every_n_epochs, @@ -569,6 +570,7 @@ def __init__( self.logger_connector.on_trainer_init(logger, flush_logs_every_n_steps, log_every_n_steps, move_metrics_to_cpu) # init debugging flags + self.val_check_interval: Union[int, float] self._init_debugging_flags( limit_train_batches, limit_val_batches, @@ -584,14 +586,14 @@ def __init__( def _init_debugging_flags( self, - limit_train_batches, - limit_val_batches, - limit_test_batches, - limit_predict_batches, - val_check_interval, - overfit_batches, - fast_dev_run, - ): + limit_train_batches: Optional[Union[int, float]], + limit_val_batches: Optional[Union[int, float]], + limit_test_batches: Optional[Union[int, float]], + limit_predict_batches: Optional[Union[int, float]], + val_check_interval: Optional[Union[int, float]], + overfit_batches: Union[int, float], + fast_dev_run: Union[int, bool], + ) -> None: if isinstance(fast_dev_run, int) and (fast_dev_run < 0): raise MisconfigurationException( f"fast_dev_run={fast_dev_run} is not a valid configuration. It should be >= 0." @@ -2640,7 +2642,30 @@ def terminate_on_nan(self, val: bool) -> None: self._terminate_on_nan = val # : 212 -def _determine_batch_limits(batches: Union[int, float], name: str) -> Union[int, float]: +def _determine_batch_limits(batches: Optional[Union[int, float]], name: str) -> Union[int, float]: + if batches is None: + # batches is optional to know if the user passed a value so that we can show the above info messages only to the + # users that set a value explicitly + return 1.0 + + # differentiating based on the type can be error-prone for users. show a message describing the chosen behaviour + if isinstance(batches, int) and batches == 1: + if name == "limit_train_batches": + message = "1 batch per epoch will be used." + elif name == "val_check_interval": + message = "validation will run after every batch." + else: + message = "1 batch will be used." + rank_zero_info(f"`Trainer({name}=1)` was configured so {message}") + elif isinstance(batches, float) and batches == 1.0: + if name == "limit_train_batches": + message = "100% of the batches per epoch will be used." + elif name == "val_check_interval": + message = "validation will run at the end of the training epoch." + else: + message = "100% of the batches will be used." + rank_zero_info(f"`Trainer({name}=1.0)` was configured so {message}.") + if 0 <= batches <= 1: return batches if batches > 1 and batches % 1.0 == 0: diff --git a/tests/trainer/flags/test_limit_batches.py b/tests/trainer/flags/test_limit_batches.py index 92febaeb7f83b..f46eb3acb10ff 100644 --- a/tests/trainer/flags/test_limit_batches.py +++ b/tests/trainer/flags/test_limit_batches.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import logging import pytest @@ -66,3 +67,23 @@ def test_eval_limit_batches(stage, mode, limit_batches): expected_batches = int(limit_batches * len(eval_loader)) if isinstance(limit_batches, float) else limit_batches assert loader_num_batches[0] == expected_batches assert len(dataloaders[0]) == len(eval_loader) + + +@pytest.mark.parametrize( + "argument", + ("limit_train_batches", "limit_val_batches", "limit_test_batches", "limit_predict_batches", "overfit_batches"), +) +@pytest.mark.parametrize("value", (1, 1.0)) +def test_limit_batches_info_message(caplog, argument, value): + with caplog.at_level(logging.INFO): + Trainer(**{argument: value}) + assert f"`Trainer({argument}={value})` was configured" in caplog.text + message = f"configured so {'1' if isinstance(value, int) else '100%'}" + assert message in caplog.text + + caplog.clear() + + # the message should not appear by default + with caplog.at_level(logging.INFO): + Trainer() + assert message not in caplog.text diff --git a/tests/trainer/flags/test_val_check_interval.py b/tests/trainer/flags/test_val_check_interval.py index 8bc6ef774f4a1..685e104805daa 100644 --- a/tests/trainer/flags/test_val_check_interval.py +++ b/tests/trainer/flags/test_val_check_interval.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import logging + import pytest from pytorch_lightning.trainer import Trainer @@ -39,3 +41,19 @@ def on_validation_epoch_start(self) -> None: assert model.train_epoch_calls == max_epochs assert model.val_epoch_calls == max_epochs * denominator + + +@pytest.mark.parametrize("value", (1, 1.0)) +def test_val_check_interval_info_message(caplog, value): + with caplog.at_level(logging.INFO): + Trainer(val_check_interval=value) + assert f"`Trainer(val_check_interval={value})` was configured" in caplog.text + message = "configured so validation will run" + assert message in caplog.text + + caplog.clear() + + # the message should not appear by default + with caplog.at_level(logging.INFO): + Trainer() + assert message not in caplog.text