Skip to content

Commit

Permalink
Add propagate_real_tensors mode for unbacked
Browse files Browse the repository at this point in the history
Signed-off-by: Edward Z. Yang <ezyang@meta.com>

ghstack-source-id: 8bb80b5522e80b619bdf436bd3087e8cc8514203
Pull Request resolved: #125115
  • Loading branch information
ezyang committed Apr 29, 2024
1 parent 06b845d commit 2337dfd
Show file tree
Hide file tree
Showing 6 changed files with 251 additions and 36 deletions.
6 changes: 6 additions & 0 deletions test/test_dynamic_shapes.py
Expand Up @@ -512,6 +512,12 @@ def test_data_dependent_guard(self):
s0 = shape_env.create_unbacked_symint()
self.assertRaises(GuardOnDataDependentSymNode, lambda: bool(s0 == 0))

def test_data_dependent_guard_propagate_real_tensors(self):
shape_env = ShapeEnv()
s0 = shape_env.create_unbacked_symint()
shape_env.unbacked_var_to_val[s0.node.expr] = 0
self.assertEqual(bool(s0 == 0), True)

def test_expect_true_basic(self):
shape_env = ShapeEnv()
i0 = shape_env.create_unbacked_symint()
Expand Down
71 changes: 57 additions & 14 deletions test/test_fake_tensor.py
Expand Up @@ -6,6 +6,7 @@
instantiate_parametrized_tests, TemporaryFileName)
import torch
import torch._dynamo
from torch._dynamo.testing import make_test_cls_with_patches
import itertools
import numpy as np
from torch.testing._internal.jit_utils import RUN_CUDA
Expand Down Expand Up @@ -53,6 +54,10 @@
torch._dynamo.config.fake_tensor_cache_enabled = True
torch._dynamo.config.fake_tensor_cache_crosscheck_enabled = True

def expectedFailurePropagateRealTensors(fn):
fn._expected_failure_propagate_real_tensors = True
return fn

class FakeTensorTest(TestCase):
def checkType(self, t, device_str, size):
self.assertTrue(isinstance(t, FakeTensor))
Expand Down Expand Up @@ -207,6 +212,8 @@ def test_fake_dispatch_keys(self):
FileCheck().check("CPU").check("AutocastCPU").run(torch._C._dispatch_key_set(y))
FileCheck().check_not("ADInplaceOrView").check_not("Autograd").run(torch._C._dispatch_key_set(y))

# TODO: functorch support for propagate real tensors
@expectedFailurePropagateRealTensors
def test_batch_tensor(self):
x = torch.rand((3, 4, 5))
b = _add_batch_dim(x, 0, 0)
Expand Down Expand Up @@ -392,10 +399,10 @@ def test_out_multi_device(self):
x = torch.rand([4])
y = torch.rand([4], device="cuda")

with self.assertRaisesRegex(Exception, "found two different devices"):
with self.assertRaisesRegex(Exception, "found.+two.+devices"):
torch.sin(x, out=y)

with self.assertRaisesRegex(Exception, "found two different devices"):
with self.assertRaisesRegex(Exception, "found.+two.+devices"):
x.add_(y)


Expand Down Expand Up @@ -578,6 +585,9 @@ def test_same_shape_env_preserved(self):
self.assertIs(t2.size(0).node.shape_env, t1.size(0).node.shape_env)
self.assertEqual(str(t2.size(0)), str(t1.size(0)))

# TODO: Support NJT. There's also some funny business with dynamic shapes
# which would need to be dealt with as well
@expectedFailurePropagateRealTensors
def test_jagged_fake_to_fake_preserved(self):
from torch.nested._internal.nested_tensor import jagged_from_list

Expand Down Expand Up @@ -736,7 +746,9 @@ def test_aten_index_multi_device(self):
x2 = torch.rand(4, 4, device="cuda")
i1 = torch.tensor([0, 1], device="cuda")
i2 = torch.tensor([0, 1], device="cpu")
r1 = torch.ops.aten.index(x1, i1)
# NB: This one does not work: cuda indices not allowed on cpu
# tensor
# r1 = torch.ops.aten.index(x1, i1)
r2 = torch.ops.aten.index(x2, i2)

y1 = torch.rand(4, device="cpu")
Expand All @@ -745,7 +757,7 @@ def test_aten_index_multi_device(self):
j2 = torch.tensor([2], device="cpu")
r3 = torch.ops.aten.index_put.default(x1, j1, y1)
r4 = torch.ops.aten.index_put.default(x2, j2, y2)
self.checkType(r1, "cpu", ())
# self.checkType(r1, "cpu", ())
self.checkType(r2, "cuda", ())
self.checkType(r3, "cpu", (4, 4))
self.checkType(r4, "cuda", (4, 4))
Expand Down Expand Up @@ -785,6 +797,23 @@ def forward(self, input):
self.assertTrue(isinstance(ep, torch.export.ExportedProgram))


instantiate_parametrized_tests(FakeTensorTest)


def make_propagate_real_tensors_cls(cls):
cls = make_test_cls_with_patches(
cls,
"PropagateRealTensors",
"_propagate_real_tensors",
(torch._functorch.config, "fake_tensor_propagate_real_tensors", True),
xfail_prop="_expected_failure_propagate_real_tensors",
)
globals()[cls.__name__] = cls


make_propagate_real_tensors_cls(FakeTensorTest)


class FakeTensorConstHandling(TestCase):
def assertConst(self, *args):
for arg in args:
Expand Down Expand Up @@ -875,6 +904,10 @@ def test_constant_propagate_through_functions(self):
y = torch.div(4, 4, rounding_mode='trunc')
self.assertConst(y)


