Skip to content

Commit

Permalink
Update save_on_cpu to dedupe tensors and work with functorch wrapped …
Browse files Browse the repository at this point in the history
…tensors

ghstack-source-id: b6fa3b4d555ac0c2d6e5a285532927403c4a0824
Pull Request resolved: #89166
  • Loading branch information
soulitzer committed Nov 17, 2022
1 parent c9e955c commit 2153ccb
Show file tree
Hide file tree
Showing 7 changed files with 319 additions and 46 deletions.
5 changes: 5 additions & 0 deletions aten/src/ATen/functorch/TensorWrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,11 @@ struct TORCH_API TensorWrapper : public c10::TensorImpl {
bool allow_tensor_metadata_change) const override;
void shallow_copy_from(const c10::intrusive_ptr<TensorImpl>& impl) override;

// This is pretty unsafe
void _set_value(const Tensor& value) {
value_ = value;
}

private:
const char* tensorimpl_type_name() const override;
Tensor value_;
Expand Down
21 changes: 11 additions & 10 deletions functorch/_src/vmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,17 @@


def doesnt_support_saved_tensors_hooks(f):
message = (
"functorch transforms don't yet support saved tensor hooks. "
"Please open an issue with your use case."
)

@functools.wraps(f)
def fn(*args, **kwargs):
with torch.autograd.graph.disable_saved_tensors_hooks(message):
return f(*args, **kwargs)
return fn
return f
# message = (
# "functorch transforms don't yet support saved tensor hooks. "
# "Please open an issue with your use case."
# )

# @functools.wraps(f)
# def fn(*args, **kwargs):
# with torch.autograd.graph.disable_saved_tensors_hooks(message):
# return f(*args, **kwargs)
# return fn


# Checks that all args-to-be-batched have the same batch dim size
Expand Down
81 changes: 81 additions & 0 deletions test/functorch/test_eager_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
)
from functorch._src.eager_transforms import enable_fwd_grad, _slice_argnums
from functorch.experimental import functionalize
from torch.utils.checkpoint import checkpoint

# NB: numpy is a testing dependency!
import numpy as np
Expand Down Expand Up @@ -3479,6 +3480,83 @@ def forward(self, x_1):
""")


class TestSavedTensorHooks(TestCase):
def test_save_on_cpu(self):
a = torch.ones(4, 2)

def fn(x):
return x.sin().exp().sin().sum()

def sum(fn):
def wrapped(x):
return fn(x).sum()
return wrapped

reference = grad(fn)(a)
with torch.autograd.graph.save_on_cpu():
actual = grad(fn)(a)
self.assertEqual(reference, actual)

reference = grad(sum(grad(fn)))(a)
with torch.autograd.graph.save_on_cpu():
actual = grad(sum(grad(fn)))(a)
self.assertEqual(reference, actual)

reference = grad(sum(vmap(grad(fn))))(a)
with torch.autograd.graph.save_on_cpu():
actual = grad(sum(vmap(grad(fn))))(a)
self.assertEqual(reference, actual)

a = torch.rand(4, 2, requires_grad=True)
grad1 = grad(sum(vmap(grad(fn))))(a)
with torch.autograd.graph.save_on_cpu():
grad2 = grad(sum(vmap(grad(fn))))(a)
self.assertEqual(reference, actual)
grad1.sum().backward()
reference = a.grad.clone()
a.grad = None
grad2.sum().backward()
self.assertEqual(reference, a.grad)

def test_nonreentrant_checkpoint(self):
a = torch.rand(4, 2, requires_grad=True)

def fn(x):
return x.sin().exp().sin().sum()

def sum(fn):
def wrapped(x):
return fn(x).sum()
return wrapped

def apply_checkpoint(fn):
def wrapped(*args, **kwargs):
return checkpoint(fn, *args, use_reentrant=False, **kwargs)
return wrapped

reference = grad(fn)(a)
actual = grad(apply_checkpoint((fn)))(a)
self.assertEqual(reference, actual)

reference = grad(sum(grad(fn)))(a)
actual = grad(sum(grad(apply_checkpoint((fn)))))(a)
self.assertEqual(reference, actual)

reference = grad(sum(vmap(grad(fn))))(a)
actual = grad(sum(vmap(grad(apply_checkpoint(fn)))))(a)
self.assertEqual(reference, actual)

# TODO: not supported yet
# a = torch.rand(4, 2, requires_grad=True)
# grad1 = grad(sum(vmap(grad(fn))))(a)
# grad2 = grad(sum(vmap(grad(apply_checkpoint(fn)))))(a)
# self.assertEqual(reference, actual)
# grad1.sum().backward()
# reference = a.grad.clone()
# a.grad = None
# grad2.sum().backward()
# self.assertEqual(reference, a.grad)


only_for = ("cpu", "cuda")
instantiate_device_type_tests(
Expand Down Expand Up @@ -3529,6 +3607,9 @@ def forward(self, x_1):
instantiate_parametrized_tests(
TestMakeFunctional,
)
instantiate_parametrized_tests(
TestSavedTensorHooks,
)

if __name__ == '__main__':
run_tests()
7 changes: 4 additions & 3 deletions test/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -7160,9 +7160,10 @@ def test(get_input, cuda, pin_memory):
test(lambda: torch.randn(5, requires_grad=True), cuda, pin_memory)
# DoubleTensor
test(lambda: torch.randn(5, requires_grad=True, dtype=torch.double), cuda, pin_memory)
# Sparse tensor
x = torch.sparse_coo_tensor(torch.tensor([[1, 1]]).long(), torch.tensor([1., 1.]), requires_grad=True)
test(lambda: x, cuda, pin_memory)
# TODO(soulitzer): Fix _get_tid for sparse tensors
# # Sparse tensor
# x = torch.sparse_coo_tensor(torch.tensor([[1, 1]]).long(), torch.tensor([1., 1.]), requires_grad=True)
# test(lambda: x, cuda, pin_memory)

@unittest.skipIf(not TEST_CUDA, "test requires CUDA")
def test_graph_save_on_cpu_cuda(self):
Expand Down
188 changes: 168 additions & 20 deletions torch/autograd/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,31 @@
"allow_mutation_on_saved_tensors",
]


def _get_tid(t) -> Tuple[int, int]:
# Returns a unique identifier for a particular version of a tensor. This is
# currently used in saved_tensor_hook context managers such as save_on_cpu()
# and allow_mutation_on_saved_tensors().
#
# id() and t._version are not sufficient because sometimes id() get reused
# so we also compare t.data_ptr() and the first value in the tensor.
# techincally this does not guarantee correctness either, but makes it
# somewhat safer at least.
#
# TODO: support tensors that don't have storage
item = t.item() if t.ndim == 0 else t[(0,) * t.ndim].item()
return (id(t), t.data_ptr(), t._version, item)

def _get_sid(t) -> Tuple[int, int]:
# Returns a tuple that uniquely identifies a tensor's storage
#
# NB: two tensors that share the same storage may have different
# sid if their storage offset is different
return (t.data_ptr(), t._version)

class _Handle():
pass

class saved_tensors_hooks():
"""Context-manager that sets a pair of pack / unpack hooks for saved tensors.
Expand Down Expand Up @@ -84,6 +109,80 @@ def __enter__(self):
def __exit__(self, *args: Any):
torch._C._autograd._pop_saved_tensors_default_hooks()

