Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update save_on_cpu and checkpointing to work with functorch wrapped tensors #89166

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
5 changes: 5 additions & 0 deletions aten/src/ATen/functorch/TensorWrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,11 @@ struct TORCH_API TensorWrapper : public c10::TensorImpl {
bool allow_tensor_metadata_change) const override;
void shallow_copy_from(const c10::intrusive_ptr<TensorImpl>& impl) override;

// This is pretty unsafe
void _set_value(const Tensor& value) {
value_ = value;
}

private:
const char* tensorimpl_type_name() const override;
Tensor value_;
Expand Down
21 changes: 11 additions & 10 deletions functorch/_src/vmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,17 @@


def doesnt_support_saved_tensors_hooks(f):
message = (
"functorch transforms don't yet support saved tensor hooks. "
"Please open an issue with your use case."
)

@functools.wraps(f)
def fn(*args, **kwargs):
with torch.autograd.graph.disable_saved_tensors_hooks(message):
return f(*args, **kwargs)
return fn
return f
# message = (
# "functorch transforms don't yet support saved tensor hooks. "
# "Please open an issue with your use case."
# )

# @functools.wraps(f)
# def fn(*args, **kwargs):
# with torch.autograd.graph.disable_saved_tensors_hooks(message):
# return f(*args, **kwargs)
# return fn


# Checks that all args-to-be-batched have the same batch dim size
Expand Down
81 changes: 81 additions & 0 deletions test/functorch/test_eager_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
)
from functorch._src.eager_transforms import enable_fwd_grad, _slice_argnums
from functorch.experimental import functionalize
from torch.utils.checkpoint import checkpoint

# NB: numpy is a testing dependency!
import numpy as np
Expand Down Expand Up @@ -3479,6 +3480,83 @@ def forward(self, x_1):
""")


class TestSavedTensorHooks(TestCase):
def test_save_on_cpu(self):
a = torch.ones(4, 2)

def fn(x):
return x.sin().exp().sin().sum()

def sum(fn):
def wrapped(x):
return fn(x).sum()
return wrapped

reference = grad(fn)(a)
with torch.autograd.graph.save_on_cpu():
actual = grad(fn)(a)
self.assertEqual(reference, actual)

reference = grad(sum(grad(fn)))(a)
with torch.autograd.graph.save_on_cpu():
actual = grad(sum(grad(fn)))(a)
self.assertEqual(reference, actual)

reference = grad(sum(vmap(grad(fn))))(a)
with torch.autograd.graph.save_on_cpu():
actual = grad(sum(vmap(grad(fn))))(a)
self.assertEqual(reference, actual)

a = torch.rand(4, 2, requires_grad=True)
grad1 = grad(sum(vmap(grad(fn))))(a)
with torch.autograd.graph.save_on_cpu():
grad2 = grad(sum(vmap(grad(fn))))(a)
self.assertEqual(reference, actual)
grad1.sum().backward()
reference = a.grad.clone()
a.grad = None
grad2.sum().backward()
self.assertEqual(reference, a.grad)

def test_nonreentrant_checkpoint(self):
a = torch.rand(4, 2, requires_grad=True)

def fn(x):
return x.sin().exp().sin().sum()

def sum(fn):
def wrapped(x):
return fn(x).sum()
return wrapped

def apply_checkpoint(fn):
def wrapped(*args, **kwargs):
return checkpoint(fn, *args, use_reentrant=False, **kwargs)
return wrapped

reference = grad(fn)(a)
actual = grad(apply_checkpoint((fn)))(a)
self.assertEqual(reference, actual)

reference = grad(sum(grad(fn)))(a)
actual = grad(sum(grad(apply_checkpoint((fn)))))(a)
self.assertEqual(reference, actual)

reference = grad(sum(vmap(grad(fn))))(a)
actual = grad(sum(vmap(grad(apply_checkpoint(fn)))))(a)
self.assertEqual(reference, actual)

# TODO: not supported yet
# a = torch.rand(4, 2, requires_grad=True)
# grad1 = grad(sum(vmap(grad(fn))))(a)
# grad2 = grad(sum(vmap(grad(apply_checkpoint(fn)))))(a)
# self.assertEqual(reference, actual)
# grad1.sum().backward()
# reference = a.grad.clone()
# a.grad = None
# grad2.sum().backward()
# self.assertEqual(reference, a.grad)


only_for = ("cpu", "cuda")
instantiate_device_type_tests(
Expand Down Expand Up @@ -3529,6 +3607,9 @@ def forward(self, x_1):
instantiate_parametrized_tests(
TestMakeFunctional,
)
instantiate_parametrized_tests(
TestSavedTensorHooks,
)

if __name__ == '__main__':
run_tests()
7 changes: 4 additions & 3 deletions test/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -7160,9 +7160,10 @@ def test(get_input, cuda, pin_memory):
test(lambda: torch.randn(5, requires_grad=True), cuda, pin_memory)
# DoubleTensor
test(lambda: torch.randn(5, requires_grad=True, dtype=torch.double), cuda, pin_memory)
# Sparse tensor
x = torch.sparse_coo_tensor(torch.tensor([[1, 1]]).long(), torch.tensor([1., 1.]), requires_grad=True)
test(lambda: x, cuda, pin_memory)
# TODO(soulitzer): Fix _get_tid for sparse tensors
# # Sparse tensor
# x = torch.sparse_coo_tensor(torch.tensor([[1, 1]]).long(), torch.tensor([1., 1.]), requires_grad=True)
# test(lambda: x, cuda, pin_memory)

@unittest.skipIf(not TEST_CUDA, "test requires CUDA")
def test_graph_save_on_cpu_cuda(self):
Expand Down
188 changes: 168 additions & 20 deletions torch/autograd/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,31 @@
"allow_mutation_on_saved_tensors",
]


def _get_tid(t) -> Tuple[int, int]:
# Returns a unique identifier for a particular version of a tensor. This is
# currently used in saved_tensor_hook context managers such as save_on_cpu()
# and allow_mutation_on_saved_tensors().
#
# id() and t._version are not sufficient because sometimes id() get reused
# so we also compare t.data_ptr() and the first value in the tensor.
# techincally this does not guarantee correctness either, but makes it
# somewhat safer at least.
#
# TODO: support tensors that don't have storage
item = t.item() if t.ndim == 0 else t[(0,) * t.ndim].item()
return (id(t), t.data_ptr(), t._version, item)

def _get_sid(t) -> Tuple[int, int]:
# Returns a tuple that uniquely identifies a tensor's storage
#
# NB: two tensors that share the same storage may have different
# sid if their storage offset is different
return (t.data_ptr(), t._version)

class _Handle():
pass

class saved_tensors_hooks():
"""Context-manager that sets a pair of pack / unpack hooks for saved tensors.

