Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add function to explicitly mark forward methods in Fabric #19690

Merged
merged 26 commits into from May 8, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
7425e1e
mark forward method
awaelchli Mar 24, 2024
f40b290
Update src/lightning/fabric/wrappers.py
awaelchli Mar 24, 2024
d01babd
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 24, 2024
e31210c
delete
awaelchli Mar 26, 2024
aed73aa
Merge branch 'master' into feature/mark-forward-method
awaelchli Mar 27, 2024
8ce9068
Merge branch 'master' into feature/mark-forward-method
awaelchli Apr 3, 2024
358d725
docs
awaelchli Apr 3, 2024
6a0823b
the long explanation
awaelchli Apr 3, 2024
92cdef5
update
awaelchli Apr 3, 2024
84b11c8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 3, 2024
e87916f
update
awaelchli Apr 3, 2024
f8c3084
Merge branch 'feature/mark-forward-method' of github.com:Lightning-AI…
awaelchli Apr 3, 2024
0a8ff61
update
awaelchli Apr 3, 2024
8c3c468
test corner cases
awaelchli Apr 4, 2024
44cfdc1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 4, 2024
e0cabd2
chlog
awaelchli Apr 4, 2024
8ebe728
Merge branch 'feature/mark-forward-method' of github.com:Lightning-AI…
awaelchli Apr 4, 2024
285a161
update test
awaelchli Apr 4, 2024
70cf896
update
awaelchli Apr 4, 2024
a4154d5
Merge branch 'master' into feature/mark-forward-method
awaelchli Apr 5, 2024
98967d6
edge case
awaelchli Apr 5, 2024
a5b12b4
Merge branch 'master' into feature/mark-forward-method
awaelchli Apr 29, 2024
54ce4d6
Merge branch 'master' into feature/mark-forward-method
awaelchli May 8, 2024
219c56a
Update docs/source-fabric/api/wrappers.rst
awaelchli May 8, 2024
361fb2c
Update docs/source-fabric/api/wrappers.rst
awaelchli May 8, 2024
bb21cc8
fix docs render backticks
awaelchli May 8, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
18 changes: 13 additions & 5 deletions src/lightning/fabric/wrappers.py
Expand Up @@ -14,6 +14,7 @@
import inspect
from copy import deepcopy
from functools import partial, wraps
from types import MethodType
from typing import (
TYPE_CHECKING,
Any,
Expand All @@ -27,7 +28,7 @@
Tuple,
TypeVar,
Union,
overload,
overload, Set,
)

import torch
Expand Down Expand Up @@ -127,6 +128,7 @@ def __init__(
self._forward_module = forward_module
self._original_module = original_module or forward_module
self._strategy = strategy
self._forward_methods: Set[str] = set(_LIGHTNING_MODULE_STEP_METHODS)
self._fabric_module_initialized = True

@property
Expand All @@ -137,6 +139,7 @@ def module(self) -> nn.Module:
def forward(self, *args: Any, **kwargs: Any) -> Any:
"""Casts all inputs to the right precision and handles autocast for operations in the module forward method."""
precision = self._strategy.precision

awaelchli marked this conversation as resolved.
Show resolved Hide resolved
args, kwargs = precision.convert_input((args, kwargs))

with precision.forward_context():
Expand Down Expand Up @@ -169,6 +172,11 @@ def load_state_dict( # type: ignore[override]
) -> _IncompatibleKeys:
return self._original_module.load_state_dict(state_dict=state_dict, strict=strict, **kwargs)

def mark_forward_method(self, method: Union[MethodType, str]) -> None:
name = method if isinstance(method, str) else method.__name__
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
assert hasattr(self._original_module, name)
self._forward_methods.add(name)

def _redirection_through_forward(self, method_name: str) -> Callable:
assert method_name != "forward"
original_forward = self._original_module.forward
Expand Down Expand Up @@ -211,8 +219,8 @@ def _wrapped_method(*args: Any, **kwargs: Any) -> Any:
if module_called:
raise RuntimeError(
f"You are calling the method `{type(self._original_module).__name__}.{name}()` from outside the"
" model. This will bypass the wrapper from the strategy and result in incorrect behavior in"
" `.backward()`. You should pass your inputs through `forward()`.",
" model. To avoid issues with the currently selected strategy, explicitly mark it as a"
f" forward method with `fabric_model.mark_forward_method({name!r})` after `fabric.setup()`."
)
for handle in handles:
handle.remove()
Expand All @@ -235,8 +243,8 @@ def _register_backward_hook(self, tensor: Tensor) -> Tensor:

@override
def __getattr__(self, item: Any) -> Any:
if item in _LIGHTNING_MODULE_STEP_METHODS and self._forward_module != self._original_module:
# Special support for `LightningModule`, to prevent bypassing DDP's forward
if item != "_forward_methods" and item in self._forward_methods and self._forward_module != self._original_module:
# Special support for methods marked by `mark_forward_method` to prevent bypassing DDP's forward
return self._redirection_through_forward(item)

try:
Expand Down
30 changes: 21 additions & 9 deletions tests/tests_fabric/test_wrappers.py
Expand Up @@ -120,6 +120,10 @@ def __init__(self, module):
):
assert fabric_module.method_with_self_invocation() == 102

