Skip to content

Commit

Permalink
Add propagate_real_tensors mode for unbacked
Browse files Browse the repository at this point in the history
Signed-off-by: Edward Z. Yang <ezyang@meta.com>

ghstack-source-id: d90e89f0eabad52bb912c2a511b0e854e7809b9c
Pull Request resolved: #125115
  • Loading branch information
ezyang committed May 1, 2024
1 parent c1a3fcf commit 12eb77b
Show file tree
Hide file tree
Showing 11 changed files with 367 additions and 61 deletions.
11 changes: 11 additions & 0 deletions benchmarks/dynamo/common.py
Expand Up @@ -75,6 +75,7 @@
graph_break_reasons,
maybe_enable_compiled_autograd,
)
import torch._functorch.config
from torch._functorch.aot_autograd import set_model_name
from torch._inductor import config as inductor_config, metrics
from torch._subclasses.fake_tensor import FakeTensorMode
Expand Down Expand Up @@ -3155,6 +3156,11 @@ def get_example_inputs(self):
action="store_true",
help="Runs a dynamic shapes version of the benchmark, if available.",
)
parser.add_argument(
"--propagate-real-tensors",
action="store_true",
help="Capture as much data dependent as you can by unsoundly propagating real tensors",
)
parser.add_argument(
"--dynamic-batch-only",
action="store_true",
Expand Down Expand Up @@ -3603,6 +3609,11 @@ def run(runner, args, original_dir=None):
if args.dynamic_shapes:
if not args.dynamic_batch_only:
torch._dynamo.config.assume_static_by_default = False
if args.propagate_real_tensors:
# TODO: Separate flag for data dependent
torch._dynamo.config.capture_scalar_outputs = True
torch._dynamo.config.capture_dynamic_output_shape_ops = True
torch._functorch.config.fake_tensor_propagate_real_tensors = True
if args.specialize_int:
torch._dynamo.config.specialize_int = True
if args.ci:
Expand Down
17 changes: 17 additions & 0 deletions test/dynamo/test_misc.py
Expand Up @@ -10516,6 +10516,23 @@ def fn(x, d):
with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True):
fn(torch.randn(4), d)

@unittest.skipIf(not TEST_CUDA, "requires cuda")
@torch._dynamo.config.patch(
capture_scalar_outputs=True, capture_dynamic_output_shape_ops=True
)
@torch._functorch.config.patch(fake_tensor_propagate_real_tensors=True)
def test_interpolate_propagate_real_tensors(self):
@torch.compile(backend="eager", fullgraph=True)
def f(mask, box):
# u0, u1 = mask.tolist()
mask = torch.randn(1, 1, 30, 30, device="cuda")
h, w = box.tolist()
return torch.nn.functional.interpolate(
mask, (h, w), mode="bilinear", align_corners=False
)

f(torch.tensor([30, 30], device="cuda"), torch.tensor([68, 32], device="cuda"))

def test_custom_iter_dict(self):
class ReversedDict(dict):
def __iter__(self):
Expand Down
6 changes: 6 additions & 0 deletions test/test_dynamic_shapes.py
Expand Up @@ -512,6 +512,12 @@ def test_data_dependent_guard(self):
s0 = shape_env.create_unbacked_symint()
self.assertRaises(GuardOnDataDependentSymNode, lambda: bool(s0 == 0))

def test_data_dependent_guard_propagate_real_tensors(self):
shape_env = ShapeEnv()
s0 = shape_env.create_unbacked_symint()
shape_env.set_unbacked_var_to_val(s0.node.expr, 0)
self.assertEqual(bool(s0 == 0), True)

def test_expect_true_basic(self):
shape_env = ShapeEnv()
i0 = shape_env.create_unbacked_symint()
Expand Down
106 changes: 82 additions & 24 deletions test/test_fake_tensor.py
Expand Up @@ -6,6 +6,7 @@
instantiate_parametrized_tests, TemporaryFileName)
import torch
import torch._dynamo
from torch._dynamo.testing import make_test_cls_with_patches
import itertools
import numpy as np
from torch.testing._internal.jit_utils import RUN_CUDA
Expand Down Expand Up @@ -53,6 +54,10 @@
torch._dynamo.config.fake_tensor_cache_enabled = True
torch._dynamo.config.fake_tensor_cache_crosscheck_enabled = True

