Skip to content

Commit

Permalink
Refactor Strategy._move_optimizer_states as utility functions (#11758)
Browse files Browse the repository at this point in the history
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: thomas chaton <thomas@grid.ai>
  • Loading branch information
3 people committed Feb 18, 2022
1 parent d613719 commit cf64f34
Show file tree
Hide file tree
Showing 9 changed files with 78 additions and 25 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `Accelerator.is_available` to check device availability ([#11797](https://github.com/PyTorchLightning/pytorch-lightning/pull/11797))


- Added utility functions for moving optimizers to devices ([#11758](https://github.com/PyTorchLightning/pytorch-lightning/pull/11758))


### Changed

- Implemented a new native and rich format in `_print_results` method of the `EvaluationLoop` ([#11332](https://github.com/PyTorchLightning/pytorch-lightning/pull/11332))
Expand Down
1 change: 1 addition & 0 deletions docs/source/api_references.rst
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,7 @@ Utilities API
finite_checks
memory
model_summary
optimizer
parsing
rank_zero
seed
Expand Down
3 changes: 2 additions & 1 deletion pytorch_lightning/strategies/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _DEEPSPEED_AVAILABLE
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.optimizer import optimizers_to_device
from pytorch_lightning.utilities.rank_zero import rank_zero_info
from pytorch_lightning.utilities.seed import reset_seed
from pytorch_lightning.utilities.types import _PATH, LRSchedulerConfig, LRSchedulerTypeUnion, STEP_OUTPUT
Expand Down Expand Up @@ -349,7 +350,7 @@ def setup(self, trainer: "pl.Trainer") -> None:
self.accelerator.setup(trainer)
self.setup_optimizers(trainer)
self.setup_precision_plugin()
self._move_optimizer_state()
optimizers_to_device(self.optimizers, self.root_device)
self.init_deepspeed()
self.barrier()

Expand Down
3 changes: 2 additions & 1 deletion pytorch_lightning/strategies/fully_sharded.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from pytorch_lightning.utilities import _FAIRSCALE_FULLY_SHARDED_AVAILABLE
from pytorch_lightning.utilities.enums import PrecisionType
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.optimizer import optimizers_to_device
from pytorch_lightning.utilities.types import STEP_OUTPUT

if _FAIRSCALE_FULLY_SHARDED_AVAILABLE:
Expand Down Expand Up @@ -136,7 +137,7 @@ def setup(self, trainer: "pl.Trainer") -> None:
self.accelerator.setup(trainer)
self.setup_optimizers(trainer)
self.setup_precision_plugin()
self._move_optimizer_state()
optimizers_to_device(self.optimizers, self.root_device)

if self.sync_batchnorm:
self.model = self.configure_sync_batchnorm(self.model)
Expand Down
16 changes: 5 additions & 11 deletions pytorch_lightning/strategies/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,10 @@
from pytorch_lightning.strategies.launchers.base import _Launcher
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities import rank_zero_deprecation
from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device
from pytorch_lightning.utilities.apply_func import move_data_to_device
from pytorch_lightning.utilities.distributed import ReduceOp
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.optimizer import optimizer_to_device, optimizers_to_device
from pytorch_lightning.utilities.types import _PATH, LRSchedulerConfig, STEP_OUTPUT

TBroadcast = TypeVar("TBroadcast")
Expand Down Expand Up @@ -138,7 +139,7 @@ def setup(self, trainer: "pl.Trainer") -> None:
self.accelerator.setup(trainer)
self.setup_optimizers(trainer)
self.setup_precision_plugin()
self._move_optimizer_state()
optimizers_to_device(self.optimizers, self.root_device)

def setup_precision_plugin(self) -> None:
"""Attaches the precision plugin to the accelerator."""
Expand All @@ -149,14 +150,6 @@ def setup_precision_plugin(self) -> None:
self.optimizers = optimizers
self.lr_scheduler_configs = lr_scheduler_configs

def _move_optimizer_state(self, device: Optional[torch.device] = None) -> None:
"""Moves the state of the optimizers to the appropriate device if needed."""
for opt in self.optimizers:
for p, v in opt.state.items():
# `self.root_device` would raise error if called outside the spawn process
# while training on 8 and more cores.
opt.state[p] = apply_to_collection(v, torch.Tensor, move_data_to_device, device or self.root_device)

def optimizer_state(self, optimizer: Optimizer) -> Dict[str, Tensor]:
"""Returns state of an optimizer.
Expand Down Expand Up @@ -330,6 +323,7 @@ def load_optimizer_state_dict(self, checkpoint: Mapping[str, Any]) -> None:
optimizer_states = checkpoint["optimizer_states"]
for optimizer, opt_state in zip(self.optimizers, optimizer_states):
optimizer.load_state_dict(opt_state)
optimizer_to_device(optimizer, self.root_device)

def training_step(self, *args, **kwargs) -> STEP_OUTPUT:
"""The actual training step.
Expand Down Expand Up @@ -445,7 +439,7 @@ def teardown(self) -> None:
It is the right place to release memory and free other resources.
"""
self._move_optimizer_state(torch.device("cpu"))
optimizers_to_device(self.optimizers, torch.device("cpu"))
self.precision_plugin.teardown()

@classmethod
Expand Down
3 changes: 2 additions & 1 deletion pytorch_lightning/strategies/tpu_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from pytorch_lightning.utilities.distributed import ReduceOp
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.optimizer import optimizers_to_device
from pytorch_lightning.utilities.rank_zero import rank_zero_only
from pytorch_lightning.utilities.seed import reset_seed
from pytorch_lightning.utilities.types import _PATH, STEP_OUTPUT
Expand Down Expand Up @@ -126,7 +127,7 @@ def setup(self, trainer: "pl.Trainer") -> None:
self.accelerator.setup(trainer)
self.setup_optimizers(trainer)
self.setup_precision_plugin()
self._move_optimizer_state()
optimizers_to_device(self.optimizers, self.root_device)

if self.debug:
os.environ["PT_XLA_DEBUG"] = str(1)
Expand Down
11 changes: 0 additions & 11 deletions pytorch_lightning/trainer/connectors/checkpoint_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,17 +296,6 @@ def restore_optimizers(self) -> None:

# restore the optimizers
self.trainer.strategy.load_optimizer_state_dict(self._loaded_checkpoint)
for optimizer in self.trainer.optimizers:
# move optimizer to GPU 1 weight at a time
# avoids OOM
if self.trainer.root_gpu is not None:
for param, state in optimizer.state.items():
if isinstance(state, dict):
for k, v in state.items():
if isinstance(v, torch.Tensor):
state[k] = v.cuda(self.trainer.root_gpu)
elif isinstance(state, torch.Tensor):
optimizer.state[param] = state.cuda(self.trainer.root_gpu)

def restore_lr_schedulers(self) -> None:
"""Restores the learning rate scheduler states from the pre-loaded checkpoint."""
Expand Down
33 changes: 33 additions & 0 deletions pytorch_lightning/utilities/optimizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Iterable

import torch
from torch.optim import Optimizer

from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device
from pytorch_lightning.utilities.types import _DEVICE


def optimizers_to_device(optimizers: Iterable[Optimizer], device: _DEVICE) -> None:
"""Moves optimizer states for a sequence of optimizers to the device."""
for opt in optimizers:
optimizer_to_device(opt, device)


def optimizer_to_device(optimizer: Optimizer, device: _DEVICE) -> None:
"""Moves the state of a single optimizer to the device."""
for p, v in optimizer.state.items():
optimizer.state[p] = apply_to_collection(v, torch.Tensor, move_data_to_device, device)
30 changes: 30 additions & 0 deletions tests/utilities/test_optimizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import collections

import torch

from pytorch_lightning.utilities.optimizer import optimizer_to_device


def test_optimizer_to_device():
class TestOptimizer(torch.optim.SGD):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.state["dummy"] = torch.tensor(0)

layer = torch.nn.Linear(32, 2)
opt = TestOptimizer(layer.parameters(), lr=0.1)
optimizer_to_device(opt, "cpu")
if torch.cuda.is_available():
optimizer_to_device(opt, "cuda")
assert_opt_parameters_on_device(opt, "cuda")


def assert_opt_parameters_on_device(opt, device: str):
for param in opt.state.values():
# Not sure there are any global tensors in the state dict
if isinstance(param, torch.Tensor):
assert param.data.device.type == device
elif isinstance(param, collections.Mapping):
for subparam in param.values():
if isinstance(subparam, torch.Tensor):
assert param.data.device.type == device

0 comments on commit cf64f34

Please sign in to comment.