diff --git a/test/dynamo/test_unspec.py b/test/dynamo/test_unspec.py index 22f975d0f9d68c2..fd5396981b74082 100644 --- a/test/dynamo/test_unspec.py +++ b/test/dynamo/test_unspec.py @@ -50,9 +50,6 @@ class UnspecTest(cls): UnspecReproTests = make_unspec_cls(test_repros.ReproTests) UnspecNNModuleTests = make_unspec_cls(test_modules.NNModuleTests) -# RuntimeError: a leaf Variable that requires grad is being used in an in-place operation. -unittest.expectedFailure(UnspecReproTests.test_batch_norm_act_unspec) - @patch.object(torch._dynamo.config, "specialize_int_float", False) class UnspecTests(torch._dynamo.test_case.TestCase): diff --git a/test/test_meta.py b/test/test_meta.py index 997e4224654369d..26f9103b6e864a9 100644 --- a/test/test_meta.py +++ b/test/test_meta.py @@ -6,7 +6,7 @@ from enum import Enum from torch.overrides import resolve_name from torch.utils._pytree import tree_map, tree_flatten, tree_unflatten -from torch._subclasses.meta_utils import MetaConverter +from torch._subclasses.meta_utils import MetaConverter, assert_metadata_eq import torch.utils._python_dispatch from torch._dispatch.python import enable_python_dispatcher from torch.testing._internal.common_utils import ( @@ -66,6 +66,9 @@ def assertSameVersionCounter(self, m1, m2): self.assertNotEqual(m1._version, vc) self.assertEqual(m2._version, m1._version) + def assertMetadataMatches(self, m1, m2): + assert_metadata_eq(self.assertEqual, m1, m2) + def test_view_of_non_leaf(self): x = torch.randn(4, requires_grad=True) y = x.neg() @@ -74,9 +77,14 @@ def test_view_of_non_leaf(self): to_meta = MetaConverter() m1 = to_meta(z1) m2 = to_meta(z2) - self.assertEqual(m1.shape, z1.shape) + + # check the test is actually testing what it claims self.assertTrue(m1._is_view()) self.assertFalse(m1._base.is_leaf) + + self.assertIsNot(m1, m2) + self.assertMetadataMatches(m1, z1) + self.assertMetadataMatches(m2, z2) self.assertSameVersionCounter(m1, m2) def test_view_of_leaf(self): @@ -86,35 +94,133 @@ def test_view_of_leaf(self): to_meta = MetaConverter() m1 = to_meta(z1) m2 = to_meta(z2) - self.assertEqual(m1.shape, z1.shape) + + # check the test is actually testing what it claims self.assertTrue(m1._is_view()) self.assertTrue(m1._base.is_leaf) + + self.assertIsNot(m1, m2) + self.assertMetadataMatches(m1, z1) + self.assertMetadataMatches(m2, z2) self.assertSameVersionCounter(m1, m2) + def test_view_of_view_of_leaf(self): + x = torch.randn(8) + y = x.view(2, 4) + y.requires_grad = True + z = y.view(2, 2, 2) + + to_meta = MetaConverter() + mx = to_meta(x) + mz = to_meta(z) + + self.assertFalse(z.is_leaf) + + self.assertMetadataMatches(mx, x) + self.assertMetadataMatches(mz, z) + def test_leaf(self): x = torch.randn(4, requires_grad=True) to_meta = MetaConverter() m = to_meta(x) - self.assertEqual(m.shape, x.shape) + + # check the test is actually testing what it claims self.assertTrue(m.is_leaf) self.assertTrue(m.requires_grad) + self.assertMetadataMatches(m, x) + def test_non_leaf(self): x = torch.randn(4, requires_grad=True) y = x.neg() to_meta = MetaConverter() m = to_meta(y) - self.assertEqual(m.shape, y.shape) + + # check the test is actually testing what it claims self.assertFalse(m.is_leaf) self.assertTrue(m.requires_grad) + self.assertMetadataMatches(m, y) + def test_requires_grad_false(self): x = torch.randn(4, requires_grad=False) to_meta = MetaConverter() m = to_meta(x) - self.assertEqual(m.shape, x.shape) + + # check the test is actually testing what it claims self.assertFalse(m.requires_grad) + self.assertMetadataMatches(m, x) + + def test_channels_last(self): + x = torch.empty(2, 3, 4, 5, memory_format=torch.channels_last) + to_meta = MetaConverter() + m = to_meta(x) + + # check the test is actually testing what it claims + self.assertTrue(m.is_leaf) + + self.assertMetadataMatches(m, x) + + def test_channels_last_leaf(self): + x = torch.empty(2, 3, 4, 5, memory_format=torch.channels_last, requires_grad=True) + to_meta = MetaConverter() + m = to_meta(x) + + # check the test is actually testing what it claims + self.assertTrue(m.requires_grad) + self.assertTrue(m.is_leaf) + + self.assertMetadataMatches(m, x) + + def test_channels_last_non_leaf(self): + x = torch.empty(2, 3, 4, 5, memory_format=torch.channels_last, requires_grad=True) + y = x + 2 + + # sanity + self.assertEqual(x.stride(), y.stride()) + self.assertFalse(y.is_leaf) + + to_meta = MetaConverter() + m = to_meta(y) + + # check the test is actually testing what it claims + self.assertTrue(m.requires_grad) + self.assertFalse(m.is_leaf) + + self.assertMetadataMatches(m, y) + + # Check that we can autograd with m as input without erroring; + # see https://github.com/pytorch/pytorch/issues/87956 + loss = m.sum() + torch.autograd.grad(loss, m) + + def test_empty_strided_non_dense_leaf(self): + x = torch.empty_strided((2, 2), (4, 2), requires_grad=True) + + to_meta = MetaConverter() + m = to_meta(x) + + # check the test is actually testing what it claims + self.assertTrue(m.requires_grad) + self.assertTrue(m.is_leaf) + + self.assertMetadataMatches(m, x) + + def test_non_leaf_torture(self): + x = torch.empty(20, requires_grad=True) + with torch.no_grad(): + x.set_(x.storage(), 10, (2,), (2,)) + + to_meta = MetaConverter() + m = to_meta(x) + + # check the test is actually testing what it claims + self.assertTrue(m.requires_grad) + self.assertTrue(m.is_leaf) + + self.assertMetadataMatches(m, x) + # NB: complex stuff is not actually exercised right now because # we have a blanket exclusion for complex conversion @@ -122,41 +228,30 @@ def test_view_as_real(self): x = torch.randn(4, dtype=torch.complex64) y = torch.view_as_real(x) m = MetaConverter()(y) - self.assertEqual(m.shape, y.shape) - self.assertEqual(m.stride(), y.stride()) - self.assertEqual(m.dtype, y.dtype) + self.assertMetadataMatches(m, y) def test_complex_noncontiguous_bug(self): x = torch.randn((2, 2, 4, 9), dtype=torch.complex32)[:, 0, :, :] m = MetaConverter()(x) - self.assertEqual(m.shape, x.shape) - self.assertEqual(m.stride(), x.stride()) - self.assertEqual(m.dtype, x.dtype) + self.assertMetadataMatches(m, x) def test_view_as_complex(self): x = torch.randn((4, 2), dtype=torch.float32) y = torch.view_as_complex(x) m = MetaConverter()(y) - self.assertEqual(m.shape, y.shape) - self.assertEqual(m.stride(), y.stride()) - self.assertEqual(m.dtype, y.dtype) + self.assertMetadataMatches(m, y) def test_view_dtype(self): x = torch.randn(4, dtype=torch.float32) y = x.view(dtype=torch.int32) m = MetaConverter()(y) - self.assertEqual(m.shape, y.shape) - self.assertEqual(m.stride(), y.stride()) - self.assertEqual(m.dtype, y.dtype) + self.assertMetadataMatches(m, y) def test_imag(self): x = torch.randn(4, dtype=torch.complex64) y = x.imag m = MetaConverter()(y) - self.assertEqual(m.shape, y.shape) - self.assertEqual(m.dtype, y.dtype) - self.assertEqual(m.stride(), y.stride()) - self.assertEqual(m.storage_offset(), y.storage_offset()) + self.assertMetadataMatches(m, y) def test_weakref(self): x = torch.randn(4, 4, 4) @@ -742,7 +837,12 @@ def __init__(self, test_case, *, device, dtype, inplace): def __torch_function__(self, func, types, args=(), kwargs=None): kwargs = kwargs or {} - if torch.jit.is_tracing() or isinstance(func, torch.ScriptMethod): + if ( + torch.jit.is_tracing() or isinstance(func, torch.ScriptMethod) or + # meta converter doesn't work correctly when no_dispatch() is on, so + # skip running the crossref test in this case + torch._C._dispatch_tls_local_exclude_set().has(torch._C.DispatchKey.Python) + ): return func(*args, **kwargs) if self.dtype in meta_function_skips.get(func, set()): diff --git a/tools/autograd/templates/python_variable_methods.cpp b/tools/autograd/templates/python_variable_methods.cpp index 7122532a54410f8..e3c0a8b987bd68f 100644 --- a/tools/autograd/templates/python_variable_methods.cpp +++ b/tools/autograd/templates/python_variable_methods.cpp @@ -1135,7 +1135,7 @@ static PyObject* THPVariable_set_( { "set_()", "set_(Storage source)", - "set_(Storage source, int64_t storage_offset, IntArrayRef size, IntArrayRef stride=None)", + "set_(Storage source, SymInt storage_offset, SymIntArrayRef size, SymIntArrayRef stride=None)", "set_(Tensor source)", "set_(Tensor source, SymInt storage_offset, SymIntArrayRef size, SymIntArrayRef stride=None)", }, @@ -1181,14 +1181,14 @@ static PyObject* THPVariable_set_( " for argument 1 'storage'"); auto dispatch_set_ = [](const Tensor& self, Storage source, - int64_t storage_offset, - IntArrayRef size, - IntArrayRef stride) -> Tensor { + c10::SymInt storage_offset, + c10::SymIntArrayRef size, + c10::SymIntArrayRef stride) -> Tensor { pybind11::gil_scoped_release no_gil; - return self.set_(source, storage_offset, size, stride); + return self.set__symint(source, storage_offset, size, stride); }; return wrap(dispatch_set_( - self, storage, _r.toInt64(1), _r.intlist(2), _r.intlist(3))); + self, storage, _r.toSymInt(1), _r.symintlist(2), _r.symintlist(3))); } case 3: { // aten::set_.source_Tensor(Tensor(a!) self, Tensor source) -> Tensor(a!) diff --git a/torch/_subclasses/fake_tensor.py b/torch/_subclasses/fake_tensor.py index 3e5cbdb65226400..d6e6f79647fd3b5 100644 --- a/torch/_subclasses/fake_tensor.py +++ b/torch/_subclasses/fake_tensor.py @@ -2,7 +2,6 @@ import functools import itertools import sys -import warnings import weakref from dataclasses import dataclass from functools import partial @@ -139,15 +138,17 @@ def tree_flatten_only(ty: Type[T], pytree: PyTree): # structure. Like `MetaConverter`, it uses `WeakTensorRefKey` to # hold a weak reference for all memoized tensors. class FakeTensorConverter(object): - tensor_memo: weakref.WeakValueDictionary + @property + def tensor_memo(self): + return self.meta_converter.tensor_memo + meta_converter: MetaConverter constant_storage_mapping: Dict[StorageWeakRef, List[TensorWeakRef]] def __init__(self): - # FakeTensors store the FakeTensorMode which in turn stores a - # FakeTensor, so we need to hold a weak reference to the FakeTensor - # otherwise we would induce a circular reference - self.tensor_memo = weakref.WeakValueDictionary() + # In principle preserving views should be OK, but in practice + # AOTAutograd (or maybe autograd) seems to do the wrong thing. See + # https://github.com/pytorch/torchdynamo/issues/1815 self.meta_converter = MetaConverter() # map from to storage to corresponding constant tensors @@ -214,28 +215,31 @@ def from_real_tensor(self, fake_mode, t, make_constant=False, shape_env=None): # not yet supported in metatensors if t.is_quantized: raise UnsupportedFakeTensorException("quantized nyi in meta tensors") - with no_dispatch(): - meta_t = self.meta_converter(t, shape_env=shape_env) - if meta_t.device.type != "meta": - raise UnsupportedFakeTensorException("meta converter nyi") - out = FakeTensor( - fake_mode, - meta_t, - existing_device, - constant=t if make_constant else None, - ) - out.requires_grad_(t.requires_grad) - if make_constant: - self.add_constant_storage_mapping(out) if type(t) is torch.nn.Parameter: assert not make_constant - out = torch.nn.Parameter(out, requires_grad=out.requires_grad) # type: ignore[assignment] - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", "The .grad attribute of a Tensor") - grad_not_none = t.grad is not None - if grad_not_none: - out.grad = self.from_real_tensor(fake_mode, t.grad, shape_env=shape_env) - self.set_tensor_memo(t, out) + + def mk_fake_tensor(make_meta_t): + # NB: don't use in_kernel_invocation_manager. to + # ensure FakeTensor can internally do constant computation + # as necessary. Invocation manager is "more correct" as + # it works for more operators in make_meta_t, but + # invariant is that make_meta_t only calls factories + # for which it is not strictly necessary to use the + # invocation manager (I think!) + with no_dispatch(): + return FakeTensor( + fake_mode, + make_meta_t(), + existing_device, + constant=t if make_constant else None, + ) + + out = self.meta_converter(t, shape_env=shape_env, callback=mk_fake_tensor) + if out is NotImplemented: + raise UnsupportedFakeTensorException("meta converter nyi") + if make_constant: + self.add_constant_storage_mapping(out) + # NB: meta_converter set the memo return out # If you specify the device, it MUST be a meta tensor. @@ -296,7 +300,9 @@ def constructors(fake_mode, func, *args, **kwargs): out_device = new_kwargs.pop("device", None) out_device = out_device if out_device is not None else default_device new_kwargs["device"] = torch.device("meta") - r = func(*args, **new_kwargs) + # Not in_kernel_invocation_manager as no fake tensor inputs + with no_dispatch(): + r = func(*args, **new_kwargs) return FakeTensor(fake_mode, r, out_device) @@ -309,7 +315,8 @@ def non_kwarg_to(fake_mode, func, *args, **kwargs): out_device = input_device if input_device else new_kwargs["input"].device new_kwargs["device"] = torch.device("meta") inp = new_kwargs.pop("input") - r = func(inp, **new_kwargs) + with in_kernel_invocation_manager(fake_mode): + r = func(inp, **new_kwargs) # TODO: I think this does the wrong thing if r is inp return fake_mode.fake_tensor_converter.from_meta_and_device( fake_mode, r, out_device @@ -320,7 +327,8 @@ def non_kwarg_to(fake_mode, func, *args, **kwargs): # since the device of `the_template` is ignored @register_op_impl(aten.resize_as_.default) def resize_as_(fake_mode, func, *args, **kwargs): - return func(*args, **kwargs) + with in_kernel_invocation_manager(fake_mode): + return func(*args, **kwargs) @register_op_impl(aten._sparse_coo_tensor_with_dims_and_tensors.default) @@ -710,6 +718,13 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None): else: return args[0].fake_device + # Some attribute queries that can be serviced directly + # See Note [is_coalesced is dispatched] + if func in [torch.ops.aten.is_coalesced.default]: + # NB: no_dispatch is ok here too, this func is very simple + with in_kernel_invocation_manager(self): + return func(*args, **kwargs) + flat_arg_fake_tensors = tree_flatten_only(FakeTensor, (args, kwargs)) flat_symints = tree_flatten_only(torch.SymInt, (args, kwargs)) has_symbolic_sizes = ( @@ -725,38 +740,38 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None): if func in self.lift_fns: out = func(*args, **kwargs) if self.may_turn_const(out): + # NB: not in_kernel_invocation_manager because we're doing real + # compute here with no_dispatch(): - return converter(self, out.clone(), make_constant=True) - - with no_dispatch(): - flat_arg_tensors = tree_flatten_only(torch.Tensor, (args, kwargs)) - # See [subclass inputs] below - # NB: If you're seeing a mysterious infinite loop involving fake - # tensor, it might be related to this line. Though I'm not sure - # how you'll know to read this comment, as this line won't show up - # in the stack trace. - if self.check_for_subclass(flat_arg_tensors): - return NotImplemented - - # if we are in the dispatch mode, we will enter this function even if the inputs - # are not FakeTensors. For now, throw if any non-Fake Tensor inputs - # and just support constructors. - - # this is generated from torch.tensor(), which does not use the - # dispatcher, to allow wrapper subclasses to wrap the new tensor - if func in self.lift_fns: - assert ( - len(kwargs) == 0 - and len(args) == 1 - and type(args[0]) is torch.Tensor - ), f"{args} {kwargs}" - return converter(self, args[0]) - - if self.check_for_non_fake(flat_arg_tensors): - raise Exception( - "Invoking operators with non-Fake Tensor inputs in FakeTensorMode is not yet supported. " - f"Please convert all Tensors to FakeTensors first. Found in {func}(*{args}, **{kwargs})" - ) + out = out.clone() + return converter(self, out, make_constant=True) + + flat_arg_tensors = tree_flatten_only(torch.Tensor, (args, kwargs)) + # See [subclass inputs] below + # NB: If you're seeing a mysterious infinite loop involving fake + # tensor, it might be related to this line. Though I'm not sure + # how you'll know to read this comment, as this line won't show up + # in the stack trace. + if self.check_for_subclass(flat_arg_tensors): + return NotImplemented + + # if we are in the dispatch mode, we will enter this function even if the inputs + # are not FakeTensors. For now, throw if any non-Fake Tensor inputs + # and just support constructors. + + # this is generated from torch.tensor(), which does not use the + # dispatcher, to allow wrapper subclasses to wrap the new tensor + if func in self.lift_fns: + assert ( + len(kwargs) == 0 and len(args) == 1 and type(args[0]) is torch.Tensor + ), f"{args} {kwargs}" + return converter(self, args[0]) + + if self.check_for_non_fake(flat_arg_tensors): + raise Exception( + "Invoking operators with non-Fake Tensor inputs in FakeTensorMode is not yet supported. " + f"Please convert all Tensors to FakeTensors first. Found in {func}(*{args}, **{kwargs})" + ) # The current constant handling only support tracing systems # (aot autograd, torchdynamo) where each operation is run consecutively. @@ -776,27 +791,30 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None): and len(flat_arg_fake_tensors) != 0 and not has_symbolic_sizes ): + const_args, const_kwargs = pytree.tree_map_only( + FakeTensor, lambda t: t.constant, (args, kwargs) + ) + + # NB: not in_kernel_invocation_manager(self) as we want to do REAL + # compute with no_dispatch(): - const_args, const_kwargs = pytree.tree_map_only( - FakeTensor, lambda t: t.constant, (args, kwargs) - ) out = func(*const_args, **const_kwargs) - all_constant = pytree.tree_all_only( - torch.Tensor, lambda t: self.may_turn_const(t), out - ) + all_constant = pytree.tree_all_only( + torch.Tensor, lambda t: self.may_turn_const(t), out + ) - if all_constant: - return pytree.tree_map_only( - torch.Tensor, - lambda t: converter(self, t, make_constant=True), - out, - ) + if all_constant: + return pytree.tree_map_only( + torch.Tensor, + lambda t: converter(self, t, make_constant=True), + out, + ) - # we weren't able to turn outputs to constants, - # so invalidate all constants that might be aliases of the outputs - for ten in tree_flatten_only(torch.Tensor, out): - converter.invalidate_constant_aliases(ten) + # we weren't able to turn outputs to constants, + # so invalidate all constants that might be aliases of the outputs + for ten in tree_flatten_only(torch.Tensor, out): + converter.invalidate_constant_aliases(ten) # we are falling through to running non constant tensors, any input constant that # is written to must be invalidated @@ -817,14 +835,13 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None): ): from torch._decomp import meta_table as meta_table - with no_dispatch(): - if func == aten.size.default: - sys.stderr.write( - "Trying to call aten.size on a tensor with symbolic shapes. " - "It's likely that this is from calling tensor.shape in C++" - ) - # We do this to allow for better error localization with `TORCH_SHOW_CPP_STACKTRACES=1` - return None + if func == aten.size.default: + sys.stderr.write( + "Trying to call aten.size on a tensor with symbolic shapes. " + "It's likely that this is from calling tensor.shape in C++" + ) + # We do this to allow for better error localization with `TORCH_SHOW_CPP_STACKTRACES=1` + return None with self: if func in meta_table: @@ -860,32 +877,27 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None): f"{func} - couldn't find symbolic meta function/decomposition" ) - with no_dispatch(): - # special handling for funcs registered through `register_op_impl`, - # e.g., manipulating args on constructor calls to construct meta tensors - # and then afterwards wrapping them to a FakeTensor - for run_impl_check, op_impl in op_implementations: - if run_impl_check(func): - op_impl_out = op_impl(self, func, *args, **kwargs) - if op_impl_out != NotImplemented: - return op_impl_out - - # run kernel registered to meta for func, which include - # python meta registrations, prims, decomps, and c++ meta fns (structured kernels) - try: - with in_kernel_invocation_manager(self): - r = func(*args, **kwargs) - except NotImplementedError as not_implemented_error: - # no meta kernel registered, fallback to kernel for the device - if not self.allow_fallback_kernels: - raise not_implemented_error - return run_fallback_kernel( - self, func, args, kwargs, not_implemented_error - ) - - return self.wrap_meta_outputs_with_default_device_logic( - r, func, args, kwargs - ) + # special handling for funcs registered through `register_op_impl`, + # e.g., manipulating args on constructor calls to construct meta tensors + # and then afterwards wrapping them to a FakeTensor + for run_impl_check, op_impl in op_implementations: + if run_impl_check(func): + op_impl_out = op_impl(self, func, *args, **kwargs) + if op_impl_out != NotImplemented: + return op_impl_out + + # run kernel registered to meta for func, which include + # python meta registrations, prims, decomps, and c++ meta fns (structured kernels) + try: + with in_kernel_invocation_manager(self): + r = func(*args, **kwargs) + except NotImplementedError as not_implemented_error: + # no meta kernel registered, fallback to kernel for the device + if not self.allow_fallback_kernels: + raise not_implemented_error + return run_fallback_kernel(self, func, args, kwargs, not_implemented_error) + + return self.wrap_meta_outputs_with_default_device_logic(r, func, args, kwargs) # [subclass inputs] # Suppose we enable fake tensor mode. This means that fake tensor @@ -959,6 +971,7 @@ def functions_with_cpp_meta_impl_that_support_symint(self): aten.as_strided.default, aten.zeros.default, aten.detach.default, + aten.set_.source_Storage_storage_offset, ] @property @@ -1004,8 +1017,11 @@ def run_fallback_kernel(fake_mode, func, args, kwargs, orig_not_implemented_exce if torch.Tag.inplace_view in func.tags: # type: ignore[attr-defined] raise orig_not_implemented_exception + inp_impls = {} + + # Don't use in_kernel_invocation_manager(fake_mode) as we want to do + # REAL compute (not with meta device) with no_dispatch(): - inp_impls = {} def to_real_tensor(e): if isinstance(e, FakeTensor): @@ -1021,25 +1037,25 @@ def to_real_tensor(e): r = func(*args, **kwargs) - tensor_impls = set() - storages = set() - - for e in tree_flatten((args, kwargs))[0]: - if isinstance(e, torch.Tensor): - if not e.is_sparse: - storages.add(e.storage()._cdata) - - # TODO: also check metadata change on inputs - # proper aliasing/metadata relationship between outputs and inputs will - # not be set up, bc of conversion to device, unless we can reuse an - # input impl - for e in tree_flatten(r)[0]: - if id(e) not in inp_impls and ( - isinstance(e, torch.Tensor) - and not e.is_sparse - and e.storage()._cdata in storages - ): - raise orig_not_implemented_exception + tensor_impls = set() + storages = set() + + for e in tree_flatten((args, kwargs))[0]: + if isinstance(e, torch.Tensor): + if not e.is_sparse: + storages.add(e.storage()._cdata) + + # TODO: also check metadata change on inputs + # proper aliasing/metadata relationship between outputs and inputs will + # not be set up, bc of conversion to device, unless we can reuse an + # input impl + for e in tree_flatten(r)[0]: + if id(e) not in inp_impls and ( + isinstance(e, torch.Tensor) + and not e.is_sparse + and e.storage()._cdata in storages + ): + raise orig_not_implemented_exception def map_out(e): if isinstance(e, torch.Tensor): diff --git a/torch/_subclasses/meta_utils.py b/torch/_subclasses/meta_utils.py index 51231811631bc5a..0e2bbe49dd22644 100644 --- a/torch/_subclasses/meta_utils.py +++ b/torch/_subclasses/meta_utils.py @@ -1,8 +1,10 @@ +import contextlib +import warnings import weakref +from typing import ContextManager import torch from torch.multiprocessing.reductions import StorageWeakRef -from torch.utils._mode_utils import no_dispatch def safe_is_leaf(t): @@ -13,6 +15,47 @@ def safe_is_leaf(t): return False +def safe_grad(t): + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", "The .grad attribute of a Tensor") + return t.grad + + +def assert_eq(a, b): + assert a == b, f"{a} != {b}" + + +def assert_metadata_eq(assert_eq, m1, m2): + def go(m1, m2): + assert_eq(m1.dtype, m2.dtype) + assert_eq(m1.shape, m2.shape) + assert_eq(m1.requires_grad, m2.requires_grad) + assert_eq(m1.is_leaf, m2.is_leaf) + assert_eq(m1.grad_fn is None, m2.grad_fn is None) + assert_eq(m1.is_sparse, m2.is_sparse) + assert_eq(m1.is_inference(), m2.is_inference()) + assert_eq(m1.is_conj(), m2.is_conj()) + assert_eq(m1.is_neg(), m2.is_neg()) + assert_eq(safe_grad(m1) is not None, safe_grad(m2) is not None) + if safe_grad(m1) is not None: + go(m1.grad, m2.grad) + if m1.is_sparse: + assert_eq(m1.dense_dim(), m2.dense_dim()) + assert_eq(m1.sparse_dim(), m2.sparse_dim()) + assert_eq(m1.is_coalesced(), m2.is_coalesced()) + else: + assert_eq(m1.stride(), m2.stride()) + assert_eq(m1.storage_offset(), m2.storage_offset()) + assert_eq(m1._is_view(), m2._is_view()) + if m1._is_view(): + go(m1._base, m2._base) + # TODO: test if is resizable (no direct query for this atm) + # TODO: audit AutogradMeta to see if it matches + # TODO: test forward AD + + return go(m1, m2) + + # torch.Tensors cannot be used as a key in a dictionary # because they define a custom __eq__ function which when used # to resolve hash collisions will throw when comparing tensors: @@ -127,18 +170,31 @@ def del_ten(): # NB: doesn't actually return a storage, because meta storage is # not supported - def meta_storage(self, s): + def meta_storage(self, s, callback): # NB: TypedStorage is freshly allocated and cannot be used as hash # key index. # Use a Weak Ref to s in order to not leak memory swr = StorageWeakRef(s) if swr not in self.storage_memo: - self.storage_memo[swr] = torch.empty(s.size(), dtype=s.dtype, device="meta") + self.storage_memo[swr] = ( + callback( + lambda: torch.empty(s.size(), dtype=torch.uint8, device="meta") + ) + .storage() + .untyped() + ) return self.storage_memo[swr] # This function assumes that it's possible to do the conversion - def meta_tensor(self, t, shape_env=None): + def meta_tensor(self, t, shape_env=None, callback=lambda t: t()): + # This indicates you set no_dispatch() before calling into this + # function. This is an error: we may be creating fake tensors and + # will perform operations on them which need fake tensor mode to + # be active. You will segfault if you are in a no_dispatch() block. + assert not torch._C._dispatch_tls_local_exclude_set().has( + torch._C.DispatchKey.Python + ) arg_cnt = self.arg_cnt self.arg_cnt += 1 @@ -166,14 +222,22 @@ def sym_sizes_strides(t): if t.is_sparse: assert shape_env is None, "symbolic on sparse NYI" is_leaf = safe_is_leaf(t) - r = torch.ops.aten._sparse_coo_tensor_with_dims( - t.sparse_dim(), - t.dense_dim(), - t.shape, - dtype=t.dtype, - layout=torch.sparse_coo, - device="meta", + r = callback( + lambda: torch.ops.aten._sparse_coo_tensor_with_dims( + t.sparse_dim(), + t.dense_dim(), + t.shape, + dtype=t.dtype, + layout=torch.sparse_coo, + device="meta", + ) ) + assert safe_is_leaf(r), "the callback you passed in doesn't detach" + # Note [is_coalesced is dispatched] + # Strangely enough, is_coalesced() is a dispatched operator, + # which means that it will get caught by fake tensor mode. + # Ordinarily this would error, but there's some logic in + # fake tensor ensure this doesn't happen. r._coalesced_(t.is_coalesced()) if t.requires_grad: r.requires_grad = True @@ -184,11 +248,12 @@ def sym_sizes_strides(t): elif t._is_view(): # Construct views in two steps: recursively meta-fy their - # base, and then create the view off that. NB: doing it + # base, and then create view(s) off that. NB: doing it # directly from storage is WRONG because this won't cause # version counters to get shared. assert t._is_view() - base = self.meta_tensor(t._base) + + base = self.meta_tensor(t._base, shape_env, callback) def is_c_of_r(complex_dtype, real_dtype): return ( @@ -209,43 +274,135 @@ def is_c_of_r(complex_dtype, real_dtype): # that hasn't been handled here base = base.view(t.dtype) - with torch.enable_grad(): - sizes, strides = sym_sizes_strides(t) - r = base.as_strided(sizes, strides, sym(t.storage_offset())) + # This is very tricky. Naively, you might expect this + # to hold: + # + # if t.requires_grad and not safe_is_leaf(t) + # assert t._base.requires_grad + # + # But it's not true! As you can see in the following + # program: + # + # x = torch.zeros(4) + # y = x.view(1, 4) + # y.requires_grad = True + # z = y.view(1, 1, 4) + # assert z._base is x + # + # So we may have to do *two* views out of the base to + # recreate this situation. + + sizes, strides = sym_sizes_strides(t) + if safe_is_leaf(t): + # Leaf views that track view metadata are created by + # creating a view inside a no_grad block + with torch.no_grad(): + r = base.as_strided(sizes, strides, sym(t.storage_offset())) + # As it's a leaf, we can directly assign requires_grad + r.requires_grad = t.requires_grad + else: + if t._base.requires_grad == t.requires_grad: + # Easy case, just run the view op + with torch.enable_grad(): + r = base.as_strided( + sizes, strides, sym(t.storage_offset()) + ) + else: + # Obscure case. Create a leaf view and give it the + # correct requires_grad, then do the final view. + # NB: Can't have a non-leaf without requiring grad! + assert t.requires_grad + with torch.no_grad(): + mid = base.view(base.shape) + mid.requires_grad = t.requires_grad + with torch.enable_grad(): + r = mid.as_strided( + sizes, strides, sym(t.storage_offset()) + ) + else: is_leaf = safe_is_leaf(t) - # Fake up some autograd history. - if t.requires_grad: - r = torch.empty( - (0,), dtype=t.dtype, device="meta", requires_grad=True + sizes, strides = sym_sizes_strides(t) + storage_offset = sym(t.storage_offset()) + r = callback( + lambda: torch.empty_strided( + sizes, strides, dtype=t.dtype, device="meta" ) + ) + assert safe_is_leaf(r), "the callback you passed in doesn't detach" + if t.requires_grad: + r.requires_grad = t.requires_grad if not is_leaf: + # Fake up some autograd history. with torch.enable_grad(): - # The backward function here will be wrong, but - # that's OK; our goal is just to get the metadata - # looking as close as possible; we're not going to - # actually try to backward() on these produced - # metas. TODO: would be safer to install some - # sort of unsupported grad_fn here - r = r.clone() + # preserve_format is the default, but we want to + # emphasize how important it is to preserve + # format here + r = r.clone(memory_format=torch.preserve_format) + + s = t.storage().untyped() + swr = StorageWeakRef(s) + if ( + swr not in self.storage_memo + and r.stride() == strides + and r.storage_offset() == storage_offset + ): + # You're normal and happy, install the fresh storage into the memo + self.storage_memo[swr] = r.storage().untyped() else: - r = torch.empty((0,), dtype=t.dtype, device="meta") - # As long as meta storage is not supported, need to prevent - # redispatching on set_(Storage, ...) which will choke with - # meta storage - s = self.meta_storage(t.storage()) - with no_dispatch(): - sizes, strides = sym_sizes_strides(t) - with torch.no_grad(): - r.set_(s, sym(t.storage_offset()), sizes, strides) + # You're in crazy town; somehow you gave us a tensor + # that wasn't a view, but had nonzero storage offset, + # nontrivial strides (such that clone() couldn't + # preserve them), or already aliases with another + # tensor's storage. The most typical way to end + # up here is with set_. So use set_ to bludgeon this + # in. + r_s = self.meta_storage(s, callback=callback) + # NB: In principle, this should always work, but there + # is some subtle difference in the autograd metadata + # that means we will backprop the set_ call, even if + # r is declared as an input to grad. + # See https://github.com/pytorch/pytorch/issues/87956 + # for the reproducer. + # NB: The in_kernel_invocation_manager here is necessary + # for fake tensor. If we run the set_ call with fake + # tensor on, r will improperly report that it is NOT a + # meta tensor but a cpu tensor, and then the set_ call + # will fail due to device mismatch. no_dispatch() is + # not enough, because the fake tensor will still claim + # to be a CPU tensor and you'll end up in the CPU + # kernel. Arguably this is a hack; a cleaner way to + # solve this is to have a FakeStorage concept which + # would report it's CPU device--no problem now! But + # this is difficult to do because we don't have storage + # subclasses. Relevant test is + # DynamicShapesFunctionTests::test_add_dynamic_shapes in + # test/dynamo/test_dynamic_shapes.py + maybe_fake_mgr: ContextManager[None] = contextlib.nullcontext() + from torch._subclasses.fake_tensor import ( + FakeTensor, + in_kernel_invocation_manager, + ) + if isinstance(r, FakeTensor): + maybe_fake_mgr = in_kernel_invocation_manager(r.fake_mode) + with maybe_fake_mgr, torch.no_grad(): + r.set_(r_s, storage_offset, sizes, strides) + + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", "The .grad attribute of a Tensor") + grad_not_none = t.grad is not None + if grad_not_none: + r.grad = self.meta_tensor(t.grad, shape_env, callback) torch._C._set_conj(r, t.is_conj()) torch._C._set_neg(r, t.is_neg()) + # This can be skipped if necessary for performance reasons + # assert_metadata_eq(assert_eq, t, r) self.set_tensor_memo(t, r) return self.get_tensor_memo(t) - def __call__(self, t, shape_env=None): + def __call__(self, t, shape_env=None, *, callback=lambda t: t()): # TODO: zero tensors? We appear to have eliminated them by # excluding complex for now from torch._subclasses.fake_tensor import FakeTensor @@ -280,10 +437,11 @@ def __call__(self, t, shape_env=None): # tests all break so we just exclude this. In any case # the to conversion isn't really right anyhow. self.miss += 1 - return t + return NotImplemented else: self.hit += 1 - r = self.meta_tensor(t, shape_env=shape_env) + r = self.meta_tensor(t, shape_env=shape_env, callback=callback) + # TODO: this is suspicious, now that we have callback argument if type(t) is torch.nn.Parameter: r = torch.nn.Parameter(r, requires_grad=r.requires_grad) return r @@ -294,7 +452,7 @@ def __call__(self, t, shape_env=None): # support meta. Trying to YOLO this is more trouble than it's # worth. self.miss += 1 - return t + return NotImplemented else: # non-Tensor types don't count as hit or miss return t diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py index 2f85b8af1d81f5e..9903e95228fc846 100644 --- a/torch/testing/_internal/common_utils.py +++ b/torch/testing/_internal/common_utils.py @@ -1367,6 +1367,8 @@ def freeze_rng_state(): # # In the long run torch.cuda.set_rng_state should probably be # an operator. + # + # NB: Mode disable is to avoid running cross-ref tests on thes seeding with no_dispatch(), disable_functorch(): if torch.cuda.is_available(): torch.cuda.set_rng_state(cuda_rng_state)