Skip to content

Commit

Permalink
Update save_on_cpu to dedupe tensors and work with functorch wrapped …
Browse files Browse the repository at this point in the history
…tensors

ghstack-source-id: d20523b6ba4f05d236621450d38e2686d10c24e4
Pull Request resolved: #89166
  • Loading branch information
soulitzer committed Nov 16, 2022
1 parent c9e955c commit 8586ea2
Show file tree
Hide file tree
Showing 5 changed files with 226 additions and 30 deletions.
5 changes: 5 additions & 0 deletions aten/src/ATen/functorch/TensorWrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,11 @@ struct TORCH_API TensorWrapper : public c10::TensorImpl {
bool allow_tensor_metadata_change) const override;
void shallow_copy_from(const c10::intrusive_ptr<TensorImpl>& impl) override;

// This is pretty unsafe
void _set_value(const Tensor& value) {
value_ = value;
}

private:
const char* tensorimpl_type_name() const override;
Tensor value_;
Expand Down
21 changes: 11 additions & 10 deletions functorch/_src/vmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,17 @@


def doesnt_support_saved_tensors_hooks(f):
message = (
"functorch transforms don't yet support saved tensor hooks. "
"Please open an issue with your use case."
)

@functools.wraps(f)
def fn(*args, **kwargs):
with torch.autograd.graph.disable_saved_tensors_hooks(message):
return f(*args, **kwargs)
return fn
return f
# message = (
# "functorch transforms don't yet support saved tensor hooks. "
# "Please open an issue with your use case."
# )

# @functools.wraps(f)
# def fn(*args, **kwargs):
# with torch.autograd.graph.disable_saved_tensors_hooks(message):
# return f(*args, **kwargs)
# return fn


# Checks that all args-to-be-batched have the same batch dim size
Expand Down
41 changes: 41 additions & 0 deletions test/functorch/test_eager_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -3479,6 +3479,44 @@ def forward(self, x_1):
""")


class TestSavedTensorHooks(TestCase):
def test_save_on_cpu(self):
a = torch.ones(4, 2)

def fn(x):
return x.sin().exp().sin().sum()

def sum(fn):
def wrapped(x):
return fn(x).sum()
return wrapped

reference = grad(fn)(a)
with torch.autograd.graph.save_on_cpu():
actual = grad(fn)(a)
self.assertEqual(reference, actual)

reference = grad(sum(grad(fn)))(a)
with torch.autograd.graph.save_on_cpu():
actual = grad(sum(grad(fn)))(a)
self.assertEqual(reference, actual)

reference = grad(sum(vmap(grad(fn))))(a)
with torch.autograd.graph.save_on_cpu():
actual = grad(sum(vmap(grad(fn))))(a)
self.assertEqual(reference, actual)

a = torch.tensor(1., requires_grad=True)
grad1 = grad(sum(fn))(a)
with torch.autograd.graph.save_on_cpu():
grad2 = grad(sum(fn))(a)
self.assertEqual(reference, actual)
grad1.backward()
reference = a.grad.clone()
a.grad = None
grad2.backward()
self.assertEqual(reference, a.grad)


only_for = ("cpu", "cuda")
instantiate_device_type_tests(
Expand Down Expand Up @@ -3529,6 +3567,9 @@ def forward(self, x_1):
instantiate_parametrized_tests(
TestMakeFunctional,
)
instantiate_parametrized_tests(
TestSavedTensorHooks,
)

if __name__ == '__main__':
run_tests()
180 changes: 160 additions & 20 deletions torch/autograd/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,40 @@
"allow_mutation_on_saved_tensors",
]


def _get_tid(t) -> Tuple[int, int]:
# Returns a unique identifier for a particular version of a tensor. This is
# currently used in saved_tensor_hook context managers such as save_on_cpu()
# and allow_mutation_on_saved_tensors().
#
# We claim that id() and t._version are necessary and sufficient to uniquely
# identify a version of the tensor:
# - id corresponds to TensorImpl because we have pyobject persistence.
# On the other hand, Keeping track of storage via something like data_ptr
# is not sufficient because there can be different views to the same storage.
# - we don't need t.data_ptr() because the only way it can change from
# underneath a TensorImpl is 1) if someone uses .data (we're okay with
# silently wrong are produced if .data is used) OR 2) if version
# counter also changes, e.g. when we perform an in-place view
# We choose to omit this from the identifier as to support tensors that
# don't have storage, e.g. sparse tensors. It may be better to have
# special handling for sparse tensors, but that would also need to consider
# additional things like coalescing.
#
# TODO(soulitzer): check sparse correctness, make sure that if contents of sparse
# tensor changes, that is also reflected here.
return (id(t), t._version)

def _get_sid(t) -> Tuple[int, int]:
# Returns a tuple that uniquely identifies a tensor's storage
#
# NB: two tensors that share the same storage may have different
# sid if their storage offset is different
return (t.data_ptr(), t._version)

class _Handle():
pass

class saved_tensors_hooks():
"""Context-manager that sets a pair of pack / unpack hooks for saved tensors.
Expand Down Expand Up @@ -84,6 +118,63 @@ def __enter__(self):
def __exit__(self, *args: Any):
torch._C._autograd._pop_saved_tensors_default_hooks()

