Skip to content

Commit

Permalink
Change capture_scalar_outputs to use SymInt/SymFloat rather than Tens…
Browse files Browse the repository at this point in the history
…or to model scalars (#93150)

Previously, Dynamo faked support for item() when `capture_scalar_outputs` was True by representing it internally as a Tensor. With dynamic shapes, this is no longer necessary; we can represent it directly as a SymInt/SymFloat. Do so. Doing this requires you to use dynamic shapes; in principle we could support scalar outputs WITHOUT dynamic shapes but I won't do this unless someone hollers for it.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Differential Revision: [D42885775](https://our.internmc.facebook.com/intern/diff/D42885775)
Pull Request resolved: #93150
Approved by: https://github.com/voznesenskym
  • Loading branch information
ezyang authored and pytorchmergebot committed Jan 31, 2023
1 parent 76b683b commit 902b4db
Show file tree
Hide file tree
Showing 8 changed files with 38 additions and 40 deletions.
9 changes: 9 additions & 0 deletions test/dynamo/test_export.py
Expand Up @@ -320,6 +320,7 @@ def func(x, z, k):

self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))

@patch.object(torch._dynamo.config, "dynamic_shapes", True)
@patch.object(torch._dynamo.config, "capture_scalar_outputs", True)
def test_dupes_and_bypass_with_non_tensor_output(self):
inp = torch.tensor([0.1, 0.1])
Expand Down Expand Up @@ -366,6 +367,7 @@ def func(a, b, c):

self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))

@patch.object(torch._dynamo.config, "dynamic_shapes", True)
@patch.object(torch._dynamo.config, "capture_scalar_outputs", True)
def test_zeroes_in_new_shape_scalar_out(self):
inp = torch.zeros(10)
Expand All @@ -390,6 +392,7 @@ def func(a, b, c):

self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))

@patch.object(torch._dynamo.config, "dynamic_shapes", True)
@patch.object(torch._dynamo.config, "capture_scalar_outputs", True)
def test_zeroes_in_new_shape_scalar_out_permute(self):
inp = torch.zeros(10)
Expand All @@ -414,6 +417,7 @@ def func(a, b, c):

self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))

@patch.object(torch._dynamo.config, "dynamic_shapes", True)
@patch.object(torch._dynamo.config, "capture_scalar_outputs", True)
def test_zeroes_in_new_shape_scalar_out_permute_dupe_and_bypass(self):
inp = torch.zeros(10)
Expand Down Expand Up @@ -771,6 +775,7 @@ def func(x, z, k):

self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))

@patch.object(torch._dynamo.config, "dynamic_shapes", True)
@patch.object(torch._dynamo.config, "capture_scalar_outputs", True)
def test_dupes_and_bypass_with_non_tensor_output_with_aten_graph(self):
inp = torch.tensor([0.1, 0.1])
Expand Down Expand Up @@ -1421,6 +1426,7 @@ def nop(x):
f, (torch.randn(5)), aten_graph=False, tracing_mode="symbolic"
)

@patch.object(torch._dynamo.config, "dynamic_shapes", True)
@patch.object(torch._dynamo.config, "capture_scalar_outputs", True)
def test_export_with_module_layer(self):
from functorch.experimental.control_flow import cond
Expand Down Expand Up @@ -1634,6 +1640,7 @@ def g(x, y):
)

@patch.object(torch._dynamo.config, "dynamic_shapes", True)
@patch.object(torch._dynamo.config, "capture_scalar_outputs", True)
def test_dynamic_slicing_simple(self):
def f(x):
return x[slice(None, None, None)]
Expand All @@ -1645,6 +1652,8 @@ def f(x):
inp = torch.randn(6, 7)
self.assertEqual(gm(inp), f(inp))

@patch.object(torch._dynamo.config, "dynamic_shapes", True)
@patch.object(torch._dynamo.config, "capture_scalar_outputs", True)
def test_export_cond_in_aten_symbolic(self):
class ConditionOp(torch.nn.Module):
def __init__(self):
Expand Down
5 changes: 5 additions & 0 deletions test/dynamo/test_misc.py
Expand Up @@ -448,6 +448,7 @@ def fn(a):
self, fn=fn, nargs=1, expected_ops=5, expected_ops_dynamic=8
)

