Skip to content

Commit

Permalink
Support gradient accumulation using Horovod's `backward_passes_per_st…
Browse files Browse the repository at this point in the history
…ep` (#11911)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>
Co-authored-by: ananthsub <ananth.subramaniam@gmail.com>
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
  • Loading branch information
5 people committed Feb 19, 2022
1 parent cf64f34 commit 0374fe6
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 21 deletions.
5 changes: 4 additions & 1 deletion CHANGELOG.md
Expand Up @@ -9,7 +9,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Add new `DETAIL` log level to provide useful logs for improving monitoring and debugging of batch jobs
- Enable gradient accumulation using Horovod's `backward_passes_per_step` ([#11911](https://github.com/PyTorchLightning/pytorch-lightning/pull/11911))


- Add new `DETAIL` log level to provide useful logs for improving monitoring and debugging of batch jobs ([#11008](https://github.com/PyTorchLightning/pytorch-lightning/pull/11008))


- Added a flag `SLURMEnvironment(auto_requeue=True|False)` to control whether Lightning handles the requeuing ([#10601](https://github.com/PyTorchLightning/pytorch-lightning/pull/10601))
Expand Down
24 changes: 21 additions & 3 deletions pytorch_lightning/strategies/horovod.py
Expand Up @@ -26,6 +26,7 @@
from pytorch_lightning.utilities.distributed import distributed_available
from pytorch_lightning.utilities.distributed import group as dist_group
from pytorch_lightning.utilities.distributed import ReduceOp
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _HOROVOD_AVAILABLE
from pytorch_lightning.utilities.rank_zero import rank_zero_only

Expand Down Expand Up @@ -76,6 +77,11 @@ def distributed_sampler_kwargs(self):
distributed_sampler_kwargs = dict(num_replicas=self.world_size, rank=self.global_rank)
return distributed_sampler_kwargs

@property
def handles_gradient_accumulation(self) -> bool:
"""Whether the plugin handles gradient accumulation internally."""
return True

def setup(self, trainer: "pl.Trainer") -> None:
self.model_to_device()

Expand Down Expand Up @@ -111,7 +117,13 @@ def _unpack_lightning_optimizer(opt):
for optimizer in optimizers:
hvd.broadcast_optimizer_state(optimizer, root_rank=0)

self.optimizers = self._wrap_optimizers(optimizers)
accumulation_scheduler = trainer.accumulation_scheduler
if accumulation_scheduler.epochs != [0]:
raise MisconfigurationException(
"Horovod currently does not support different `accumulate_grad_batches` at different epochs."
)

self.optimizers = self._wrap_optimizers(optimizers, trainer.accumulate_grad_batches)
for optimizer in self.optimizers:
# Synchronization will be performed explicitly following backward()
self._exit_stack.enter_context(optimizer.skip_synchronize())
Expand Down Expand Up @@ -181,10 +193,16 @@ def post_backward(self, closure_loss: torch.Tensor) -> None:
for optimizer in self.optimizers:
optimizer.synchronize()

def _wrap_optimizers(self, optimizers: List[Optimizer]) -> List["hvd.DistributedOptimizer"]:
def _wrap_optimizers(
self, optimizers: List[Optimizer], accumulate_grad_batches: int
) -> List["hvd.DistributedOptimizer"]:
"""Wraps optimizers to perform gradient aggregation via allreduce."""
return [
hvd.DistributedOptimizer(opt, named_parameters=self._filter_named_parameters(self.lightning_module, opt))
hvd.DistributedOptimizer(
opt,
backward_passes_per_step=accumulate_grad_batches,
named_parameters=self._filter_named_parameters(self.lightning_module, opt),
)
if "horovod" not in str(opt.__class__)
else opt
for opt in optimizers
Expand Down
86 changes: 69 additions & 17 deletions tests/models/test_horovod.py
Expand Up @@ -30,6 +30,7 @@
from pytorch_lightning import Trainer
from pytorch_lightning.accelerators import CPUAccelerator
from pytorch_lightning.utilities import _HOROVOD_AVAILABLE
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.helpers import BoringModel
from tests.helpers.advanced_models import BasicGAN
from tests.helpers.runif import RunIf
Expand All @@ -42,25 +43,23 @@
TEST_SCRIPT = os.path.join(os.path.dirname(__file__), "data", "horovod", "train_default_model.py")


def _run_horovod(trainer_options, on_gpu=False):
def _run_horovod(trainer_options):
"""Execute the training script across multiple workers in parallel."""
num_processes = trainer_options.get("gpus", 2)
# for Horovod, we interpret `gpus` to be set per worker
trainer_options.update(gpus=1 if on_gpu else None)
devices = trainer_options.get("devices", 1)
tutils.reset_seed()
# TODO: Find out why coverage breaks CI.
# append = '-a' if '.coverage' in os.listdir(_PROJECT_ROOT) else ''
# str(num_processes), sys.executable, '-m', 'coverage', 'run', '--source', 'pytorch_lightning', append,
cmdline = [
"horovodrun",
"-np",
str(num_processes),
str(devices),
sys.executable,
TEST_SCRIPT,
"--trainer-options",
shlex.quote(json.dumps(trainer_options)),
]
if on_gpu:
if trainer_options.get("accelerator", "cpu") == "gpu":
cmdline += ["--on-gpu"]
exit_code = subprocess.call(" ".join(cmdline), shell=True, env=os.environ.copy())
assert exit_code == 0
Expand All @@ -82,6 +81,20 @@ def test_horovod_cpu(tmpdir):
_run_horovod(trainer_options)


@RunIf(skip_windows=True, horovod=True, skip_49370=True)
def test_horovod_cpu_accumulate_grad_batches(tmpdir):
trainer_options = dict(
default_root_dir=tmpdir,
enable_progress_bar=False,
max_epochs=1,
limit_train_batches=4,
limit_val_batches=0,
accumulate_grad_batches=2,
strategy="horovod",
)
_run_horovod(trainer_options)


@RunIf(skip_windows=True, horovod=True, skip_49370=True)
def test_horovod_cpu_clip_grad_by_value(tmpdir):
"""Test Horovod running multi-process on CPU."""
Expand Down Expand Up @@ -125,10 +138,44 @@ def test_horovod_multi_gpu(tmpdir):
max_epochs=1,
limit_train_batches=0.4,
limit_val_batches=0.2,
gpus=2,
accelerator="gpu",
devices=2,
strategy="horovod",
)
_run_horovod(trainer_options)


@RunIf(min_gpus=2, skip_windows=True, horovod_nccl=True)
def test_horovod_multi_gpu_accumulate_grad_batches(tmpdir):
trainer_options = dict(
default_root_dir=tmpdir,
enable_progress_bar=False,
max_epochs=1,
limit_train_batches=4,
limit_val_batches=0,
accumulate_grad_batches=2,
accelerator="gpu",
devices=2,
strategy="horovod",
)
_run_horovod(trainer_options, on_gpu=True)
_run_horovod(trainer_options)


@RunIf(horovod=True, skip_windows=True)
def test_horovod_raises_unsupported_accumulate_grad_batches(tmpdir):
"""Ensure MisConfigurationException for different `accumulate_grad_batches` at different epochs for Horovod
Strategy on multi-gpus."""
model = BoringModel()
trainer = Trainer(
default_root_dir=tmpdir,
enable_progress_bar=False,
accumulate_grad_batches={0: 4, 2: 2},
accelerator="auto",
devices=1,
strategy="horovod",
)
with pytest.raises(MisconfigurationException, match="Horovod.*does not support.*accumulate_grad_batches"):
trainer.fit(model)


@RunIf(min_gpus=2, skip_windows=True, horovod_nccl=True)
Expand All @@ -143,10 +190,11 @@ def test_horovod_multi_gpu_grad_by_value(tmpdir):
max_epochs=1,
limit_train_batches=0.4,
limit_val_batches=0.2,
gpus=2,
accelerator="gpu",
devices=2,
strategy="horovod",
)
_run_horovod(trainer_options, on_gpu=True)
_run_horovod(trainer_options)


# todo: need to be fixed :]
Expand All @@ -164,12 +212,13 @@ def test_horovod_apex(tmpdir):
max_epochs=1,
limit_train_batches=0.4,
limit_val_batches=0.2,
gpus=2,
accelerator="gpu",
devices=2,
strategy="horovod",
amp_backend="apex",
precision=16,
)
_run_horovod(trainer_options, on_gpu=True)
_run_horovod(trainer_options)


@RunIf(min_gpus=2, skip_windows=True, horovod_nccl=True)
Expand All @@ -183,12 +232,13 @@ def test_horovod_amp(tmpdir):
max_epochs=1,
limit_train_batches=0.4,
limit_val_batches=0.2,
gpus=2,
accelerator="gpu",
devices=2,
strategy="horovod",
amp_backend="native",
precision=16,
)
_run_horovod(trainer_options, on_gpu=True)
_run_horovod(trainer_options)


@RunIf(min_gpus=2, skip_windows=True, horovod_nccl=True)
Expand All @@ -202,10 +252,11 @@ def test_horovod_gather(tmpdir):
max_epochs=1,
limit_train_batches=0.4,
limit_val_batches=0.2,
gpus=2,
accelerator="gpu",
devices=2,
strategy="horovod",
)
_run_horovod(trainer_options, on_gpu=True)
_run_horovod(trainer_options)


@RunIf(min_gpus=1, skip_windows=True, horovod_nccl=True)
Expand All @@ -227,7 +278,8 @@ def validation_step(self, batch, *args, **kwargs):
max_epochs=1,
limit_train_batches=0.4,
limit_val_batches=0.2,
gpus=1,
accelerator="gpu",
devices=1,
strategy="horovod",
)
tpipes.run_model_test_without_loggers(trainer_options, model)
Expand Down

0 comments on commit 0374fe6

Please sign in to comment.