Skip to content

Commit

Permalink
[v1.12.0] Fix non-reentrant hooks based checkpointing (#79490)
Browse files Browse the repository at this point in the history
* merge fix

* Test fix

* Lint
  • Loading branch information
rohan-varma committed Jun 17, 2022
1 parent 92437c6 commit 681a6e3
Show file tree
Hide file tree
Showing 4 changed files with 251 additions and 15 deletions.
111 changes: 111 additions & 0 deletions test/distributed/fsdp/test_checkpoint_wrapper.py
@@ -0,0 +1,111 @@
# Owner(s): ["oncall: distributed"]

from copy import deepcopy
from functools import partial

import torch
import torch.nn as nn
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
checkpoint_wrapper,
CheckpointImpl
)

from torch.utils.checkpoint import checkpoint

from torch.testing._internal.common_utils import (
run_tests,
TestCase,
)

import unittest

class CheckpointWrapperTest(TestCase):
def setUp(self):
super().setUp()

def test_load_activation_checkpointed_module(self):
lin = nn.Linear(10, 10, bias=False)
lin = checkpoint_wrapper(lin)
state_dict = deepcopy(lin.state_dict())
# Load into non-checkpoint wrapped linear module
lin_new = nn.Linear(10, 10, bias=False)
lin_new.load_state_dict(state_dict)
for p1, p2 in zip(lin.parameters(), lin_new.parameters()):
self.assertEqual(p1, p2)
self.assertTrue(torch.allclose(p1, p2))

# Load non-checkpoint wrapped module into checkpoint wrapped one
# Make params different
for p in lin_new.parameters():
with torch.no_grad():
p.add_(0.5)

state_dict = deepcopy(lin_new.state_dict())
# Verify checkpoint wrapped linear can load unwrapped linear
lin.load_state_dict(state_dict)
for p1, p2 in zip(lin.parameters(), lin_new.parameters()):
self.assertEqual(p1, p2)

@unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA")
def test_checkpoint_wrapper_parity(self):
"""
Tests that using checkpoint_wrapper or the functional
torch.utils.checkpoint (with the same reentrant config)
results in the same maximum memory usage, i.e. they are
equivalent memory usage wise.
"""
class Model(nn.Module):
def __init__(
self,
n: int,
use_cp: bool,
use_wrapper: bool = False,
use_reentrant: bool = True
):
super().__init__()
self.layers = nn.ModuleList()
self.n = n
self.use_cp = use_cp
self.use_wrapper = use_wrapper
self.use_reentrant = use_reentrant
wrp = partial(
checkpoint_wrapper,
checkpoint_impl=CheckpointImpl.REENTRANT if use_reentrant else CheckpointImpl.NO_REENTRANT
)
for i in range(self.n):
l = nn.Sequential(nn.Linear(256, 256), nn.Linear(256, 256), nn.Linear(256, 256))
use_checkpoint_wrapper = self.use_wrapper
if use_checkpoint_wrapper:
l = wrp(l)
self.layers.append(l)

def forward(self, x):
for i in range(self.n):
if (
self.use_wrapper or
not self.use_cp
):
x = self.layers[i](x)
else:
x = checkpoint(self.layers[i], x, use_reentrant=self.use_reentrant)
return x

def test(use_checkpointing, use_wrapper, use_reentrant):
a = Model(8, use_checkpointing, use_wrapper=use_wrapper, use_reentrant=use_reentrant).cuda()
x = torch.randn(10000, 256, requires_grad=True).cuda()
torch.cuda.reset_peak_memory_stats()
loss = a(x).sum()
loss.backward()
return torch.cuda.max_memory_allocated()

functional_no_reentrant = test(use_checkpointing=True, use_wrapper=False, use_reentrant=False)
wrapper_no_reentrant = test(use_checkpointing=False, use_wrapper=True, use_reentrant=False)
self.assertEqual(functional_no_reentrant, wrapper_no_reentrant)

functional_reentrant = test(use_checkpointing=True, use_wrapper=False, use_reentrant=True)
wrapper_reentrant = test(use_checkpointing=False, use_wrapper=True, use_reentrant=True)
self.assertEqual(functional_no_reentrant, wrapper_no_reentrant)


if __name__ == "__main__":
run_tests()
126 changes: 126 additions & 0 deletions test/test_autograd.py
Expand Up @@ -4434,6 +4434,132 @@ def test_checkpointing(self):
mean_combined = torch.stack(feat_combined).mean()
mean_combined.backward()

@unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA")
@slowTest
def test_checkpointing_without_reentrant_memory_savings(self):
class MyModel(nn.Module):
def __init__(self, n, use_checkpoint, use_reentrant):
super().__init__()
self.n = n
self.use_checkpoint = use_checkpoint
self.use_reentrant = use_reentrant
self.layers = nn.ModuleList()
for i in range(self.n):
layer = nn.Sequential(
nn.Linear(256, 256), nn.Linear(256, 256), nn.Linear(256, 256)
)
self.layers.append(layer)
# pre-allocate the grad so that increased memory usage is mainly
# due to activations.
for layer in self.layers:
for lin in layer:
lin.weight.grad = torch.ones_like(lin.weight)
lin.bias.grad = torch.ones_like(lin.bias)

def forward(self, x):
for i in range(self.n):
if not self.use_checkpoint:
x = self.layers[i](x)
else:
x = checkpoint(self.layers[i], x, use_reentrant=self.use_reentrant)

return x

model_no_checkpoint = MyModel(8, use_checkpoint=False, use_reentrant=False).cuda()
model_reentrant_checkpoint = MyModel(8, use_checkpoint=True, use_reentrant=True).cuda()
model_no_reentrant_checkpoint = MyModel(8, use_checkpoint=True, use_reentrant=False).cuda()

