Skip to content

Commit

Permalink
Port all and any full reductions to structured kernels.
Browse files Browse the repository at this point in the history
This PR creates out overloads for both `all` and `any` kernels (full reduction overload),
and ports them to structured kernels.

ghstack-source-id: a924523fc9ceeea4fb98f67f6eac4b67c3c4dee9
Pull Request resolved: #64642
  • Loading branch information
ysiraichi committed Sep 12, 2021
1 parent 30a7c76 commit 58aa072
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 54 deletions.
90 changes: 41 additions & 49 deletions aten/src/ATen/native/ReduceOps.cpp
Expand Up @@ -92,13 +92,7 @@ ScalarType get_result_or_bytebool_dtype(const Tensor& self, const Tensor& result
}
}

void check_all_any(const char* name, const Tensor& self, const Tensor& result) {
// Refer [all, any : uint8 compatibility]
TORCH_CHECK(
self.layout() == Layout::Strided,
name, " only supports strided layout, got: ",
self.layout());

void check_result_is_bytebool(const char* name, const Tensor& self, const Tensor& result) {
if (result.defined()) {
// Refer [all, any : uint8 compatibility]
TORCH_CHECK(
Expand All @@ -109,20 +103,36 @@ void check_all_any(const char* name, const Tensor& self, const Tensor& result) {
}
}

void allany_meta(
impl::MetaBase& meta,
const char* name,
const Tensor& self,
IntArrayRef dims,
bool keepdim) {
const auto& result = meta.maybe_get_output();
check_result_is_bytebool(name, self, result);
auto out_dtype = get_result_or_bytebool_dtype(self, result);
resize_reduction(meta, self, dims, keepdim, out_dtype);
}

TORCH_PRECOMPUTE_META_FUNC2(all, dim)(const Tensor& self, int64_t dim, bool keepdim) {
check_all_any("all", self, maybe_get_output());
auto out_dtype = get_result_or_bytebool_dtype(self, maybe_get_output());
resize_reduction(*this, self, dim, keepdim, out_dtype);
allany_meta(*this, "all", self, dim, keepdim);
return TORCH_PRECOMPUTE_STRUCT2(all, dim)().set_dim(maybe_wrap_dim(dim, self.dim()));
}

TORCH_META_FUNC(all)(const Tensor& self) {
allany_meta(*this, "all", self, {}, false);
}

TORCH_PRECOMPUTE_META_FUNC2(any, dim)(const Tensor& self, int64_t dim, bool keepdim) {
check_all_any("any", self, maybe_get_output());
auto out_dtype = get_result_or_bytebool_dtype(self, maybe_get_output());
resize_reduction(*this, self, dim, keepdim, out_dtype);
allany_meta(*this, "any", self, dim, keepdim);
return TORCH_PRECOMPUTE_STRUCT2(any, dim)().set_dim(maybe_wrap_dim(dim, self.dim()));
}

TORCH_META_FUNC(any)(const Tensor& self) {
allany_meta(*this, "any", self, {}, false);
}

void check_argmax_argmin(
const char* name,
const Tensor& self,
Expand Down Expand Up @@ -1299,9 +1309,11 @@ Tensor norm(const Tensor& self, const Scalar& p) {
// Tensor of dtype `bool`. However for compatibility reason,
// for `uint8`, they return Tensor of same dtype `uint8`.
// Reference: https://github.com/pytorch/pytorch/pull/47878#issuecomment-747108561
inline const Tensor & _all(const Tensor & result, TensorIterator & iter) {
inline const Tensor & _all(const Tensor & self, const Tensor & result, TensorIterator & iter) {
if (iter.numel() == 0) {
result.fill_(1);
} else if (iter.numel() == 1) {
result.fill_(self.item());
} else {
and_stub(iter.device_type(), iter);
}
Expand All @@ -1325,61 +1337,41 @@ inline TensorIterator get_allany_iter(
self, result, dims, keepdim, result.scalar_type());
}

Tensor all(const Tensor& self) {
Tensor result;

meta::check_all_any("all", self, result);
auto out_dtype = meta::get_result_or_bytebool_dtype(self, result);
auto shape = meta::get_reduction_shape(self, {}, false);

result = at::empty(shape, self.options().dtype(out_dtype));
auto iter = get_allany_iter(self, result, {}, false);

return _all(result, iter);
}

TORCH_IMPL_FUNC(all_out)
(const Tensor& self, int64_t dim, bool keepdim, const Tensor& result) {
auto iter = get_allany_iter(self, result, dim, keepdim);
auto mut_result = const_cast<Tensor&>(result);
if (!_dimreduce_return_trivial(mut_result, self, 1, dim, keepdim)) {
_all(mut_result, iter);
}
_all(self, result, iter);
}

inline const Tensor & _any(const Tensor & result, TensorIterator & iter) {
TORCH_IMPL_FUNC(all_all_out)(const Tensor& self, const Tensor& result) {
auto iter = get_allany_iter(self, result, {}, false);
_all(self, result, iter);
}

inline const Tensor & _any(const Tensor & self, const Tensor & result, TensorIterator & iter) {
if (iter.numel() == 0) {
result.fill_(0);
} else if (iter.numel() == 1) {
result.fill_(self.item());
} else {
or_stub(iter.device_type(), iter);
}

return result;
}

Tensor any(const Tensor& self) {
Tensor result;

meta::check_all_any("any", self, result);
auto out_dtype = meta::get_result_or_bytebool_dtype(self, result);
auto shape = meta::get_reduction_shape(self, {}, false);

result = at::empty(shape, self.options().dtype(out_dtype));
auto iter = get_allany_iter(self, result, {}, false);

return _any(result, iter);
}

TORCH_IMPL_FUNC(any_out)
(const Tensor& self,
int64_t dim,
bool keepdim,
const Tensor& result) {
auto iter = get_allany_iter(self, result, dim, keepdim);
auto mut_result = const_cast<Tensor&>(result);
if (!_dimreduce_return_trivial(mut_result, self, 0, dim, keepdim)) {
_any(mut_result, iter);
}
_any(self, result, iter);
}

TORCH_IMPL_FUNC(any_all_out)(const Tensor& self, const Tensor& result) {
auto iter = get_allany_iter(self, result, {}, false);
_any(self, result, iter);
}

Tensor &amin_out(const Tensor& self, IntArrayRef dim, bool keepdim, Tensor& result) {
Expand Down
5 changes: 2 additions & 3 deletions aten/src/ATen/native/ReduceOpsUtils.h
Expand Up @@ -51,17 +51,16 @@ static inline Tensor restride_dim(
return src.as_strided(replacement_shape, strides);
}

inline Tensor &_dimreduce_setup(Tensor &result, const Tensor &self,
inline void _dimreduce_setup(const Tensor &result, const Tensor &self,
int64_t dim) {
IntArrayRef self_sizes = self.sizes();
std::vector<int64_t> result_sizes;
result_sizes.insert(result_sizes.end(), self_sizes.begin(), self_sizes.end());
result_sizes[dim] = 1;
result.resize_(result_sizes);
return result;
}

inline bool _dimreduce_return_trivial(Tensor &result, const Tensor &self,
inline bool _dimreduce_return_trivial(const Tensor &result, const Tensor &self,
const Scalar& ident, int64_t dim, bool keepdim) {
if (self.numel() == 1 && self.ndimension() == 0) {
result.resize_({});
Expand Down
15 changes: 13 additions & 2 deletions aten/src/ATen/native/native_functions.yaml
Expand Up @@ -7363,17 +7363,28 @@

- func: all(Tensor self) -> Tensor
device_check: NoCheck # TensorIterator
structured_delegate: all.all_out
variants: method, function

- func: all.all_out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
device_check: NoCheck
structured: True
dispatch:
CPU, CUDA: all
CPU, CUDA: all_all_out

- func: any(Tensor self) -> Tensor
device_check: NoCheck # TensorIterator
structured_delegate: any.all_out
variants: method, function
dispatch:
CPU, CUDA: any
SparseCPU, SparseCUDA: any_sparse

- func: any.all_out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
device_check: NoCheck
structured: True
dispatch:
CPU, CUDA: any_all_out

- func: renorm.out(Tensor self, Scalar p, int dim, Scalar maxnorm, *, Tensor(a!) out) -> Tensor(a!)
device_check: NoCheck # TensorIterator
structured: True
Expand Down

0 comments on commit 58aa072

Please sign in to comment.