Skip to content

Commit

Permalink
handle scatter(Scalar) overload in inductor (pytorch#88894)
Browse files Browse the repository at this point in the history
Relanding pytorch#88210

Pull Request resolved: pytorch#88894
Approved by: https://github.com/desertfire
  • Loading branch information
Krovatkin authored and kulinseth committed Dec 9, 2022
1 parent 17c05c8 commit 5a26cb9
Showing 1 changed file with 35 additions and 9 deletions.
44 changes: 35 additions & 9 deletions torch/_inductor/lowering.py
Expand Up @@ -2029,14 +2029,37 @@ def scatter(x, dim: int, index, src, **kwargs):
return scatter_(clone(x), dim, index, src, **kwargs)


def scatter_fallback(
fn, self, dim: int, index, src, *, reduce: str = None, include_self: bool = True
):

if reduce not in {None, "sum"} or (
reduce == "sum" and self.get_dtype() in {torch.bool, torch.int64}
):
self.realize()
return fallback_handler(fn)(
self, dim, index, src, reduce=reduce, include_self=include_self
)

return None


@register_lowering(aten.scatter_, type_promotion_kind=None)
def scatter_(self, dim: int, index, src, *, reduce: str = None):

if reduce == "add":
reduce = "sum"
elif reduce == "multiply":
reduce = "prod"
else:
assert reduce is None

fallback_result = scatter_fallback(
aten.scatter_, self, dim, index, src, reduce=reduce
)

if fallback_result:
return fallback_result
return scatter_reduce_(self, dim, index, src, reduce)


Expand All @@ -2062,15 +2085,18 @@ def scatter_reduce(x, dim: int, index, src, reduction_type, **kwargs):
def scatter_reduce_(self, dim: int, index, src, reduce, *, include_self: bool = True):
assert reduce in {None, "sum", "prod", "mean", "amax", "amin"}

# TODO: Need to support more reduction type
# For reduction of "sum", tl.atomic_add doesn't support bool or int64
if reduce not in {None, "sum"} or (
reduce == "sum" and self.get_dtype() in {torch.bool, torch.int64}
):
self.realize()
return fallback_scatter_reduce_(
self, dim, index, src, reduce, include_self=include_self
)
fallback_result = scatter_fallback(
aten.scatter_reduce_,
self,
dim,
index,
src,
reduce=reduce,
include_self=include_self,
)

if fallback_result:
return fallback_result

assert isinstance(self, TensorBox)
assert "int" in str(index.get_dtype())
Expand Down

0 comments on commit 5a26cb9

Please sign in to comment.