Skip to content

Commit

Permalink
Intercept aten._reshape_alias for nvFuser (#87072)
Browse files Browse the repository at this point in the history
This would help forming larger fusion groups. If this won't end up executed by nvFuser then eager mode implementation would call into `.reshape`: https://github.com/pytorch/pytorch/blob/37e9e89afbc3554258545a026fab4cd9e1a4b85d/torch/_prims/nvfuser_prims.py#L552-L553

cc @kevinstephano @jjsjann123
Pull Request resolved: #87072
Approved by: https://github.com/ngimel
  • Loading branch information
IvanYashchuk authored and pytorchmergebot committed Oct 25, 2022
1 parent a3d495b commit ff2569b
Showing 1 changed file with 6 additions and 0 deletions.
6 changes: 6 additions & 0 deletions torch/_prims/context.py
Expand Up @@ -405,6 +405,12 @@ def __torch_function__(
warn("view has ignored kwargs!")
return torch.ops.nvprims.view(a, shape)

if orig_func == torch.ops.aten._reshape_alias.default:
a, shape, stride = args
if len(kwargs) > 0:
warn("view has ignored kwargs!")
return torch.ops.nvprims.view(a, shape)

if self._is_native_batch_norm(orig_func):
return torch.ops.nvprims.native_batch_norm(*args, **kwargs)

Expand Down

0 comments on commit ff2569b

Please sign in to comment.