Skip to content

Commit

Permalink
improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
rohitgr7 committed Jan 19, 2022
1 parent 5e2b98c commit 0e0ebd8
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 10 deletions.
2 changes: 1 addition & 1 deletion _notebooks
4 changes: 2 additions & 2 deletions docs/source/common/test_set.rst
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ To run the test set after training completes, use this method.
.. warning::

It is recommended to test on single device since distributed strategies such as DDP
It is recommended to test with ``Trainer(devices=1)`` since distributed strategies such as DDP
uses :class:`~torch.utils.data.distributed.DistributedSampler` internally, which replicates some samples to
make sure all devices have same batch size in case of uneven inputs. This is helpful to make sure
benchmarking for research papers is done the right way.
Expand Down Expand Up @@ -146,7 +146,7 @@ Apart from this ``.validate`` has same API as ``.test``, but would rely respecti

.. warning::

When using ``trainer.validate()``, it is recommended to use a single device since distributed strategies such as DDP
When using ``trainer.validate()``, it is recommended to use ``Trainer(devices=1)`` since distributed strategies such as DDP
uses :class:`~torch.utils.data.distributed.DistributedSampler` internally, which replicates some samples to
make sure all devices have same batch size in case of uneven inputs. This is helpful to make sure
benchmarking for research papers is done the right way.
Expand Down
7 changes: 4 additions & 3 deletions pytorch_lightning/trainer/connectors/data_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,13 +390,14 @@ def _resolve_sampler(self, dataloader: DataLoader, shuffle: bool, mode: Optional
**self.trainer.distributed_sampler_kwargs,
)

# update docs too once this is resolved
trainer_fn = self.trainer.state.fn
if isinstance(sampler, DistributedSampler) and trainer_fn in (TrainerFn.VALIDATING, TrainerFn.TESTING):
rank_zero_warn(
f"Using `DistributedSampler` with the dataloaders. During trainer.{trainer_fn.value}(),"
" it is recommended to use single device strategy to ensure each sample/batch gets evaluated"
" exactly once. Otherwise multi-device setting uses `DistributedSampler` replicates some samples"
" to make sure all devices have same batch size in case of uneven inputs."
" it is recommended to use `Trainer(devices=1)` to ensure each sample/batch gets evaluated"
" exactly once. Otherwise, multi-device settings use `DistributedSampler` that replicates"
" some samples to make sure all devices have same batch size in case of uneven inputs."
)

return sampler
Expand Down
8 changes: 4 additions & 4 deletions tests/trainer/connectors/test_data_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,10 @@ def test_eval_distributed_sampler_warning(tmpdir):
trainer = Trainer(strategy="ddp", devices=2, accelerator="cpu", fast_dev_run=True)
trainer._data_connector.attach_data(model)

with pytest.warns(UserWarning, match="use single device strategy to ensure each sample"):
trainer.state.fn = TrainerFn.VALIDATING
trainer.state.fn = TrainerFn.VALIDATING
with pytest.warns(UserWarning, match="multi-device settings use `DistributedSampler`"):
trainer.reset_val_dataloader(model)

with pytest.warns(UserWarning, match="use single device strategy to ensure each sample"):
trainer.state.fn = TrainerFn.TESTING
trainer.state.fn = TrainerFn.TESTING
with pytest.warns(UserWarning, match="multi-device settings use `DistributedSampler`"):
trainer.reset_test_dataloader(model)

0 comments on commit 0e0ebd8

Please sign in to comment.