Skip to content

Commit

Permalink
Add propagate_real_tensors mode for unbacked
Browse files Browse the repository at this point in the history
Signed-off-by: Edward Z. Yang <ezyang@meta.com>

ghstack-source-id: 1f97e037a55711b1db7ce2f805d6b24be3ef8f5c
Pull Request resolved: #125115
  • Loading branch information
ezyang committed Apr 29, 2024
1 parent 06b845d commit e08e813
Show file tree
Hide file tree
Showing 9 changed files with 316 additions and 44 deletions.
11 changes: 11 additions & 0 deletions benchmarks/dynamo/common.py
Expand Up @@ -75,6 +75,7 @@
graph_break_reasons,
maybe_enable_compiled_autograd,
)
import torch._functorch.config
from torch._functorch.aot_autograd import set_model_name
from torch._inductor import config as inductor_config, metrics
from torch._subclasses.fake_tensor import FakeTensorMode
Expand Down Expand Up @@ -3155,6 +3156,11 @@ def get_example_inputs(self):
action="store_true",
help="Runs a dynamic shapes version of the benchmark, if available.",
)
parser.add_argument(
"--propagate-real-tensors",
action="store_true",
help="Capture as much data dependent as you can by unsoundly propagating real tensors",
)
parser.add_argument(
"--dynamic-batch-only",
action="store_true",
Expand Down Expand Up @@ -3603,6 +3609,11 @@ def run(runner, args, original_dir=None):
if args.dynamic_shapes:
if not args.dynamic_batch_only:
torch._dynamo.config.assume_static_by_default = False
if args.propagate_real_tensors:
# TODO: Separate flag for data dependent
torch._dynamo.config.capture_scalar_outputs = True
torch._dynamo.config.capture_dynamic_output_shape_ops = True
torch._functorch.config.fake_tensor_propagate_real_tensors = True
if args.specialize_int:
torch._dynamo.config.specialize_int = True
if args.ci:
Expand Down
17 changes: 17 additions & 0 deletions test/dynamo/test_misc.py
Expand Up @@ -10518,6 +10518,23 @@ def fn(x, d):
with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True):
fn(torch.randn(4), d)

@unittest.skipIf(not TEST_CUDA, "requires cuda")
@torch._dynamo.config.patch(
capture_scalar_outputs=True, capture_dynamic_output_shape_ops=True
)
@torch._functorch.config.patch(fake_tensor_propagate_real_tensors=True)
def test_interpolate_propagate_real_tensors(self):
@torch.compile(backend="eager", fullgraph=True)
def f(mask, box):
# u0, u1 = mask.tolist()
mask = torch.randn(1, 1, 30, 30, device="cuda")
h, w = box.tolist()
return torch.nn.functional.interpolate(
mask, (h, w), mode="bilinear", align_corners=False
)

f(torch.tensor([30, 30], device="cuda"), torch.tensor([68, 32], device="cuda"))

def test_custom_iter_dict(self):
class ReversedDict(dict):
def __iter__(self):
Expand Down
6 changes: 6 additions & 0 deletions test/test_dynamic_shapes.py
Expand Up @@ -512,6 +512,12 @@ def test_data_dependent_guard(self):
s0 = shape_env.create_unbacked_symint()
self.assertRaises(GuardOnDataDependentSymNode, lambda: bool(s0 == 0))

def test_data_dependent_guard_propagate_real_tensors(self):
shape_env = ShapeEnv()
s0 = shape_env.create_unbacked_symint()
shape_env.unbacked_var_to_val[s0.node.expr] = 0
self.assertEqual(bool(s0 == 0), True)