def _is_grad_wrapper(wrapped):
# There's probably a better way lol
current_level = torch._C._functorch.maybe_get_level(wrapped)
if current_level == -1:
return False
try:
unwrapped = torch._C._functorch._unwrap_for_grad(wrapped, current_level)
assert wrapped is not unwrapped
return True
except:
return False

def _unwrap(tensor):
current_level = torch._C._functorch.maybe_get_level(tensor)
assert current_level >= 1

# Assume either is grad wrapper or batched wrapper
if _is_grad_wrapper(tensor):
unwrapped = torch._C._functorch._unwrap_for_grad(tensor, current_level)
assert tensor is not unwrapped

def rewrap_fn(new_value):
# This prematurely lifts to the current interpreter level unfortunately,
# we prefer to rewrap ourselves so we can actually restore the original autograd
# so we have to manually unwrap again after the clone :(
level = torch._C._functorch.maybe_get_level(tensor)
cloned_wrapped = tensor.clone()
cloned_wrapped, _ = _functorch_unwrap_to_level(cloned_wrapped, level)
cloned_wrapped.data = new_value
return cloned_wrapped
else:
bdim = torch._C._functorch.maybe_get_bdim(tensor)
unwrapped = torch._C._functorch.get_unwrapped(tensor)

def rewrap_fn(new_value):
return torch._C._functorch._add_batch_dim(new_value, bdim, current_level)

return unwrapped, rewrap_fn

# It might be better to do more things in cpp:
# https://github.com/pytorch/pytorch/pull/88976
def _functorch_unwrap_to_level(tensor: torch.Tensor, target_level: int) -> torch.Tensor:
assert target_level != 0, "level 0 is not supported, you should pass -1 instead"
current_level = torch._C._functorch.maybe_get_level(tensor)
assert current_level >= target_level, (current_level, target_level)
rewrap_fns = []
for _ in range(max(current_level, 0), max(target_level, 0), -1):
current_level = torch._C._functorch.maybe_get_level(tensor)
if current_level == target_level:
# Sometimes wrappers can skip levels
break
tensor, rewrap_fn = _unwrap(tensor)
rewrap_fns.append(rewrap_fn)

result_level = torch._C._functorch.maybe_get_level(tensor)
assert result_level == target_level, (result_level, target_level)
return tensor, rewrap_fns

class save_on_cpu(saved_tensors_hooks):
"""Context-manager under which tensors saved by the forward pass will be
Expand Down Expand Up @@ -127,24 +218,82 @@ class save_on_cpu(saved_tensors_hooks):
"""
def __init__(self, pin_memory=False):
# We use weak references here to makes sure that the only owning reference
# to any tensors that are offloaded have lifetimes that are linked to that
# of the graph
self.unwrapped_copies: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary()
self.tid_to_weakhandle: weakref.WeakValueDictionary = weakref.WeakValueDictionary()

def pack_to_cpu(tensor):
if not pin_memory:
return (tensor.device, tensor.cpu())
device = tensor.device

# NB: Always unwrap the outer layer, which is guaranteed to be a TensorWrapper
# or a vanilla tensor. Unwrapping if the level is -1 is a no-op (I think).
level = torch._C._functorch.maybe_get_level(tensor)
inner = torch._C._functorch._unwrap_for_grad(tensor, level)
# NB: level after unwrapping isn't always level - 1, so query again
level = torch._C._functorch.maybe_get_level(inner)

# Unwrap all the way, but remember how to restore the wrappers
# including their autograd metadata
unwrapped, rewrap_fns = _functorch_unwrap_to_level(inner, -1)
tid = _get_tid(unwrapped)

