Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

added warning for distributedsampler in case of evaluation #11479

Merged
merged 15 commits into from Feb 3, 2022
3 changes: 3 additions & 0 deletions CHANGELOG.md
Expand Up @@ -72,6 +72,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added a `MisconfigurationException` if user provided `opt_idx` in scheduler config doesn't match with actual optimizer index of its respective optimizer ([#11247](https://github.com/PyTorchLightning/pytorch-lightning/issues/11247))


- Added a warning when using `DistributedSampler` during validation/testing ([#11479](https://github.com/PyTorchLightning/pytorch-lightning/pull/11479))


### Changed

- Raised exception in `init_dist_connection()` when torch distibuted is not available ([#10418](https://github.com/PyTorchLightning/pytorch-lightning/issues/10418))
Expand Down
11 changes: 9 additions & 2 deletions docs/source/common/test_set.rst
Expand Up @@ -50,8 +50,8 @@ To run the test set after training completes, use this method.

.. warning::

It is recommended to test on single device since Distributed Training such as DDP internally
uses :class:`~torch.utils.data.distributed.DistributedSampler` which replicates some samples to
It is recommended to test with ``Trainer(devices=1)`` since distributed strategies such as DDP
use :class:`~torch.utils.data.distributed.DistributedSampler` internally, which replicates some samples to
make sure all devices have same batch size in case of uneven inputs. This is helpful to make sure
benchmarking for research papers is done the right way.

Expand Down Expand Up @@ -144,5 +144,12 @@ Apart from this ``.validate`` has same API as ``.test``, but would rely respecti
``.validate`` method uses the same validation logic being used under validation happening within
:meth:`~pytorch_lightning.trainer.trainer.Trainer.fit` call.

.. warning::

When using ``trainer.validate()``, it is recommended to use ``Trainer(devices=1)`` since distributed strategies such as DDP
uses :class:`~torch.utils.data.distributed.DistributedSampler` internally, which replicates some samples to
make sure all devices have same batch size in case of uneven inputs. This is helpful to make sure
benchmarking for research papers is done the right way.

.. automethod:: pytorch_lightning.trainer.Trainer.validate
:noindex:
17 changes: 15 additions & 2 deletions pytorch_lightning/trainer/connectors/data_connector.py
Expand Up @@ -24,7 +24,7 @@

import pytorch_lightning as pl
from pytorch_lightning.overrides.distributed import UnrepeatedDistributedSampler
from pytorch_lightning.trainer.states import RunningStage
from pytorch_lightning.trainer.states import RunningStage, TrainerFn
from pytorch_lightning.trainer.supporters import CombinedLoader, CycleIterator
from pytorch_lightning.utilities import rank_zero_deprecation
from pytorch_lightning.utilities.apply_func import apply_to_collection
Expand Down Expand Up @@ -382,14 +382,27 @@ def _resolve_sampler(self, dataloader: DataLoader, shuffle: bool, mode: Optional
" distributed training. Either remove the sampler from your DataLoader or set"
" `replace_sampler_ddp=False` if you want to use your custom sampler."
)
return self._get_distributed_sampler(
sampler = self._get_distributed_sampler(
dataloader,
shuffle,
mode=mode,
overfit_batches=self.trainer.overfit_batches,
**self.trainer.distributed_sampler_kwargs,
)

# update docs too once this is resolved
trainer_fn = self.trainer.state.fn
if isinstance(sampler, DistributedSampler) and trainer_fn in (TrainerFn.VALIDATING, TrainerFn.TESTING):
carmocca marked this conversation as resolved.
Show resolved Hide resolved
rank_zero_warn(
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
tchaton marked this conversation as resolved.
Show resolved Hide resolved
f"Using `DistributedSampler` with the dataloaders. During `trainer.{trainer_fn.value}()`,"
" it is recommended to use `Trainer(devices=1)` to ensure each sample/batch gets evaluated"
" exactly once. Otherwise, multi-device settings use `DistributedSampler` that replicates"
" some samples to make sure all devices have same batch size in case of uneven inputs.",
category=PossibleUserWarning,
)
carmocca marked this conversation as resolved.
Show resolved Hide resolved

return sampler

return dataloader.sampler

@staticmethod
Expand Down
18 changes: 18 additions & 0 deletions tests/trainer/connectors/test_data_connector.py
Expand Up @@ -18,6 +18,8 @@

from pytorch_lightning import Trainer
from pytorch_lightning.trainer.connectors.data_connector import _DataLoaderSource
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities.warnings import PossibleUserWarning
from tests.helpers import BoringDataModule, BoringModel


Expand Down Expand Up @@ -66,3 +68,19 @@ def test_dataloader_source_request_from_module():
module.foo.assert_not_called()
assert isinstance(source.dataloader(), DataLoader)
module.foo.assert_called_once()


def test_eval_distributed_sampler_warning(tmpdir):
"""Test that a warning is raised when `DistributedSampler` is used with evaluation."""

model = BoringModel()
trainer = Trainer(strategy="ddp", devices=2, accelerator="cpu", fast_dev_run=True)
trainer._data_connector.attach_data(model)

trainer.state.fn = TrainerFn.VALIDATING
with pytest.warns(PossibleUserWarning, match="multi-device settings use `DistributedSampler`"):
trainer.reset_val_dataloader(model)

trainer.state.fn = TrainerFn.TESTING
with pytest.warns(PossibleUserWarning, match="multi-device settings use `DistributedSampler`"):
trainer.reset_test_dataloader(model)