def expectedFailurePropagateRealTensors(fn):
fn._expected_failure_propagate_real_tensors = True
return fn

class FakeTensorTest(TestCase):
def checkType(self, t, device_str, size):
self.assertTrue(isinstance(t, FakeTensor))
Expand Down Expand Up @@ -83,18 +88,22 @@ def test_basic(self):
def test_custom_op_fallback(self):
from torch.library import Library, impl

test_lib = Library("my_test_op", "DEF") # noqa: TOR901
test_lib.define('foo(Tensor self) -> Tensor')
try:
test_lib = Library("my_test_op", "DEF") # noqa: TOR901
test_lib.define('foo(Tensor self) -> Tensor')

@impl(test_lib, 'foo', 'CPU')
def foo_impl(self):
return self.cos()
@impl(test_lib, 'foo', 'CPU')
def foo_impl(self):
return self.cos()

x = torch.empty(2, 2, device="cpu")
with self.assertRaisesRegex(UnsupportedOperatorException, "my_test_op.foo.default"):
with FakeTensorMode(allow_fallback_kernels=True) as mode:
x = mode.from_tensor(x)
torch.ops.my_test_op.foo(x)
x = torch.empty(2, 2, device="cpu")
with self.assertRaisesRegex(UnsupportedOperatorException, "my_test_op.foo.default"):
with FakeTensorMode(allow_fallback_kernels=True) as mode:
x = mode.from_tensor(x)
torch.ops.my_test_op.foo(x)

finally:
test_lib._destroy()

def test_parameter_instantiation(self):
with FakeTensorMode():
Expand Down Expand Up @@ -207,6 +216,8 @@ def test_fake_dispatch_keys(self):
FileCheck().check("CPU").check("AutocastCPU").run(torch._C._dispatch_key_set(y))
FileCheck().check_not("ADInplaceOrView").check_not("Autograd").run(torch._C._dispatch_key_set(y))

# TODO: functorch support for propagate real tensors
@expectedFailurePropagateRealTensors
def test_batch_tensor(self):
x = torch.rand((3, 4, 5))
b = _add_batch_dim(x, 0, 0)
Expand Down Expand Up @@ -392,10 +403,10 @@ def test_out_multi_device(self):
x = torch.rand([4])
y = torch.rand([4], device="cuda")

with self.assertRaisesRegex(Exception, "found two different devices"):
with self.assertRaisesRegex(Exception, "found.+two.+devices"):
torch.sin(x, out=y)

with self.assertRaisesRegex(Exception, "found two different devices"):
with self.assertRaisesRegex(Exception, "found.+two.+devices"):
x.add_(y)


Expand Down Expand Up @@ -559,6 +570,8 @@ def test_tolist(self):
x = torch.rand([10])
x.tolist()

# Propagate real tensors doesn't work with fake-on-fake
@expectedFailurePropagateRealTensors
def test_same_shape_env_preserved(self):
shape_env = ShapeEnv()
mode1 = FakeTensorMode(shape_env=shape_env)
Expand All @@ -578,6 +591,9 @@ def test_same_shape_env_preserved(self):
self.assertIs(t2.size(0).node.shape_env, t1.size(0).node.shape_env)
self.assertEqual(str(t2.size(0)), str(t1.size(0)))

# TODO: Support NJT. There's also some funny business with dynamic shapes
# which would need to be dealt with as well
@expectedFailurePropagateRealTensors
def test_jagged_fake_to_fake_preserved(self):
from torch.nested._internal.nested_tensor import jagged_from_list

Expand Down Expand Up @@ -736,7 +752,9 @@ def test_aten_index_multi_device(self):
x2 = torch.rand(4, 4, device="cuda")
i1 = torch.tensor([0, 1], device="cuda")
i2 = torch.tensor([0, 1], device="cpu")
r1 = torch.ops.aten.index(x1, i1)
# NB: This one does not work: cuda indices not allowed on cpu
# tensor
# r1 = torch.ops.aten.index(x1, i1)
r2 = torch.ops.aten.index(x2, i2)

