From 735169f9498c20c30b9906e16f8d20298dc736e0 Mon Sep 17 00:00:00 2001 From: Elias Ellison Date: Fri, 11 Nov 2022 04:25:11 +0000 Subject: [PATCH] Take input striding for conv forward based on eager output (#88706) From discussion with @Chillee and @ngimel we'll likely need further fixes to ensure that we hit channels last kernels but this is still worth landing in its own right. Pull Request resolved: https://github.com/pytorch/pytorch/pull/88706 Approved by: https://github.com/ngimel --- test/inductor/test_torchinductor.py | 26 +++++++++++ torch/_inductor/ir.py | 72 +++++++++++++++++------------ 2 files changed, 69 insertions(+), 29 deletions(-) diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index 121f3d31f39c..aea8013bdfac 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -4601,6 +4601,8 @@ def fn(a): CommonTemplate.install(CudaTests, "cuda") class CudaReproTests(TestCase): + common = check_model_cuda + def test_index_put_issue(self): def forward( self, @@ -4637,6 +4639,30 @@ def forward( compiled = compile_fx_inner(mod, inps) compiled(inps) + @requires_cuda() + def test_input_channels_last(self): + m = torch.nn.Sequential( + torch.nn.Conv2d(3, 3, 1, 1), + ToTuple(), + ).cuda() + inp = ( + torch.randn([2, 3, 16, 16]).to(memory_format=torch.channels_last).cuda() + ) + + self.common( + m, + (inp,), + check_lowp=False, + ) + + @torch._dynamo.optimize() + def foo(m, inp): + return m(inp) + + self.assertTrue( + foo(m, inp)[0].is_contiguous(memory_format=torch.channels_last) + ) + # https://github.com/pytorch/torchdynamo/issues/1681#issuecomment-1283433527 @requires_cuda() def test_unspec_inputs_interop(self): diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index 448c057ecb0e..240c196a73b6 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -19,7 +19,12 @@ import torch.fx import torch.utils._pytree as pytree -from torch._prims_common import is_boolean_dtype, is_float_dtype +from torch._prims_common import ( + is_boolean_dtype, + is_float_dtype, + make_channels_last_strides_for, + make_contiguous_strides_for, +) from torch._subclasses.fake_tensor import FakeTensorMode from . import config, dependencies @@ -133,7 +138,7 @@ def ir_node_to_tensor(x, guard_shape=True): if is_storage_and_layout(x): stride = [shape_fn(s) for s in x.get_layout().stride] else: - stride = torch._prims_common.make_contiguous_strides_for(size) + stride = make_contiguous_strides_for(size) dtype = x.get_dtype() device = x.get_device() t = torch.empty_strided( @@ -2462,6 +2467,9 @@ def require_stride_order(cls, x, order): x.get_layout(), FixedLayout ) and x.get_layout().is_stride_ordered(order): return x + # TODO - Storage to InputBuffer + if isinstance(x, InputBuffer) and x.get_layout().is_stride_ordered(order): + return x x = cls.copy_input(x) as_storage_and_layout(x, freeze=True, want_contiguous=False, stride_order=order) assert is_stride_order_storage_and_layout(x, order) @@ -3052,9 +3060,32 @@ def create( output_padding_: List[int], groups: int, ): + with torch._subclasses.FakeTensorMode(): + x_fake = ir_node_to_tensor(x, guard_shape=True) + weight_fake = ir_node_to_tensor(weight, guard_shape=True) + bias_fake = ( + ir_node_to_tensor(bias, guard_shape=True) if bias is not None else bias + ) + output = torch.ops.aten.convolution( + x_fake, + weight_fake, + bias_fake, + stride_, + padding_, + dilation_, + transposed, + output_padding_, + groups, + ) + req_stride_order = get_stride_order(output.stride()) + + if config.triton.convolution == "aten": + weight = cls.require_stride_order(weight, req_stride_order) + x = cls.require_stride_order(x, req_stride_order) + else: + x = cls.require_stride1(cls.realize_input(x)) + weight = cls.require_stride1(cls.realize_input(weight)) - weight = cls.require_stride1(cls.realize_input(weight)) - x = cls.require_stride_order(x, get_stride_order(weight.get_stride())) stride = tuple(stride_) padding = tuple(padding_) dilation = tuple(dilation_) @@ -3062,22 +3093,6 @@ def create( output_padding = tuple(output_padding_) assert isinstance(groups, int) - # TODO - enable FakeTensorMode for propagation more globally. incorrect stride metas for fallback - # kernels will lead to runtime failures - with FakeTensorMode(): - output, *_ = cls.process_kernel( - torch.ops.aten.convolution, - x, - weight, - bias, - stride, - padding, - dilation, - transposed, - output_padding, - groups, - ) - output_size = output.shape weight_shape = [ @@ -3122,6 +3137,7 @@ def create( # for conv2d or conv3d, prefer channels last format if kernel == "triton_ops.conv": output_layout_str = "torch.channels_last" + elif config.tune_layout and len(x.get_size()) == 4: from .codegen.autotuner import tuned_conv_layout @@ -3151,14 +3167,19 @@ def create( if len(stride_order) < len(output_size): # add batch dim if it exists stride_order = [len(stride_order)] + stride_order + strides = make_channels_last_strides_for(output_size) else: stride_order = list(reversed(range(len(output_size)))) + strides = make_contiguous_strides_for(output_size) - output_layout = FlexibleLayout( + if config.triton.convolution != "aten": + x = cls.require_stride_order(x, stride_order) + + output_layout = FixedLayout( x.get_device(), x.get_dtype(), output_size, - stride_order, + strides, ) if bias is not None: @@ -3178,13 +3199,6 @@ def create( kernel, ) - def apply_constraint(self): - x = self.inputs[0] - # FixedLayout of input - x = self.require_stride_order(x, self.layout.preferred_stride_order) - self.inputs[0] = x - self.freeze_layout_with_stride_order(self.layout.preferred_stride_order) - def map_args(self): # x, w, bias in_args = [x.codegen_reference() for x in self.inputs]