Skip to content

Commit

Permalink
[FSDP] Fix use_orig_params=True + AC (pytorch#87413)
Browse files Browse the repository at this point in the history
Without this change, the post-backward hooks do not run when using reentrant activation checkpointing.

**Explanation**
FSDP registers the original parameters as plain `Tensor`s in the forward pass so that their ops are tracked by autograd to ensure proper gradient propagation into the `FlatParameter`s. FSDP registers the post-backward hooks in its pre-forward.

For `use_orig_params=True`, FSDP replaces the plain `Tensor`s with the sharded `nn.Parameter`s in the post-forward when resharding. This differs from `use_orig_params=False`, which keeps the plain `Tensor`s registered as attributes, except their data are freed, meaning that accessing them between forward and backward errors. Before this PR, for `use_orig_params=True`, FSDP simply restores the unsharded original parameter data in the pre-backward to enable correct gradient computation. However, this does not suffice for reentrant activation checkpointing (AC), where the recomputed forward happens after FSDP's pre-backward and the ops in the recomputed forward must be tracked by autograd.

My initial solution was to simply have FSDP restore the original parameters as plain `Tensor`s again in the pre-backward so that they would be tracked by autograd exactly like the normal forward. However, this seems to not suffice in general. The `FlatParameter`'s `AccumulateGrad` object may change after the original pre-forward when performing a recomputed forward.

The new approach in this PR is to follow the `use_orig_params=False` way -- namely, to preserve the plain `Tensor` variables across forward and backward. I achieved this by saving the variables explicitly in the forward and restoring them in the pre-backward. I clear them in the post-backward to avoid the dangling references (though, I do not think this is strictly necessary).

An alternative approach I considered is using forward hooks. However, this does not change the order of operations across FSDP, checkpoint, and the wrapped module, so it does not work. (As long as the order is FSDP(checkpoint(module)), then registered hooks still happen either before or after the checkpoint recomputation -- we cannot insert logic to run inside the checkpoint recomputation.)

**Test Plan**
I augmented the existing reentrant checkpointing unit tests to also test `use_orig_params=True`. I also verified that the pycls model does not error (even with the new approach).
Pull Request resolved: pytorch#87413
Approved by: https://github.com/rohan-varma
  • Loading branch information
awgu authored and kulinseth committed Nov 5, 2022
1 parent 893f8a1 commit 59cb5ce
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 16 deletions.
32 changes: 23 additions & 9 deletions test/distributed/fsdp/test_fsdp_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,28 +111,35 @@ def _verify_parity(self, losses, outputs, models):
[CPUOffload(offload_params=True), CPUOffload(offload_params=False)],
)
@parametrize("offload_activations", [True, False])
def test_checkpoint_fsdp_wrapping(self, cpu_offload, offload_activations):
@parametrize("use_orig_params", [False, True])
def test_checkpoint_fsdp_wrapping(
self,
cpu_offload: CPUOffload,
offload_activations: bool,
use_orig_params: bool,
):
# Test checkpoint(FSDP(layer1), FSDP(layer2), ....)
if offload_activations:
wrapper_to_use = offload_wrapper
else:
wrapper_to_use = checkpoint_wrapper

fsdp_kwargs = {"cpu_offload": cpu_offload, "use_orig_params": use_orig_params}
ckpt_sequential_wrapped_fsdp = wrapper_to_use(
TestFSDPCheckpoint.SequentialModule(
wrap_fsdp=True, cpu_offload=cpu_offload
wrap_fsdp=True, **fsdp_kwargs,
),
)
# Test FSDP(checkpoint(layer1)), FSDP(checkpoint(layer2)), ....
inner_ckpt = TestFSDPCheckpoint.SequentialModule(
checkpoint_layer=True,
offload_activations=offload_activations,
wrap_fsdp=True,
cpu_offload=cpu_offload,
**fsdp_kwargs,
)

baseline = TestFSDPCheckpoint.SequentialModule(
wrap_fsdp=True, cpu_offload=cpu_offload
wrap_fsdp=True, **fsdp_kwargs,
)

# note that reentrant-based checkpointing requires inputs to have grad
Expand Down Expand Up @@ -168,28 +175,35 @@ def test_checkpoint_fsdp_wrapping(self, cpu_offload, offload_activations):
[CPUOffload(offload_params=True), CPUOffload(offload_params=False)],
)
@parametrize("offload_activations", [True, False])
def test_basic_checkpoint_end_to_end(self, cpu_offload, offload_activations):
@parametrize("use_orig_params", [False, True])
def test_basic_checkpoint_end_to_end(
self,
cpu_offload: CPUOffload,
offload_activations: bool,
use_orig_params: bool,
):
fsdp_kwargs = {"cpu_offload": cpu_offload, "use_orig_params": use_orig_params}
global _save_on_cpu_called
with patch_save_on_cpu(get_patched_save_on_cpu()):
seq = TestFSDPCheckpoint.SequentialModule().to(torch.cuda.current_device())
# Runs FSDP with no checkpointing
fsdp_only_seq = FSDP(deepcopy(seq), cpu_offload=cpu_offload)
fsdp_only_seq = FSDP(deepcopy(seq), **fsdp_kwargs)
# Runs checkpoint-wrapped FSDP
if offload_activations:
wrapper_to_use = offload_wrapper
else:
wrapper_to_use = checkpoint_wrapper