def test_expect_true_basic(self):
shape_env = ShapeEnv()
i0 = shape_env.create_unbacked_symint()
Expand Down
75 changes: 61 additions & 14 deletions test/test_fake_tensor.py
Expand Up @@ -6,6 +6,7 @@
instantiate_parametrized_tests, TemporaryFileName)
import torch
import torch._dynamo
from torch._dynamo.testing import make_test_cls_with_patches
import itertools
import numpy as np
from torch.testing._internal.jit_utils import RUN_CUDA
Expand Down Expand Up @@ -53,6 +54,10 @@
torch._dynamo.config.fake_tensor_cache_enabled = True
torch._dynamo.config.fake_tensor_cache_crosscheck_enabled = True

def expectedFailurePropagateRealTensors(fn):
fn._expected_failure_propagate_real_tensors = True
return fn

class FakeTensorTest(TestCase):
def checkType(self, t, device_str, size):
self.assertTrue(isinstance(t, FakeTensor))
Expand Down Expand Up @@ -207,6 +212,8 @@ def test_fake_dispatch_keys(self):
FileCheck().check("CPU").check("AutocastCPU").run(torch._C._dispatch_key_set(y))
FileCheck().check_not("ADInplaceOrView").check_not("Autograd").run(torch._C._dispatch_key_set(y))

# TODO: functorch support for propagate real tensors
@expectedFailurePropagateRealTensors
def test_batch_tensor(self):
x = torch.rand((3, 4, 5))
b = _add_batch_dim(x, 0, 0)
Expand Down Expand Up @@ -392,10 +399,10 @@ def test_out_multi_device(self):
x = torch.rand([4])
y = torch.rand([4], device="cuda")

with self.assertRaisesRegex(Exception, "found two different devices"):
with self.assertRaisesRegex(Exception, "found.+two.+devices"):
torch.sin(x, out=y)

with self.assertRaisesRegex(Exception, "found two different devices"):
with self.assertRaisesRegex(Exception, "found.+two.+devices"):
x.add_(y)


Expand Down Expand Up @@ -578,6 +585,9 @@ def test_same_shape_env_preserved(self):
self.assertIs(t2.size(0).node.shape_env, t1.size(0).node.shape_env)
self.assertEqual(str(t2.size(0)), str(t1.size(0)))

# TODO: Support NJT. There's also some funny business with dynamic shapes
# which would need to be dealt with as well
@expectedFailurePropagateRealTensors
def test_jagged_fake_to_fake_preserved(self):
from torch.nested._internal.nested_tensor import jagged_from_list

Expand Down Expand Up @@ -736,7 +746,9 @@ def test_aten_index_multi_device(self):
x2 = torch.rand(4, 4, device="cuda")
i1 = torch.tensor([0, 1], device="cuda")
i2 = torch.tensor([0, 1], device="cpu")
r1 = torch.ops.aten.index(x1, i1)
# NB: This one does not work: cuda indices not allowed on cpu
# tensor
# r1 = torch.ops.aten.index(x1, i1)
r2 = torch.ops.aten.index(x2, i2)

y1 = torch.rand(4, device="cpu")
Expand All @@ -745,7 +757,7 @@ def test_aten_index_multi_device(self):
j2 = torch.tensor([2], device="cpu")
r3 = torch.ops.aten.index_put.default(x1, j1, y1)
r4 = torch.ops.aten.index_put.default(x2, j2, y2)
self.checkType(r1, "cpu", ())
# self.checkType(r1, "cpu", ())
self.checkType(r2, "cuda", ())
self.checkType(r3, "cpu", (4, 4))
self.checkType(r4, "cuda", (4, 4))
Expand Down Expand Up @@ -785,6 +797,23 @@ def forward(self, input):
self.assertTrue(isinstance(ep, torch.export.ExportedProgram))


instantiate_parametrized_tests(FakeTensorTest)


def make_propagate_real_tensors_cls(cls):
cls = make_test_cls_with_patches(
cls,
"PropagateRealTensors",
"_propagate_real_tensors",
(torch._functorch.config, "fake_tensor_propagate_real_tensors", True),
xfail_prop="_expected_failure_propagate_real_tensors",
)
globals()[cls.__name__] = cls