make_propagate_real_tensors_cls(FakeTensorConstHandling)


def contains_type(type: torch._C.Type, maybe_contained_type: torch._C.Type):
return maybe_contained_type.isSubtypeOf(type) or any(
contains_type(e, maybe_contained_type) for e in type.containedTypes()
Expand All @@ -891,6 +924,13 @@ def test_fake(self, device, dtype, op):
optests.fake_check(op, args, kwargs)


instantiate_device_type_tests(FakeTensorOpInfoTest, globals(), only_for=("cpu", "cuda"))


# CPU only for efficiency ig
make_propagate_real_tensors_cls(FakeTensorOpInfoTestCPU)


class FakeTensorConverterTest(TestCase):
def test_memoized_conversion_to_meta(self):
x = torch.rand(2, 2, 2)
Expand Down Expand Up @@ -1002,16 +1042,17 @@ def test_no_ref_cycle(self):
assert y_weak() is None


make_propagate_real_tensors_cls(FakeTensorConverterTest)


class FakeTensorOperatorInvariants(TestCase):
@staticmethod
def get_aten_op(schema):
def get_aten_op(self, schema):
namespace, name = schema.name.split("::")
overload = schema.overload_name if schema.overload_name else "default"
assert namespace == "aten"
return getattr(getattr(torch.ops.aten, name), overload)

@staticmethod
def get_all_aten_schemas():
def get_all_aten_schemas(self):
for schema in torch._C._jit_get_all_schemas():
namespace = schema.name.split("::")[0]
if namespace != "aten":
Expand Down Expand Up @@ -1220,6 +1261,9 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None):
self.assertEqual(mode.count, 0)


make_propagate_real_tensors_cls(FakeTensorOperatorInvariants)


class FakeTensorPropTest(TestCase):
def test_fake_tensor_prop_on_nn_module(self):
class ToyNnModuleWithParameters(torch.nn.Module):
Expand Down Expand Up @@ -1305,6 +1349,8 @@ def forward(self, value, another_value=None, another_optional_value=None):
FakeTensorProp(graph_model, fake_mode).propagate(value, None, another_optional_value)


# TODO: not sure about this one, kinda strange
@expectedFailurePropagateRealTensors
def test_unbacked_shape_realloc(self):
def f(x):
return x.nonzero()
Expand Down Expand Up @@ -1352,6 +1398,9 @@ def forward(self, x):
torch.load(state_dict_file, map_location="cpu") # scenario 2


make_propagate_real_tensors_cls(FakeTensorPropTest)


class FakeTensorSerialization(TestCase):
def test_serialization(self):
x = torch.tensor([0], device="cpu")
Expand Down Expand Up @@ -1690,11 +1739,5 @@ def test_inference_mode(self):
extract_tensor_metadata(res4),
)


instantiate_parametrized_tests(FakeTensorTest)

only_for = ("cpu", "cuda")
instantiate_device_type_tests(FakeTensorOpInfoTest, globals(), only_for=only_for)

if __name__ == "__main__":
run_tests()
17 changes: 17 additions & 0 deletions test/test_proxy_tensor.py
Expand Up @@ -25,6 +25,7 @@
from torch.fx.experimental.proxy_tensor import make_fx, DecompositionInterpreter, get_isolated_graphmodule
from torch.utils._pytree import tree_map
from torch import nn
import torch._functorch.config
import re

import functools
Expand Down Expand Up @@ -1518,6 +1519,22 @@ def f(a):

make_fx(f, tracing_mode="symbolic")(torch.randn(4))

@torch._functorch.config.patch(fake_tensor_propagate_real_tensors=True)
def test_invalidate_nonzero_propagate_real_tensors(self):
def f(a):
b = a.clone()
x = b.nonzero()
x1 = b.nonzero()
x2 = b.nonzero()
assert x1.shape[0] == x2.shape[0]
b.normal_()
y = b.nonzero()
# Because you're not actually going to generate exactly zero with
# normal_ lol
assert x1.shape[0] == y.shape[0]

make_fx(f, tracing_mode="symbolic")(torch.randn(4))

def test_sqrt_size(self):
def f(a):
return a / a.size(-1) ** 0.5
Expand Down
30 changes: 30 additions & 0 deletions torch/_functorch/config.py
Expand Up @@ -106,6 +106,36 @@
# tokens.
unlift_effect_tokens = False

# This mode specifies that we should also keep track of the real
# tensor along with the fake tensor, and do real compute. While
# seemingly this eliminates the whole point of fake tensors, there are
# two obvious use cases for it:
#
# 1. When users call item()/other data dependent operations,
# if we propagate_real_tensors we are able to determine what
# the true value is and keep going.
#
# 2. It can be useful for testing, when you want to see if the fake
# and real tensors agree with each other. (Note that there are
# currently known inaccuracies in how we clone real tensors, that
# would have to be tightened up for this to be useful in this
# case.)
#
# Note that fake tensors are typically understood to be cheap to store
# indefinitely, so we tend to hold on to them longer than we would
# hold onto the real tensors. So we also support you explicitly
# deallocating the real tensor associated with a fake tensor, at which
# point we will stop propagating real tensors.
#
# One more thing: when you provide a real tensor to fakeify, we will
# clone it, so that we can safely perform mutations on it if necessary.
# This will increase live memory usage. This could potentially be
# optimized by using COW. We also currently do not faithfully
# maintain autograd metadata on the real tensor; this is fine because
# AOTAutograd will only use the fake tensor to determine leafness/etc
# of tensors in question.
fake_tensor_propagate_real_tensors = False

if TYPE_CHECKING:
from torch.utils._config_typing import * # noqa: F401, F403

Expand Down

0 comments on commit 2337dfd

Please sign in to comment.