Skip to content

Commit

Permalink
Revert "Support SymBool input to torch.compile (#107850)"
Browse files Browse the repository at this point in the history
This reverts commit 9f6d70b.

Reverted #107850 on behalf of https://github.com/huydhn due to Sorry for reverting this, but test_export_with_symbool_inputs is failing in trunk https://hud.pytorch.org/pytorch/pytorch/commit/a08e1370ef8cb13cfbf18d9663427a57fa8657f2 ([comment](#107850 (comment)))
  • Loading branch information
pytorchmergebot committed Sep 14, 2023
1 parent de76c88 commit 47f79e9
Show file tree
Hide file tree
Showing 9 changed files with 11 additions and 283 deletions.
77 changes: 3 additions & 74 deletions test/dynamo/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,11 @@
from functorch.experimental.control_flow import cond
from torch._dynamo import config
from torch._dynamo.exc import UserError
from torch._dynamo.testing import normalize_gm
from torch._export import dynamic_dim
from torch._export.constraints import constrain_as_size, constrain_as_value
from torch._higher_order_ops.out_dtype import out_dtype
from torch._subclasses import fake_tensor
from torch.fx.experimental.proxy_tensor import make_fx
from torch.fx.experimental.symbolic_shapes import (
ConstraintViolationError,
DimDynamic,
ShapeEnv,
)
from torch.fx.experimental.symbolic_shapes import ConstraintViolationError
from torch.testing._internal import common_utils


Expand Down Expand Up @@ -3161,6 +3155,8 @@ def forward(self, x):

def test_capture_symbolic_tracing_simple_within_fake_mode(self):
from torch._dynamo.output_graph import config
from torch._subclasses import fake_tensor
from torch.fx.experimental.symbolic_shapes import ShapeEnv

def f(x):
y = torch.randn(3)
Expand All @@ -3182,73 +3178,6 @@ def f(x):
+ str(aten_graph),
)

def test_export_with_symbool_inputs(self):
def f(pred: bool, x: torch.Tensor):
if pred:
return x.sin()
else:
return x.cos()

x = torch.randn([3, 4])

def test_symbool_guards(
f, size_tests, exp_graph, exp_guard_code, exp_shape_env_guards
):
shape_env = ShapeEnv()
with fake_tensor.FakeTensorMode(
shape_env=shape_env,
) as fake_mode:
fake_x = fake_mode.from_tensor(
x, dynamic_dims=[DimDynamic.DYNAMIC for _ in range(x.dim())]
)
for i, size in enumerate(size_tests):
pred = fake_x.size(0) == size
gm, guards = torch._dynamo.export(f)(pred, x)
actual = normalize_gm(gm.print_readable(print_output=False))
self.assertExpectedInline(actual, exp_graph[i])
dynamo_shape_env_guards = [
guard for guard in guards if "SHAPE_ENV" in guard.guard_types
]
self.assertEqual(len(dynamo_shape_env_guards), 1)
guard_code_on_predicate = [
code
for code in dynamo_shape_env_guards[0].code_list
if "L['pred']" in code
]
self.assertEqual(guard_code_on_predicate, exp_guard_code[i])
outter_shape_env_guards = [
str(guard.expr) for guard in shape_env.guards
]
self.assertEqual(outter_shape_env_guards, exp_shape_env_guards[i])

true_graph = """\
class GraphModule(torch.nn.Module):
def forward(self, pred, x):
arg0, arg1: f32[s1, s2], = fx_pytree.tree_flatten_spec(([pred, x], {}), self._in_spec)
sin = arg1.sin(); arg1 = None
return pytree.tree_unflatten([sin], self._out_spec)
"""
false_graph = """\
class GraphModule(torch.nn.Module):
def forward(self, pred, x):
arg0, arg1: f32[s1, s2], = fx_pytree.tree_flatten_spec(([pred, x], {}), self._in_spec)
cos = arg1.cos(); arg1 = None
return pytree.tree_unflatten([cos], self._out_spec)
"""
true_guard_code = ["cast_symbool_to_symint_guardless(L['pred']) == 1"]
false_guard_code = [
"Ne(cast_symbool_to_symint_guardless(L['pred']), 1)",
"-9223372036854775808 <= cast_symbool_to_symint_guardless(L['pred'])",
]
test_symbool_guards(
f,
[3, 3, 4, 5],
[true_graph, true_graph, false_graph, false_graph],
[true_guard_code, true_guard_code, false_guard_code, false_guard_code],
# Outter shape env should have no guards in it because we never specialize on the outter symbool.
[[], [], [], []],
)

