Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add function to explicitly mark forward methods in Fabric #19690

Merged
merged 26 commits into from May 8, 2024
Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
7425e1e
mark forward method
awaelchli Mar 24, 2024
f40b290
Update src/lightning/fabric/wrappers.py
awaelchli Mar 24, 2024
d01babd
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 24, 2024
e31210c
delete
awaelchli Mar 26, 2024
aed73aa
Merge branch 'master' into feature/mark-forward-method
awaelchli Mar 27, 2024
8ce9068
Merge branch 'master' into feature/mark-forward-method
awaelchli Apr 3, 2024
358d725
docs
awaelchli Apr 3, 2024
6a0823b
the long explanation
awaelchli Apr 3, 2024
92cdef5
update
awaelchli Apr 3, 2024
84b11c8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 3, 2024
e87916f
update
awaelchli Apr 3, 2024
f8c3084
Merge branch 'feature/mark-forward-method' of github.com:Lightning-AI…
awaelchli Apr 3, 2024
0a8ff61
update
awaelchli Apr 3, 2024
8c3c468
test corner cases
awaelchli Apr 4, 2024
44cfdc1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 4, 2024
e0cabd2
chlog
awaelchli Apr 4, 2024
8ebe728
Merge branch 'feature/mark-forward-method' of github.com:Lightning-AI…
awaelchli Apr 4, 2024
285a161
update test
awaelchli Apr 4, 2024
70cf896
update
awaelchli Apr 4, 2024
a4154d5
Merge branch 'master' into feature/mark-forward-method
awaelchli Apr 5, 2024
98967d6
edge case
awaelchli Apr 5, 2024
a5b12b4
Merge branch 'master' into feature/mark-forward-method
awaelchli Apr 29, 2024
54ce4d6
Merge branch 'master' into feature/mark-forward-method
awaelchli May 8, 2024
219c56a
Update docs/source-fabric/api/wrappers.rst
awaelchli May 8, 2024
361fb2c
Update docs/source-fabric/api/wrappers.rst
awaelchli May 8, 2024
bb21cc8
fix docs render backticks
awaelchli May 8, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source-fabric/api/fabric_methods.rst
Expand Up @@ -49,7 +49,7 @@ Moves the model and optimizer to the correct device automatically.


The setup method also prepares the model for the selected precision choice so that operations during ``forward()`` get
cast automatically.
cast automatically. Advanced users should read :doc:`the notes on models wrapped by Fabric <../api/wrappers>`.

setup_dataloaders
=================
Expand Down
147 changes: 147 additions & 0 deletions docs/source-fabric/api/wrappers.rst
@@ -0,0 +1,147 @@
########################
Models wrapped by Fabric
########################

When you :doc:`set up <../api/fabric_methods>` a model in Fabric, it gets automatically wrapped by a new module, the ``FabricModule``:

.. code-block:: python

import torch
import lightning as L

fabric = L.Fabric()
model = torch.nn.Linear(10, 2)
model = fabric.setup(model)

print(type(model)) # <class 'lightning.fabric.wrappers._FabricModule'>

This wrapper module takes care of a few things for you, notably:

- Strategy: Handles strategy-specific logic for the forward method (DDP, FSDP, etc.).
- Precision: Inputs and outputs passed through ``forward`` get automatically converted to the right precision depending on the ``Fabric(precision=...)`` setting.
- Device: The wrapper remembers which device the model is on. You can access it with `model.device`.

.. note::
The FabricModule wrapper is completely transparent and most users will never need to interact with it directly.
awaelchli marked this conversation as resolved.
Show resolved Hide resolved

Below we describe a few functions and properties of the wrapper for advanced use cases.
This might be useful if you are building a custom Trainer using Fabric as the core.


----


********************************
Accessing methods and attributes
********************************

Access to methods and attributes gets redirected to the original model automatically:

.. code-block:: python

import torch
import lightning as L

fabric = L.Fabric()
model = torch.nn.Linear(10, 2)
fabric_model = fabric.setup(model)

# You can access attributes and methods normally
print(fabric_model.weight is model.weight) # True