def _is_grad_wrapper(wrapped):
# There's probably a better way lol
current_level = torch._C._functorch.maybe_get_level(wrapped)
if current_level == -1:
return False
try:
unwrapped = torch._C._functorch._unwrap_for_grad(wrapped, current_level)
assert wrapped is not unwrapped
return True
except:
return False

def _unwrap(tensor):
current_level = torch._C._functorch.maybe_get_level(tensor)
assert current_level >= 1

# Assume either is grad wrapper or batched wrapper
if _is_grad_wrapper(tensor):
unwrapped = torch._C._functorch._unwrap_for_grad(tensor, current_level)
assert tensor is not unwrapped
inner_level = torch._C._functorch.maybe_get_level(unwrapped)

# Clone and replace the inner tensor of the TensorWrapper with an empty tensor
# This is so that we can capture the autograd metadata without capturing the
# data as well.
# TODO: figure out a way to do this without actually cloning the data.
with torch.autograd.enable_grad():
# Why do we need to reenable grad?
captured_wrapper = tensor.clone()
new_inner = torch.empty_like(unwrapped)
# Undo lifting
new_inner = _functorch_unwrap_to_level_no_rewrap(new_inner, inner_level)
captured_wrapper = _functorch_unwrap_to_level_no_rewrap(captured_wrapper, current_level)
captured_wrapper.data = new_inner

def rewrap_fn(new_value):
nonlocal captured_wrapper
captured_wrapper.data = new_value
return captured_wrapper
else:
bdim = torch._C._functorch.maybe_get_bdim(tensor)
unwrapped = torch._C._functorch.get_unwrapped(tensor)

def rewrap_fn(new_value):
return torch._C._functorch._add_batch_dim(new_value, bdim, current_level)

return unwrapped, rewrap_fn

def _functorch_unwrap_to_level_no_rewrap(tensor: torch.Tensor, target_level: int) -> torch.Tensor:
current_level = torch._C._functorch.maybe_get_level(tensor)
while current_level > target_level:
tensor = torch._C._functorch._unwrap_for_grad(tensor, current_level)
current_level = torch._C._functorch.maybe_get_level(tensor)
assert current_level == target_level, (current_level, target_level)
return tensor