# No error if explicitly marked as forward method
fabric_module.mark_forward_method("method_with_self_invocation")
assert fabric_module.method_with_self_invocation() == 102


def test_fabric_module_setattr():
"""Test that setattr sets attributes on the original module."""
Expand Down Expand Up @@ -530,8 +534,8 @@ def test_unwrap_objects(compile):


def test_step_method_redirection():
"""Test that the FabricModule redirects the special `LightningModule.*_step` methods through the forward-
module."""
"""Test that the FabricModule redirects methods marked as 'forward methods' through forward to avoid bypassing
the DDP/FSDP wrappers."""

class DDP(torch.nn.Module):
def __init__(self, module):
Expand All @@ -551,11 +555,11 @@ def training_step(self, arg, kwarg=None):
assert kwarg == "train_kwarg"
return "training_step_return"

def validation_step(self, arg, kwarg=None):
def marked_method(self, arg, kwarg=None):
assert self() == "forward_return"
assert arg == "val_arg"
assert kwarg == "val_kwarg"
return "validation_step_return"
assert arg == "marked_arg"
assert kwarg == "marked_kwarg"
return "marked_method_return"

def normal_method(self):
pass
Expand Down Expand Up @@ -583,18 +587,26 @@ def normal_method(self):
assert original_module.forward.__name__ == "forward"

# The special methods get redirected correctly to produce the expected output
strategy.precision.forward_context.reset_mock()
assert fabric_module.training_step("train_arg", kwarg="train_kwarg") == "training_step_return"
assert fabric_module.training_step("train_arg", kwarg="train_kwarg") == "training_step_return" # call 2nd time
assert fabric_module.validation_step("val_arg", kwarg="val_kwarg") == "validation_step_return"
strategy.precision.forward_context.assert_called()
assert strategy.precision.forward_context.call_count == 2

# Other methods must be marked explicitly to be redirected
strategy.precision.forward_context.reset_mock()
with pytest.raises(RuntimeError, match="You are calling the method .* from outside the model"):
fabric_module.marked_method("marked_arg", kwarg="marked_kwarg")
fabric_module.mark_forward_method("marked_method")
assert fabric_module.marked_method("marked_arg", kwarg="marked_kwarg") == "marked_method_return"
strategy.precision.forward_context.assert_called_once()

# The forward method remains untouched/unpatched after the special methods have been called
assert original_module.forward.__name__ == "forward"

# Special case: forward_module == original_module -> no special treatment applied
fabric_module = _FabricModule(forward_module=original_module, strategy=Mock(), original_module=original_module)
assert fabric_module.training_step == original_module.training_step
assert fabric_module.validation_step == original_module.validation_step
assert fabric_module.marked_method == original_module.marked_method


@RunIf(dynamo=True)
Expand Down