Skip to content

Commit

Permalink
Port all and any full reductions to structured kernels.
Browse files Browse the repository at this point in the history
This PR creates out overloads for both `all` and `any` kernels (full reduction overload),
and ports them to structured kernels.

ghstack-source-id: b0c2e5b5ba78cdaed8e1259c99671a81e4defeaf
Pull Request resolved: #64642
  • Loading branch information
ysiraichi committed Sep 8, 2021
1 parent 32fbeb1 commit c663c47
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 52 deletions.
82 changes: 35 additions & 47 deletions aten/src/ATen/native/ReduceOps.cpp
Expand Up @@ -92,13 +92,7 @@ ScalarType get_result_or_bytebool_dtype(const Tensor& self, const Tensor& result
}
}

void check_all_any(const char* name, const Tensor& self, const Tensor& result) {
// Refer [all, any : uint8 compatibility]
TORCH_CHECK(
self.layout() == Layout::Strided,
name, " only supports strided layout, got: ",
self.layout());

void check_result_is_bytebool(const char* name, const Tensor& self, const Tensor& result) {
if (result.defined()) {
// Refer [all, any : uint8 compatibility]
TORCH_CHECK(
Expand All @@ -109,20 +103,36 @@ void check_all_any(const char* name, const Tensor& self, const Tensor& result) {
}
}

void allany_meta(
impl::MetaBase& meta,
const char* name,
const Tensor& self,
IntArrayRef dims,
bool keepdim) {
const auto& result = meta.maybe_get_output();
check_result_is_bytebool(name, self, result);
auto out_dtype = get_result_or_bytebool_dtype(self, result);
resize_reduction(meta, self, dims, keepdim, out_dtype);
}

TORCH_PRECOMPUTE_META_FUNC2(all, dim)(const Tensor& self, int64_t dim, bool keepdim) {
check_all_any("all", self, maybe_get_output());
auto out_dtype = get_result_or_bytebool_dtype(self, maybe_get_output());
resize_reduction(*this, self, dim, keepdim, out_dtype);
allany_meta(*this, "all", self, dim, keepdim);
return TORCH_PRECOMPUTE_STRUCT2(all, dim)().set_dim(maybe_wrap_dim(dim, self.dim()));
}

TORCH_META_FUNC(all)(const Tensor& self) {
allany_meta(*this, "all", self, {}, false);
}

TORCH_PRECOMPUTE_META_FUNC2(any, dim)(const Tensor& self, int64_t dim, bool keepdim) {
check_all_any("any", self, maybe_get_output());
auto out_dtype = get_result_or_bytebool_dtype(self, maybe_get_output());
resize_reduction(*this, self, dim, keepdim, out_dtype);
allany_meta(*this, "any", self, dim, keepdim);
return TORCH_PRECOMPUTE_STRUCT2(any, dim)().set_dim(maybe_wrap_dim(dim, self.dim()));
}

TORCH_META_FUNC(any)(const Tensor& self) {
allany_meta(*this, "any", self, {}, false);
}

void check_argmax_argmin(
const char* name,
const Tensor& self,
Expand Down Expand Up @@ -1325,26 +1335,15 @@ inline TensorIterator get_allany_iter(
self, result, dims, keepdim, result.scalar_type());
}

Tensor all(const Tensor& self) {
Tensor result;

meta::check_all_any("all", self, result);
auto out_dtype = meta::get_result_or_bytebool_dtype(self, result);
auto shape = meta::get_reduction_shape(self, {}, false);

result = at::empty(shape, self.options().dtype(out_dtype));
auto iter = get_allany_iter(self, result, {}, false);

return _all(result, iter);
}

TORCH_IMPL_FUNC(all_out)
(const Tensor& self, int64_t dim, bool keepdim, const Tensor& result) {
auto iter = get_allany_iter(self, result, dim, keepdim);
auto mut_result = const_cast<Tensor&>(result);
if (!_dimreduce_return_trivial(mut_result, self, 1, dim, keepdim)) {
_all(mut_result, iter);
}
_all(result, iter);
}

TORCH_IMPL_FUNC(all_full_out)(const Tensor& self, const Tensor& result) {
auto iter = get_allany_iter(self, result, {}, false);
_all(result, iter);
}

inline const Tensor & _any(const Tensor & result, TensorIterator & iter) {
Expand All @@ -1357,29 +1356,18 @@ inline const Tensor & _any(const Tensor & result, TensorIterator & iter) {
return result;
}

Tensor any(const Tensor& self) {
Tensor result;

meta::check_all_any("any", self, result);
auto out_dtype = meta::get_result_or_bytebool_dtype(self, result);
auto shape = meta::get_reduction_shape(self, {}, false);

result = at::empty(shape, self.options().dtype(out_dtype));
auto iter = get_allany_iter(self, result, {}, false);

return _any(result, iter);
}

TORCH_IMPL_FUNC(any_out)
(const Tensor& self,
int64_t dim,
bool keepdim,
const Tensor& result) {
auto iter = get_allany_iter(self, result, dim, keepdim);
auto mut_result = const_cast<Tensor&>(result);
if (!_dimreduce_return_trivial(mut_result, self, 0, dim, keepdim)) {
_any(mut_result, iter);
}
_any(result, iter);
}

TORCH_IMPL_FUNC(any_full_out)(const Tensor& self, const Tensor& result) {
auto iter = get_allany_iter(self, result, {}, false);
_any(result, iter);
}

Tensor &amin_out(const Tensor& self, IntArrayRef dim, bool keepdim, Tensor& result) {
Expand Down
5 changes: 2 additions & 3 deletions aten/src/ATen/native/ReduceOpsUtils.h
Expand Up @@ -51,17 +51,16 @@ static inline Tensor restride_dim(
return src.as_strided(replacement_shape, strides);
}

inline Tensor &_dimreduce_setup(Tensor &result, const Tensor &self,
inline void _dimreduce_setup(const Tensor &result, const Tensor &self,
int64_t dim) {
IntArrayRef self_sizes = self.sizes();
std::vector<int64_t> result_sizes;
result_sizes.insert(result_sizes.end(), self_sizes.begin(), self_sizes.end());
result_sizes[dim] = 1;
result.resize_(result_sizes);
return result;
}

inline bool _dimreduce_return_trivial(Tensor &result, const Tensor &self,
inline bool _dimreduce_return_trivial(const Tensor &result, const Tensor &self,
const Scalar& ident, int64_t dim, bool keepdim) {
if (self.numel() == 1 && self.ndimension() == 0) {
result.resize_({});
Expand Down
15 changes: 13 additions & 2 deletions aten/src/ATen/native/native_functions.yaml
Expand Up @@ -7360,17 +7360,28 @@

- func: all(Tensor self) -> Tensor
device_check: NoCheck # TensorIterator
structured_delegate: all.full_out
variants: method, function

- func: all.full_out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
device_check: NoCheck
structured: True
dispatch:
CPU, CUDA: all
CPU, CUDA: all_full_out

- func: any(Tensor self) -> Tensor
device_check: NoCheck # TensorIterator
structured_delegate: any.full_out
variants: method, function
dispatch:
CPU, CUDA: any
SparseCPU, SparseCUDA: any_sparse

- func: any.full_out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
device_check: NoCheck
structured: True
dispatch:
CPU, CUDA: any_full_out

- func: renorm.out(Tensor self, Scalar p, int dim, Scalar maxnorm, *, Tensor(a!) out) -> Tensor(a!)
device_check: NoCheck # TensorIterator
structured: True
Expand Down

0 comments on commit c663c47

Please sign in to comment.