Expand Down Expand Up @@ -84,6 +109,80 @@ def __enter__(self):
def __exit__(self, *args: Any):
torch._C._autograd._pop_saved_tensors_default_hooks()

def _is_grad_wrapper(wrapped):
# There's probably a better way lol
current_level = torch._C._functorch.maybe_get_level(wrapped)
if current_level == -1:
return False
try:
unwrapped = torch._C._functorch._unwrap_for_grad(wrapped, current_level)
assert wrapped is not unwrapped
return True
except:
return False

def _unwrap(tensor):
current_level = torch._C._functorch.maybe_get_level(tensor)
assert current_level >= 1

# Assume either is grad wrapper or batched wrapper
if _is_grad_wrapper(tensor):
unwrapped = torch._C._functorch._unwrap_for_grad(tensor, current_level)
assert tensor is not unwrapped
inner_level = torch._C._functorch.maybe_get_level(unwrapped)

# Clone and replace the inner tensor of the TensorWrapper with an empty tensor
# This is so that we can capture the autograd metadata without capturing the
# data as well.
# TODO: figure out a way to do this without actually cloning the data.
with torch.autograd.enable_grad():
# Why do we need to reenable grad?
captured_wrapper = tensor.clone()
new_inner = torch.empty_like(unwrapped)
# Undo lifting
new_inner = _functorch_unwrap_to_level_no_rewrap(new_inner, inner_level)
captured_wrapper = _functorch_unwrap_to_level_no_rewrap(captured_wrapper, current_level)
captured_wrapper.data = new_inner

def rewrap_fn(new_value):
nonlocal captured_wrapper
captured_wrapper.data = new_value
return captured_wrapper
else:
bdim = torch._C._functorch.maybe_get_bdim(tensor)
unwrapped = torch._C._functorch.get_unwrapped(tensor)

def rewrap_fn(new_value):
return torch._C._functorch._add_batch_dim(new_value, bdim, current_level)

return unwrapped, rewrap_fn

def _functorch_unwrap_to_level_no_rewrap(tensor: torch.Tensor, target_level: int) -> torch.Tensor:
current_level = torch._C._functorch.maybe_get_level(tensor)
while current_level > target_level:
tensor = torch._C._functorch._unwrap_for_grad(tensor, current_level)
current_level = torch._C._functorch.maybe_get_level(tensor)
assert current_level == target_level, (current_level, target_level)
return tensor

# It might be better to do more things in cpp:
# https://github.com/pytorch/pytorch/pull/88976
def _functorch_unwrap_to_level(tensor: torch.Tensor, target_level: int) -> torch.Tensor:
assert target_level != 0, "level 0 is not supported, you should pass -1 instead"
current_level = torch._C._functorch.maybe_get_level(tensor)
assert current_level >= target_level, (current_level, target_level)
rewrap_fns = []
for _ in range(max(current_level, 0), max(target_level, 0), -1):
current_level = torch._C._functorch.maybe_get_level(tensor)
if current_level == target_level:
# Sometimes wrappers can skip levels
break
tensor, rewrap_fn = _unwrap(tensor)
rewrap_fns.append(rewrap_fn)

