Skip to content

Commit

Permalink
Take input striding for conv forward based on eager output (#88706)
Browse files Browse the repository at this point in the history
From discussion with @Chillee and @ngimel we'll likely need further fixes to ensure that we hit channels last kernels but this is still worth landing in its own right.
Pull Request resolved: #88706
Approved by: https://github.com/ngimel
  • Loading branch information
eellison authored and pytorchmergebot committed Nov 11, 2022
1 parent adfbd83 commit 8ff2e34
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 29 deletions.
26 changes: 26 additions & 0 deletions test/inductor/test_torchinductor.py
Expand Up @@ -4601,6 +4601,8 @@ def fn(a):
CommonTemplate.install(CudaTests, "cuda")

class CudaReproTests(TestCase):
common = check_model_cuda

def test_index_put_issue(self):
def forward(
self,
Expand Down Expand Up @@ -4637,6 +4639,30 @@ def forward(
compiled = compile_fx_inner(mod, inps)
compiled(inps)

@requires_cuda()
def test_input_channels_last(self):
m = torch.nn.Sequential(
torch.nn.Conv2d(3, 3, 1, 1),
ToTuple(),
).cuda()
inp = (
torch.randn([2, 3, 16, 16]).to(memory_format=torch.channels_last).cuda()
)

self.common(
m,
(inp,),
check_lowp=False,
)

@torch._dynamo.optimize()
def foo(m, inp):
return m(inp)

self.assertTrue(
foo(m, inp)[0].is_contiguous(memory_format=torch.channels_last)
)

# https://github.com/pytorch/torchdynamo/issues/1681#issuecomment-1283433527
@requires_cuda()
def test_unspec_inputs_interop(self):
Expand Down
72 changes: 43 additions & 29 deletions torch/_inductor/ir.py
Expand Up @@ -19,7 +19,12 @@

import torch.fx
import torch.utils._pytree as pytree
from torch._prims_common import is_boolean_dtype, is_float_dtype
from torch._prims_common import (
is_boolean_dtype,
is_float_dtype,
make_channels_last_strides_for,
make_contiguous_strides_for,
)
from torch._subclasses.fake_tensor import FakeTensorMode

from . import config, dependencies
Expand Down Expand Up @@ -133,7 +138,7 @@ def ir_node_to_tensor(x, guard_shape=True):
if is_storage_and_layout(x):
stride = [shape_fn(s) for s in x.get_layout().stride]
else:
stride = torch._prims_common.make_contiguous_strides_for(size)
stride = make_contiguous_strides_for(size)
dtype = x.get_dtype()
device = x.get_device()
t = torch.empty_strided(
Expand Down Expand Up @@ -2462,6 +2467,9 @@ def require_stride_order(cls, x, order):
x.get_layout(), FixedLayout
) and x.get_layout().is_stride_ordered(order):
return x
# TODO - Storage to InputBuffer
if isinstance(x, InputBuffer) and x.get_layout().is_stride_ordered(order):
return x
x = cls.copy_input(x)
as_storage_and_layout(x, freeze=True, want_contiguous=False, stride_order=order)
assert is_stride_order_storage_and_layout(x, order)
Expand Down Expand Up @@ -3052,32 +3060,39 @@ def create(
output_padding_: List[int],
groups: int,
):
with torch._subclasses.FakeTensorMode():
x_fake = ir_node_to_tensor(x, guard_shape=True)
weight_fake = ir_node_to_tensor(weight, guard_shape=True)
bias_fake = (
ir_node_to_tensor(bias, guard_shape=True) if bias is not None else bias
)
output = torch.ops.aten.convolution(
x_fake,
weight_fake,
bias_fake,
stride_,
padding_,
dilation_,
transposed,
output_padding_,
groups,
)
req_stride_order = get_stride_order(output.stride())

if config.triton.convolution == "aten":
weight = cls.require_stride_order(weight, req_stride_order)
x = cls.require_stride_order(x, req_stride_order)
else:
x = cls.require_stride1(cls.realize_input(x))
weight = cls.require_stride1(cls.realize_input(weight))

weight = cls.require_stride1(cls.realize_input(weight))
x = cls.require_stride_order(x, get_stride_order(weight.get_stride()))
stride = tuple(stride_)
padding = tuple(padding_)
dilation = tuple(dilation_)
assert isinstance(transposed, bool)
output_padding = tuple(output_padding_)
assert isinstance(groups, int)

# TODO - enable FakeTensorMode for propagation more globally. incorrect stride metas for fallback
# kernels will lead to runtime failures
with FakeTensorMode():
output, *_ = cls.process_kernel(
torch.ops.aten.convolution,
x,
weight,
bias,
stride,
padding,
dilation,
transposed,
output_padding,
groups,
)

output_size = output.shape

weight_shape = [
Expand Down Expand Up @@ -3122,6 +3137,7 @@ def create(
# for conv2d or conv3d, prefer channels last format
if kernel == "triton_ops.conv":
output_layout_str = "torch.channels_last"

elif config.tune_layout and len(x.get_size()) == 4:
from .codegen.autotuner import tuned_conv_layout

Expand Down Expand Up @@ -3151,14 +3167,19 @@ def create(
if len(stride_order) < len(output_size):
# add batch dim if it exists
stride_order = [len(stride_order)] + stride_order
strides = make_channels_last_strides_for(output_size)
else:
stride_order = list(reversed(range(len(output_size))))
strides = make_contiguous_strides_for(output_size)

output_layout = FlexibleLayout(
if config.triton.convolution != "aten":
x = cls.require_stride_order(x, stride_order)

output_layout = FixedLayout(
x.get_device(),
x.get_dtype(),
output_size,
stride_order,
strides,
)

if bias is not None:
Expand All @@ -3178,13 +3199,6 @@ def create(
kernel,
)

def apply_constraint(self):
x = self.inputs[0]
# FixedLayout of input
x = self.require_stride_order(x, self.layout.preferred_stride_order)
self.inputs[0] = x
self.freeze_layout_with_stride_order(self.layout.preferred_stride_order)

def map_args(self):
# x, w, bias
in_args = [x.codegen_reference() for x in self.inputs]
Expand Down

0 comments on commit 8ff2e34

Please sign in to comment.