def test_invalid_input_global(self) -> None:
global bulbous_bouffant
bulbous_bouffant = torch.randn(3)
Expand Down
90 changes: 0 additions & 90 deletions test/dynamo/test_subclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,96 +436,6 @@ def fn(x):
self.assertEqual(lower_bound_str, expected_lower_bound)
self.assertEqual(upper_bound_str, expected_upper_bound)

def test_recompile_with_symbool_inputs(self):
def f(pred: bool):
if pred:
return torch.ones([3, 4])
else:
return torch.ones([4, 3])

def test_recompilation(
f, x, sizes, exp_graphs, exp_frame_count, exp_shape_env_guards
):
torch._dynamo.reset()
shape_env = ShapeEnv()
backend = torch._dynamo.testing.EagerAndRecordGraphs()
cnt = torch._dynamo.testing.CompileCounterWithBackend(backend)
f_cond = torch.compile(f, backend=cnt, fullgraph=True)
with torch._subclasses.fake_tensor.FakeTensorMode(
shape_env=shape_env
) as fake_mode:
fake_inp = fake_mode.from_tensor(
x, dynamic_dims=[DimDynamic.DYNAMIC for i in range(x.dim())]
)
for i, size in enumerate(sizes):
pred = fake_inp.size(0) == size
f_cond(pred)
actual = normalize_gm(
backend.graphs[exp_frame_count[i] - 1].print_readable(
print_output=False
)
)
actual_guard_str = [str(guard.expr) for guard in shape_env.guards]
self.assertExpectedInline(actual, exp_graphs[i])
self.assertEqual(cnt.frame_count, exp_frame_count[i])
self.assertEqual(actual_guard_str, exp_shape_env_guards[i])

true_graph = """\
class GraphModule(torch.nn.Module):
def forward(self):
ones = torch.ones([3, 4])
return (ones,)
"""
false_graph = """\
class GraphModule(torch.nn.Module):
def forward(self):
ones = torch.ones([4, 3])
return (ones,)
"""
test_recompilation(
f,
torch.randn([3, 4]),
[3, 3, 4, 5],
exp_graphs=[true_graph, true_graph, false_graph, false_graph],
exp_frame_count=[1, 1, 2, 2],
exp_shape_env_guards=[
[],
# s0 is specialized and guarded in outter shape_env when dynamo checks the guards
["Eq(Piecewise((1, Eq(s0, 3)), (0, True)), 1)"],
[
"Eq(Piecewise((1, Eq(s0, 3)), (0, True)), 1)",
"Ne(Piecewise((1, Eq(s0, 4)), (0, True)), 1)",
],
[
"Eq(Piecewise((1, Eq(s0, 3)), (0, True)), 1)",
"Ne(Piecewise((1, Eq(s0, 4)), (0, True)), 1)",
"Ne(Piecewise((1, Eq(s0, 5)), (0, True)), 1)",
],
],
)

test_recompilation(
f,
torch.randn([3, 4]),
[4, 5, 3, 3],
exp_graphs=[false_graph, false_graph, true_graph, true_graph],
exp_frame_count=[1, 1, 2, 2],
exp_shape_env_guards=[
[],
# s0 is specialized and guarded in outter shape_env when dynamo checks the guards
["Ne(Piecewise((1, Eq(s0, 5)), (0, True)), 1)"],
[
"Ne(Piecewise((1, Eq(s0, 5)), (0, True)), 1)",
"Eq(Piecewise((1, Eq(s0, 3)), (0, True)), 1)",
],
[
"Ne(Piecewise((1, Eq(s0, 5)), (0, True)), 1)",
"Eq(Piecewise((1, Eq(s0, 3)), (0, True)), 1)",
"Eq(Piecewise((1, Eq(s0, 3)), (0, True)), 1)",
],
],
)