@patch.object(torch._dynamo.config, "dynamic_shapes", True)
@patch.object(torch._dynamo.config, "capture_scalar_outputs", True)
def test_tensor_item_capture(self):
def fn(a, b):
Expand All @@ -462,6 +463,7 @@ def fn(a, b):
self.assertEqual(cnts.frame_count, 1)
self.assertEqual(cnts.op_count, 3)

@patch.object(torch._dynamo.config, "dynamic_shapes", True)
@patch.object(torch._dynamo.config, "capture_scalar_outputs", False)
def test_tensor_item_no_capture(self):
def fn(a, b):
Expand Down Expand Up @@ -2035,6 +2037,7 @@ def f(x, n):
opt_f(x, n)
self.assertEqual(cnts.frame_count, 1)

@patch.object(torch._dynamo.config, "dynamic_shapes", True)
@patch.object(torch._dynamo.config, "capture_scalar_outputs", True)
def test_item(self):
class MyMod(torch.nn.Module):
Expand All @@ -2048,6 +2051,7 @@ def forward(self, x):

self.assertEqual(y, 11)

@patch.object(torch._dynamo.config, "dynamic_shapes", True)
@patch.object(torch._dynamo.config, "capture_scalar_outputs", True)
def test_item_changes(self):
class MyMod(torch.nn.Module):
Expand All @@ -2064,6 +2068,7 @@ def forward(self, x):
self.assertEqual(y, 11)
self.assertEqual(z, 61)

@patch.object(torch._dynamo.config, "dynamic_shapes", True)
@patch.object(torch._dynamo.config, "capture_scalar_outputs", True)
def test_item_changes_new_shape(self):
class MyMod(torch.nn.Module):
Expand Down
12 changes: 4 additions & 8 deletions test/dynamo/test_repros.py
Expand Up @@ -29,6 +29,7 @@
from torch import nn
from torch._dynamo.debug_utils import same_two_models
from torch._dynamo.testing import rand_strided, requires_static_shapes, same
from torch._dynamo.utils import ifdyn
from torch.nn import functional as F


Expand All @@ -42,13 +43,6 @@ def is_fx_tracing_test() -> bool:
return torch.nn.Module.__call__ is not _orig_module_call


def ifdyn(count1, count2):
if torch._dynamo.config.dynamic_shapes:
return count1
else:
return count2


def has_detectron2():
try:
from detectron2.layers.mask_ops import _paste_masks_tensor_shape
Expand Down Expand Up @@ -948,6 +942,7 @@ def test_chunk_reformer_ff(self):
# uncomment/adjust the assertEqual below
@unittest.expectedFailure
@patch.object(torch._dynamo.config, "fake_tensor_propagation", True)
@patch.object(torch._dynamo.config, "dynamic_shapes", True)
@patch.object(torch._dynamo.config, "capture_scalar_outputs", True)
def test_maml_item_capture(self):
a = torch.randn(5, 1, 28, 28)
Expand All @@ -966,6 +961,7 @@ def test_maml_item_capture(self):
self.assertIn(cnt.op_count, (36, 35, 34, 29, 28, 27))

# see: https://github.com/pytorch/pytorch/issues/80067
@patch.object(torch._dynamo.config, "dynamic_shapes", True)
@patch.object(torch._dynamo.config, "capture_scalar_outputs", False)
def test_maml_no_item_capture(self):
a = torch.randn(5, 1, 28, 28)
Expand All @@ -979,7 +975,7 @@ def test_maml_no_item_capture(self):
for _ in range(10):
self.assertTrue(same(opt_model(a, b, c, d), correct))

self.assertEqual(cnt.frame_count, ifdyn(5, 4))
self.assertEqual(cnt.frame_count, 5)
# TODO(jansel): figure out why op count depends on imports
self.assertIn(cnt.op_count, (31, 36, 35, 34, 29, 28))