if tid in self.tid_to_weakhandle:
handle = self.tid_to_weakhandle[tid]
else:
if not pin_memory:
unwrapped_copy = unwrapped.cpu()
else:
unwrapped_copy = torch.empty(
inner.size(),
dtype=inner.dtype,
layout=inner.layout,
pin_memory=(torch.cuda.is_available() and not inner.is_sparse))
unwrapped_copy.copy_(inner)

handle = _Handle()
self.unwrapped_copies[handle] = unwrapped_copy
self.tid_to_weakhandle[tid] = handle

packed = torch.empty(
tensor.size(),
dtype=tensor.dtype,
layout=tensor.layout,
pin_memory=(torch.cuda.is_available() and not tensor.is_sparse))
packed.copy_(tensor)
return (tensor.device, packed)
return (device, handle, rewrap_fns, level)

def unpack_from_cpu(packed):
device, tensor = packed
return tensor.to(device, non_blocking=pin_memory)
device, handle, rewrap_fns, original_level = packed
ret = self.unwrapped_copies[handle]
for rewrap_fn in reversed(rewrap_fns):
ret = rewrap_fn(ret)

new_level = torch._C._functorch.maybe_get_level(ret)
assert new_level == original_level, (new_level, original_level)

ret = ret.to(device, non_blocking=pin_memory)
# Ideally we would be completely working with unwrapped tensors during backward,
# but for a grad transform the wrapping level is the same as that of forward.
# I think that should not be the case. (Can we fix this?)
#
# If saved tensor logic properly detached, we shouldn't have to unwrap and rewrap
# here at all. The rewrapping here is implicit due to lifting.
assert torch._C._functorch.maybe_get_level(ret) > original_level or original_level == -1, (
"lifting should happen unless we've left transforms entirely")

return ret

super().__init__(pack_to_cpu, unpack_from_cpu)

def __enter__(self):
torch._C._autograd._push_saved_tensors_default_hooks(self.pack_hook, self.unpack_hook)
return self

def __exit__(self, *args: Any):
# We cannot clear here because that would be bc-breaking: people sometimes run
# forward using save_on_cpu, but exit the ctx before running backward.
# Note that this behavior is inconsistent with that of allow_mutation_on_saved ctx
# which requires that backward also be run inside the ctx.
#
# self.inner_copies.clear()
# self.tid_to_weakhandle.clear()
torch._C._autograd._pop_saved_tensors_default_hooks()

@contextlib.contextmanager
def disable_saved_tensors_hooks(error_message):
Expand Down Expand Up @@ -289,15 +438,6 @@ def __setstate__(self, state):
# - if the clone exists, the tensor must've been modified in-place
_allow_mutation_on_saved_tensors_enabled = False

def _get_tid(t) -> Tuple[int, int, int]:
return (id(t), t.data_ptr(), t._version)

def _get_sid(t) -> Tuple[int, int]:
return (t.data_ptr(), t._version)

class _Handle():
pass

class _swap_with_cloned(saved_tensors_hooks):
def __init__(self, ctx):
def pack_hook(t):
Expand Down
9 changes: 9 additions & 0 deletions torch/csrc/autograd/variable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <torch/csrc/autograd/functions/tensor.h>
#include <torch/csrc/autograd/generated/Functions.h>
#include <torch/csrc/autograd/utils/error_messages.h>
#include <ATen/functorch/TensorWrapper.h>

#include <ATen/core/VariableHooksInterface.h>

Expand Down Expand Up @@ -479,6 +480,14 @@ void VariableHooks::set_data(
at::OptionalTensorRef new_data_ref(new_data_base);
const Tensor& new_data = *new_data_ref;

// Store a dummy tensor in the meantime to avoid saving the entire original tensor
if (self.unsafeGetTensorImpl()->key_set().has(c10::DispatchKey::FuncTorchGradWrapper)) {
// If t is a TensorWrapper, setting t.data replaces the tensor it wraps
at::functorch::TensorWrapper* impl = reinterpret_cast<at::functorch::TensorWrapper*>(self.unsafeGetTensorImpl());
impl->_set_value(new_data);
return;
}

// `var.set_data(new_data)` shallow-copies all non-autograd TensorImpl fields
// from `new_data` to `var`. It requires that `new_data` and `var` have
// compatible tensor type.
Expand Down

0 comments on commit 8586ea2

Please sign in to comment.