checkpointed_fsdp = wrapper_to_use(
FSDP(deepcopy(seq), cpu_offload=cpu_offload),
FSDP(deepcopy(seq), **fsdp_kwargs),
)
# Runs FSDP-wrapped checkpointed module
fsdp_wrapped_checkpoint = FSDP(
wrapper_to_use(deepcopy(seq)),
cpu_offload=cpu_offload,
**fsdp_kwargs,
)
# Runs FSDP with manual calls to checkpoint.
fsdp_call_checkpoint = FSDP(deepcopy(seq), cpu_offload=cpu_offload)
fsdp_call_checkpoint = FSDP(deepcopy(seq), **fsdp_kwargs)
# note that reentrant-based checkpointing requires inputs to have grad
# flag set.

Expand Down
55 changes: 48 additions & 7 deletions torch/distributed/fsdp/flat_param.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,13 @@ class FlatParameter(nn.Parameter):
_shared_params (Optional[List[nn.Parameter]]): The original shared
parameter variables if ``use_orig_params=True`` and ``None``
otherwise.
_tensors (Optional[List[Optional[Tensor]]]): This saves the ``Tensor``
views created in the forward and tracked by autograd when
``use_orig_params=True`` and is ``None`` otherwise. This is to
preserve those ``Tensor`` variables for the backward to ensure that
the ``FlatParameter`` 's ``AccumulateGrad`` object does not change
in which case the post-backward hook does not run. This is relevant
for cases like reentrant activation checkpointing.
_is_grad_none (Optional[List[bool]]): A mask over the original
parameters' gradients indicating if it is logically ``None`` or not
if ``use_orig_params=True`` and ``None`` otherwise. This is needed
Expand Down Expand Up @@ -273,10 +280,14 @@ def _init_metadata(
self._is_grad_none: Optional[List[bool]] = [
False for _ in range(len(params))
]
self._tensors: Optional[List[Optional[Tensor]]] = [
None for _ in range(len(self._params))
]
else:
self._params = None
self._shared_params = None
self._is_grad_none = None
self._tensors = None
self._unpadded_unsharded_size = self.size()
_set_fsdp_flattened(self)

Expand Down Expand Up @@ -835,11 +846,15 @@ def _use_unsharded_flat_param(
unsharded_size
) # this `.view()` is not autograd visible
in_forward = self._training_state == HandleTrainingState.FORWARD
in_pre_backward = self._training_state == HandleTrainingState.BACKWARD_PRE
if self._use_orig_params:
# NOTE: When not in the forward, `as_params=True` suffices since we
# only need to restore the tensor *values* for backward computation
# and do not fresh `Tensor` views.
self._use_unsharded_views(as_params=(not in_forward))
# We use `Tensor` views in the forward so that they are tracked by
# autograd. We use them in the pre-backward as well to support
# reentrant activation checkpointing, which needs the views to be
# tracked by autograd in the backward pass's recomputed forward.
self._use_unsharded_views(
as_params=(not in_forward and not in_pre_backward)
)
elif in_forward:
self._use_unsharded_views(as_params=False)

Expand Down Expand Up @@ -903,7 +918,9 @@ def unshard_grad(self):
self._check_sharded(flat_param.grad)
flat_param._saved_grad_shard = flat_param.grad # type: ignore[attr-defined]
sharded_grad = flat_param._saved_grad_shard # type: ignore[attr-defined]
dist.all_gather_into_tensor(padded_unsharded_grad, sharded_grad, self.process_group)
dist.all_gather_into_tensor(
padded_unsharded_grad, sharded_grad, self.process_group
)
unsharded_size = self.flat_param._unpadded_unsharded_size
flat_param.grad = padded_unsharded_grad[: unsharded_size.numel()].view(
unsharded_size
Expand Down Expand Up @@ -1198,8 +1215,27 @@ def _use_unsharded_views(self, as_params: bool) -> None:
param.data = view
elif as_params:
module.register_parameter(param_name, nn.Parameter(view))
else:
setattr(module, param_name, view)
else: # `as_params=False`
param_var: Tensor = view
if self._use_orig_params:
if self._training_state == HandleTrainingState.FORWARD:
assert self.flat_param._tensors is not None
# Save the `Tensor` for the pre-backward
self.flat_param._tensors[i] = view # save for pre-backward
elif self._training_state == HandleTrainingState.BACKWARD_PRE:
# Use the saved `Tensor` variable from the forward to
# preserve the autograd graph so that the post-backward
# hook fires (e.g. for reentrant AC)
assert self.flat_param._tensors is not None # mypy
tensor = self.flat_param._tensors[i]
p_assert(
tensor is not None,
"Expects `Tensor` to have been saved in forward",
)
tensor.data = view # type: ignore[union-attr]
assert tensor is not None # mypy
param_var = tensor
setattr(module, param_name, param_var)
for i, (
param_name,
module,
Expand Down Expand Up @@ -1341,6 +1377,11 @@ def _use_sharded_views(self) -> None:
setattr(module, param_name, param)
prim_param = getattr(prim_module, prim_param_name)
param.data = prim_param # could be both empty and non-empty
if self._training_state == HandleTrainingState.BACKWARD_POST:
assert self.flat_param._tensors is not None # mypy
# Clear the saved `Tensor`s since they are unneeded now
for i in range(len(self.flat_param._tensors)):
self.flat_param._tensors[i] = None # type: ignore[index]

@torch.no_grad()
def _use_sharded_grad_views(self) -> None:
Expand Down

0 comments on commit 59cb5ce

Please sign in to comment.