From f4b1fa114140047dd6d7986295fb8509d812b052 Mon Sep 17 00:00:00 2001 From: Kushashwa Ravi Shrimali Date: Mon, 14 Feb 2022 15:41:22 +0530 Subject: [PATCH 01/17] Support gradient accumulation using Horovod's backward_passes_per_step --- CHANGELOG.md | 3 +++ pytorch_lightning/strategies/horovod.py | 18 +++++++++++++++++- 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 2782a4cb1d9f1..54f0c59277054 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added +- Enable gradient accumulation using Horovod's `backward_passes_per_step` + + - Add new `DETAIL` log level to provide useful logs for improving monitoring and debugging of batch jobs diff --git a/pytorch_lightning/strategies/horovod.py b/pytorch_lightning/strategies/horovod.py index 3bd81fd754585..e9d31e2292fb5 100644 --- a/pytorch_lightning/strategies/horovod.py +++ b/pytorch_lightning/strategies/horovod.py @@ -13,6 +13,7 @@ # limitations under the License. from contextlib import ExitStack from typing import Any, List, Optional, Tuple, Union +from pytorch_lightning.utilities.exceptions import MisconfigurationException import torch import torch.nn as nn @@ -77,6 +78,11 @@ def distributed_sampler_kwargs(self): distributed_sampler_kwargs = dict(num_replicas=self.world_size, rank=self.global_rank) return distributed_sampler_kwargs + @property + def handles_gradient_accumulation(self) -> bool: + """Whether the plugin handles gradient accumulation internally.""" + return True + def setup(self, trainer: "pl.Trainer") -> None: self.model_to_device() @@ -112,6 +118,12 @@ def _unpack_lightning_optimizer(opt): for optimizer in optimizers: hvd.broadcast_optimizer_state(optimizer, root_rank=0) + accumulation_scheduler = self.lightning_module.trainer.accumulation_scheduler + if accumulation_scheduler.epochs != [0]: + raise MisconfigurationException( + "Horovod currently does not support different `accumulate_grad_batches` at different epochs." + ) + self.optimizers = self._wrap_optimizers(optimizers) for optimizer in self.optimizers: # Synchronization will be performed explicitly following backward() @@ -184,8 +196,12 @@ def post_backward(self, closure_loss: torch.Tensor) -> None: def _wrap_optimizers(self, optimizers: List[Optimizer]) -> List["hvd.DistributedOptimizer"]: """Wraps optimizers to perform gradient aggregation via allreduce.""" + accumulate_grad_batches = self.trainer.accumulate_grad_batches return [ - hvd.DistributedOptimizer(opt, named_parameters=self._filter_named_parameters(self.lightning_module, opt)) + hvd.DistributedOptimizer( + opt, + backward_passes_per_step=accumulate_grad_batches, + named_parameters=self._filter_named_parameters(self.lightning_module, opt)) if "horovod" not in str(opt.__class__) else opt for opt in optimizers From aeb85ca832ecc119b936b4f04ad13cd843870ac6 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 14 Feb 2022 10:18:44 +0000 Subject: [PATCH 02/17] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/strategies/horovod.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/strategies/horovod.py b/pytorch_lightning/strategies/horovod.py index e9d31e2292fb5..f334fda74e6fb 100644 --- a/pytorch_lightning/strategies/horovod.py +++ b/pytorch_lightning/strategies/horovod.py @@ -13,7 +13,6 @@ # limitations under the License. from contextlib import ExitStack from typing import Any, List, Optional, Tuple, Union -from pytorch_lightning.utilities.exceptions import MisconfigurationException import torch import torch.nn as nn @@ -28,6 +27,7 @@ from pytorch_lightning.utilities.distributed import group as dist_group from pytorch_lightning.utilities.distributed import ReduceOp from pytorch_lightning.utilities.enums import _StrategyType +from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _HOROVOD_AVAILABLE from pytorch_lightning.utilities.rank_zero import rank_zero_only @@ -201,7 +201,8 @@ def _wrap_optimizers(self, optimizers: List[Optimizer]) -> List["hvd.Distributed hvd.DistributedOptimizer( opt, backward_passes_per_step=accumulate_grad_batches, - named_parameters=self._filter_named_parameters(self.lightning_module, opt)) + named_parameters=self._filter_named_parameters(self.lightning_module, opt), + ) if "horovod" not in str(opt.__class__) else opt for opt in optimizers From 0e1b93a4ea1375e2ec3b7fa3dbb4add6b368b816 Mon Sep 17 00:00:00 2001 From: Kushashwa Ravi Shrimali Date: Mon, 14 Feb 2022 16:20:10 +0530 Subject: [PATCH 03/17] Update CHANGELOG.md Co-authored-by: Rohit Gupta --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 54f0c59277054..a64aa0a3386dc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,7 +9,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added -- Enable gradient accumulation using Horovod's `backward_passes_per_step` +- Enable gradient accumulation using Horovod's `backward_passes_per_step` ([#11911](https://github.com/PyTorchLightning/pytorch-lightning/pull/11911)) - Add new `DETAIL` log level to provide useful logs for improving monitoring and debugging of batch jobs From 7db02cae9395cea151fa9f68a263e1ec3c854c9e Mon Sep 17 00:00:00 2001 From: Kushashwa Ravi Shrimali Date: Mon, 14 Feb 2022 16:20:16 +0530 Subject: [PATCH 04/17] Update CHANGELOG.md Co-authored-by: Rohit Gupta --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index a64aa0a3386dc..320de85a53d0c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,7 +12,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Enable gradient accumulation using Horovod's `backward_passes_per_step` ([#11911](https://github.com/PyTorchLightning/pytorch-lightning/pull/11911)) -- Add new `DETAIL` log level to provide useful logs for improving monitoring and debugging of batch jobs +- Add new `DETAIL` log level to provide useful logs for improving monitoring and debugging of batch jobs ([#11008](https://github.com/PyTorchLightning/pytorch-lightning/pull/11008)) - Added a flag `SLURMEnvironment(auto_requeue=True|False)` to control whether Lightning handles the requeuing ([#10601](https://github.com/PyTorchLightning/pytorch-lightning/pull/10601)) From 0bf0fb9b1cac499060d9e1c705c3ff38e3483d8d Mon Sep 17 00:00:00 2001 From: Kushashwa Ravi Shrimali Date: Mon, 14 Feb 2022 17:07:39 +0530 Subject: [PATCH 05/17] Add test and minor fix --- pytorch_lightning/strategies/horovod.py | 2 +- tests/models/test_horovod.py | 35 +++++++++++++++++++++++++ 2 files changed, 36 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/strategies/horovod.py b/pytorch_lightning/strategies/horovod.py index e9d31e2292fb5..81e5714b1ac39 100644 --- a/pytorch_lightning/strategies/horovod.py +++ b/pytorch_lightning/strategies/horovod.py @@ -196,7 +196,7 @@ def post_backward(self, closure_loss: torch.Tensor) -> None: def _wrap_optimizers(self, optimizers: List[Optimizer]) -> List["hvd.DistributedOptimizer"]: """Wraps optimizers to perform gradient aggregation via allreduce.""" - accumulate_grad_batches = self.trainer.accumulate_grad_batches + accumulate_grad_batches = self.lightning_module.trainer.accumulate_grad_batches return [ hvd.DistributedOptimizer( opt, diff --git a/tests/models/test_horovod.py b/tests/models/test_horovod.py index c3dc03b1a7fde..602a2234fc2ae 100644 --- a/tests/models/test_horovod.py +++ b/tests/models/test_horovod.py @@ -82,6 +82,23 @@ def test_horovod_cpu(tmpdir): _run_horovod(trainer_options) +@RunIf(skip_windows=True, horovod=True, skip_49370=True) +def test_horovod_cpu_accumulate_grad_batches(tmpdir): + """Test Horovod running multi-process on CPU.""" + trainer_options = dict( + default_root_dir=str(tmpdir), + weights_save_path=str(tmpdir), + gradient_clip_val=1.0, + enable_progress_bar=False, + max_epochs=1, + limit_train_batches=0.4, + limit_val_batches=0.2, + accumulate_grad_batches=4, + strategy="horovod", + ) + _run_horovod(trainer_options) + + @RunIf(skip_windows=True, horovod=True, skip_49370=True) def test_horovod_cpu_clip_grad_by_value(tmpdir): """Test Horovod running multi-process on CPU.""" @@ -131,6 +148,24 @@ def test_horovod_multi_gpu(tmpdir): _run_horovod(trainer_options, on_gpu=True) +@RunIf(min_gpus=2, skip_windows=True, horovod_nccl=True) +def test_horovod_multi_gpu_accumulate_grad_batches(tmpdir): + """Test Horovod with multi-GPU support.""" + trainer_options = dict( + default_root_dir=str(tmpdir), + weights_save_path=str(tmpdir), + gradient_clip_val=1.0, + enable_progress_bar=False, + max_epochs=1, + limit_train_batches=0.4, + limit_val_batches=0.2, + accumulate_grad_batches=4, + gpus=2, + strategy="horovod", + ) + _run_horovod(trainer_options, on_gpu=True) + + @RunIf(min_gpus=2, skip_windows=True, horovod_nccl=True) def test_horovod_multi_gpu_grad_by_value(tmpdir): """Test Horovod with multi-GPU support.""" From 5a17c2b6a7a27fa03433d16800fa2640f8b74aa1 Mon Sep 17 00:00:00 2001 From: Kushashwa Ravi Shrimali Date: Tue, 15 Feb 2022 11:03:13 +0530 Subject: [PATCH 06/17] Update pytorch_lightning/strategies/horovod.py Co-authored-by: ananthsub --- pytorch_lightning/strategies/horovod.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/strategies/horovod.py b/pytorch_lightning/strategies/horovod.py index cb15132310654..a920677f1e663 100644 --- a/pytorch_lightning/strategies/horovod.py +++ b/pytorch_lightning/strategies/horovod.py @@ -118,7 +118,7 @@ def _unpack_lightning_optimizer(opt): for optimizer in optimizers: hvd.broadcast_optimizer_state(optimizer, root_rank=0) - accumulation_scheduler = self.lightning_module.trainer.accumulation_scheduler + accumulation_scheduler = trainer.accumulation_scheduler if accumulation_scheduler.epochs != [0]: raise MisconfigurationException( "Horovod currently does not support different `accumulate_grad_batches` at different epochs." From 53fd24aa09cc016fad1898ced8d1ca647a83de8b Mon Sep 17 00:00:00 2001 From: Kushashwa Ravi Shrimali Date: Tue, 15 Feb 2022 11:04:54 +0530 Subject: [PATCH 07/17] Changes per review, pass accumulate_grad_batches as arg --- pytorch_lightning/strategies/horovod.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/strategies/horovod.py b/pytorch_lightning/strategies/horovod.py index cb15132310654..793d4ce7dbff8 100644 --- a/pytorch_lightning/strategies/horovod.py +++ b/pytorch_lightning/strategies/horovod.py @@ -118,13 +118,13 @@ def _unpack_lightning_optimizer(opt): for optimizer in optimizers: hvd.broadcast_optimizer_state(optimizer, root_rank=0) - accumulation_scheduler = self.lightning_module.trainer.accumulation_scheduler + accumulation_scheduler = trainer.accumulation_scheduler if accumulation_scheduler.epochs != [0]: raise MisconfigurationException( "Horovod currently does not support different `accumulate_grad_batches` at different epochs." ) - self.optimizers = self._wrap_optimizers(optimizers) + self.optimizers = self._wrap_optimizers(optimizers, trainer.accumulate_grad_batches) for optimizer in self.optimizers: # Synchronization will be performed explicitly following backward() self._exit_stack.enter_context(optimizer.skip_synchronize()) @@ -194,9 +194,8 @@ def post_backward(self, closure_loss: torch.Tensor) -> None: for optimizer in self.optimizers: optimizer.synchronize() - def _wrap_optimizers(self, optimizers: List[Optimizer]) -> List["hvd.DistributedOptimizer"]: + def _wrap_optimizers(self, optimizers: List[Optimizer], accumulate_grad_batches: int) -> List["hvd.DistributedOptimizer"]: """Wraps optimizers to perform gradient aggregation via allreduce.""" - accumulate_grad_batches = self.lightning_module.trainer.accumulate_grad_batches return [ hvd.DistributedOptimizer( opt, From c914e3a4a7c26c8846b83c162b6b5f833d7e654a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 15 Feb 2022 05:36:32 +0000 Subject: [PATCH 08/17] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/strategies/horovod.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/strategies/horovod.py b/pytorch_lightning/strategies/horovod.py index 793d4ce7dbff8..ce8374afb3060 100644 --- a/pytorch_lightning/strategies/horovod.py +++ b/pytorch_lightning/strategies/horovod.py @@ -194,7 +194,9 @@ def post_backward(self, closure_loss: torch.Tensor) -> None: for optimizer in self.optimizers: optimizer.synchronize() - def _wrap_optimizers(self, optimizers: List[Optimizer], accumulate_grad_batches: int) -> List["hvd.DistributedOptimizer"]: + def _wrap_optimizers( + self, optimizers: List[Optimizer], accumulate_grad_batches: int + ) -> List["hvd.DistributedOptimizer"]: """Wraps optimizers to perform gradient aggregation via allreduce.""" return [ hvd.DistributedOptimizer( From 2a4c6fa4b27c350294c5e1a5485ae42a788424ca Mon Sep 17 00:00:00 2001 From: Kushashwa Ravi Shrimali Date: Wed, 16 Feb 2022 14:14:37 +0530 Subject: [PATCH 09/17] Add tests to ensure error raised, replace gpus= with accelerator=gpu and devices=count --- tests/models/test_horovod.py | 66 ++++++++++++++++++++++++++++++++---- 1 file changed, 59 insertions(+), 7 deletions(-) diff --git a/tests/models/test_horovod.py b/tests/models/test_horovod.py index 602a2234fc2ae..04db5c7ce2389 100644 --- a/tests/models/test_horovod.py +++ b/tests/models/test_horovod.py @@ -24,6 +24,7 @@ from sklearn.metrics import accuracy_score from torch import optim from torchmetrics.classification.accuracy import Accuracy +from pytorch_lightning.utilities.exceptions import MisconfigurationException import tests.helpers.pipelines as tpipes import tests.helpers.utils as tutils @@ -99,6 +100,27 @@ def test_horovod_cpu_accumulate_grad_batches(tmpdir): _run_horovod(trainer_options) +@RunIf(skip_windows=True, horovod=True, skip_49370=True) +def test_horovod_accumulate_grad_batches_different(tmpdir): + """ + Ensure MisConfigurationException for different `accumulate_grad_batches` + at different epochs for Horovod Strategy on multi-cpus. + """ + trainer_options = dict( + default_root_dir=str(tmpdir), + weights_save_path=str(tmpdir), + gradient_clip_val=1.0, + enable_progress_bar=False, + max_epochs=4, + limit_train_batches=0.4, + limit_val_batches=0.2, + accumulate_grad_batches={0: 4, 2: 2}, + strategy="horovod", + ) + with pytest.raises(MisconfigurationException): + _run_horovod(trainer_options) + + @RunIf(skip_windows=True, horovod=True, skip_49370=True) def test_horovod_cpu_clip_grad_by_value(tmpdir): """Test Horovod running multi-process on CPU.""" @@ -142,7 +164,8 @@ def test_horovod_multi_gpu(tmpdir): max_epochs=1, limit_train_batches=0.4, limit_val_batches=0.2, - gpus=2, + accelerator="gpu", + devices=2, strategy="horovod", ) _run_horovod(trainer_options, on_gpu=True) @@ -160,12 +183,36 @@ def test_horovod_multi_gpu_accumulate_grad_batches(tmpdir): limit_train_batches=0.4, limit_val_batches=0.2, accumulate_grad_batches=4, - gpus=2, + accelerator="gpu", + devices=2, strategy="horovod", ) _run_horovod(trainer_options, on_gpu=True) +@RunIf(min_gpus=2, skip_windows=True, horovod_nccl=True) +def test_horovod_multi_gpu_accumulate_grad_batches_different(tmpdir): + """ + Ensure MisConfigurationException for different `accumulate_grad_batches` + at different epochs for Horovod Strategy on multi-gpus. + """ + trainer_options = dict( + default_root_dir=str(tmpdir), + weights_save_path=str(tmpdir), + gradient_clip_val=1.0, + enable_progress_bar=False, + max_epochs=1, + limit_train_batches=0.4, + limit_val_batches=0.2, + accumulate_grad_batches={0: 4, 2: 2}, + accelerator="gpu", + devices=2, + strategy="horovod", + ) + with pytest.raises(MisconfigurationException): + _run_horovod(trainer_options) + + @RunIf(min_gpus=2, skip_windows=True, horovod_nccl=True) def test_horovod_multi_gpu_grad_by_value(tmpdir): """Test Horovod with multi-GPU support.""" @@ -178,7 +225,8 @@ def test_horovod_multi_gpu_grad_by_value(tmpdir): max_epochs=1, limit_train_batches=0.4, limit_val_batches=0.2, - gpus=2, + accelerator="gpu", + devices=2, strategy="horovod", ) _run_horovod(trainer_options, on_gpu=True) @@ -199,7 +247,8 @@ def test_horovod_apex(tmpdir): max_epochs=1, limit_train_batches=0.4, limit_val_batches=0.2, - gpus=2, + accelerator="gpu", + devices=2, strategy="horovod", amp_backend="apex", precision=16, @@ -218,7 +267,8 @@ def test_horovod_amp(tmpdir): max_epochs=1, limit_train_batches=0.4, limit_val_batches=0.2, - gpus=2, + accelerator="gpu", + devices=2, strategy="horovod", amp_backend="native", precision=16, @@ -237,7 +287,8 @@ def test_horovod_gather(tmpdir): max_epochs=1, limit_train_batches=0.4, limit_val_batches=0.2, - gpus=2, + accelerator="gpu", + devices=2, strategy="horovod", ) _run_horovod(trainer_options, on_gpu=True) @@ -262,7 +313,8 @@ def validation_step(self, batch, *args, **kwargs): max_epochs=1, limit_train_batches=0.4, limit_val_batches=0.2, - gpus=1, + accelerator="gpu", + devices=1, strategy="horovod", ) tpipes.run_model_test_without_loggers(trainer_options, model) From 0850b0acc15637d257b828a58a7879e9f7dbb22a Mon Sep 17 00:00:00 2001 From: Kushashwa Ravi Shrimali Date: Wed, 16 Feb 2022 14:46:15 +0530 Subject: [PATCH 10/17] Don't add tests for now --- tests/models/test_horovod.py | 45 ------------------------------------ 1 file changed, 45 deletions(-) diff --git a/tests/models/test_horovod.py b/tests/models/test_horovod.py index 04db5c7ce2389..3b11821783315 100644 --- a/tests/models/test_horovod.py +++ b/tests/models/test_horovod.py @@ -24,7 +24,6 @@ from sklearn.metrics import accuracy_score from torch import optim from torchmetrics.classification.accuracy import Accuracy -from pytorch_lightning.utilities.exceptions import MisconfigurationException import tests.helpers.pipelines as tpipes import tests.helpers.utils as tutils @@ -100,27 +99,6 @@ def test_horovod_cpu_accumulate_grad_batches(tmpdir): _run_horovod(trainer_options) -@RunIf(skip_windows=True, horovod=True, skip_49370=True) -def test_horovod_accumulate_grad_batches_different(tmpdir): - """ - Ensure MisConfigurationException for different `accumulate_grad_batches` - at different epochs for Horovod Strategy on multi-cpus. - """ - trainer_options = dict( - default_root_dir=str(tmpdir), - weights_save_path=str(tmpdir), - gradient_clip_val=1.0, - enable_progress_bar=False, - max_epochs=4, - limit_train_batches=0.4, - limit_val_batches=0.2, - accumulate_grad_batches={0: 4, 2: 2}, - strategy="horovod", - ) - with pytest.raises(MisconfigurationException): - _run_horovod(trainer_options) - - @RunIf(skip_windows=True, horovod=True, skip_49370=True) def test_horovod_cpu_clip_grad_by_value(tmpdir): """Test Horovod running multi-process on CPU.""" @@ -190,29 +168,6 @@ def test_horovod_multi_gpu_accumulate_grad_batches(tmpdir): _run_horovod(trainer_options, on_gpu=True) -@RunIf(min_gpus=2, skip_windows=True, horovod_nccl=True) -def test_horovod_multi_gpu_accumulate_grad_batches_different(tmpdir): - """ - Ensure MisConfigurationException for different `accumulate_grad_batches` - at different epochs for Horovod Strategy on multi-gpus. - """ - trainer_options = dict( - default_root_dir=str(tmpdir), - weights_save_path=str(tmpdir), - gradient_clip_val=1.0, - enable_progress_bar=False, - max_epochs=1, - limit_train_batches=0.4, - limit_val_batches=0.2, - accumulate_grad_batches={0: 4, 2: 2}, - accelerator="gpu", - devices=2, - strategy="horovod", - ) - with pytest.raises(MisconfigurationException): - _run_horovod(trainer_options) - - @RunIf(min_gpus=2, skip_windows=True, horovod_nccl=True) def test_horovod_multi_gpu_grad_by_value(tmpdir): """Test Horovod with multi-GPU support.""" From 9fd23ce827c561ef03a6479c0d36ffb720a7b2a9 Mon Sep 17 00:00:00 2001 From: Kushashwa Ravi Shrimali Date: Wed, 16 Feb 2022 14:53:39 +0530 Subject: [PATCH 11/17] Use basic model, and trainer.fit instead of _run_horovod --- tests/models/test_horovod.py | 48 ++++++++++++++++++++++++++++++++++++ 1 file changed, 48 insertions(+) diff --git a/tests/models/test_horovod.py b/tests/models/test_horovod.py index 3b11821783315..ee654d2ed0137 100644 --- a/tests/models/test_horovod.py +++ b/tests/models/test_horovod.py @@ -30,8 +30,10 @@ from pytorch_lightning import Trainer from pytorch_lightning.accelerators import CPUAccelerator from pytorch_lightning.utilities import _HOROVOD_AVAILABLE +from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.helpers import BoringModel from tests.helpers.advanced_models import BasicGAN +from tests.helpers.simple_models import ClassificationModel from tests.helpers.runif import RunIf if _HOROVOD_AVAILABLE: @@ -99,6 +101,28 @@ def test_horovod_cpu_accumulate_grad_batches(tmpdir): _run_horovod(trainer_options) +@RunIf(skip_windows=True, horovod=True, skip_49370=True) +def test_horovod_cpu_accumulate_grad_batches_different(tmpdir): + """ + Ensure MisConfigurationException for different `accumulate_grad_batches` + at different epochs for Horovod Strategy on multi-cpus. + """ + model = ClassificationModel() + trainer = Trainer( + default_root_dir=str(tmpdir), + weights_save_path=str(tmpdir), + gradient_clip_val=1.0, + enable_progress_bar=False, + max_epochs=4, + limit_train_batches=0.4, + limit_val_batches=0.2, + accumulate_grad_batches={0: 4, 2: 2}, + strategy="horovod", + ) + with pytest.raises(MisconfigurationException): + trainer.fit(model) + + @RunIf(skip_windows=True, horovod=True, skip_49370=True) def test_horovod_cpu_clip_grad_by_value(tmpdir): """Test Horovod running multi-process on CPU.""" @@ -168,6 +192,30 @@ def test_horovod_multi_gpu_accumulate_grad_batches(tmpdir): _run_horovod(trainer_options, on_gpu=True) +@RunIf(min_gpus=2, skip_windows=True, horovod_nccl=True) +def test_horovod_multi_gpu_accumulate_grad_batches_different(tmpdir): + """ + Ensure MisConfigurationException for different `accumulate_grad_batches` + at different epochs for Horovod Strategy on multi-gpus. + """ + model = ClassificationModel() + trainer = Trainer( + default_root_dir=str(tmpdir), + weights_save_path=str(tmpdir), + gradient_clip_val=1.0, + enable_progress_bar=False, + max_epochs=1, + limit_train_batches=0.4, + limit_val_batches=0.2, + accumulate_grad_batches={0: 4, 2: 2}, + accelerator="gpu", + devices=2, + strategy="horovod", + ) + with pytest.raises(MisconfigurationException): + trainer.fit(model) + + @RunIf(min_gpus=2, skip_windows=True, horovod_nccl=True) def test_horovod_multi_gpu_grad_by_value(tmpdir): """Test Horovod with multi-GPU support.""" From 04ee625d31c0bb64b2957d8085673dd01793a314 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 16 Feb 2022 09:24:59 +0000 Subject: [PATCH 12/17] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/models/test_horovod.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/tests/models/test_horovod.py b/tests/models/test_horovod.py index ee654d2ed0137..455315f0c762e 100644 --- a/tests/models/test_horovod.py +++ b/tests/models/test_horovod.py @@ -33,8 +33,8 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.helpers import BoringModel from tests.helpers.advanced_models import BasicGAN -from tests.helpers.simple_models import ClassificationModel from tests.helpers.runif import RunIf +from tests.helpers.simple_models import ClassificationModel if _HOROVOD_AVAILABLE: import horovod @@ -103,10 +103,8 @@ def test_horovod_cpu_accumulate_grad_batches(tmpdir): @RunIf(skip_windows=True, horovod=True, skip_49370=True) def test_horovod_cpu_accumulate_grad_batches_different(tmpdir): - """ - Ensure MisConfigurationException for different `accumulate_grad_batches` - at different epochs for Horovod Strategy on multi-cpus. - """ + """Ensure MisConfigurationException for different `accumulate_grad_batches` at different epochs for Horovod + Strategy on multi-cpus.""" model = ClassificationModel() trainer = Trainer( default_root_dir=str(tmpdir), @@ -194,10 +192,8 @@ def test_horovod_multi_gpu_accumulate_grad_batches(tmpdir): @RunIf(min_gpus=2, skip_windows=True, horovod_nccl=True) def test_horovod_multi_gpu_accumulate_grad_batches_different(tmpdir): - """ - Ensure MisConfigurationException for different `accumulate_grad_batches` - at different epochs for Horovod Strategy on multi-gpus. - """ + """Ensure MisConfigurationException for different `accumulate_grad_batches` at different epochs for Horovod + Strategy on multi-gpus.""" model = ClassificationModel() trainer = Trainer( default_root_dir=str(tmpdir), From 7c73b0c7f1ecb7f081953dd552c7982162e8c36b Mon Sep 17 00:00:00 2001 From: Kushashwa Ravi Shrimali Date: Wed, 16 Feb 2022 21:40:03 +0530 Subject: [PATCH 13/17] Apply suggestions from code review MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Carlos MocholĂ­ --- tests/models/test_horovod.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/tests/models/test_horovod.py b/tests/models/test_horovod.py index 455315f0c762e..aa6179c8b2484 100644 --- a/tests/models/test_horovod.py +++ b/tests/models/test_horovod.py @@ -196,19 +196,14 @@ def test_horovod_multi_gpu_accumulate_grad_batches_different(tmpdir): Strategy on multi-gpus.""" model = ClassificationModel() trainer = Trainer( - default_root_dir=str(tmpdir), - weights_save_path=str(tmpdir), - gradient_clip_val=1.0, + default_root_dir=tmpdir, enable_progress_bar=False, - max_epochs=1, - limit_train_batches=0.4, - limit_val_batches=0.2, accumulate_grad_batches={0: 4, 2: 2}, accelerator="gpu", devices=2, strategy="horovod", ) - with pytest.raises(MisconfigurationException): + with pytest.raises(MisconfigurationException, match="Horovod.*does not support.*accumulate_grad_batches"): trainer.fit(model) From 8c1e3f1e7dae05af0c3cde214c6d0fe15014eb67 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 17 Feb 2022 14:16:46 +0100 Subject: [PATCH 14/17] Self review --- tests/models/test_horovod.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/models/test_horovod.py b/tests/models/test_horovod.py index aa6179c8b2484..01f195cb18d93 100644 --- a/tests/models/test_horovod.py +++ b/tests/models/test_horovod.py @@ -190,7 +190,7 @@ def test_horovod_multi_gpu_accumulate_grad_batches(tmpdir): _run_horovod(trainer_options, on_gpu=True) -@RunIf(min_gpus=2, skip_windows=True, horovod_nccl=True) +@RunIf(skip_windows=True) def test_horovod_multi_gpu_accumulate_grad_batches_different(tmpdir): """Ensure MisConfigurationException for different `accumulate_grad_batches` at different epochs for Horovod Strategy on multi-gpus.""" @@ -199,8 +199,8 @@ def test_horovod_multi_gpu_accumulate_grad_batches_different(tmpdir): default_root_dir=tmpdir, enable_progress_bar=False, accumulate_grad_batches={0: 4, 2: 2}, - accelerator="gpu", - devices=2, + accelerator="auto", + devices=1, strategy="horovod", ) with pytest.raises(MisconfigurationException, match="Horovod.*does not support.*accumulate_grad_batches"): From 60ba198430c6bb89c39f7790bddb3034a4e7d3ef Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 17 Feb 2022 14:20:50 +0100 Subject: [PATCH 15/17] Remove duplicated test --- tests/models/test_horovod.py | 22 +--------------------- 1 file changed, 1 insertion(+), 21 deletions(-) diff --git a/tests/models/test_horovod.py b/tests/models/test_horovod.py index 01f195cb18d93..7cdcff5667509 100644 --- a/tests/models/test_horovod.py +++ b/tests/models/test_horovod.py @@ -101,26 +101,6 @@ def test_horovod_cpu_accumulate_grad_batches(tmpdir): _run_horovod(trainer_options) -@RunIf(skip_windows=True, horovod=True, skip_49370=True) -def test_horovod_cpu_accumulate_grad_batches_different(tmpdir): - """Ensure MisConfigurationException for different `accumulate_grad_batches` at different epochs for Horovod - Strategy on multi-cpus.""" - model = ClassificationModel() - trainer = Trainer( - default_root_dir=str(tmpdir), - weights_save_path=str(tmpdir), - gradient_clip_val=1.0, - enable_progress_bar=False, - max_epochs=4, - limit_train_batches=0.4, - limit_val_batches=0.2, - accumulate_grad_batches={0: 4, 2: 2}, - strategy="horovod", - ) - with pytest.raises(MisconfigurationException): - trainer.fit(model) - - @RunIf(skip_windows=True, horovod=True, skip_49370=True) def test_horovod_cpu_clip_grad_by_value(tmpdir): """Test Horovod running multi-process on CPU.""" @@ -191,7 +171,7 @@ def test_horovod_multi_gpu_accumulate_grad_batches(tmpdir): @RunIf(skip_windows=True) -def test_horovod_multi_gpu_accumulate_grad_batches_different(tmpdir): +def test_horovod_raises_unsupported_accumulate_grad_batches(tmpdir): """Ensure MisConfigurationException for different `accumulate_grad_batches` at different epochs for Horovod Strategy on multi-gpus.""" model = ClassificationModel() From 5786ce805682fe2b30a9b5f3f6db8ee6e341a653 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 17 Feb 2022 15:55:48 +0100 Subject: [PATCH 16/17] Missed horovod in RunIf --- tests/models/test_horovod.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/test_horovod.py b/tests/models/test_horovod.py index 7cdcff5667509..4e512f814427b 100644 --- a/tests/models/test_horovod.py +++ b/tests/models/test_horovod.py @@ -170,7 +170,7 @@ def test_horovod_multi_gpu_accumulate_grad_batches(tmpdir): _run_horovod(trainer_options, on_gpu=True) -@RunIf(skip_windows=True) +@RunIf(horovod=True, skip_windows=True) def test_horovod_raises_unsupported_accumulate_grad_batches(tmpdir): """Ensure MisConfigurationException for different `accumulate_grad_batches` at different epochs for Horovod Strategy on multi-gpus.""" From 670fbac82b848a046f515863b77821ad0070866c Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 17 Feb 2022 22:56:13 +0100 Subject: [PATCH 17/17] Further simplification --- tests/models/test_horovod.py | 47 +++++++++++++++--------------------- 1 file changed, 19 insertions(+), 28 deletions(-) diff --git a/tests/models/test_horovod.py b/tests/models/test_horovod.py index 4e512f814427b..c4d364ad1fa88 100644 --- a/tests/models/test_horovod.py +++ b/tests/models/test_horovod.py @@ -34,7 +34,6 @@ from tests.helpers import BoringModel from tests.helpers.advanced_models import BasicGAN from tests.helpers.runif import RunIf -from tests.helpers.simple_models import ClassificationModel if _HOROVOD_AVAILABLE: import horovod @@ -44,11 +43,9 @@ TEST_SCRIPT = os.path.join(os.path.dirname(__file__), "data", "horovod", "train_default_model.py") -def _run_horovod(trainer_options, on_gpu=False): +def _run_horovod(trainer_options): """Execute the training script across multiple workers in parallel.""" - num_processes = trainer_options.get("gpus", 2) - # for Horovod, we interpret `gpus` to be set per worker - trainer_options.update(gpus=1 if on_gpu else None) + devices = trainer_options.get("devices", 1) tutils.reset_seed() # TODO: Find out why coverage breaks CI. # append = '-a' if '.coverage' in os.listdir(_PROJECT_ROOT) else '' @@ -56,13 +53,13 @@ def _run_horovod(trainer_options, on_gpu=False): cmdline = [ "horovodrun", "-np", - str(num_processes), + str(devices), sys.executable, TEST_SCRIPT, "--trainer-options", shlex.quote(json.dumps(trainer_options)), ] - if on_gpu: + if trainer_options.get("accelerator", "cpu") == "gpu": cmdline += ["--on-gpu"] exit_code = subprocess.call(" ".join(cmdline), shell=True, env=os.environ.copy()) assert exit_code == 0 @@ -86,16 +83,13 @@ def test_horovod_cpu(tmpdir): @RunIf(skip_windows=True, horovod=True, skip_49370=True) def test_horovod_cpu_accumulate_grad_batches(tmpdir): - """Test Horovod running multi-process on CPU.""" trainer_options = dict( - default_root_dir=str(tmpdir), - weights_save_path=str(tmpdir), - gradient_clip_val=1.0, + default_root_dir=tmpdir, enable_progress_bar=False, max_epochs=1, - limit_train_batches=0.4, - limit_val_batches=0.2, - accumulate_grad_batches=4, + limit_train_batches=4, + limit_val_batches=0, + accumulate_grad_batches=2, strategy="horovod", ) _run_horovod(trainer_options) @@ -148,33 +142,30 @@ def test_horovod_multi_gpu(tmpdir): devices=2, strategy="horovod", ) - _run_horovod(trainer_options, on_gpu=True) + _run_horovod(trainer_options) @RunIf(min_gpus=2, skip_windows=True, horovod_nccl=True) def test_horovod_multi_gpu_accumulate_grad_batches(tmpdir): - """Test Horovod with multi-GPU support.""" trainer_options = dict( - default_root_dir=str(tmpdir), - weights_save_path=str(tmpdir), - gradient_clip_val=1.0, + default_root_dir=tmpdir, enable_progress_bar=False, max_epochs=1, - limit_train_batches=0.4, - limit_val_batches=0.2, - accumulate_grad_batches=4, + limit_train_batches=4, + limit_val_batches=0, + accumulate_grad_batches=2, accelerator="gpu", devices=2, strategy="horovod", ) - _run_horovod(trainer_options, on_gpu=True) + _run_horovod(trainer_options) @RunIf(horovod=True, skip_windows=True) def test_horovod_raises_unsupported_accumulate_grad_batches(tmpdir): """Ensure MisConfigurationException for different `accumulate_grad_batches` at different epochs for Horovod Strategy on multi-gpus.""" - model = ClassificationModel() + model = BoringModel() trainer = Trainer( default_root_dir=tmpdir, enable_progress_bar=False, @@ -203,7 +194,7 @@ def test_horovod_multi_gpu_grad_by_value(tmpdir): devices=2, strategy="horovod", ) - _run_horovod(trainer_options, on_gpu=True) + _run_horovod(trainer_options) # todo: need to be fixed :] @@ -227,7 +218,7 @@ def test_horovod_apex(tmpdir): amp_backend="apex", precision=16, ) - _run_horovod(trainer_options, on_gpu=True) + _run_horovod(trainer_options) @RunIf(min_gpus=2, skip_windows=True, horovod_nccl=True) @@ -247,7 +238,7 @@ def test_horovod_amp(tmpdir): amp_backend="native", precision=16, ) - _run_horovod(trainer_options, on_gpu=True) + _run_horovod(trainer_options) @RunIf(min_gpus=2, skip_windows=True, horovod_nccl=True) @@ -265,7 +256,7 @@ def test_horovod_gather(tmpdir): devices=2, strategy="horovod", ) - _run_horovod(trainer_options, on_gpu=True) + _run_horovod(trainer_options) @RunIf(min_gpus=1, skip_windows=True, horovod_nccl=True)