----


********************
Unwrapping the model
********************

You can check whether a model is wrapped in a ``FabricModule`` with the ``is_wrapped`` utility function:

.. code-block:: python

import torch
import lightning as L
from lightning.fabric import is_wrapped

fabric = L.Fabric()
model = torch.nn.Linear(10, 2)
fabric_model = fabric.setup(model)

print(is_wrapped(model)) # False
print(is_wrapped(fabric_model)) # True


If you ever need to, you can access the original model explicitly via ``.module``:

.. code-block:: python

# Access the original model explicitly
original_model = fabric_model.module

print(original_model is model) # True


----


************************************************
Using methods other than forward for computation
************************************************

PyTorch's ``nn.Modules`` have a special contract you need to follow when using them for training: Your forward computation has to be defined in the **forward** method and you should call this forward method directly.
But sometimes your model may need to define different flavors of forward, like in this example below where the regular forward is used for training, but the `generate` method does something slightly different for inference:
awaelchli marked this conversation as resolved.
Show resolved Hide resolved

.. code-block:: python

import torch
import lightning as L


class MyModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.layer = torch.nn.Linear(10, 2)

def forward(self, x):
return self.layer(x)

def generate(self):
sample = torch.randn(10)
return self(sample)


If you were to run this model in Fabric with multiple devices (DDP or FSDP), you would get an error:

.. code-block:: python

fabric = L.Fabric(accelerator="cpu", devices=2)
fabric.launch()
model = MyModel()
model = fabric.setup(model)

# OK: Calling the model directly
output = model(torch.randn(10))

# OK: Calling the model's forward (equivalent to the abvoe)
output = model.forward(torch.randn(10))

# ERROR: Calling another method that calls forward indirectly
output = model.generate()

Fabric produces an error there informing the user about incorrect usage because this is normally not allowed in PyTorch and could potentially lead to silent correctness bugs.
If you want to use such methods, you need to mark them explicitly with ``.mark_forward_method()`` so that Fabric can do some rerouting behind the scenes for you to do the right thing:

.. code-block:: python

# You must mark special forward methods explicitly:
model.mark_forward_method(model.generate)

# Passing just the name is also sufficient
model.mark_forward_method("generate")

# OK: Fabric will do some rerouting behind the scenes now
output = model.generate()

|
6 changes: 6 additions & 0 deletions docs/source-fabric/glossary/index.rst
Expand Up @@ -8,6 +8,7 @@ Glossary

Checkpoint <../guide/checkpoint/index>
Weights and Biases <../guide/loggers/wandb>
Wrappers <../api/wrappers>


.. raw:: html
Expand Down Expand Up @@ -80,6 +81,11 @@ Glossary
:button_link: ../fundamentals/accelerators.html
:col_css: col-md-4

.. displayitem::
:header: FabricModule
:button_link: ../api/wrappers.html
:col_css: col-md-4

