From 70180495db3574153f466ca61e3fa8a738e5da72 Mon Sep 17 00:00:00 2001 From: Kushashwa Ravi Shrimali Date: Sat, 19 Feb 2022 07:24:04 +0530 Subject: [PATCH] Support gradient accumulation using Horovod's `backward_passes_per_step` (#11911) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Rohit Gupta Co-authored-by: ananthsub Co-authored-by: Carlos MocholĂ­ --- CHANGELOG.md | 5 +- pytorch_lightning/strategies/horovod.py | 24 ++++++- tests/models/test_horovod.py | 86 ++++++++++++++++++++----- 3 files changed, 94 insertions(+), 21 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 267895c407f25..f8bb34adeef05 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,7 +9,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added -- Add new `DETAIL` log level to provide useful logs for improving monitoring and debugging of batch jobs +- Enable gradient accumulation using Horovod's `backward_passes_per_step` ([#11911](https://github.com/PyTorchLightning/pytorch-lightning/pull/11911)) + + +- Add new `DETAIL` log level to provide useful logs for improving monitoring and debugging of batch jobs ([#11008](https://github.com/PyTorchLightning/pytorch-lightning/pull/11008)) - Added a flag `SLURMEnvironment(auto_requeue=True|False)` to control whether Lightning handles the requeuing ([#10601](https://github.com/PyTorchLightning/pytorch-lightning/pull/10601)) diff --git a/pytorch_lightning/strategies/horovod.py b/pytorch_lightning/strategies/horovod.py index f4a733909651e..101715f39952f 100644 --- a/pytorch_lightning/strategies/horovod.py +++ b/pytorch_lightning/strategies/horovod.py @@ -26,6 +26,7 @@ from pytorch_lightning.utilities.distributed import distributed_available from pytorch_lightning.utilities.distributed import group as dist_group from pytorch_lightning.utilities.distributed import ReduceOp +from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _HOROVOD_AVAILABLE from pytorch_lightning.utilities.rank_zero import rank_zero_only @@ -76,6 +77,11 @@ def distributed_sampler_kwargs(self): distributed_sampler_kwargs = dict(num_replicas=self.world_size, rank=self.global_rank) return distributed_sampler_kwargs + @property + def handles_gradient_accumulation(self) -> bool: + """Whether the plugin handles gradient accumulation internally.""" + return True + def setup(self, trainer: "pl.Trainer") -> None: self.model_to_device() @@ -111,7 +117,13 @@ def _unpack_lightning_optimizer(opt): for optimizer in optimizers: hvd.broadcast_optimizer_state(optimizer, root_rank=0) - self.optimizers = self._wrap_optimizers(optimizers) + accumulation_scheduler = trainer.accumulation_scheduler + if accumulation_scheduler.epochs != [0]: + raise MisconfigurationException( + "Horovod currently does not support different `accumulate_grad_batches` at different epochs." + ) + + self.optimizers = self._wrap_optimizers(optimizers, trainer.accumulate_grad_batches) for optimizer in self.optimizers: # Synchronization will be performed explicitly following backward() self._exit_stack.enter_context(optimizer.skip_synchronize()) @@ -181,10 +193,16 @@ def post_backward(self, closure_loss: torch.Tensor) -> None: for optimizer in self.optimizers: optimizer.synchronize() - def _wrap_optimizers(self, optimizers: List[Optimizer]) -> List["hvd.DistributedOptimizer"]: + def _wrap_optimizers( + self, optimizers: List[Optimizer], accumulate_grad_batches: int + ) -> List["hvd.DistributedOptimizer"]: """Wraps optimizers to perform gradient aggregation via allreduce.""" return [ - hvd.DistributedOptimizer(opt, named_parameters=self._filter_named_parameters(self.lightning_module, opt)) + hvd.DistributedOptimizer( + opt, + backward_passes_per_step=accumulate_grad_batches, + named_parameters=self._filter_named_parameters(self.lightning_module, opt), + ) if "horovod" not in str(opt.__class__) else opt for opt in optimizers diff --git a/tests/models/test_horovod.py b/tests/models/test_horovod.py index c3dc03b1a7fde..c4d364ad1fa88 100644 --- a/tests/models/test_horovod.py +++ b/tests/models/test_horovod.py @@ -30,6 +30,7 @@ from pytorch_lightning import Trainer from pytorch_lightning.accelerators import CPUAccelerator from pytorch_lightning.utilities import _HOROVOD_AVAILABLE +from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.helpers import BoringModel from tests.helpers.advanced_models import BasicGAN from tests.helpers.runif import RunIf @@ -42,11 +43,9 @@ TEST_SCRIPT = os.path.join(os.path.dirname(__file__), "data", "horovod", "train_default_model.py") -def _run_horovod(trainer_options, on_gpu=False): +def _run_horovod(trainer_options): """Execute the training script across multiple workers in parallel.""" - num_processes = trainer_options.get("gpus", 2) - # for Horovod, we interpret `gpus` to be set per worker - trainer_options.update(gpus=1 if on_gpu else None) + devices = trainer_options.get("devices", 1) tutils.reset_seed() # TODO: Find out why coverage breaks CI. # append = '-a' if '.coverage' in os.listdir(_PROJECT_ROOT) else '' @@ -54,13 +53,13 @@ def _run_horovod(trainer_options, on_gpu=False): cmdline = [ "horovodrun", "-np", - str(num_processes), + str(devices), sys.executable, TEST_SCRIPT, "--trainer-options", shlex.quote(json.dumps(trainer_options)), ] - if on_gpu: + if trainer_options.get("accelerator", "cpu") == "gpu": cmdline += ["--on-gpu"] exit_code = subprocess.call(" ".join(cmdline), shell=True, env=os.environ.copy()) assert exit_code == 0 @@ -82,6 +81,20 @@ def test_horovod_cpu(tmpdir): _run_horovod(trainer_options) +@RunIf(skip_windows=True, horovod=True, skip_49370=True) +def test_horovod_cpu_accumulate_grad_batches(tmpdir): + trainer_options = dict( + default_root_dir=tmpdir, + enable_progress_bar=False, + max_epochs=1, + limit_train_batches=4, + limit_val_batches=0, + accumulate_grad_batches=2, + strategy="horovod", + ) + _run_horovod(trainer_options) + + @RunIf(skip_windows=True, horovod=True, skip_49370=True) def test_horovod_cpu_clip_grad_by_value(tmpdir): """Test Horovod running multi-process on CPU.""" @@ -125,10 +138,44 @@ def test_horovod_multi_gpu(tmpdir): max_epochs=1, limit_train_batches=0.4, limit_val_batches=0.2, - gpus=2, + accelerator="gpu", + devices=2, + strategy="horovod", + ) + _run_horovod(trainer_options) + + +@RunIf(min_gpus=2, skip_windows=True, horovod_nccl=True) +def test_horovod_multi_gpu_accumulate_grad_batches(tmpdir): + trainer_options = dict( + default_root_dir=tmpdir, + enable_progress_bar=False, + max_epochs=1, + limit_train_batches=4, + limit_val_batches=0, + accumulate_grad_batches=2, + accelerator="gpu", + devices=2, strategy="horovod", ) - _run_horovod(trainer_options, on_gpu=True) + _run_horovod(trainer_options) + + +@RunIf(horovod=True, skip_windows=True) +def test_horovod_raises_unsupported_accumulate_grad_batches(tmpdir): + """Ensure MisConfigurationException for different `accumulate_grad_batches` at different epochs for Horovod + Strategy on multi-gpus.""" + model = BoringModel() + trainer = Trainer( + default_root_dir=tmpdir, + enable_progress_bar=False, + accumulate_grad_batches={0: 4, 2: 2}, + accelerator="auto", + devices=1, + strategy="horovod", + ) + with pytest.raises(MisconfigurationException, match="Horovod.*does not support.*accumulate_grad_batches"): + trainer.fit(model) @RunIf(min_gpus=2, skip_windows=True, horovod_nccl=True) @@ -143,10 +190,11 @@ def test_horovod_multi_gpu_grad_by_value(tmpdir): max_epochs=1, limit_train_batches=0.4, limit_val_batches=0.2, - gpus=2, + accelerator="gpu", + devices=2, strategy="horovod", ) - _run_horovod(trainer_options, on_gpu=True) + _run_horovod(trainer_options) # todo: need to be fixed :] @@ -164,12 +212,13 @@ def test_horovod_apex(tmpdir): max_epochs=1, limit_train_batches=0.4, limit_val_batches=0.2, - gpus=2, + accelerator="gpu", + devices=2, strategy="horovod", amp_backend="apex", precision=16, ) - _run_horovod(trainer_options, on_gpu=True) + _run_horovod(trainer_options) @RunIf(min_gpus=2, skip_windows=True, horovod_nccl=True) @@ -183,12 +232,13 @@ def test_horovod_amp(tmpdir): max_epochs=1, limit_train_batches=0.4, limit_val_batches=0.2, - gpus=2, + accelerator="gpu", + devices=2, strategy="horovod", amp_backend="native", precision=16, ) - _run_horovod(trainer_options, on_gpu=True) + _run_horovod(trainer_options) @RunIf(min_gpus=2, skip_windows=True, horovod_nccl=True) @@ -202,10 +252,11 @@ def test_horovod_gather(tmpdir): max_epochs=1, limit_train_batches=0.4, limit_val_batches=0.2, - gpus=2, + accelerator="gpu", + devices=2, strategy="horovod", ) - _run_horovod(trainer_options, on_gpu=True) + _run_horovod(trainer_options) @RunIf(min_gpus=1, skip_windows=True, horovod_nccl=True) @@ -227,7 +278,8 @@ def validation_step(self, batch, *args, **kwargs): max_epochs=1, limit_train_batches=0.4, limit_val_batches=0.2, - gpus=1, + accelerator="gpu", + devices=1, strategy="horovod", ) tpipes.run_model_test_without_loggers(trainer_options, model)