make_propagate_real_tensors_cls(FakeTensorTest)


class FakeTensorConstHandling(TestCase):
def assertConst(self, *args):
for arg in args:
Expand Down Expand Up @@ -875,6 +904,10 @@ def test_constant_propagate_through_functions(self):
y = torch.div(4, 4, rounding_mode='trunc')
self.assertConst(y)


make_propagate_real_tensors_cls(FakeTensorConstHandling)


def contains_type(type: torch._C.Type, maybe_contained_type: torch._C.Type):
return maybe_contained_type.isSubtypeOf(type) or any(
contains_type(e, maybe_contained_type) for e in type.containedTypes()
Expand All @@ -891,6 +924,13 @@ def test_fake(self, device, dtype, op):
optests.fake_check(op, args, kwargs)


instantiate_device_type_tests(FakeTensorOpInfoTest, globals(), only_for=("cpu", "cuda"))


# CPU only for efficiency ig
make_propagate_real_tensors_cls(FakeTensorOpInfoTestCPU) # noqa: F821


class FakeTensorConverterTest(TestCase):
def test_memoized_conversion_to_meta(self):
x = torch.rand(2, 2, 2)
Expand Down Expand Up @@ -1002,16 +1042,17 @@ def test_no_ref_cycle(self):
assert y_weak() is None


make_propagate_real_tensors_cls(FakeTensorConverterTest)


class FakeTensorOperatorInvariants(TestCase):
@staticmethod
def get_aten_op(schema):
def get_aten_op(self, schema):
namespace, name = schema.name.split("::")
overload = schema.overload_name if schema.overload_name else "default"
assert namespace == "aten"
return getattr(getattr(torch.ops.aten, name), overload)

@staticmethod
def get_all_aten_schemas():
def get_all_aten_schemas(self):
for schema in torch._C._jit_get_all_schemas():
namespace = schema.name.split("::")[0]
if namespace != "aten":
Expand Down Expand Up @@ -1162,6 +1203,10 @@ def forward(self, arg1, arg2, arg3):

# IMPORTANT!!! Always run even if CUDA is not available
def test_fake_cuda_no_init(self):
# Skip this test, we will try to run CUDA operations to real prop so
# it clearly will not work on CPU runner
if torch._functorch.config.fake_tensor_propagate_real_tensors:
return
with FakeTensorMode():
torch.empty(10, device='cuda')
torch.ones(10, device='cuda')
Expand Down Expand Up @@ -1220,6 +1265,9 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None):
self.assertEqual(mode.count, 0)


make_propagate_real_tensors_cls(FakeTensorOperatorInvariants)


class FakeTensorPropTest(TestCase):
def test_fake_tensor_prop_on_nn_module(self):
class ToyNnModuleWithParameters(torch.nn.Module):
Expand Down Expand Up @@ -1305,9 +1353,11 @@ def forward(self, value, another_value=None, another_optional_value=None):
FakeTensorProp(graph_model, fake_mode).propagate(value, None, another_optional_value)


@expectedFailurePropagateRealTensors # TODO: not sure about this one, kinda strange
def test_unbacked_shape_realloc(self):
def f(x):
return x.nonzero()

shape_env = ShapeEnv()
fake_mode = FakeTensorMode(shape_env=shape_env)
with fake_mode:
Expand Down Expand Up @@ -1352,6 +1402,9 @@ def forward(self, x):
torch.load(state_dict_file, map_location="cpu") # scenario 2


make_propagate_real_tensors_cls(FakeTensorPropTest)


class FakeTensorSerialization(TestCase):
def test_serialization(self):
x = torch.tensor([0], device="cpu")
Expand Down Expand Up @@ -1690,11 +1743,5 @@ def test_inference_mode(self):
extract_tensor_metadata(res4),
)


instantiate_parametrized_tests(FakeTensorTest)