.. displayitem::
:header: FSDP
:button_link: ../advanced/model_parallel/fsdp.html
Expand Down
4 changes: 3 additions & 1 deletion src/lightning/fabric/CHANGELOG.md
Expand Up @@ -9,7 +9,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Enabled consolidating distributed checkpoints through `fabric consolidate` in the new CLI [#19560](https://github.com/Lightning-AI/pytorch-lightning/pull/19560))
- Enabled consolidating distributed checkpoints through `fabric consolidate` in the new CLI ([#19560](https://github.com/Lightning-AI/pytorch-lightning/pull/19560))

- Added the ability to explicitly mark forward methods in Fabric via `_FabricModule.mark_forward_method()` ([#19690](https://github.com/Lightning-AI/pytorch-lightning/pull/19690))

- Added support for PyTorch 2.3 ([#19708](https://github.com/Lightning-AI/pytorch-lightning/pull/19708))

Expand Down
28 changes: 24 additions & 4 deletions src/lightning/fabric/wrappers.py
Expand Up @@ -14,6 +14,7 @@
import inspect
from copy import deepcopy
from functools import partial, wraps
from types import MethodType
from typing import (
Any,
Callable,
Expand Down Expand Up @@ -123,6 +124,7 @@ def __init__(
self._forward_module = forward_module
self._original_module = original_module or forward_module
self._strategy = strategy
self._forward_methods = set(_LIGHTNING_MODULE_STEP_METHODS)
self._fabric_module_initialized = True

@property
Expand Down Expand Up @@ -165,6 +167,20 @@ def load_state_dict( # type: ignore[override]
) -> _IncompatibleKeys:
return self._original_module.load_state_dict(state_dict=state_dict, strict=strict, **kwargs)

def mark_forward_method(self, method: Union[MethodType, str]) -> None:
"""Mark a method as a 'forward' method to prevent it bypassing the strategy wrapper (e.g., DDP)."""
if not isinstance(method, (MethodType, str)):
raise TypeError(f"Expected a method or a string, but got: {type(method).__name__}")
name = method if isinstance(method, str) else method.__name__
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
if name == "forward":
raise ValueError("You cannot mark the forward method itself as a forward method.")
if not isinstance(getattr(self._original_module, name, None), MethodType):
raise AttributeError(
f"You marked '{name}' as a forward method, but `{type(self._original_module).__name__}.{name}` does not"
f" exist or is not a method."
)
self._forward_methods.add(name)

def _redirection_through_forward(self, method_name: str) -> Callable:
assert method_name != "forward"
original_forward = self._original_module.forward
Expand Down Expand Up @@ -207,8 +223,8 @@ def _wrapped_method(*args: Any, **kwargs: Any) -> Any:
if module_called:
raise RuntimeError(
f"You are calling the method `{type(self._original_module).__name__}.{name}()` from outside the"
" model. This will bypass the wrapper from the strategy and result in incorrect behavior in"
" `.backward()`. You should pass your inputs through `forward()`.",
" model. To avoid issues with the currently selected strategy, explicitly mark it as a"
f" forward method with `fabric_model.mark_forward_method({name!r})` after `fabric.setup()`."
)
for handle in handles:
handle.remove()
Expand All @@ -231,8 +247,12 @@ def _register_backward_hook(self, tensor: Tensor) -> Tensor:

@override
def __getattr__(self, item: Any) -> Any:
if item in _LIGHTNING_MODULE_STEP_METHODS and self._forward_module != self._original_module:
# Special support for `LightningModule`, to prevent bypassing DDP's forward
if (
item != "_forward_methods"
and item in self._forward_methods
and self._forward_module != self._original_module
):
# Special support for methods marked by `mark_forward_method` to prevent bypassing DDP's forward
return self._redirection_through_forward(item)

try:
Expand Down
80 changes: 69 additions & 11 deletions tests/tests_fabric/test_wrappers.py
Expand Up @@ -102,15 +102,20 @@ def __init__(self, module):
super().__init__()
self.wrapped = module

def forward(self, *args, **kwargs):
return self.wrapped(*args, **kwargs)

# Regular case: forward_module == original_module -> no warnings
original_module = OriginalModule()
fabric_module = _FabricModule(forward_module=original_module, strategy=Mock(), original_module=original_module)
assert fabric_module.method_without_module_invocation() == 100

# Special case: original module wrapped by forward module: -> warn if method accepts args
# Special case: original module wrapped by forward module: -> error if method requires rerouting
original_module = OriginalModule()
wrapped_module = ModuleWrapper(original_module)
fabric_module = _FabricModule(forward_module=wrapped_module, strategy=Mock(), original_module=original_module)
fabric_module = _FabricModule(
forward_module=wrapped_module, strategy=Mock(precision=Precision()), original_module=original_module
)
assert fabric_module.method_without_module_invocation() == 100
with pytest.raises(
RuntimeError, match=r"You are calling the method `OriginalModule.method_with_submodule_invocation\(\)` from"
Expand All @@ -121,6 +126,51 @@ def __init__(self, module):
):
assert fabric_module.method_with_self_invocation() == 102

# No error if explicitly marked as forward method
fabric_module.mark_forward_method("method_with_self_invocation")
assert fabric_module.method_with_self_invocation() == 102


def test_fabric_module_mark_forward_method():
class OriginalModule(torch.nn.Module):
attribute = 1

def forward(self, x):
return x

def special(self):
pass

original_module = OriginalModule()
fabric_module = _FabricModule(original_module, Mock(), original_module=original_module)

with pytest.raises(ValueError, match="You cannot mark the forward method itself"):
fabric_module.mark_forward_method("forward")

with pytest.raises(AttributeError, match="`OriginalModule.not_exist` does not exist or is not a method."):
fabric_module.mark_forward_method("not_exist")

with pytest.raises(AttributeError, match="`OriginalModule.attribute` does not exist or is not a method."):
fabric_module.mark_forward_method("attribute")

def special(x):
return x

with pytest.raises(TypeError, match="Expected a method or a string"):
fabric_module.mark_forward_method(special)

lightning_module_methods = {"training_step", "validation_step", "test_step", "predict_step"}
assert fabric_module._forward_methods == lightning_module_methods

# Mark via name
fabric_module.mark_forward_method("special")
assert fabric_module._forward_methods == {"special"} | lightning_module_methods

# Mark by passing in the method itself
fabric_module = _FabricModule(original_module, Mock(), original_module=original_module)
fabric_module.mark_forward_method(original_module.special)
assert fabric_module._forward_methods == {"special"} | lightning_module_methods


def test_fabric_module_setattr():
"""Test that setattr sets attributes on the original module."""
Expand Down Expand Up @@ -549,8 +599,8 @@ def test_unwrap_objects(compile):


def test_step_method_redirection():
"""Test that the FabricModule redirects the special `LightningModule.*_step` methods through the forward-
module."""
"""Test that the FabricModule redirects methods marked as 'forward methods' through forward to avoid bypassing the
DDP/FSDP wrappers."""

class DDP(torch.nn.Module):
def __init__(self, module):
Expand All @@ -570,11 +620,11 @@ def training_step(self, arg, kwarg=None):
assert kwarg == "train_kwarg"
return "training_step_return"

def validation_step(self, arg, kwarg=None):
def marked_method(self, arg, kwarg=None):
assert self() == "forward_return"
assert arg == "val_arg"
assert kwarg == "val_kwarg"
return "validation_step_return"
assert arg == "marked_arg"
assert kwarg == "marked_kwarg"
return "marked_method_return"

def normal_method(self):
pass
Expand Down Expand Up @@ -602,18 +652,26 @@ def normal_method(self):
assert original_module.forward.__name__ == "forward"

# The special methods get redirected correctly to produce the expected output
strategy.precision.forward_context.reset_mock()
assert fabric_module.training_step("train_arg", kwarg="train_kwarg") == "training_step_return"
assert fabric_module.training_step("train_arg", kwarg="train_kwarg") == "training_step_return" # call 2nd time
assert fabric_module.validation_step("val_arg", kwarg="val_kwarg") == "validation_step_return"
strategy.precision.forward_context.assert_called()
assert strategy.precision.forward_context.call_count == 2

# Other methods must be marked explicitly to be redirected
strategy.precision.forward_context.reset_mock()
with pytest.raises(RuntimeError, match="You are calling the method .* from outside the model"):
fabric_module.marked_method("marked_arg", kwarg="marked_kwarg")
fabric_module.mark_forward_method("marked_method")
assert fabric_module.marked_method("marked_arg", kwarg="marked_kwarg") == "marked_method_return"
strategy.precision.forward_context.assert_called_once()

# The forward method remains untouched/unpatched after the special methods have been called
assert original_module.forward.__name__ == "forward"

# Special case: forward_module == original_module -> no special treatment applied
fabric_module = _FabricModule(forward_module=original_module, strategy=Mock(), original_module=original_module)
assert fabric_module.training_step == original_module.training_step
assert fabric_module.validation_step == original_module.validation_step
assert fabric_module.marked_method == original_module.marked_method


@RunIf(dynamo=True)
Expand Down