if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
Expand Down
21 changes: 3 additions & 18 deletions test/functorch/test_control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,7 @@ def f(x, pred, pred2):
graph = make_fx(f, tracing_mode="symbolic")(x, torch.tensor(False), torch.tensor(False))
self.assertEqual(graph(x, torch.tensor(True), torch.tensor(True)), f(x, torch.tensor(True), torch.tensor(True)))

@unittest.expectedFailure
def test_cond_functionalized(self):
def true_fn(x):
y = x.sin()
Expand Down Expand Up @@ -312,6 +313,7 @@ def f(x):
gm_functional = make_fx(torch.func.functionalize(gm_non_functional), tracing_mode="real")(inp)
self.assertEqual(gm_functional(torch.zeros(1, 2)), f(torch.zeros(1, 2)))

@unittest.expectedFailure
def test_cond_functionalized_nested(self):
def true_true_fn(x):
y = x.cos()
Expand Down Expand Up @@ -1214,6 +1216,7 @@ def main(p, pred, xs, y):
self.assertEqual(res, main(p, pred, xs, y))
self.check_map_count(gm, 2)

@unittest.expectedFailure
def test_cond_with_sym_pred(self):
def true_fn(x):
return x + x
Expand All @@ -1225,28 +1228,10 @@ def foo(x):
return cond(x.shape[0] == 4, true_fn, false_fn, [x])

gm = make_fx(foo, tracing_mode="symbolic")(torch.ones(3, 2, 1))
# The symbols in make_fx's shape_env should not be speciliazed.
self.assertEqual(len(gm.shape_env.guards), 0)

exp_code = """\
def forward(self, x_1):
sym_size = torch.ops.aten.sym_size(x_1, 0)
eq = sym_size == 4; sym_size = None
true_graph_0 = self.true_graph_0
false_graph_0 = self.false_graph_0
conditional = torch.ops.higher_order.cond(eq, true_graph_0, false_graph_0, [x_1]); \
eq = true_graph_0 = false_graph_0 = x_1 = None
return conditional
"""
self._expected_inline_normalized(gm.code, exp_code)


# We expect the traced graph module to work even if input size changes.
x = torch.ones(4, 3, 2)
self.assertEqual(gm(x), true_fn(x))
self.assertEqual(foo(x), true_fn(x))


def _check_closure_correctly_lifted(self, f, *, args, exp_res, exp_arg_num):
assert isinstance(args, (tuple, list))
self.assertEqual(f(*args), exp_res)
Expand Down
15 changes: 0 additions & 15 deletions torch/_dynamo/source.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,21 +270,6 @@ def name(self):
return f"{self.base.name()}.__neg__()"


@dataclasses.dataclass(frozen=True)
class ConvertIntSource(ChainedSource):
def __post_init__(self):
assert self.base is not None

def reconstruct(self, codegen):
return self.base.reconstruct(codegen)

def guard_source(self):
return self.base.guard_source()

def name(self):
return f"cast_symbool_to_symint_guardless({self.base.name()})"


@dataclasses.dataclass(frozen=True)
class DefaultsSource(ChainedSource):
idx_key: Union[int, str]
Expand Down
41 changes: 0 additions & 41 deletions torch/_dynamo/variables/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@
from ..source import (
AttrSource,
ConstantSource,
ConvertIntSource,
GetItemSource,
GlobalWeakRefSource,
is_constant_source,
Expand Down Expand Up @@ -714,46 +713,6 @@ def index_source(key):
source=self.source,
guards=make_guards(GuardBuilder.FUNCTION_MATCH),
)
elif isinstance(value, torch.SymBool):
# Note: the idea here is to re-use the infra we've built for SymInt by simulating the
# user provided SymBool with a SymInt in dynamo.

# Concretely,
# 1. We create a SymInt in dynamo's shape_env, whose source is constructed as ConvertIntSource(self.source).
# so that guards on the SymInts can be effectively applied on the original SymBool in user program.
# 2. We create a SymBool based on the SymInt in dynamo's ShapeEnv. Because the original user program
# depends on the value being a SymBool. This allows dynamo to interpret the user's program correctly.

