From c671041095fbf2f91d4934e899cacc02fe635098 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Mon, 28 Mar 2022 18:03:04 +0200 Subject: [PATCH] Fix --- pytorch_lightning/overrides/base.py | 10 ++++++---- pytorch_lightning/overrides/data_parallel.py | 18 ++++++++++-------- pytorch_lightning/overrides/distributed.py | 20 +------------------- pytorch_lightning/strategies/bagua.py | 8 ++++++-- pytorch_lightning/strategies/deepspeed.py | 6 ++++-- pytorch_lightning/strategies/ipu.py | 6 ++++-- 6 files changed, 31 insertions(+), 37 deletions(-) diff --git a/pytorch_lightning/overrides/base.py b/pytorch_lightning/overrides/base.py index ae567e1abc02e..727da4737107a 100644 --- a/pytorch_lightning/overrides/base.py +++ b/pytorch_lightning/overrides/base.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any +from typing import Any, Union import torch import torch.nn as nn @@ -57,7 +57,7 @@ def on_post_move_to_device(self) -> None: class _LightningModuleWrapperBase(DeviceDtypeModuleMixin, torch.nn.Module): - def __init__(self, pl_module: "pl.LightningModule"): + def __init__(self, pl_module: Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]) -> None: """Wraps the user's LightningModule and redirects the forward call to the appropriate method, either ``training_step``, ``validation_step``, ``test_step``, or ``predict_step``. @@ -74,7 +74,9 @@ def __init__(self, pl_module: "pl.LightningModule"): self._ddp_params_and_buffers_to_ignore = [f"module.{p}" for p in _ddp_params_and_buffers_to_ignore] def forward(self, *inputs: Any, **kwargs: Any) -> Any: - trainer = self.module.trainer + pl_module = unwrap_lightning_module(self.module) + trainer = pl_module.trainer + if trainer is not None: if trainer.training: output = self.module.training_step(*inputs, **kwargs) @@ -82,7 +84,7 @@ def forward(self, *inputs: Any, **kwargs: Any) -> Any: # it is done manually in `LightningModule.manual_backward` # `require_backward_grad_sync` will be reset in the # ddp_strategy `post_training_step` hook - if not self.module.automatic_optimization: + if not pl_module.automatic_optimization: trainer.model.require_backward_grad_sync = False # type: ignore[assignment] return output if trainer.testing: diff --git a/pytorch_lightning/overrides/data_parallel.py b/pytorch_lightning/overrides/data_parallel.py index ea44bc0683648..2d9a7d9a3acca 100644 --- a/pytorch_lightning/overrides/data_parallel.py +++ b/pytorch_lightning/overrides/data_parallel.py @@ -13,12 +13,12 @@ # limitations under the License. import numbers import warnings -from typing import Any, Union +from typing import Any, cast, Union import torch import pytorch_lightning as pl -from pytorch_lightning.overrides.base import _LightningModuleWrapperBase +from pytorch_lightning.overrides.base import _LightningModuleWrapperBase, _LightningPrecisionModuleWrapperBase from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.utilities.rank_zero import rank_zero_warn @@ -36,10 +36,11 @@ def _ignore_scalar_return_in_dp() -> None: class LightningParallelModule(_LightningModuleWrapperBase): """Wraps the user's LightningModule and redirects the forward call to the appropriate method, either - ``training_step``, ``validation_step``, ``test_step`` or ``predict``. This class is used in combination with - :class:`~torch.nn.parallel.DataParallel` as shown in the example. It also takes care of converting Python - scalars to Tensors and un-squeezes 0-dimensional Tensors as it is required by - :class:`~torch.nn.parallel.DataParallel`. + ``training_step``, ``validation_step``, ``test_step``, or ``predict_step``. + + This class is used in combination with :class:`~torch.nn.parallel.DataParallel` as shown in the example. + It also takes care of converting Python scalars to Tensors and un-squeezes 0-dimensional Tensors as it is required + by :class:`~torch.nn.parallel.DataParallel`. Example: @@ -53,7 +54,7 @@ class LightningParallelModule(_LightningModuleWrapperBase): pl_module: the model to wrap """ - def __init__(self, pl_module: "pl.LightningModule") -> None: + def __init__(self, pl_module: Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]) -> None: super().__init__(pl_module) _ignore_scalar_return_in_dp() @@ -63,7 +64,8 @@ def forward(self, *inputs: Any, **kwargs: Any) -> Any: output = super().forward(*inputs, **kwargs) def output_transform(data: Any) -> Any: - data = python_scalar_to_tensor(data, self.module.device) + device = cast(torch.device, self.module.device) + data = python_scalar_to_tensor(data, device) data = unsqueeze_scalar_tensor(data) return data diff --git a/pytorch_lightning/overrides/distributed.py b/pytorch_lightning/overrides/distributed.py index 35f7993bc072d..be8f972132808 100644 --- a/pytorch_lightning/overrides/distributed.py +++ b/pytorch_lightning/overrides/distributed.py @@ -19,30 +19,12 @@ from torch.nn.parallel import DistributedDataParallel from torch.utils.data import BatchSampler, DistributedSampler, Sampler -import pytorch_lightning as pl from pytorch_lightning.overrides.base import _LightningModuleWrapperBase from pytorch_lightning.utilities import rank_zero_deprecation class LightningDistributedModule(_LightningModuleWrapperBase): - def __init__(self, pl_module: "pl.LightningModule") -> None: - """Wraps the user's LightningModule and redirects the forward call to the appropriate method, either - ``training_step``, ``validation_step``, ``test_step`` or ``predict``. - - This class is used in combination with :class:`~torch.nn.parallel.DistributedDataParallel`. - - Example: - - ddp_model = torch.nn.parallel.DistributedDataParallel( - module=LightningDistributedModule(lightning_module), - device_ids=[local_rank], - ... - ) - - Args: - pl_module: the model to wrap - """ - super().__init__(pl_module) + ... def _find_tensors( diff --git a/pytorch_lightning/strategies/bagua.py b/pytorch_lightning/strategies/bagua.py index ad6a142917b40..61485395f0aea 100644 --- a/pytorch_lightning/strategies/bagua.py +++ b/pytorch_lightning/strategies/bagua.py @@ -6,7 +6,11 @@ from torch.nn import Module import pytorch_lightning as pl -from pytorch_lightning.overrides.base import _LightningModuleWrapperBase, unwrap_lightning_module +from pytorch_lightning.overrides.base import ( + _LightningModuleWrapperBase, + _LightningPrecisionModuleWrapperBase, + unwrap_lightning_module, +) from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.precision import PrecisionPlugin @@ -32,7 +36,7 @@ class LightningBaguaModule(_LightningModuleWrapperBase): - def __init__(self, pl_module: "pl.LightningModule") -> None: + def __init__(self, pl_module: Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]) -> None: super().__init__(pl_module) # Bagua use `bagua_module_name` to distinguish different modules self._bagua_module_name = f"{pl_module.__class__.__name__}{id(pl_module)}" diff --git a/pytorch_lightning/strategies/deepspeed.py b/pytorch_lightning/strategies/deepspeed.py index 35292c50ccd80..1f284aabc2ea1 100644 --- a/pytorch_lightning/strategies/deepspeed.py +++ b/pytorch_lightning/strategies/deepspeed.py @@ -27,7 +27,7 @@ import pytorch_lightning as pl from pytorch_lightning.core.optimizer import _init_optimizers_and_lr_schedulers -from pytorch_lightning.overrides.base import _LightningModuleWrapperBase +from pytorch_lightning.overrides.base import _LightningModuleWrapperBase, _LightningPrecisionModuleWrapperBase from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment from pytorch_lightning.plugins.precision import PrecisionPlugin from pytorch_lightning.strategies.ddp import DDPStrategy @@ -63,7 +63,9 @@ def remove_module_hooks(model: torch.nn.Module) -> None: class LightningDeepSpeedModule(_LightningModuleWrapperBase): - def __init__(self, pl_module: "pl.LightningModule", precision: int) -> None: + def __init__( + self, pl_module: Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase], precision: int + ) -> None: super().__init__(pl_module) self.precision = precision diff --git a/pytorch_lightning/strategies/ipu.py b/pytorch_lightning/strategies/ipu.py index cc72313a86e39..4603110c01536 100644 --- a/pytorch_lightning/strategies/ipu.py +++ b/pytorch_lightning/strategies/ipu.py @@ -19,7 +19,7 @@ from torch.utils.data import DataLoader import pytorch_lightning as pl -from pytorch_lightning.overrides.base import _LightningModuleWrapperBase +from pytorch_lightning.overrides.base import _LightningModuleWrapperBase, _LightningPrecisionModuleWrapperBase from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.precision import PrecisionPlugin @@ -40,7 +40,9 @@ class LightningIPUModule(_LightningModuleWrapperBase): - def __init__(self, pl_module: "pl.LightningModule", precision: Union[str, int]): + def __init__( + self, pl_module: Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase], precision: Union[str, int] + ) -> None: super().__init__(pl_module) self.precision = precision