Skip to content

Commit

Permalink
Sparse support for ReLU (#86749)
Browse files Browse the repository at this point in the history
ReLU support for all sparse layouts, including backward.

Fixes #85208
Pull Request resolved: #86749
Approved by: https://github.com/cpuhrsch, https://github.com/nikitaved
  • Loading branch information
amjames authored and pytorchmergebot committed Oct 14, 2022
1 parent ef04569 commit 527ebed
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 4 deletions.
8 changes: 8 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4374,6 +4374,8 @@
QuantizedCPU: relu_quantized_cpu
QuantizedCUDA: relu_quantized_cuda
NestedTensorCPU, NestedTensorCUDA: NestedTensor_relu
SparseCPU, SparseCUDA: relu_sparse
SparseCsrCPU, SparseCsrCUDA: relu_sparse_csr
tags: canonical

- func: relu_(Tensor(a!) self) -> Tensor(a!)
Expand All @@ -4386,6 +4388,8 @@
QuantizedCPU: relu_quantized_cpu_
QuantizedCUDA: relu_quantized_cuda_
NestedTensorCPU, NestedTensorCUDA: NestedTensor_relu_
SparseCPU, SparseCUDA: relu_sparse_
SparseCsrCPU, SparseCsrCUDA: relu_sparse_csr_
autogen: relu.out

- func: relu6(Tensor self) -> Tensor
Expand Down Expand Up @@ -5279,12 +5283,16 @@
dispatch:
CPU, CUDA: threshold_backward_out
MPS: threshold_backward_out_mps
SparseCPU, SparseCUDA: threshold_backward_sparse_out
SparseCsrCPU, SparseCsrCUDA: threshold_backward_sparse_compressed_out

- func: threshold_backward(Tensor grad_output, Tensor self, Scalar threshold) -> Tensor
variants: function
structured_delegate: threshold_backward.grad_input
dispatch:
MkldnnCPU: mkldnn_relu_backward
SparseCPU, SparseCUDA: threshold_backward_sparse
SparseCsrCPU, SparseCsrCUDA: threshold_backward_sparse_compressed

- func: tile(Tensor self, int[] dims) -> Tensor
variants: function, method
Expand Down
35 changes: 32 additions & 3 deletions aten/src/ATen/native/sparse/SparseCsrTensorMath.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
#else
#include <ATen/ops/_conj_physical_native.h>
#include <ATen/ops/_convert_indices_from_coo_to_csr_native.h>
#include <ATen/ops/_convert_indices_from_csr_to_coo_native.h>
#include <ATen/ops/_convert_indices_from_csr_to_coo.h>
#include <ATen/ops/_convert_indices_from_csr_to_coo_native.h>
#include <ATen/ops/_sparse_bsr_tensor_unsafe_native.h>
#include <ATen/ops/_sparse_compressed_tensor_unsafe_native.h>
#include <ATen/ops/_sparse_csr_tensor_unsafe_native.h>
Expand Down Expand Up @@ -76,6 +76,8 @@
#include <ATen/ops/ones_like.h>
#include <ATen/ops/rad2deg.h>
#include <ATen/ops/rad2deg_native.h>
#include <ATen/ops/relu.h>
#include <ATen/ops/relu_native.h>
#include <ATen/ops/resize_as_sparse_native.h>
#include <ATen/ops/result_type.h>
#include <ATen/ops/round.h>
Expand All @@ -91,15 +93,17 @@
#include <ATen/ops/sin_native.h>
#include <ATen/ops/sinh.h>
#include <ATen/ops/sinh_native.h>
#include <ATen/ops/sqrt.h>
#include <ATen/ops/sqrt_native.h>
#include <ATen/ops/sparse_mask.h>
#include <ATen/ops/sparse_mask_native.h>
#include <ATen/ops/sqrt.h>
#include <ATen/ops/sqrt_native.h>
#include <ATen/ops/tan.h>
#include <ATen/ops/tan_native.h>
#include <ATen/ops/tanh.h>
#include <ATen/ops/tanh_native.h>
#include <ATen/ops/tensor.h>
#include <ATen/ops/threshold_backward.h>
#include <ATen/ops/threshold_backward_native.h>
#include <ATen/ops/trunc.h>
#include <ATen/ops/trunc_native.h>
#include <ATen/ops/zero_native.h>
Expand Down Expand Up @@ -367,6 +371,7 @@ CREATE_UNARY_UFUNC(tan);
CREATE_UNARY_UFUNC(tanh);
CREATE_UNARY_UFUNC(trunc);
CREATE_UNARY_UFUNC(conj_physical);
CREATE_UNARY_UFUNC(relu);

// With addition of `round.decimals` overload, using CREATE_UNARY_UFUNC leads
// to unresolved overload.
Expand All @@ -384,6 +389,30 @@ Tensor& round_sparse_csr_(Tensor& self) {
return self;
}

Tensor threshold_backward_sparse_compressed(
const Tensor& grad_output,
const Tensor& self,
const Scalar& threshold) {
return get_result_tensor_for_unary_op(
[&](const Tensor& t) {
return at::threshold_backward(t, self.values(), threshold);
},
grad_output);
}

Tensor& threshold_backward_sparse_compressed_out(
const Tensor& grad_output,
const Tensor& self,
const Scalar& threshold,
Tensor& grad_input) {
return unary_op_out(
[&](const Tensor& t, Tensor& out) {
return at::threshold_backward_outf(t, self.values(), threshold, out);
},
grad_output,
grad_input);
}

// angle, isneginf, isposinf and signbit currently don't have an inplace variant
CREATE_UNARY_UFUNC_NO_INPLACE(angle);
CREATE_UNARY_UFUNC_NO_INPLACE(isneginf);
Expand Down
43 changes: 42 additions & 1 deletion aten/src/ATen/native/sparse/SparseUnaryOps.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/core/Tensor.h>
#include <ATen/SparseTensorUtils.h>
#include <ATen/core/Tensor.h>

#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
Expand Down Expand Up @@ -39,6 +39,8 @@
#include <ATen/ops/log1p_native.h>
#include <ATen/ops/nan_to_num.h>
#include <ATen/ops/nan_to_num_native.h>
#include <ATen/ops/relu.h>
#include <ATen/ops/relu_native.h>
#include <ATen/ops/round.h>
#include <ATen/ops/round_native.h>
#include <ATen/ops/sgn.h>
Expand All @@ -58,6 +60,8 @@
#include <ATen/ops/tan_native.h>
#include <ATen/ops/tanh.h>
#include <ATen/ops/tanh_native.h>
#include <ATen/ops/threshold_backward.h>
#include <ATen/ops/threshold_backward_native.h>
#include <ATen/ops/trunc.h>
#include <ATen/ops/trunc_native.h>
#endif
Expand Down Expand Up @@ -175,6 +179,7 @@ COALESCED_UNARY_UFUNC(sqrt);
COALESCED_UNARY_UFUNC(tan);
COALESCED_UNARY_UFUNC(tanh);
COALESCED_UNARY_UFUNC(trunc);
COALESCED_UNARY_UFUNC(relu);

COALESCED_UNARY_UFUNC_NO_INPLACE(signbit);
COALESCED_UNARY_UFUNC_NO_INPLACE(isneginf);
Expand All @@ -187,6 +192,42 @@ Tensor isinf_sparse_meta(const Tensor& self) {
TORCH_CHECK_NOT_IMPLEMENTED(0, "nyi isinf for SparseMeta");
}

// Threshold_backward is not unary but it is the backward used for relu which is
// unary
Tensor threshold_backward_sparse(
const Tensor& grad_output,
const Tensor& self,
const Scalar& threshold) {
auto self_v = [&self]() {
if (self.is_coalesced()) {
return self.values();
} else {
return self.coalesce().values();
}
}();
return coalesced_unary_ufunc(grad_output, [&](const Tensor& t) {
return at::threshold_backward(t, self_v, threshold);
});
}

Tensor& threshold_backward_sparse_out(
const Tensor& grad_output,
const Tensor& self,
const Scalar& threshold,
Tensor& grad_input) {
auto self_v = [&self]() {
if (self.is_coalesced()) {
return self.values();
} else {
return self.coalesce().values();
}
}();
return coalesced_unary_ufunc_out(
grad_output, grad_input, [&](const Tensor& t, Tensor& out) {
return at::threshold_backward_outf(t, self_v, threshold, out);
});
}

Tensor nan_to_num_sparse(
const Tensor &self, c10::optional<double> nan,
c10::optional<double> posinf, c10::optional<double> neginf) {
Expand Down
5 changes: 5 additions & 0 deletions torch/testing/_internal/common_methods_invocations.py
Original file line number Diff line number Diff line change
Expand Up @@ -10549,6 +10549,11 @@ def reference_flatten(input, start_dim=0, end_dim=-1):
aten_name="relu",
ref=lambda a: np.where(a <= 0, 0, a),
supports_autograd=True,
supports_sparse=True,
supports_sparse_csr=True,
supports_sparse_csc=True,
supports_sparse_bsr=True,
supports_sparse_bsc=True,
dtypes=all_types_and(torch.bfloat16),
dtypesIfCUDA=all_types_and(torch.half, torch.bfloat16),
sample_inputs_func=sample_inputs_nn_activation_relu,
Expand Down

0 comments on commit 527ebed

Please sign in to comment.