Expand Down
2 changes: 2 additions & 0 deletions test/dynamo/test_subgraphs.py
Expand Up @@ -439,6 +439,7 @@ def fn(a):
self.assertEqual(opt_fn(x), fn(x))
self.assertEqual(cnt_dynamic.frame_count, 2)

@patch.object(torch._dynamo.config, "dynamic_shapes", True)
@patch.object(torch._dynamo.config, "capture_scalar_outputs", True)
def test_no_graph_break_on_item(self):
def fn(a, b):
Expand All @@ -450,6 +451,7 @@ def fn(a, b):

self._common(fn, 1, 6)

@patch.object(torch._dynamo.config, "dynamic_shapes", True)
@patch.object(torch._dynamo.config, "capture_scalar_outputs", False)
def test_graph_break_on_item(self):
def fn(a, b):
Expand Down
1 change: 1 addition & 0 deletions torch/_dynamo/config.py
Expand Up @@ -141,6 +141,7 @@

# Not all backends support scalars. Some calls on torch.Tensor (like .item()) return a scalar type.
# When this flag is set to False, we introduce a graph break instead of capturing.
# This requires dynamic_shapes to be True.
capture_scalar_outputs = False

# Should almost always be true in prod. This relaxes the requirement that cond's true_fn and
Expand Down
7 changes: 7 additions & 0 deletions torch/_dynamo/utils.py
Expand Up @@ -1278,3 +1278,10 @@ def fqn(obj: Any):
Returns the fully qualified name of the object.
"""
return f"{obj.__module__}.{obj.__qualname__}"


def ifdyn(count1, count2):
if torch._dynamo.config.dynamic_shapes:
return count1
else:
return count2
16 changes: 0 additions & 16 deletions torch/_dynamo/variables/builder.py
Expand Up @@ -3,8 +3,6 @@
import enum
import functools
import inspect
import math
import numbers
import operator
import re
import types
Expand Down Expand Up @@ -90,7 +88,6 @@
from .nn_module import UnspecializedNNModuleVariable
from .tensor import (
DynamicShapeVariable,
FakeItemVariable,
TensorVariable,
TensorWithTFOverrideVariable,
UnspecializedPythonVariable,
Expand Down Expand Up @@ -930,19 +927,6 @@ def _clone_input(value):
):
proxy.node.meta["example_value"] = example_value
return ConstantVariable(example_value, **options)
elif (
isinstance(example_value, numbers.Number)
and (proxy.node.target == "item" or proxy.node.target in {math.sqrt, math.pow})
and config.capture_scalar_outputs
):
# item raw value should not be accessed
return wrap_fx_proxy_cls(
FakeItemVariable,
tx=tx,
proxy=proxy,
example_value=torch.tensor(example_value),
**options,
)
elif isinstance(example_value, (torch.SymInt, torch.SymFloat)):
proxy.node.meta["example_value"] = example_value
return DynamicShapeVariable(proxy, example_value, **options)
Expand Down
26 changes: 10 additions & 16 deletions torch/_dynamo/variables/tensor.py
Expand Up @@ -319,22 +319,16 @@ def call_method(
unimplemented(f"Tensor.{name}")
elif name == "nonzero" and not config.dynamic_shapes:
unimplemented(f"Tensor.{name}")
elif name == "item":
if config.capture_scalar_outputs:
example_value = get_fake_value(self.proxy.node, tx)
return wrap_fx_proxy(
tx,
tx.output.create_proxy(
"call_method",
"item",
(self.as_proxy(),),
{},
),
example_value=example_value,
**options,
)
else:
unimplemented(f"Tensor.{name}")
elif name == "item" and not config.capture_scalar_outputs:
unimplemented(f"Tensor.{name}")
elif (
name == "item"
and config.capture_scalar_outputs
and not config.dynamic_shapes
):
raise AssertionError(
"To capture_scalar_outputs, you must also set dynamic_shapes = True"
)
elif name == "__len__":
return self.call_method(tx, "size", [ConstantVariable(0, **options)], {})
elif name == "__setitem__":
Expand Down

0 comments on commit 902b4db

Please sign in to comment.