Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update save_on_cpu and checkpointing to work with functorch wrapped tensors #89166

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
5 changes: 5 additions & 0 deletions aten/src/ATen/functorch/TensorWrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,11 @@ struct TORCH_API TensorWrapper : public c10::TensorImpl {
bool allow_tensor_metadata_change) const override;
void shallow_copy_from(const c10::intrusive_ptr<TensorImpl>& impl) override;

// This is pretty unsafe
void _set_value(const Tensor& value) {
value_ = value;
}

private:
const char* tensorimpl_type_name() const override;
Tensor value_;
Expand Down
21 changes: 11 additions & 10 deletions functorch/_src/vmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,17 @@


def doesnt_support_saved_tensors_hooks(f):
message = (
"functorch transforms don't yet support saved tensor hooks. "
"Please open an issue with your use case."
)

@functools.wraps(f)
def fn(*args, **kwargs):
with torch.autograd.graph.disable_saved_tensors_hooks(message):
return f(*args, **kwargs)
return fn
return f
# message = (
# "functorch transforms don't yet support saved tensor hooks. "
# "Please open an issue with your use case."
# )

# @functools.wraps(f)
# def fn(*args, **kwargs):
# with torch.autograd.graph.disable_saved_tensors_hooks(message):
# return f(*args, **kwargs)
# return fn


# Checks that all args-to-be-batched have the same batch dim size
Expand Down
81 changes: 81 additions & 0 deletions test/functorch/test_eager_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
)
from functorch._src.eager_transforms import enable_fwd_grad, _slice_argnums
from functorch.experimental import functionalize
from torch.utils.checkpoint import checkpoint

# NB: numpy is a testing dependency!
import numpy as np
Expand Down Expand Up @@ -3479,6 +3480,83 @@ def forward(self, x_1):
""")


class TestSavedTensorHooks(TestCase):
def test_save_on_cpu(self):
a = torch.ones(4, 2)

def fn(x):
return x.sin().exp().sin().sum()

def sum(fn):
def wrapped(x):
return fn(x).sum()
return wrapped

reference = grad(fn)(a)
with torch.autograd.graph.save_on_cpu():
actual = grad(fn)(a)
self.assertEqual(reference, actual)

reference = grad(sum(grad(fn)))(a)
with torch.autograd.graph.save_on_cpu():
actual = grad(sum(grad(fn)))(a)
self.assertEqual(reference, actual)

reference = grad(sum(vmap(grad(fn))))(a)
with torch.autograd.graph.save_on_cpu():
actual = grad(sum(vmap(grad(fn))))(a)
self.assertEqual(reference, actual)

a = torch.rand(4, 2, requires_grad=True)
grad1 = grad(sum(vmap(grad(fn))))(a)
with torch.autograd.graph.save_on_cpu():
grad2 = grad(sum(vmap(grad(fn))))(a)
self.assertEqual(reference, actual)
grad1.sum().backward()
reference = a.grad.clone()
a.grad = None
grad2.sum().backward()
self.assertEqual(reference, a.grad)

def test_nonreentrant_checkpoint(self):
a = torch.rand(4, 2, requires_grad=True)

def fn(x):
return x.sin().exp().sin().sum()

def sum(fn):
def wrapped(x):
return fn(x).sum()
return wrapped

def apply_checkpoint(fn):
def wrapped(*args, **kwargs):
return checkpoint(fn, *args, use_reentrant=False, **kwargs)
return wrapped

reference = grad(fn)(a)
actual = grad(apply_checkpoint((fn)))(a)
self.assertEqual(reference, actual)

reference = grad(sum(grad(fn)))(a)
actual = grad(sum(grad(apply_checkpoint((fn)))))(a)
self.assertEqual(reference, actual)

reference = grad(sum(vmap(grad(fn))))(a)
actual = grad(sum(vmap(grad(apply_checkpoint(fn)))))(a)
self.assertEqual(reference, actual)

# TODO: not supported yet
# a = torch.rand(4, 2, requires_grad=True)
# grad1 = grad(sum(vmap(grad(fn))))(a)
# grad2 = grad(sum(vmap(grad(apply_checkpoint(fn)))))(a)
# self.assertEqual(reference, actual)
# grad1.sum().backward()
# reference = a.grad.clone()
# a.grad = None
# grad2.sum().backward()
# self.assertEqual(reference, a.grad)


only_for = ("cpu", "cuda")
instantiate_device_type_tests(
Expand Down Expand Up @@ -3529,6 +3607,9 @@ def forward(self, x_1):
instantiate_parametrized_tests(
TestMakeFunctional,
)
instantiate_parametrized_tests(
TestSavedTensorHooks,
)

if __name__ == '__main__':
run_tests()
7 changes: 4 additions & 3 deletions test/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -7160,9 +7160,10 @@ def test(get_input, cuda, pin_memory):
test(lambda: torch.randn(5, requires_grad=True), cuda, pin_memory)
# DoubleTensor
test(lambda: torch.randn(5, requires_grad=True, dtype=torch.double), cuda, pin_memory)
# Sparse tensor
x = torch.sparse_coo_tensor(torch.tensor([[1, 1]]).long(), torch.tensor([1., 1.]), requires_grad=True)
test(lambda: x, cuda, pin_memory)
# TODO(soulitzer): Fix _get_tid for sparse tensors
# # Sparse tensor
# x = torch.sparse_coo_tensor(torch.tensor([[1, 1]]).long(), torch.tensor([1., 1.]), requires_grad=True)
# test(lambda: x, cuda, pin_memory)

@unittest.skipIf(not TEST_CUDA, "test requires CUDA")
def test_graph_save_on_cpu_cuda(self):
Expand Down