y1 = torch.rand(4, device="cpu")
Expand All @@ -745,7 +763,7 @@ def test_aten_index_multi_device(self):
j2 = torch.tensor([2], device="cpu")
r3 = torch.ops.aten.index_put.default(x1, j1, y1)
r4 = torch.ops.aten.index_put.default(x2, j2, y2)
self.checkType(r1, "cpu", ())
# self.checkType(r1, "cpu", ())
self.checkType(r2, "cuda", ())
self.checkType(r3, "cpu", (4, 4))
self.checkType(r4, "cuda", (4, 4))
Expand Down Expand Up @@ -774,6 +792,9 @@ def test__adaptive_avg_pool2d_backward(self):
grad_in = torch.ops.aten._adaptive_avg_pool2d_backward(grad_out, inp)
self.assertTrue(torch._prims_common.suggest_memory_format(grad_in) == torch.channels_last)

# Propagate real tensors doesn't work when original input arguments are
# fake
@expectedFailurePropagateRealTensors
def test_export_numpy(self):
class MyNumpyModel(torch.nn.Module):
def forward(self, input):
Expand Down Expand Up @@ -801,6 +822,26 @@ def f(x):
self.assertEqual(r.size(), [3])


instantiate_parametrized_tests(FakeTensorTest)


def make_propagate_real_tensors_cls(cls):
cls = make_test_cls_with_patches(
cls,
"PropagateRealTensors",
"_propagate_real_tensors",
(torch._functorch.config, "fake_tensor_propagate_real_tensors", True),
xfail_prop="_expected_failure_propagate_real_tensors",
decorator=skipIfTorchDynamo("propagate_real_tensors affects Dynamo"),
)
cls.__file__ = __file__
cls.__module__ = __name__
globals()[cls.__name__] = cls


make_propagate_real_tensors_cls(FakeTensorTest)


class FakeTensorConstHandling(TestCase):
def assertConst(self, *args):
for arg in args:
Expand Down Expand Up @@ -891,6 +932,10 @@ def test_constant_propagate_through_functions(self):
y = torch.div(4, 4, rounding_mode='trunc')
self.assertConst(y)


make_propagate_real_tensors_cls(FakeTensorConstHandling)


def contains_type(type: torch._C.Type, maybe_contained_type: torch._C.Type):
return maybe_contained_type.isSubtypeOf(type) or any(
contains_type(e, maybe_contained_type) for e in type.containedTypes()
Expand All @@ -907,6 +952,11 @@ def test_fake(self, device, dtype, op):
optests.fake_check(op, args, kwargs)


make_propagate_real_tensors_cls(FakeTensorOpInfoTest)
instantiate_device_type_tests(FakeTensorOpInfoTest, globals(), only_for=("cpu", "cuda"))
instantiate_device_type_tests(PropagateRealTensorsFakeTensorOpInfoTest, globals(), only_for=("cpu",)) # noqa: F821


class FakeTensorConverterTest(TestCase):
def test_memoized_conversion_to_meta(self):
x = torch.rand(2, 2, 2)
Expand Down Expand Up @@ -1018,16 +1068,17 @@ def test_no_ref_cycle(self):
assert y_weak() is None


make_propagate_real_tensors_cls(FakeTensorConverterTest)


class FakeTensorOperatorInvariants(TestCase):
@staticmethod
def get_aten_op(schema):
def get_aten_op(self, schema):
namespace, name = schema.name.split("::")
overload = schema.overload_name if schema.overload_name else "default"
assert namespace == "aten"
return getattr(getattr(torch.ops.aten, name), overload)

@staticmethod
def get_all_aten_schemas():
def get_all_aten_schemas(self):
for schema in torch._C._jit_get_all_schemas():
namespace = schema.name.split("::")[0]
if namespace != "aten":
Expand Down Expand Up @@ -1178,6 +1229,10 @@ def forward(self, arg1, arg2, arg3):

# IMPORTANT!!! Always run even if CUDA is not available
def test_fake_cuda_no_init(self):
# Skip this test, we will try to run CUDA operations to real prop so
# it clearly will not work on CPU runner
if torch._functorch.config.fake_tensor_propagate_real_tensors:
return
with FakeTensorMode():
torch.empty(10, device='cuda')
torch.ones(10, device='cuda')
Expand Down Expand Up @@ -1236,6 +1291,9 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None):
self.assertEqual(mode.count, 0)


