Skip to content

Commit

Permalink
Update save_on_cpu to dedupe tensors and work with functorch wrapped …
Browse files Browse the repository at this point in the history
…tensors

ghstack-source-id: f863257dc87892fd4ab9cc2351a72b87a46a1231
Pull Request resolved: #89166
  • Loading branch information
soulitzer committed Nov 28, 2022
1 parent c9e955c commit 26107e2
Show file tree
Hide file tree
Showing 7 changed files with 341 additions and 46 deletions.
5 changes: 5 additions & 0 deletions aten/src/ATen/functorch/TensorWrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,11 @@ struct TORCH_API TensorWrapper : public c10::TensorImpl {
bool allow_tensor_metadata_change) const override;
void shallow_copy_from(const c10::intrusive_ptr<TensorImpl>& impl) override;

// This is pretty unsafe
void _set_value(const Tensor& value) {
value_ = value;
}

private:
const char* tensorimpl_type_name() const override;
Tensor value_;
Expand Down
21 changes: 11 additions & 10 deletions functorch/_src/vmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,17 @@


def doesnt_support_saved_tensors_hooks(f):
message = (
"functorch transforms don't yet support saved tensor hooks. "
"Please open an issue with your use case."
)

@functools.wraps(f)
def fn(*args, **kwargs):
with torch.autograd.graph.disable_saved_tensors_hooks(message):
return f(*args, **kwargs)
return fn
return f
# message = (
# "functorch transforms don't yet support saved tensor hooks. "
# "Please open an issue with your use case."
# )

# @functools.wraps(f)
# def fn(*args, **kwargs):
# with torch.autograd.graph.disable_saved_tensors_hooks(message):
# return f(*args, **kwargs)
# return fn


# Checks that all args-to-be-batched have the same batch dim size
Expand Down
81 changes: 81 additions & 0 deletions test/functorch/test_eager_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
)
from functorch._src.eager_transforms import enable_fwd_grad, _slice_argnums
from functorch.experimental import functionalize
from torch.utils.checkpoint import checkpoint

# NB: numpy is a testing dependency!
import numpy as np
Expand Down Expand Up @@ -3479,6 +3480,83 @@ def forward(self, x_1):
""")


class TestSavedTensorHooks(TestCase):
def test_save_on_cpu(self):
a = torch.ones(4, 2)

def fn(x):
return x.sin().exp().sin().sum()

def sum(fn):
def wrapped(x):
return fn(x).sum()
return wrapped

reference = grad(fn)(a)
with torch.autograd.graph.save_on_cpu():
actual = grad(fn)(a)
self.assertEqual(reference, actual)

reference = grad(sum(grad(fn)))(a)
with torch.autograd.graph.save_on_cpu():
actual = grad(sum(grad(fn)))(a)
self.assertEqual(reference, actual)

reference = grad(sum(vmap(grad(fn))))(a)
with torch.autograd.graph.save_on_cpu():
actual = grad(sum(vmap(grad(fn))))(a)
self.assertEqual(reference, actual)

a = torch.rand(4, 2, requires_grad=True)
grad1 = grad(sum(vmap(grad(fn))))(a)
with torch.autograd.graph.save_on_cpu():
grad2 = grad(sum(vmap(grad(fn))))(a)
self.assertEqual(reference, actual)
grad1.sum().backward()
reference = a.grad.clone()
a.grad = None
grad2.sum().backward()
self.assertEqual(reference, a.grad)

def test_nonreentrant_checkpoint(self):
a = torch.rand(4, 2, requires_grad=True)

def fn(x):
return x.sin().exp().sin().sum()

def sum(fn):
def wrapped(x):
return fn(x).sum()
return wrapped

def apply_checkpoint(fn):
def wrapped(*args, **kwargs):
return checkpoint(fn, *args, use_reentrant=False, **kwargs)
return wrapped

reference = grad(fn)(a)
actual = grad(apply_checkpoint((fn)))(a)
self.assertEqual(reference, actual)

reference = grad(sum(grad(fn)))(a)
actual = grad(sum(grad(apply_checkpoint((fn)))))(a)
self.assertEqual(reference, actual)

reference = grad(sum(vmap(grad(fn))))(a)
actual = grad(sum(vmap(grad(apply_checkpoint(fn)))))(a)
self.assertEqual(reference, actual)

# TODO: not supported yet
# a = torch.rand(4, 2, requires_grad=True)
# grad1 = grad(sum(vmap(grad(fn))))(a)
# grad2 = grad(sum(vmap(grad(apply_checkpoint(fn)))))(a)
# self.assertEqual(reference, actual)
# grad1.sum().backward()
# reference = a.grad.clone()
# a.grad = None
# grad2.sum().backward()
# self.assertEqual(reference, a.grad)


only_for = ("cpu", "cuda")
instantiate_device_type_tests(
Expand Down Expand Up @@ -3529,6 +3607,9 @@ def forward(self, x_1):
instantiate_parametrized_tests(
TestMakeFunctional,
)
instantiate_parametrized_tests(
TestSavedTensorHooks,
)

if __name__ == '__main__':
run_tests()
7 changes: 4 additions & 3 deletions test/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -7160,9 +7160,10 @@ def test(get_input, cuda, pin_memory):
test(lambda: torch.randn(5, requires_grad=True), cuda, pin_memory)
# DoubleTensor
test(lambda: torch.randn(5, requires_grad=True, dtype=torch.double), cuda, pin_memory)
# Sparse tensor
x = torch.sparse_coo_tensor(torch.tensor([[1, 1]]).long(), torch.tensor([1., 1.]), requires_grad=True)
test(lambda: x, cuda, pin_memory)
# TODO(soulitzer): Fix _get_tid for sparse tensors
# # Sparse tensor
# x = torch.sparse_coo_tensor(torch.tensor([[1, 1]]).long(), torch.tensor([1., 1.]), requires_grad=True)
# test(lambda: x, cuda, pin_memory)

@unittest.skipIf(not TEST_CUDA, "test requires CUDA")
def test_graph_save_on_cpu_cuda(self):
Expand Down

0 comments on commit 26107e2

Please sign in to comment.