only_for = ("cpu", "cuda")
instantiate_device_type_tests(FakeTensorOpInfoTest, globals(), only_for=only_for)

if __name__ == "__main__":
run_tests()
17 changes: 17 additions & 0 deletions test/test_proxy_tensor.py
Expand Up @@ -25,6 +25,7 @@
from torch.fx.experimental.proxy_tensor import make_fx, DecompositionInterpreter, get_isolated_graphmodule
from torch.utils._pytree import tree_map
from torch import nn
import torch._functorch.config
import re

import functools
Expand Down Expand Up @@ -1518,6 +1519,22 @@ def f(a):

make_fx(f, tracing_mode="symbolic")(torch.randn(4))

@torch._functorch.config.patch(fake_tensor_propagate_real_tensors=True)
def test_invalidate_nonzero_propagate_real_tensors(self):
def f(a):
b = a.clone()
x = b.nonzero()
x1 = b.nonzero()
x2 = b.nonzero()
assert x1.shape[0] == x2.shape[0]
b.normal_()
y = b.nonzero()
# Because you're not actually going to generate exactly zero with
# normal_ lol
assert x1.shape[0] == y.shape[0]

make_fx(f, tracing_mode="symbolic")(torch.randn(4))

def test_sqrt_size(self):
def f(a):
return a / a.size(-1) ** 0.5
Expand Down
7 changes: 3 additions & 4 deletions torch/_dynamo/variables/tensor.py
Expand Up @@ -648,7 +648,6 @@ def wrap(i, sub_proxy):
return SymNodeVariable.create(
tx,
sub_proxy.item(),
sym_num=tx.output.shape_env.create_unbacked_symint(),
)

if tensor.dtype not in [
Expand Down Expand Up @@ -963,11 +962,11 @@ class SymNodeVariable(VariableTracker):
}

@classmethod
def create(cls, tx, proxy, sym_num, **options):
if "example_value" in proxy.node.meta:
assert proxy.node.meta["example_value"] == sym_num
def create(cls, tx, proxy, sym_num=None, **options):
if sym_num is None:
sym_num = get_fake_value(proxy.node, tx)
if "example_value" in proxy.node.meta:
assert proxy.node.meta["example_value"] == sym_num
set_example_value(proxy.node, sym_num)

if isinstance(sym_num, (sympy.Integer, int, bool)):
Expand Down
30 changes: 30 additions & 0 deletions torch/_functorch/config.py
Expand Up @@ -106,6 +106,36 @@
# tokens.
unlift_effect_tokens = False

# This mode specifies that we should also keep track of the real
# tensor along with the fake tensor, and do real compute. While
# seemingly this eliminates the whole point of fake tensors, there are
# two obvious use cases for it:
#
# 1. When users call item()/other data dependent operations,
# if we propagate_real_tensors we are able to determine what
# the true value is and keep going.
#
# 2. It can be useful for testing, when you want to see if the fake
# and real tensors agree with each other. (Note that there are
# currently known inaccuracies in how we clone real tensors, that
# would have to be tightened up for this to be useful in this
# case.)
#
# Note that fake tensors are typically understood to be cheap to store
# indefinitely, so we tend to hold on to them longer than we would
# hold onto the real tensors. So we also support you explicitly
# deallocating the real tensor associated with a fake tensor, at which
# point we will stop propagating real tensors.
#
# One more thing: when you provide a real tensor to fakeify, we will
# clone it, so that we can safely perform mutations on it if necessary.
# This will increase live memory usage. This could potentially be
# optimized by using COW. We also currently do not faithfully
# maintain autograd metadata on the real tensor; this is fine because
# AOTAutograd will only use the fake tensor to determine leafness/etc
# of tensors in question.
fake_tensor_propagate_real_tensors = False

if TYPE_CHECKING:
from torch.utils._config_typing import * # noqa: F401, F403

Expand Down

0 comments on commit e08e813

Please sign in to comment.