Skip to content

Commit

Permalink
Port all and any full reductions to structured kernels. (#64642)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #64642

Tracking issue: #55070

This PR creates out overloads for both `all` and `any` kernels (full reduction overload),
and ports them to structured kernels.

Test Plan: Imported from OSS

Reviewed By: ngimel

Differential Revision: D30867354

Pulled By: ezyang

fbshipit-source-id: 46bccaf6c94a09ed77cc6c724d1183c82f801751
  • Loading branch information
ysiraichi authored and facebook-github-bot committed Sep 15, 2021
1 parent 54cdf65 commit 54d060a
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 79 deletions.
126 changes: 52 additions & 74 deletions aten/src/ATen/native/ReduceOps.cpp
Expand Up @@ -92,13 +92,7 @@ ScalarType get_result_or_bytebool_dtype(const Tensor& self, const Tensor& result
}
}

void check_all_any(const char* name, const Tensor& self, const Tensor& result) {
// Refer [all, any : uint8 compatibility]
TORCH_CHECK(
self.layout() == Layout::Strided,
name, " only supports strided layout, got: ",
self.layout());

void check_result_is_bytebool(const char* name, const Tensor& self, const Tensor& result) {
if (result.defined()) {
// Refer [all, any : uint8 compatibility]
TORCH_CHECK(
Expand All @@ -109,20 +103,42 @@ void check_all_any(const char* name, const Tensor& self, const Tensor& result) {
}
}

// Note [all, any : uint8 compatibility]:
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// For NumPy comptability, `all` and `any` return
// Tensor of dtype `bool`. However for compatibility reason,
// for `uint8`, they return Tensor of same dtype `uint8`.
// Reference: https://github.com/pytorch/pytorch/pull/47878#issuecomment-747108561
static void allany_meta(
impl::MetaBase& meta,
const char* name,
const Tensor& self,
IntArrayRef dims,
bool keepdim) {
const auto& result = meta.maybe_get_output();
check_result_is_bytebool(name, self, result);
auto out_dtype = get_result_or_bytebool_dtype(self, result);
resize_reduction(meta, self, dims, keepdim, out_dtype);
}

TORCH_PRECOMPUTE_META_FUNC2(all, dim)(const Tensor& self, int64_t dim, bool keepdim) {
check_all_any("all", self, maybe_get_output());
auto out_dtype = get_result_or_bytebool_dtype(self, maybe_get_output());
resize_reduction(*this, self, dim, keepdim, out_dtype);
allany_meta(*this, "all", self, dim, keepdim);
return TORCH_PRECOMPUTE_STRUCT2(all, dim)().set_dim(maybe_wrap_dim(dim, self.dim()));
}

TORCH_META_FUNC(all)(const Tensor& self) {
allany_meta(*this, "all", self, {}, false);
}

TORCH_PRECOMPUTE_META_FUNC2(any, dim)(const Tensor& self, int64_t dim, bool keepdim) {
check_all_any("any", self, maybe_get_output());
auto out_dtype = get_result_or_bytebool_dtype(self, maybe_get_output());
resize_reduction(*this, self, dim, keepdim, out_dtype);
allany_meta(*this, "any", self, dim, keepdim);
return TORCH_PRECOMPUTE_STRUCT2(any, dim)().set_dim(maybe_wrap_dim(dim, self.dim()));
}

TORCH_META_FUNC(any)(const Tensor& self) {
allany_meta(*this, "any", self, {}, false);
}

void check_argmax_argmin(
const char* name,
const Tensor& self,
Expand Down Expand Up @@ -1323,22 +1339,6 @@ Tensor norm(const Tensor& self, const Scalar& p) {
return at::norm(self, p, IntArrayRef{}, false);
}

// Note [all, any : uint8 compatibility]:
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// For NumPy comptability, `all` and `any` return
// Tensor of dtype `bool`. However for compatibility reason,
// for `uint8`, they return Tensor of same dtype `uint8`.
// Reference: https://github.com/pytorch/pytorch/pull/47878#issuecomment-747108561
inline const Tensor & _all(const Tensor & result, TensorIterator & iter) {
if (iter.numel() == 0) {
result.fill_(1);
} else {
and_stub(iter.device_type(), iter);
}

return result;
}

inline TensorIterator get_allany_iter(
const Tensor& self,
const Tensor& result,
Expand All @@ -1355,61 +1355,39 @@ inline TensorIterator get_allany_iter(
self, result, dims, keepdim, result.scalar_type());
}

Tensor all(const Tensor& self) {
Tensor result;

meta::check_all_any("all", self, result);
auto out_dtype = meta::get_result_or_bytebool_dtype(self, result);
auto shape = meta::get_reduction_shape(self, {}, false);

result = at::empty(shape, self.options().dtype(out_dtype));
auto iter = get_allany_iter(self, result, {}, false);

return _all(result, iter);
template <int identity, typename Stub>
inline void allany_impl(
const Tensor& self,
const Tensor& result,
IntArrayRef dims,
bool keepdim,
Stub& stub) {
if (self.numel() == 0) {
result.fill_(identity);
} else if (self.numel() == 1) {
result.fill_(self.item().toBool());
} else {
auto iter = get_allany_iter(self, result, dims, keepdim);
stub(iter.device_type(), iter);
}
}

TORCH_IMPL_FUNC(all_out)
(const Tensor& self, int64_t dim, bool keepdim, const Tensor& result) {
auto iter = get_allany_iter(self, result, dim, keepdim);
auto mut_result = const_cast<Tensor&>(result);
if (!_dimreduce_return_trivial(mut_result, self, 1, dim, keepdim)) {
_all(mut_result, iter);
}
allany_impl<1>(self, result, dim, keepdim, and_stub);
}

inline const Tensor & _any(const Tensor & result, TensorIterator & iter) {
if (iter.numel() == 0) {
result.fill_(0);
} else {
or_stub(iter.device_type(), iter);
}

return result;
TORCH_IMPL_FUNC(all_all_out)(const Tensor& self, const Tensor& result) {
allany_impl<1>(self, result, {}, false, and_stub);
}

Tensor any(const Tensor& self) {
Tensor result;

meta::check_all_any("any", self, result);
auto out_dtype = meta::get_result_or_bytebool_dtype(self, result);
auto shape = meta::get_reduction_shape(self, {}, false);

result = at::empty(shape, self.options().dtype(out_dtype));
auto iter = get_allany_iter(self, result, {}, false);

return _any(result, iter);
TORCH_IMPL_FUNC(any_out)
(const Tensor& self, int64_t dim, bool keepdim, const Tensor& result) {
allany_impl<0>(self, result, dim, keepdim, or_stub);
}

TORCH_IMPL_FUNC(any_out)
(const Tensor& self,
int64_t dim,
bool keepdim,
const Tensor& result) {
auto iter = get_allany_iter(self, result, dim, keepdim);
auto mut_result = const_cast<Tensor&>(result);
if (!_dimreduce_return_trivial(mut_result, self, 0, dim, keepdim)) {
_any(mut_result, iter);
}
TORCH_IMPL_FUNC(any_all_out)(const Tensor& self, const Tensor& result) {
allany_impl<0>(self, result, {}, false, or_stub);
}

Tensor &amin_out(const Tensor& self, IntArrayRef dim, bool keepdim, Tensor& result) {
Expand Down
5 changes: 2 additions & 3 deletions aten/src/ATen/native/ReduceOpsUtils.h
Expand Up @@ -51,17 +51,16 @@ static inline Tensor restride_dim(
return src.as_strided(replacement_shape, strides);
}

inline Tensor &_dimreduce_setup(Tensor &result, const Tensor &self,
inline void _dimreduce_setup(const Tensor &result, const Tensor &self,
int64_t dim) {
IntArrayRef self_sizes = self.sizes();
std::vector<int64_t> result_sizes;
result_sizes.insert(result_sizes.end(), self_sizes.begin(), self_sizes.end());
result_sizes[dim] = 1;
result.resize_(result_sizes);
return result;
}

inline bool _dimreduce_return_trivial(Tensor &result, const Tensor &self,
inline bool _dimreduce_return_trivial(const Tensor &result, const Tensor &self,
const Scalar& ident, int64_t dim, bool keepdim) {
if (self.numel() == 1 && self.ndimension() == 0) {
result.resize_({});
Expand Down
15 changes: 13 additions & 2 deletions aten/src/ATen/native/native_functions.yaml
Expand Up @@ -7375,17 +7375,28 @@

- func: all(Tensor self) -> Tensor
device_check: NoCheck # TensorIterator
structured_delegate: all.all_out
variants: method, function

- func: all.all_out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
device_check: NoCheck
structured: True
dispatch:
CPU, CUDA: all
CPU, CUDA: all_all_out

- func: any(Tensor self) -> Tensor
device_check: NoCheck # TensorIterator
structured_delegate: any.all_out
variants: method, function
dispatch:
CPU, CUDA: any
SparseCPU, SparseCUDA: any_sparse

- func: any.all_out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
device_check: NoCheck
structured: True
dispatch:
CPU, CUDA: any_all_out

- func: renorm.out(Tensor self, Scalar p, int dim, Scalar maxnorm, *, Tensor(a!) out) -> Tensor(a!)
device_check: NoCheck # TensorIterator
structured: True
Expand Down

0 comments on commit 54d060a

Please sign in to comment.