value_hint = value.node.require_hint()
new_source = ConvertIntSource(self.source)

new_symint = self.tx.output.shape_env.create_unspecified_symint_and_symbol(
int(value_hint),
new_source,
dynamic_dim=DimDynamic.DYNAMIC,
)

sym_node_proxy = self.tx.output.root_tracer.create_graph_input(
re.sub(r"[^a-zA-Z0-9]+", "_", self.name),
type(new_symint),
source=new_source,
)

sym_node_proxy.node.meta["grapharg"] = GraphArg(
new_source,
new_symint,
False,
None,
is_tensor=False,
example_strong_ref=new_symint,
)
self.tx.output.tracked_fakes.append(
TrackedFake(new_symint, new_source, None)
)
return SymNodeVariable(
sym_node_proxy,
new_symint == 1,
)
else:
result = UserDefinedObjectVariable(
value,
Expand Down
2 changes: 1 addition & 1 deletion torch/_higher_order_ops/cond.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def false_fn(x: torch.Tensor):
return cond_op(pred, true_fn, false_fn, operands)

def _validate_input(pred, true_fn, false_fn, operands):
if not isinstance(pred, (bool, torch.Tensor, torch.SymBool)):
if not isinstance(pred, (bool, torch.Tensor)):
raise RuntimeError(f"Expected pred to be bool or tensor, but got {pred}.")

if isinstance(pred, torch.Tensor) and pred.numel() != 1:
Expand Down
18 changes: 2 additions & 16 deletions torch/fx/experimental/symbolic_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1286,10 +1286,6 @@ def method_to_operator(method):
op = getattr(operator, method_attr)
return op

def cast_symbool_to_symint_guardless(symbool: torch.SymBool) -> torch.SymInt:
int_sym = sympy.Piecewise((1, symbool.node.expr), (0, True))
return symbool.node.shape_env.create_symintnode(int_sym, hint=int(symbool.node.require_hint()))

SYMPY_INTERP = {
'Eq': operator.eq,
'Ne': operator.ne,
Expand All @@ -1305,7 +1301,6 @@ def cast_symbool_to_symint_guardless(symbool: torch.SymBool) -> torch.SymInt:
'IsNonOverlappingAndDenseIndicator': eval_is_non_overlapping_and_dense,
'floor': math.floor,
'ceiling': math.ceil,
'cast_symbool_to_symint_guardless': cast_symbool_to_symint_guardless,
}

always_float_magic_methods = {"truediv", "sym_float", "sym_sqrt", "pow"}
Expand Down Expand Up @@ -2728,10 +2723,7 @@ def create_unspecified_symbol(
) -> "sympy.Expr":
# 'positive' is None for unspecified symbols, since we can't
# assume that it will be neither positive nor negative.

# We don't want to specialize zero one val for unspecified symbol
# so that we can always get a new symbol despite val.
return self.create_symbol(val, source, dynamic_dim, constraint_dim, positive=None, do_not_specialize_zero_one=True)
return self.create_symbol(val, source, dynamic_dim, constraint_dim, positive=None)

@record_shapeenv_event()
def create_symbol(
Expand All @@ -2741,13 +2733,7 @@ def create_symbol(
dynamic_dim: DimDynamic = DimDynamic.DUCK,
constraint_dim: DimConstraint = None, # NB: includes None
positive: Optional[bool] = True,
do_not_specialize_zero_one: bool = False,
) -> "sympy.Expr":
if do_not_specialize_zero_one:
specialize_zero_one = False
else:
specialize_zero_one = self.specialize_zero_one

assert isinstance(source, Source), f"{type(source)} {source}"
assert not (positive and val < 0), f"positive set for negative value: {val}"
# It's always sound to allocate a symbol as DYNAMIC. If the user
Expand All @@ -2767,7 +2753,7 @@ def create_symbol(
else:
raise AssertionError(f"unhandled dynamic_dim {dynamic_dim}")

if val in (0, 1) and specialize_zero_one:
if val in (0, 1) and self.specialize_zero_one:
r = self.val_to_var[val]
elif not duck or val not in self.val_to_var:
# If we're not duck shaping, we always create a new symbol
Expand Down

0 comments on commit 47f79e9

Please sign in to comment.