Skip to content

Commit

Permalink
Revert "Revert "Unify meta tensor and fake tensor converter conversion (
Browse files Browse the repository at this point in the history
#87943)"" (#88045)

This reverts commit bc64999.

Check torch/_subclasses/meta_utils.py for "This is very tricky" for the bugfix explanation.

cc @mlazos @soumith @voznesenskym @yanboliang @penguinwu @anijain2305 @EikanWang @jgong5 @Guobing-Chen @chunyuan-w @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx
Pull Request resolved: #88045
Approved by: https://github.com/kit1980, https://github.com/Chillee
  • Loading branch information
ezyang authored and pytorchmergebot committed Oct 31, 2022
1 parent 2e1199d commit ff94494
Show file tree
Hide file tree
Showing 6 changed files with 475 additions and 202 deletions.
3 changes: 0 additions & 3 deletions test/dynamo/test_unspec.py
Expand Up @@ -50,9 +50,6 @@ class UnspecTest(cls):
UnspecReproTests = make_unspec_cls(test_repros.ReproTests)
UnspecNNModuleTests = make_unspec_cls(test_modules.NNModuleTests)

# RuntimeError: a leaf Variable that requires grad is being used in an in-place operation.
unittest.expectedFailure(UnspecReproTests.test_batch_norm_act_unspec)


@patch.object(torch._dynamo.config, "specialize_int_float", False)
class UnspecTests(torch._dynamo.test_case.TestCase):
Expand Down
146 changes: 123 additions & 23 deletions test/test_meta.py
Expand Up @@ -6,7 +6,7 @@
from enum import Enum
from torch.overrides import resolve_name
from torch.utils._pytree import tree_map, tree_flatten, tree_unflatten
from torch._subclasses.meta_utils import MetaConverter
from torch._subclasses.meta_utils import MetaConverter, assert_metadata_eq
import torch.utils._python_dispatch
from torch._dispatch.python import enable_python_dispatcher
from torch.testing._internal.common_utils import (
Expand Down Expand Up @@ -66,6 +66,9 @@ def assertSameVersionCounter(self, m1, m2):
self.assertNotEqual(m1._version, vc)
self.assertEqual(m2._version, m1._version)

def assertMetadataMatches(self, m1, m2):
assert_metadata_eq(self.assertEqual, m1, m2)

def test_view_of_non_leaf(self):
x = torch.randn(4, requires_grad=True)
y = x.neg()
Expand All @@ -74,9 +77,14 @@ def test_view_of_non_leaf(self):
to_meta = MetaConverter()
m1 = to_meta(z1)
m2 = to_meta(z2)
self.assertEqual(m1.shape, z1.shape)

# check the test is actually testing what it claims
self.assertTrue(m1._is_view())
self.assertFalse(m1._base.is_leaf)

self.assertIsNot(m1, m2)
self.assertMetadataMatches(m1, z1)
self.assertMetadataMatches(m2, z2)
self.assertSameVersionCounter(m1, m2)

def test_view_of_leaf(self):
Expand All @@ -86,77 +94,164 @@ def test_view_of_leaf(self):
to_meta = MetaConverter()
m1 = to_meta(z1)
m2 = to_meta(z2)
self.assertEqual(m1.shape, z1.shape)

# check the test is actually testing what it claims
self.assertTrue(m1._is_view())
self.assertTrue(m1._base.is_leaf)

self.assertIsNot(m1, m2)
self.assertMetadataMatches(m1, z1)
self.assertMetadataMatches(m2, z2)
self.assertSameVersionCounter(m1, m2)

def test_view_of_view_of_leaf(self):
x = torch.randn(8)
y = x.view(2, 4)
y.requires_grad = True
z = y.view(2, 2, 2)

to_meta = MetaConverter()
mx = to_meta(x)
mz = to_meta(z)

self.assertFalse(z.is_leaf)

self.assertMetadataMatches(mx, x)
self.assertMetadataMatches(mz, z)

def test_leaf(self):
x = torch.randn(4, requires_grad=True)
to_meta = MetaConverter()
m = to_meta(x)
self.assertEqual(m.shape, x.shape)

# check the test is actually testing what it claims
self.assertTrue(m.is_leaf)
self.assertTrue(m.requires_grad)

self.assertMetadataMatches(m, x)

def test_non_leaf(self):
x = torch.randn(4, requires_grad=True)
y = x.neg()
to_meta = MetaConverter()
m = to_meta(y)
self.assertEqual(m.shape, y.shape)

# check the test is actually testing what it claims
self.assertFalse(m.is_leaf)
self.assertTrue(m.requires_grad)

self.assertMetadataMatches(m, y)

def test_requires_grad_false(self):
x = torch.randn(4, requires_grad=False)
to_meta = MetaConverter()
m = to_meta(x)
self.assertEqual(m.shape, x.shape)

# check the test is actually testing what it claims
self.assertFalse(m.requires_grad)

self.assertMetadataMatches(m, x)

def test_channels_last(self):
x = torch.empty(2, 3, 4, 5, memory_format=torch.channels_last)
to_meta = MetaConverter()
m = to_meta(x)

# check the test is actually testing what it claims
self.assertTrue(m.is_leaf)

self.assertMetadataMatches(m, x)

def test_channels_last_leaf(self):
x = torch.empty(2, 3, 4, 5, memory_format=torch.channels_last, requires_grad=True)
to_meta = MetaConverter()
m = to_meta(x)

# check the test is actually testing what it claims
self.assertTrue(m.requires_grad)
self.assertTrue(m.is_leaf)

self.assertMetadataMatches(m, x)

def test_channels_last_non_leaf(self):
x = torch.empty(2, 3, 4, 5, memory_format=torch.channels_last, requires_grad=True)
y = x + 2

# sanity
self.assertEqual(x.stride(), y.stride())
self.assertFalse(y.is_leaf)

to_meta = MetaConverter()
m = to_meta(y)

# check the test is actually testing what it claims
self.assertTrue(m.requires_grad)
self.assertFalse(m.is_leaf)

self.assertMetadataMatches(m, y)

# Check that we can autograd with m as input without erroring;
# see https://github.com/pytorch/pytorch/issues/87956
loss = m.sum()
torch.autograd.grad(loss, m)

def test_empty_strided_non_dense_leaf(self):
x = torch.empty_strided((2, 2), (4, 2), requires_grad=True)

to_meta = MetaConverter()
m = to_meta(x)

# check the test is actually testing what it claims
self.assertTrue(m.requires_grad)
self.assertTrue(m.is_leaf)

self.assertMetadataMatches(m, x)

def test_non_leaf_torture(self):
x = torch.empty(20, requires_grad=True)
with torch.no_grad():
x.set_(x.storage(), 10, (2,), (2,))

to_meta = MetaConverter()
m = to_meta(x)

# check the test is actually testing what it claims
self.assertTrue(m.requires_grad)
self.assertTrue(m.is_leaf)

self.assertMetadataMatches(m, x)

# NB: complex stuff is not actually exercised right now because
# we have a blanket exclusion for complex conversion

def test_view_as_real(self):
x = torch.randn(4, dtype=torch.complex64)
y = torch.view_as_real(x)
m = MetaConverter()(y)
self.assertEqual(m.shape, y.shape)
self.assertEqual(m.stride(), y.stride())
self.assertEqual(m.dtype, y.dtype)
self.assertMetadataMatches(m, y)

def test_complex_noncontiguous_bug(self):
x = torch.randn((2, 2, 4, 9), dtype=torch.complex32)[:, 0, :, :]
m = MetaConverter()(x)
self.assertEqual(m.shape, x.shape)
self.assertEqual(m.stride(), x.stride())
self.assertEqual(m.dtype, x.dtype)
self.assertMetadataMatches(m, x)

def test_view_as_complex(self):
x = torch.randn((4, 2), dtype=torch.float32)
y = torch.view_as_complex(x)
m = MetaConverter()(y)
self.assertEqual(m.shape, y.shape)
self.assertEqual(m.stride(), y.stride())
self.assertEqual(m.dtype, y.dtype)
self.assertMetadataMatches(m, y)

def test_view_dtype(self):
x = torch.randn(4, dtype=torch.float32)
y = x.view(dtype=torch.int32)
m = MetaConverter()(y)
self.assertEqual(m.shape, y.shape)
self.assertEqual(m.stride(), y.stride())
self.assertEqual(m.dtype, y.dtype)
self.assertMetadataMatches(m, y)

def test_imag(self):
x = torch.randn(4, dtype=torch.complex64)
y = x.imag
m = MetaConverter()(y)
self.assertEqual(m.shape, y.shape)
self.assertEqual(m.dtype, y.dtype)
self.assertEqual(m.stride(), y.stride())
self.assertEqual(m.storage_offset(), y.storage_offset())
self.assertMetadataMatches(m, y)

def test_weakref(self):
x = torch.randn(4, 4, 4)
Expand Down Expand Up @@ -742,7 +837,12 @@ def __init__(self, test_case, *, device, dtype, inplace):
def __torch_function__(self, func, types, args=(), kwargs=None):
kwargs = kwargs or {}

if torch.jit.is_tracing() or isinstance(func, torch.ScriptMethod):
if (
torch.jit.is_tracing() or isinstance(func, torch.ScriptMethod) or
# meta converter doesn't work correctly when no_dispatch() is on, so
# skip running the crossref test in this case
torch._C._dispatch_tls_local_exclude_set().has(torch._C.DispatchKey.Python)
):
return func(*args, **kwargs)

if self.dtype in meta_function_skips.get(func, set()):
Expand Down
12 changes: 6 additions & 6 deletions tools/autograd/templates/python_variable_methods.cpp
Expand Up @@ -1135,7 +1135,7 @@ static PyObject* THPVariable_set_(
{
"set_()",
"set_(Storage source)",
"set_(Storage source, int64_t storage_offset, IntArrayRef size, IntArrayRef stride=None)",
"set_(Storage source, SymInt storage_offset, SymIntArrayRef size, SymIntArrayRef stride=None)",
"set_(Tensor source)",
"set_(Tensor source, SymInt storage_offset, SymIntArrayRef size, SymIntArrayRef stride=None)",
},
Expand Down Expand Up @@ -1181,14 +1181,14 @@ static PyObject* THPVariable_set_(
" for argument 1 'storage'");
auto dispatch_set_ = [](const Tensor& self,
Storage source,
int64_t storage_offset,
IntArrayRef size,
IntArrayRef stride) -> Tensor {
c10::SymInt storage_offset,
c10::SymIntArrayRef size,
c10::SymIntArrayRef stride) -> Tensor {
pybind11::gil_scoped_release no_gil;
return self.set_(source, storage_offset, size, stride);
return self.set__symint(source, storage_offset, size, stride);
};
return wrap(dispatch_set_(
self, storage, _r.toInt64(1), _r.intlist(2), _r.intlist(3)));
self, storage, _r.toSymInt(1), _r.symintlist(2), _r.symintlist(3)));
}
case 3: {
// aten::set_.source_Tensor(Tensor(a!) self, Tensor source) -> Tensor(a!)
Expand Down

0 comments on commit ff94494

Please sign in to comment.