# It might be better to do more things in cpp:
# https://github.com/pytorch/pytorch/pull/88976
def _functorch_unwrap_to_level(tensor: torch.Tensor, target_level: int) -> torch.Tensor:
assert target_level != 0, "level 0 is not supported, you should pass -1 instead"
current_level = torch._C._functorch.maybe_get_level(tensor)
assert current_level >= target_level, (current_level, target_level)
rewrap_fns = []
for _ in range(max(current_level, 0), max(target_level, 0), -1):
current_level = torch._C._functorch.maybe_get_level(tensor)
if current_level == target_level:
# Sometimes wrappers can skip levels
break
tensor, rewrap_fn = _unwrap(tensor)
rewrap_fns.append(rewrap_fn)

result_level = torch._C._functorch.maybe_get_level(tensor)
assert result_level == target_level, (result_level, target_level)
return tensor, rewrap_fns

class save_on_cpu(saved_tensors_hooks):
"""Context-manager under which tensors saved by the forward pass will be
Expand Down Expand Up @@ -127,24 +226,82 @@ class save_on_cpu(saved_tensors_hooks):
"""
def __init__(self, pin_memory=False):
# We use weak references here to makes sure that the only owning reference
# to any tensors that are offloaded have lifetimes that are linked to that
# of the graph
self.unwrapped_copies: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary()
self.tid_to_weakhandle: weakref.WeakValueDictionary = weakref.WeakValueDictionary()

def pack_to_cpu(tensor):
if not pin_memory:
return (tensor.device, tensor.cpu())
device = tensor.device

# NB: Always unwrap the outer layer, which is guaranteed to be a TensorWrapper
# or a vanilla tensor. Unwrapping if the level is -1 is a no-op (I think).
level = torch._C._functorch.maybe_get_level(tensor)
inner = torch._C._functorch._unwrap_for_grad(tensor, level)
# NB: level after unwrapping isn't always level - 1, so query again
level = torch._C._functorch.maybe_get_level(inner)

# Unwrap all the way, but remember how to restore the wrappers
# including their autograd metadata
unwrapped, rewrap_fns = _functorch_unwrap_to_level(inner, -1)
tid = _get_tid(unwrapped)

if tid in self.tid_to_weakhandle:
handle = self.tid_to_weakhandle[tid]
else:
if not pin_memory:
unwrapped_copy = unwrapped.cpu()
else:
unwrapped_copy = torch.empty(
inner.size(),
dtype=inner.dtype,
layout=inner.layout,
pin_memory=(torch.cuda.is_available() and not inner.is_sparse))
unwrapped_copy.copy_(inner)

handle = _Handle()
self.unwrapped_copies[handle] = unwrapped_copy
self.tid_to_weakhandle[tid] = handle

packed = torch.empty(
tensor.size(),
dtype=tensor.dtype,
layout=tensor.layout,
pin_memory=(torch.cuda.is_available() and not tensor.is_sparse))
packed.copy_(tensor)
return (tensor.device, packed)
return (device, handle, rewrap_fns, level)

def unpack_from_cpu(packed):
device, tensor = packed
return tensor.to(device, non_blocking=pin_memory)
device, handle, rewrap_fns, original_level = packed
ret = self.unwrapped_copies[handle]
for rewrap_fn in reversed(rewrap_fns):
ret = rewrap_fn(ret)

new_level = torch._C._functorch.maybe_get_level(ret)
assert new_level == original_level, (new_level, original_level)

ret = ret.to(device, non_blocking=pin_memory)
# Ideally we would be completely working with unwrapped tensors during backward,
# but for a grad transform the wrapping level is the same as that of forward.
# I think that should not be the case. (Can we fix this?)
#
# If saved tensor logic properly detached, we shouldn't have to unwrap and rewrap
# here at all. The rewrapping here is implicit due to lifting.
assert torch._C._functorch.maybe_get_level(ret) > original_level or original_level == -1, (
"lifting should happen unless we've left transforms entirely")

return ret

super().__init__(pack_to_cpu, unpack_from_cpu)

def __enter__(self):
torch._C._autograd._push_saved_tensors_default_hooks(self.pack_hook, self.unpack_hook)
return self

def __exit__(self, *args: Any):
# We cannot clear here because that would be bc-breaking: people sometimes run
# forward using save_on_cpu, but exit the ctx before running backward.
# Note that this behavior is inconsistent with that of allow_mutation_on_saved ctx
# which requires that backward also be run inside the ctx.
#
# self.inner_copies.clear()
# self.tid_to_weakhandle.clear()
torch._C._autograd._pop_saved_tensors_default_hooks()

@contextlib.contextmanager
def disable_saved_tensors_hooks(error_message):
Expand Down Expand Up @@ -289,15 +446,6 @@ def __setstate__(self, state):
# - if the clone exists, the tensor must've been modified in-place
_allow_mutation_on_saved_tensors_enabled = False

def _get_tid(t) -> Tuple[int, int, int]:
return (id(t), t.data_ptr(), t._version)

def _get_sid(t) -> Tuple[int, int]:
return (t.data_ptr(), t._version)

class _Handle():
pass

class _swap_with_cloned(saved_tensors_hooks):
def __init__(self, ctx):
def pack_hook(t):
Expand Down

0 comments on commit 2153ccb

Please sign in to comment.