x = torch.randn(100, 256, requires_grad=True, device='cuda')

torch.cuda.reset_peak_memory_stats()
loss = model_no_checkpoint(x.clone()).sum()
loss.backward()
mem_no_checkpoint = torch.cuda.max_memory_allocated()

torch.cuda.reset_peak_memory_stats()
loss = model_reentrant_checkpoint(x.clone()).sum()
loss.backward()
mem_reentrant_checkpoint = torch.cuda.max_memory_allocated()

torch.cuda.reset_peak_memory_stats()
loss = model_no_reentrant_checkpoint(x.clone()).sum()
loss.backward()
mem_no_reentrant_checkpoint = torch.cuda.max_memory_allocated()

self.assertTrue(mem_reentrant_checkpoint < mem_no_checkpoint)
self.assertTrue(mem_no_reentrant_checkpoint < mem_no_checkpoint)

def test_checkpointing_without_reentrant_custom_function_raises(self):
"""
Accessing ctx.saved_tensors multiple times in a custom function
backward pass with non-reentrant checkpoint currently throws due to
saved tensors not being recomputed in between the accesses.
"""
# For verifying first access to ctx.saved_tensors succeeded.

_first_saved_tensor_access_succeeded = False

class MyFunc(torch.autograd.Function):
@staticmethod
def forward(ctx, x, y, z):
w = x * y * z
out = w + w
ctx.save_for_backward(x, y, z, w, out)
return out

@staticmethod
def backward(ctx, grad_out):
x, y, z, w, out = ctx.saved_tensors
nonlocal _first_saved_tensor_access_succeeded
_first_saved_tensor_access_succeeded = True
# Raises issue in non-reentrant checkpointing where
# second access to saved tensors raises because they were
# not recomputed.
x_2, y_2, z_2, w_2, out_2 = ctx.saved_tensors

x = torch.tensor(1., requires_grad=True)
y = torch.tensor(2., requires_grad=True)
z = torch.tensor(3., requires_grad=True)

def foo(x, y, z):
x = x * y * z
y = y * y * z
z = z * z
out = MyFunc.apply(x, y, z)
return out

out = checkpoint(foo, x, y, z, use_reentrant=False)
with self.assertRaisesRegex(
RuntimeError,
"Attempt to retrieve a tensor saved by autograd multiple times"
):
out.sum().backward()

self.assertTrue(_first_saved_tensor_access_succeeded)

def test_access_saved_tensor_twice_without_recomputation_raises(self):
"""
If using saved tensor hooks based checkpointing and a saved tensor
is accessed multiple times without triggering recomputation in the
middle, error is raised indicating so.
"""
def foo(a):
b = a * a
c = a * b
d = torch.exp(a)
return d

a = torch.randn(5, requires_grad=True)
d = checkpoint(foo, a, use_reentrant=False)
# First access
d.grad_fn._saved_result
# Second access raises error
with self.assertRaisesRegex(
RuntimeError,
"Attempt to retrieve a tensor saved by autograd multiple times"
):
d.grad_fn._saved_result

@slowTest
@parametrize("input_requires_grad", [True, False])
def test_checkpointing_without_reentrant(self, input_requires_grad):
Expand Down
11 changes: 0 additions & 11 deletions torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py
Expand Up @@ -110,16 +110,5 @@ def checkpoint_wrapper(
(nn.Module):
Wrapped module
"""
# saved tensor hooks based-checkpoint wrapper is not yet supported.
if checkpoint_impl == CheckpointImpl.NO_REENTRANT:
raise ValueError(
"No support for non-reentrant based checkpoint implementation."
)

if offload_to_cpu and checkpoint_impl != CheckpointImpl.REENTRANT:
raise ValueError(
"No support for CPU offload activations and non-reentrant based "
"checkpoint implementation."
)

return CheckpointWrapper(module, checkpoint_impl, offload_to_cpu)
18 changes: 14 additions & 4 deletions torch/utils/checkpoint.py
@@ -1,6 +1,6 @@
import torch
import warnings
from typing import Any, Iterable, List, Tuple, Union
from typing import Any, Dict, Iterable, List, Optional, Tuple


def detach_variable(inputs: Tuple[Any, ...]) -> Tuple[torch.Tensor, ...]:
Expand Down Expand Up @@ -332,7 +332,7 @@ def _checkpoint_without_reentrant(function, preserve_rng_state=True, *args):
had_cuda_in_fwd = True
fwd_gpu_devices, fwd_gpu_states = get_device_states(*args)

storage: List[Union[torch.Tensor, None]] = []
storage: Dict[int, Optional[torch.Tensor]] = {}
counter = 0

def pack(x):
Expand All @@ -343,10 +343,13 @@ def pack(x):
return counter - 1

def unpack(x):
unpack_counter = 0
if len(storage) == 0:

def inner_pack(inner):
storage.append(inner)
nonlocal unpack_counter
storage[unpack_counter] = inner
unpack_counter += 1
return None

def inner_unpack(packed):
Expand All @@ -367,7 +370,14 @@ def inner_unpack(packed):
with torch.autograd.graph.saved_tensors_hooks(inner_pack, inner_unpack):
_unused = function(*args)

return storage[x]
if x not in storage:
raise RuntimeError(
"Attempt to retrieve a tensor saved by autograd multiple times without checkpoint"
" recomputation being triggered in between, this is not currently supported. Please"
" open an issue with details on your use case so that we can prioritize adding this."
)

return storage.pop(x)

with torch.autograd.graph.saved_tensors_hooks(pack, unpack):
output = function(*args)
Expand Down

0 comments on commit 681a6e3

Please sign in to comment.