Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Port all and any full reductions to structured kernels. #64642

Closed
wants to merge 7 commits into from
126 changes: 52 additions & 74 deletions aten/src/ATen/native/ReduceOps.cpp
Expand Up @@ -92,13 +92,7 @@ ScalarType get_result_or_bytebool_dtype(const Tensor& self, const Tensor& result
}
}

void check_all_any(const char* name, const Tensor& self, const Tensor& result) {
// Refer [all, any : uint8 compatibility]
TORCH_CHECK(
self.layout() == Layout::Strided,
name, " only supports strided layout, got: ",
self.layout());

void check_result_is_bytebool(const char* name, const Tensor& self, const Tensor& result) {
if (result.defined()) {
// Refer [all, any : uint8 compatibility]
TORCH_CHECK(
Expand All @@ -109,20 +103,42 @@ void check_all_any(const char* name, const Tensor& self, const Tensor& result) {
}
}

// Note [all, any : uint8 compatibility]:
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// For NumPy comptability, `all` and `any` return
// Tensor of dtype `bool`. However for compatibility reason,
// for `uint8`, they return Tensor of same dtype `uint8`.
// Reference: https://github.com/pytorch/pytorch/pull/47878#issuecomment-747108561
static void allany_meta(
impl::MetaBase& meta,
const char* name,
const Tensor& self,
IntArrayRef dims,
bool keepdim) {
const auto& result = meta.maybe_get_output();
check_result_is_bytebool(name, self, result);
auto out_dtype = get_result_or_bytebool_dtype(self, result);
resize_reduction(meta, self, dims, keepdim, out_dtype);
}

TORCH_PRECOMPUTE_META_FUNC2(all, dim)(const Tensor& self, int64_t dim, bool keepdim) {
check_all_any("all", self, maybe_get_output());
auto out_dtype = get_result_or_bytebool_dtype(self, maybe_get_output());
resize_reduction(*this, self, dim, keepdim, out_dtype);
allany_meta(*this, "all", self, dim, keepdim);
return TORCH_PRECOMPUTE_STRUCT2(all, dim)().set_dim(maybe_wrap_dim(dim, self.dim()));
}

TORCH_META_FUNC(all)(const Tensor& self) {
allany_meta(*this, "all", self, {}, false);
}

TORCH_PRECOMPUTE_META_FUNC2(any, dim)(const Tensor& self, int64_t dim, bool keepdim) {
check_all_any("any", self, maybe_get_output());
auto out_dtype = get_result_or_bytebool_dtype(self, maybe_get_output());
resize_reduction(*this, self, dim, keepdim, out_dtype);
allany_meta(*this, "any", self, dim, keepdim);
return TORCH_PRECOMPUTE_STRUCT2(any, dim)().set_dim(maybe_wrap_dim(dim, self.dim()));
}

TORCH_META_FUNC(any)(const Tensor& self) {
allany_meta(*this, "any", self, {}, false);
}

void check_argmax_argmin(
const char* name,
const Tensor& self,
Expand Down Expand Up @@ -1293,22 +1309,6 @@ Tensor norm(const Tensor& self, const Scalar& p) {
return at::norm(self, p, IntArrayRef{}, false);
}

// Note [all, any : uint8 compatibility]:
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// For NumPy comptability, `all` and `any` return
// Tensor of dtype `bool`. However for compatibility reason,
// for `uint8`, they return Tensor of same dtype `uint8`.
// Reference: https://github.com/pytorch/pytorch/pull/47878#issuecomment-747108561
inline const Tensor & _all(const Tensor & result, TensorIterator & iter) {
if (iter.numel() == 0) {
result.fill_(1);
} else {
and_stub(iter.device_type(), iter);
}

return result;
}

inline TensorIterator get_allany_iter(
const Tensor& self,
const Tensor& result,
Expand All @@ -1325,61 +1325,39 @@ inline TensorIterator get_allany_iter(
self, result, dims, keepdim, result.scalar_type());
}

Tensor all(const Tensor& self) {
Tensor result;

meta::check_all_any("all", self, result);
auto out_dtype = meta::get_result_or_bytebool_dtype(self, result);
auto shape = meta::get_reduction_shape(self, {}, false);

result = at::empty(shape, self.options().dtype(out_dtype));
auto iter = get_allany_iter(self, result, {}, false);

return _all(result, iter);
template <int identity, typename Stub>
inline void allany_impl(
const Tensor& self,
const Tensor& result,
IntArrayRef dims,
bool keepdim,
Stub& stub) {
if (self.numel() == 0) {
result.fill_(identity);
} else if (self.numel() == 1) {
result.fill_(self.item().toBool());
} else {
auto iter = get_allany_iter(self, result, dims, keepdim);
stub(iter.device_type(), iter);
}
}

TORCH_IMPL_FUNC(all_out)
(const Tensor& self, int64_t dim, bool keepdim, const Tensor& result) {
auto iter = get_allany_iter(self, result, dim, keepdim);
auto mut_result = const_cast<Tensor&>(result);
if (!_dimreduce_return_trivial(mut_result, self, 1, dim, keepdim)) {
_all(mut_result, iter);
}
allany_impl<1>(self, result, dim, keepdim, and_stub);
}

inline const Tensor & _any(const Tensor & result, TensorIterator & iter) {
if (iter.numel() == 0) {
result.fill_(0);
} else {
or_stub(iter.device_type(), iter);
}

return result;
TORCH_IMPL_FUNC(all_all_out)(const Tensor& self, const Tensor& result) {
allany_impl<1>(self, result, {}, false, and_stub);
}

Tensor any(const Tensor& self) {
Tensor result;

meta::check_all_any("any", self, result);
auto out_dtype = meta::get_result_or_bytebool_dtype(self, result);
auto shape = meta::get_reduction_shape(self, {}, false);

result = at::empty(shape, self.options().dtype(out_dtype));
auto iter = get_allany_iter(self, result, {}, false);

return _any(result, iter);
TORCH_IMPL_FUNC(any_out)
(const Tensor& self, int64_t dim, bool keepdim, const Tensor& result) {
allany_impl<0>(self, result, dim, keepdim, or_stub);
}

TORCH_IMPL_FUNC(any_out)
(const Tensor& self,
int64_t dim,
bool keepdim,
const Tensor& result) {
auto iter = get_allany_iter(self, result, dim, keepdim);
auto mut_result = const_cast<Tensor&>(result);
if (!_dimreduce_return_trivial(mut_result, self, 0, dim, keepdim)) {
_any(mut_result, iter);
}
TORCH_IMPL_FUNC(any_all_out)(const Tensor& self, const Tensor& result) {
allany_impl<0>(self, result, {}, false, or_stub);
}

Tensor &amin_out(const Tensor& self, IntArrayRef dim, bool keepdim, Tensor& result) {
Expand Down
5 changes: 2 additions & 3 deletions aten/src/ATen/native/ReduceOpsUtils.h
Expand Up @@ -51,17 +51,16 @@ static inline Tensor restride_dim(
return src.as_strided(replacement_shape, strides);
}

inline Tensor &_dimreduce_setup(Tensor &result, const Tensor &self,
inline void _dimreduce_setup(const Tensor &result, const Tensor &self,
int64_t dim) {
IntArrayRef self_sizes = self.sizes();
std::vector<int64_t> result_sizes;
result_sizes.insert(result_sizes.end(), self_sizes.begin(), self_sizes.end());
result_sizes[dim] = 1;
result.resize_(result_sizes);
return result;
}

inline bool _dimreduce_return_trivial(Tensor &result, const Tensor &self,
inline bool _dimreduce_return_trivial(const Tensor &result, const Tensor &self,
const Scalar& ident, int64_t dim, bool keepdim) {
if (self.numel() == 1 && self.ndimension() == 0) {
result.resize_({});
Expand Down
15 changes: 13 additions & 2 deletions aten/src/ATen/native/native_functions.yaml
Expand Up @@ -7363,17 +7363,28 @@

- func: all(Tensor self) -> Tensor
device_check: NoCheck # TensorIterator
structured_delegate: all.all_out
variants: method, function

- func: all.all_out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
device_check: NoCheck
structured: True
dispatch:
CPU, CUDA: all
CPU, CUDA: all_all_out

- func: any(Tensor self) -> Tensor
device_check: NoCheck # TensorIterator
structured_delegate: any.all_out
variants: method, function
dispatch:
CPU, CUDA: any
SparseCPU, SparseCUDA: any_sparse

- func: any.all_out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
device_check: NoCheck
structured: True
dispatch:
CPU, CUDA: any_all_out

- func: renorm.out(Tensor self, Scalar p, int dim, Scalar maxnorm, *, Tensor(a!) out) -> Tensor(a!)
device_check: NoCheck # TensorIterator
structured: True
Expand Down