Skip to content

Commit

Permalink
Fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Carlos Mocholí authored and carmocca committed Mar 29, 2022
1 parent 5e612aa commit c671041
Show file tree
Hide file tree
Showing 6 changed files with 31 additions and 37 deletions.
10 changes: 6 additions & 4 deletions pytorch_lightning/overrides/base.py
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
from typing import Any, Union

import torch
import torch.nn as nn
Expand Down Expand Up @@ -57,7 +57,7 @@ def on_post_move_to_device(self) -> None:


class _LightningModuleWrapperBase(DeviceDtypeModuleMixin, torch.nn.Module):
def __init__(self, pl_module: "pl.LightningModule"):
def __init__(self, pl_module: Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]) -> None:
"""Wraps the user's LightningModule and redirects the forward call to the appropriate method, either
``training_step``, ``validation_step``, ``test_step``, or ``predict_step``.
Expand All @@ -74,15 +74,17 @@ def __init__(self, pl_module: "pl.LightningModule"):
self._ddp_params_and_buffers_to_ignore = [f"module.{p}" for p in _ddp_params_and_buffers_to_ignore]

def forward(self, *inputs: Any, **kwargs: Any) -> Any:
trainer = self.module.trainer
pl_module = unwrap_lightning_module(self.module)
trainer = pl_module.trainer

if trainer is not None:
if trainer.training:
output = self.module.training_step(*inputs, **kwargs)
# In manual_optimization, we need to prevent DDP reducer as
# it is done manually in `LightningModule.manual_backward`
# `require_backward_grad_sync` will be reset in the
# ddp_strategy `post_training_step` hook
if not self.module.automatic_optimization:
if not pl_module.automatic_optimization:
trainer.model.require_backward_grad_sync = False # type: ignore[assignment]
return output
if trainer.testing:
Expand Down
18 changes: 10 additions & 8 deletions pytorch_lightning/overrides/data_parallel.py
Expand Up @@ -13,12 +13,12 @@
# limitations under the License.
import numbers
import warnings
from typing import Any, Union
from typing import Any, cast, Union

import torch

import pytorch_lightning as pl
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase, _LightningPrecisionModuleWrapperBase
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.rank_zero import rank_zero_warn

Expand All @@ -36,10 +36,11 @@ def _ignore_scalar_return_in_dp() -> None:

class LightningParallelModule(_LightningModuleWrapperBase):
"""Wraps the user's LightningModule and redirects the forward call to the appropriate method, either
``training_step``, ``validation_step``, ``test_step`` or ``predict``. This class is used in combination with
:class:`~torch.nn.parallel.DataParallel` as shown in the example. It also takes care of converting Python
scalars to Tensors and un-squeezes 0-dimensional Tensors as it is required by
:class:`~torch.nn.parallel.DataParallel`.
``training_step``, ``validation_step``, ``test_step``, or ``predict_step``.
This class is used in combination with :class:`~torch.nn.parallel.DataParallel` as shown in the example.
It also takes care of converting Python scalars to Tensors and un-squeezes 0-dimensional Tensors as it is required
by :class:`~torch.nn.parallel.DataParallel`.
Example:
Expand All @@ -53,7 +54,7 @@ class LightningParallelModule(_LightningModuleWrapperBase):
pl_module: the model to wrap
"""

def __init__(self, pl_module: "pl.LightningModule") -> None:
def __init__(self, pl_module: Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]) -> None:
super().__init__(pl_module)
_ignore_scalar_return_in_dp()

Expand All @@ -63,7 +64,8 @@ def forward(self, *inputs: Any, **kwargs: Any) -> Any:
output = super().forward(*inputs, **kwargs)

def output_transform(data: Any) -> Any:
data = python_scalar_to_tensor(data, self.module.device)
device = cast(torch.device, self.module.device)
data = python_scalar_to_tensor(data, device)
data = unsqueeze_scalar_tensor(data)
return data

Expand Down
20 changes: 1 addition & 19 deletions pytorch_lightning/overrides/distributed.py
Expand Up @@ -19,30 +19,12 @@
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import BatchSampler, DistributedSampler, Sampler

import pytorch_lightning as pl
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase
from pytorch_lightning.utilities import rank_zero_deprecation


class LightningDistributedModule(_LightningModuleWrapperBase):
def __init__(self, pl_module: "pl.LightningModule") -> None:
"""Wraps the user's LightningModule and redirects the forward call to the appropriate method, either
``training_step``, ``validation_step``, ``test_step`` or ``predict``.
This class is used in combination with :class:`~torch.nn.parallel.DistributedDataParallel`.
Example:
ddp_model = torch.nn.parallel.DistributedDataParallel(
module=LightningDistributedModule(lightning_module),
device_ids=[local_rank],
...
)
Args:
pl_module: the model to wrap
"""
super().__init__(pl_module)
...


def _find_tensors(
Expand Down
8 changes: 6 additions & 2 deletions pytorch_lightning/strategies/bagua.py
Expand Up @@ -6,7 +6,11 @@
from torch.nn import Module

import pytorch_lightning as pl
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase, unwrap_lightning_module
from pytorch_lightning.overrides.base import (
_LightningModuleWrapperBase,
_LightningPrecisionModuleWrapperBase,
unwrap_lightning_module,
)
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
from pytorch_lightning.plugins.precision import PrecisionPlugin
Expand All @@ -32,7 +36,7 @@


class LightningBaguaModule(_LightningModuleWrapperBase):
def __init__(self, pl_module: "pl.LightningModule") -> None:
def __init__(self, pl_module: Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]) -> None:
super().__init__(pl_module)
# Bagua use `bagua_module_name` to distinguish different modules
self._bagua_module_name = f"{pl_module.__class__.__name__}{id(pl_module)}"
Expand Down
6 changes: 4 additions & 2 deletions pytorch_lightning/strategies/deepspeed.py
Expand Up @@ -27,7 +27,7 @@

import pytorch_lightning as pl
from pytorch_lightning.core.optimizer import _init_optimizers_and_lr_schedulers
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase, _LightningPrecisionModuleWrapperBase
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
from pytorch_lightning.plugins.precision import PrecisionPlugin
from pytorch_lightning.strategies.ddp import DDPStrategy
Expand Down Expand Up @@ -63,7 +63,9 @@ def remove_module_hooks(model: torch.nn.Module) -> None:


class LightningDeepSpeedModule(_LightningModuleWrapperBase):
def __init__(self, pl_module: "pl.LightningModule", precision: int) -> None:
def __init__(
self, pl_module: Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase], precision: int
) -> None:
super().__init__(pl_module)
self.precision = precision

Expand Down
6 changes: 4 additions & 2 deletions pytorch_lightning/strategies/ipu.py
Expand Up @@ -19,7 +19,7 @@
from torch.utils.data import DataLoader

import pytorch_lightning as pl
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase, _LightningPrecisionModuleWrapperBase
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
from pytorch_lightning.plugins.precision import PrecisionPlugin
Expand All @@ -40,7 +40,9 @@


class LightningIPUModule(_LightningModuleWrapperBase):
def __init__(self, pl_module: "pl.LightningModule", precision: Union[str, int]):
def __init__(
self, pl_module: Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase], precision: Union[str, int]
) -> None:
super().__init__(pl_module)
self.precision = precision

Expand Down

0 comments on commit c671041

Please sign in to comment.