make_propagate_real_tensors_cls(FakeTensorOperatorInvariants)


class FakeTensorPropTest(TestCase):
def test_fake_tensor_prop_on_nn_module(self):
class ToyNnModuleWithParameters(torch.nn.Module):
Expand Down Expand Up @@ -1294,6 +1352,7 @@ def to_fake_tensor(x):
self.assertTrue(failed)


@expectedFailurePropagateRealTensors # Propagate real tensors doesn't work with fake-on-fake
def test_fake_tensor_prop_on_nn_module_with_optional_args(self):
class OptionalArgumentInBetween(torch.nn.Module):
def __init__(self):
Expand Down Expand Up @@ -1321,9 +1380,11 @@ def forward(self, value, another_value=None, another_optional_value=None):
FakeTensorProp(graph_model, fake_mode).propagate(value, None, another_optional_value)


@expectedFailurePropagateRealTensors # TODO: not sure about this one, kinda strange
def test_unbacked_shape_realloc(self):
def f(x):
return x.nonzero()

shape_env = ShapeEnv()
fake_mode = FakeTensorMode(shape_env=shape_env)
with fake_mode:
Expand Down Expand Up @@ -1368,6 +1429,9 @@ def forward(self, x):
torch.load(state_dict_file, map_location="cpu") # scenario 2


make_propagate_real_tensors_cls(FakeTensorPropTest)


class FakeTensorSerialization(TestCase):
def test_serialization(self):
x = torch.tensor([0], device="cpu")
Expand Down Expand Up @@ -1706,11 +1770,5 @@ def test_inference_mode(self):
extract_tensor_metadata(res4),
)


instantiate_parametrized_tests(FakeTensorTest)

only_for = ("cpu", "cuda")
instantiate_device_type_tests(FakeTensorOpInfoTest, globals(), only_for=only_for)

if __name__ == "__main__":
run_tests()
17 changes: 17 additions & 0 deletions test/test_proxy_tensor.py
Expand Up @@ -26,6 +26,7 @@
from torch.utils._pytree import tree_map
from torch.fx.passes.runtime_assert import insert_deferred_runtime_asserts
from torch import nn
import torch._functorch.config
import re

import functools
Expand Down Expand Up @@ -1543,6 +1544,22 @@ def f(a):

make_fx(f, tracing_mode="symbolic")(torch.randn(4))

@torch._functorch.config.patch(fake_tensor_propagate_real_tensors=True)
def test_invalidate_nonzero_propagate_real_tensors(self):
def f(a):
b = a.clone()
x = b.nonzero()
x1 = b.nonzero()
x2 = b.nonzero()
assert x1.shape[0] == x2.shape[0]
b.normal_()
y = b.nonzero()
# Because you're not actually going to generate exactly zero with
# normal_ lol
assert x1.shape[0] == y.shape[0]

make_fx(f, tracing_mode="symbolic")(torch.randn(4))

def test_sqrt_size(self):
def f(a):
return a / a.size(-1) ** 0.5
Expand Down
6 changes: 4 additions & 2 deletions torch/_dynamo/testing.py
Expand Up @@ -311,7 +311,9 @@ def _fn(*args, **kwargs):
return _fn


def make_test_cls_with_patches(cls, cls_prefix, fn_suffix, *patches, xfail_prop=None):
def make_test_cls_with_patches(
cls, cls_prefix, fn_suffix, *patches, xfail_prop=None, decorator=lambda x: x
):
DummyTestClass = type(f"{cls_prefix}{cls.__name__}", cls.__bases__, {})
DummyTestClass.__qualname__ = DummyTestClass.__name__

Expand All @@ -326,7 +328,7 @@ def make_test_cls_with_patches(cls, cls_prefix, fn_suffix, *patches, xfail_prop=
new_fn.__name__ = new_name
if xfail_prop is not None and hasattr(fn, xfail_prop):
new_fn = unittest.expectedFailure(new_fn)
setattr(DummyTestClass, new_name, new_fn)
setattr(DummyTestClass, new_name, decorator(new_fn))
# NB: Doesn't handle slots correctly, but whatever
elif not hasattr(DummyTestClass, name):
setattr(DummyTestClass, name, getattr(cls, name))
Expand Down

0 comments on commit 12eb77b

Please sign in to comment.