Skip to content

Commit

Permalink
Workarounds for cudnn_batch_norm with TorchRefsNvfuserCapabilityMode (p…
Browse files Browse the repository at this point in the history
…ytorch#86796)

This PR adds workarounds to support AOT Autograd's graphs containing `aten.cudnn_batch_norm` and `aten.cudnn_batch_norm_backward` with `TorchRefsNvfuserCapabilityMode`.

The problem with the decomposition of `aten.cudnn_batch_norm` is that it uses a `new_empty` call that is not supported by nvFuser and we are conservative with lowering functions to nvprims by default.

The problem with the decomposition of `aten.cudnn_batch_norm_backward` is described here pytorch#86115 (comment), but changing the decomposition directly in that PR makes many tests fail.
Pull Request resolved: pytorch#86796
Approved by: https://github.com/mruberry
  • Loading branch information
IvanYashchuk authored and pytorchmergebot committed Oct 17, 2022
1 parent 33343de commit 3193151
Show file tree
Hide file tree
Showing 2 changed files with 155 additions and 9 deletions.
86 changes: 77 additions & 9 deletions test/test_prims.py
Original file line number Diff line number Diff line change
Expand Up @@ -611,6 +611,63 @@ def func(
out = execute(gm, sample.input, *sample.args, executor="strictly_nvfuser")
self.assertEqual(out, gm(sample.input, *sample.args))

@onlyCUDA
@skipCUDAIfRocm
@dtypes(torch.float32, torch.float64)
def test_cudnn_batch_norm_nvprims(self, device, dtype):
from torch._prims.context import TorchRefsNvfuserCapabilityMode
from torch._prims.executor import execute

# This test verifies that cudnn_batch_norm is translated into nvprims
# and can be executed with nvFuser
from torch.fx.experimental.proxy_tensor import make_fx
from torch.testing._internal.common_methods_invocations import (
sample_inputs_native_batch_norm,
)

samples = sample_inputs_native_batch_norm(
None, device, dtype, requires_grad=False
)
for sample in samples:
if sample.input.numel() == 0:
continue

def func(
input, weight, bias, running_mean, running_var, training, momentum, eps
):
return torch.ops.aten.cudnn_batch_norm.default(
input,
weight,
bias,
running_mean,
running_var,
training,
momentum,
eps,
)

with TorchRefsNvfuserCapabilityMode():
gm = make_fx(func)(sample.input, *sample.args)

call_function_nodes = list(
filter(lambda n: n.op == "call_function", gm.graph.nodes)
)
includes_aten_batch_norm = any(
torch.ops.aten.cudnn_batch_norm.default == node.target
for node in call_function_nodes
)
self.assertFalse(includes_aten_batch_norm)

includes_nvprims_batch_norm = any(
torch.ops.nvprims.native_batch_norm.default == node.target
for node in call_function_nodes
)
self.assertTrue(includes_nvprims_batch_norm)

# Check that the graph can be executed with nvFuser
out = execute(gm, sample.input, *sample.args, executor="nvfuser")
self.assertEqual(out, gm(sample.input, *sample.args))

# decomposition of native_batch_norm_backward uses a casting, which prevents nvprim lowering on CPU build
@onlyCUDA
@dtypes(torch.float32, torch.float16)
Expand All @@ -624,23 +681,34 @@ def test_batch_norm_backward_nvprims(self, device, dtype):
sample = next(samples_iter)
grad = torch.randn_like(sample.input)

def func(grad, input, weight, rm, rv, eps, train):
def func1(grad, input, weight, rm, rv, eps, train):
return torch.ops.aten.native_batch_norm_backward.default(
grad, input, weight, rm, rv, rm, rv, train, eps, [True, True, True]
)

def func2(grad, input, weight, rm, rv, eps, train):
return torch.ops.aten.cudnn_batch_norm_backward.default(
input, grad, weight, rm, rv, rm, rv, eps, grad
)

args = sample.args
kwargs = sample.kwargs
all_args = [grad, sample.input, args[2], args[0], args[1], kwargs['eps'], kwargs['training']]
with TorchRefsNvfuserCapabilityMode():
gm = make_fx(func)(*all_args)

call_function_nodes = list(filter(lambda n: n.op == "call_function", gm.graph.nodes))
includes_batch_norm_backward = any(
torch.ops.aten.native_batch_norm_backward.default == node.target
for node in call_function_nodes
)
self.assertFalse(includes_batch_norm_backward)
for func in (func1, func2):
with TorchRefsNvfuserCapabilityMode():
gm = make_fx(func)(*all_args)

call_function_nodes = list(filter(lambda n: n.op == "call_function", gm.graph.nodes))
includes_batch_norm_backward = any(
torch.ops.aten.native_batch_norm_backward.default == node.target
for node in call_function_nodes
)
self.assertFalse(includes_batch_norm_backward)
all_nvprims = all(
str(node.target).startswith("nvprims") for node in call_function_nodes
)
self.assertTrue(all_nvprims)

@onlyCUDA
@skipCUDAIfRocm
Expand Down
78 changes: 78 additions & 0 deletions torch/_prims/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,68 @@ def __init__(self, *, skip_ops=()):
prims_mode_cls=functools.partial(NvfuserPrimsMode, skip_ops=skip_ops),
)

# TODO: remove this once version from _decomp/decompositions.py is working
# with this context manager
# This is a workaround for AOT Autograd graphs
def _cudnn_batch_norm(
self,
input,
weight,
bias,
running_mean,
running_var,
training,
exponential_average_factor,
epsilon,
):
a, b, c = torch.ops.nvprims.native_batch_norm(
input,
weight,
bias,
running_mean,
running_var,
training,
exponential_average_factor,
epsilon,
)
if training:
return (a, b, c, input.new_zeros((0,), dtype=torch.uint8))
return (
a,
weight.new_zeros((0,)),
weight.new_zeros((0,)),
input.new_zeros((0,), dtype=torch.uint8),
)

# This is a workaround for AOT Autograd graphs
def _cudnn_batch_norm_backward(
self,
input,
grad_output,
weight,
running_mean,
running_var,
save_mean,
save_var,
epsilon,
reserveSpace,
):
func = torch._decomp.decomposition_table[
torch.ops.aten.native_batch_norm_backward.default
]
return func(
grad_output,
input,
weight,
running_mean,
running_var,
save_mean,
save_var,
True,
epsilon,
[True, True, True],
)

def _is_var_mean(self, func):
return "torch.var_mean" == torch.overrides.resolve_name(func) or (
(
Expand Down Expand Up @@ -313,6 +375,22 @@ def __torch_function__(
if self._is_var_mean(orig_func):
return torch.ops.nvprims.var_mean(*args, **kwargs)

if (
orig_func == torch.ops.aten.cudnn_batch_norm.default
or orig_func == torch.ops.aten.cudnn_batch_norm
):
with self:
return self._cudnn_batch_norm(*args, **kwargs)

# A workaround for AOT Autograd graphs
# See https://github.com/pytorch/pytorch/pull/86115#issue-1394883782
if (
orig_func == torch.ops.aten.cudnn_batch_norm_backward.default
or orig_func == torch.ops.aten.cudnn_batch_norm_backward
):
with self:
return self._cudnn_batch_norm_backward(*args, **kwargs)

if self._is_view_or_reshape(orig_func):
a, *shape = args
shape = torch._prims_common.extract_shape_from_varargs(
Expand Down

0 comments on commit 3193151

Please sign in to comment.