result_level = torch._C._functorch.maybe_get_level(tensor)
assert result_level == target_level, (result_level, target_level)
return tensor, rewrap_fns

class save_on_cpu(saved_tensors_hooks):
"""Context-manager under which tensors saved by the forward pass will be
Expand Down Expand Up @@ -127,24 +226,82 @@ class save_on_cpu(saved_tensors_hooks):

"""
def __init__(self, pin_memory=False):
# We use weak references here to makes sure that the only owning reference
# to any tensors that are offloaded have lifetimes that are linked to that
# of the graph
self.unwrapped_copies: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary()
self.tid_to_weakhandle: weakref.WeakValueDictionary = weakref.WeakValueDictionary()

def pack_to_cpu(tensor):
if not pin_memory:
return (tensor.device, tensor.cpu())
device = tensor.device

# NB: Always unwrap the outer layer, which is guaranteed to be a TensorWrapper
# or a vanilla tensor. Unwrapping if the level is -1 is a no-op (I think).
level = torch._C._functorch.maybe_get_level(tensor)
inner = torch._C._functorch._unwrap_for_grad(tensor, level)
# NB: level after unwrapping isn't always level - 1, so query again
level = torch._C._functorch.maybe_get_level(inner)

# Unwrap all the way, but remember how to restore the wrappers
# including their autograd metadata
unwrapped, rewrap_fns = _functorch_unwrap_to_level(inner, -1)
tid = _get_tid(unwrapped)

if tid in self.tid_to_weakhandle:
handle = self.tid_to_weakhandle[tid]
else:
if not pin_memory:
unwrapped_copy = unwrapped.cpu()
else:
unwrapped_copy = torch.empty(
inner.size(),
dtype=inner.dtype,
layout=inner.layout,
pin_memory=(torch.cuda.is_available() and not inner.is_sparse))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

inner -> unwrapped

unwrapped_copy.copy_(inner)

handle = _Handle()
self.unwrapped_copies[handle] = unwrapped_copy
self.tid_to_weakhandle[tid] = handle

packed = torch.empty(
tensor.size(),
dtype=tensor.dtype,
layout=tensor.layout,
pin_memory=(torch.cuda.is_available() and not tensor.is_sparse))
packed.copy_(tensor)
return (tensor.device, packed)
return (device, handle, rewrap_fns, level)

def unpack_from_cpu(packed):
device, tensor = packed
return tensor.to(device, non_blocking=pin_memory)
device, handle, rewrap_fns, original_level = packed
ret = self.unwrapped_copies[handle]
for rewrap_fn in reversed(rewrap_fns):
ret = rewrap_fn(ret)

new_level = torch._C._functorch.maybe_get_level(ret)
assert new_level == original_level, (new_level, original_level)

ret = ret.to(device, non_blocking=pin_memory)
# Ideally we would be completely working with unwrapped tensors during backward,
# but for a grad transform the wrapping level is the same as that of forward.
# I think that should not be the case. (Can we fix this?)
#
# If saved tensor logic properly detached, we shouldn't have to unwrap and rewrap
# here at all. The rewrapping here is implicit due to lifting.
assert torch._C._functorch.maybe_get_level(ret) > original_level or original_level == -1, (
"lifting should happen unless we've left transforms entirely")
Copy link
Contributor Author

@soulitzer soulitzer Nov 17, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure this is true actually, if we consider the case;

grad(
   forward
        grad(
               forward  (tensor packed here)
               backward (or here)
   backward             (but unpacked here)

Lifting shouldn't happen, but I haven't seen the assert triggered.


return ret

super().__init__(pack_to_cpu, unpack_from_cpu)

def __enter__(self):
torch._C._autograd._push_saved_tensors_default_hooks(self.pack_hook, self.unpack_hook)
return self

def __exit__(self, *args: Any):
# We cannot clear here because that would be bc-breaking: people sometimes run
# forward using save_on_cpu, but exit the ctx before running backward.
# Note that this behavior is inconsistent with that of allow_mutation_on_saved ctx
# which requires that backward also be run inside the ctx.
#
# self.inner_copies.clear()
# self.tid_to_weakhandle.clear()
torch._C._autograd._pop_saved_tensors_default_hooks()

@contextlib.contextmanager
def disable_saved_tensors_hooks(error_message):
Expand Down Expand Up @@ -289,15 +446,6 @@ def __setstate__(self, state):
# - if the clone exists, the tensor must've been modified in-place
_allow_mutation_on_saved_tensors_enabled = False

def _get_tid(t) -> Tuple[int, int, int]:
return (id(t), t.data_ptr(), t._version)

def _get_sid(t) -> Tuple[int, int]:
return (t.data_ptr(), t._version)

class _Handle():
pass

class _swap_with_cloned(saved_tensors_hooks):
def __init__(self, ctx):
def pack_hook(t):
Expand Down