From 108fcfa6175a5bf5c0140cf52bfa57bc15adfc11 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Wed, 19 Jan 2022 16:28:16 +0530 Subject: [PATCH] improvements --- pytorch_lightning/trainer/connectors/data_connector.py | 3 ++- tests/trainer/connectors/test_data_connector.py | 5 +++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/data_connector.py b/pytorch_lightning/trainer/connectors/data_connector.py index 96d329261ef85e..8d0c7565897f49 100644 --- a/pytorch_lightning/trainer/connectors/data_connector.py +++ b/pytorch_lightning/trainer/connectors/data_connector.py @@ -397,7 +397,8 @@ def _resolve_sampler(self, dataloader: DataLoader, shuffle: bool, mode: Optional f"Using `DistributedSampler` with the dataloaders. During trainer.{trainer_fn.value}()," " it is recommended to use `Trainer(devices=1)` to ensure each sample/batch gets evaluated" " exactly once. Otherwise, multi-device settings use `DistributedSampler` that replicates" - " some samples to make sure all devices have same batch size in case of uneven inputs." + " some samples to make sure all devices have same batch size in case of uneven inputs.", + category=PossibleUserWarning, ) return sampler diff --git a/tests/trainer/connectors/test_data_connector.py b/tests/trainer/connectors/test_data_connector.py index 88a2d90d8154e3..a02a6119a2a18c 100644 --- a/tests/trainer/connectors/test_data_connector.py +++ b/tests/trainer/connectors/test_data_connector.py @@ -19,6 +19,7 @@ from pytorch_lightning import Trainer from pytorch_lightning.trainer.connectors.data_connector import _DataLoaderSource from pytorch_lightning.trainer.states import TrainerFn +from pytorch_lightning.utilities.warnings import PossibleUserWarning from tests.helpers import BoringDataModule, BoringModel @@ -77,9 +78,9 @@ def test_eval_distributed_sampler_warning(tmpdir): trainer._data_connector.attach_data(model) trainer.state.fn = TrainerFn.VALIDATING - with pytest.warns(UserWarning, match="multi-device settings use `DistributedSampler`"): + with pytest.warns(PossibleUserWarning, match="multi-device settings use `DistributedSampler`"): trainer.reset_val_dataloader(model) trainer.state.fn = TrainerFn.TESTING - with pytest.warns(UserWarning, match="multi-device settings use `DistributedSampler`"): + with pytest.warns(PossibleUserWarning, match="multi-device settings use `DistributedSampler`"): trainer.reset_test_dataloader(model)