From 1e05bf8a335156d15c3c59e7155c4c481a0f5f32 Mon Sep 17 00:00:00 2001 From: BrilliantYuKaimin <91609464+BrilliantYuKaimin@users.noreply.github.com> Date: Sat, 16 Apr 2022 10:28:29 +0800 Subject: [PATCH 01/16] Create elementwise_heaviside_op.cc --- .../elementwise/elementwise_heaviside_op.cc | 70 +++++++++++++++++++ 1 file changed, 70 insertions(+) create mode 100644 paddle/fluid/operators/elementwise/elementwise_heaviside_op.cc diff --git a/paddle/fluid/operators/elementwise/elementwise_heaviside_op.cc b/paddle/fluid/operators/elementwise/elementwise_heaviside_op.cc new file mode 100644 index 0000000000000..e003a43b5c56b --- /dev/null +++ b/paddle/fluid/operators/elementwise/elementwise_heaviside_op.cc @@ -0,0 +1,70 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "paddle/fluid/operators/elementwise/elementwise_op.h" + +namespace paddle { +namespace operators { + +class ElementwiseHeavisideOpMaker : public ElementwiseOpMaker { + protected: + std::string GetName() const override { return "Heaviside"; } + std::string GetEquation() const override { return "Out = Heaviside(X, Y)"; } + + void AddInputX() override { + AddInput("X", + "(Tensor), The input tensor of Heaviside step function. " + "Its dtype can be int32, int64, float32 and float64"); + } + + void AddInputY() override { + AddInput("Y", + "(Tensor), The tensor determining a Heaviside step function, " + "which is the value when X = 0. Its dtype should be same as X."); + } + + std::string GetOpFuntionality() const override { + return "Computes the Heaviside step function determined by Y " + "for each element in X."; + } +}; + +template +class ElementwiseHeavisideGradOpMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + protected: + void Apply(GradOpPtr op) const override { + op->SetType("elementwise_heaviside_grad"); + op->SetInput("X", this->Input("X")); + op->SetInput("Y", this->Input("Y")); + op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); + op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); + op->SetOutput(framework::GradVarName("Y"), this->InputGrad("Y")); + op->SetAttrMap(this->Attrs()); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR( + elementwise_heaviside, ops::ElementwiseOp, ops::ElementwiseHeavisideOpMaker, + ops::ElementwiseHeavisideGradOpMaker, + ops::ElementwiseHeavisideGradOpMaker); + +REGISTER_OPERATOR(elementwise_heaviside_grad, ops::ElementwiseOpGrad); From 02c4f1da8a2310e7e2e93f26be8f8c3145f133e7 Mon Sep 17 00:00:00 2001 From: BrilliantYuKaimin <91609464+BrilliantYuKaimin@users.noreply.github.com> Date: Sat, 16 Apr 2022 10:29:13 +0800 Subject: [PATCH 02/16] add ElementwiseHeavisideFunctor --- paddle/phi/kernels/funcs/elementwise_functor.h | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/paddle/phi/kernels/funcs/elementwise_functor.h b/paddle/phi/kernels/funcs/elementwise_functor.h index 8d9dd65786705..5cf7c5f190df8 100644 --- a/paddle/phi/kernels/funcs/elementwise_functor.h +++ b/paddle/phi/kernels/funcs/elementwise_functor.h @@ -539,6 +539,13 @@ struct InverseModuloFunctor< } }; +template +struct ElementwiseHeavisideFunctor { + inline HOSTDEVICE T operator()(const T a, const T b) const { + return a == static_cast(0) ? b : static_cast(a > 0); + } +}; + template struct FloorDivideFunctor { inline HOSTDEVICE T operator()(const T a, const T b) const { From 69592c7cd5b59809b8f3b984120b062b533601ed Mon Sep 17 00:00:00 2001 From: BrilliantYuKaimin <91609464+BrilliantYuKaimin@users.noreply.github.com> Date: Sat, 16 Apr 2022 10:29:46 +0800 Subject: [PATCH 03/16] Create test_elementwise_heaviside_op.py --- .../test_elementwise_heaviside_op.py | 169 ++++++++++++++++++ 1 file changed, 169 insertions(+) create mode 100644 python/paddle/fluid/tests/unittests/test_elementwise_heaviside_op.py diff --git a/python/paddle/fluid/tests/unittests/test_elementwise_heaviside_op.py b/python/paddle/fluid/tests/unittests/test_elementwise_heaviside_op.py new file mode 100644 index 0000000000000..8a8e74e28ec72 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_elementwise_heaviside_op.py @@ -0,0 +1,169 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np +from op_test import OpTest +import paddle + + +class TestElementwiseOp(OpTest): + def setUp(self): + self.op_type = "elementwise_heaviside" + x = np.random.random((13, 17)).astype("float64") + y = np.random.random((13, 17)).astype("float64") + self.inputs = {'X': x, 'Y': y} + self.outputs = {'Out': np.heaviside(self.inputs['X'], self.inputs['Y'])} + + def test_check_output(self): + self.check_output() + + def test_check_grad_normal(self): + self.check_grad(['X', 'Y'], 'Out') + + def test_check_grad_ingore_x(self): + self.check_grad(['Y'], 'Out', no_grad_set=set("X")) + + def test_check_grad_ingore_y(self): + self.check_grad(['X'], 'Out', no_grad_set=set('Y')) + + +class TestHeavisideBroadcast(unittest.TestCase): + def setUp(self): + self.input_1 = np.random.rand(2, 100, 13, 17).astype("float32") + self.input_2 = np.random.rand(100, 13, 17).astype("float32") + self.input_3 = np.random.rand(100, 13, 1).astype("float32") + self.input_4 = np.random.rand(13, 17).astype("float32") + self.input_5 = np.random.rand(1).astype("float32") + + self.np_expected1 = np.heaviside(self.input_1, self.input_2) + self.np_expected2 = np.heaviside(self.input_2, self.input_3) + self.np_expected3 = np.heaviside(self.input_2, self.input_4) + self.np_expected4 = np.heaviside(self.input_4, self.input_5) + + def test_broadcast(self): + paddle.disable_static() + self.tensor_1 = paddle.to_tensor(self.input_1) + self.tensor_2 = paddle.to_tensor(self.input_2) + self.tensor_3 = paddle.to_tensor(self.input_3) + self.tensor_4 = paddle.to_tensor(self.input_4) + self.tensor_5 = paddle.to_tensor(self.input_5) + + res = paddle.heaviside(self.tensor_1, self.tensor_2) + res = res.numpy() + self.assertTrue(np.allclose(res, self.np_expected1)) + + res = paddle.heaviside(self.tensor_2, self.tensor_3) + res = res.numpy() + self.assertTrue(np.allclose(res, self.np_expected2)) + + res = paddle.heaviside(self.tensor_2, self.tensor_4) + res = res.numpy() + self.assertTrue(np.allclose(res, self.np_expected3)) + + res = paddle.heaviside(self.tensor_4, self.tensor_5) + res = res.numpy() + self.assertTrue(np.allclose(res, self.np_expected4)) + + +class TestHeavisideAPI_float64(unittest.TestCase): + def setUp(self): + self.x_np = np.random.random((13, 17)).astype("float64") + self.y_np = np.random.random((13, 17)).astype("float64") + self.out_np = np.heaviside(self.x_np, self.y_np) + self.dtype = "float64" + + def test_static(self): + for use_cuda in ([False, True] + if paddle.device.is_compiled_with_cuda() else [False]): + place = paddle.CUDAPlace(0) if use_cuda else paddle.CPUPlace() + + paddle.enable_static() + prog = paddle.static.Program() + with paddle.static.program_guard(prog): + x = paddle.static.data( + name=f"x_{self.dtype}", shape=[13, 17], dtype=self.dtype) + y = paddle.static.data( + name=f"y_{self.dtype}", shape=[13, 17], dtype=self.dtype) + out = paddle.heaviside(x, y) + + exe = paddle.static.Executor(place=place) + res = exe.run(prog, + feed={ + f"x_{self.dtype}": self.x_np, + f"y_{self.dtype}": self.y_np + }, + fetch_list=out, + use_prune=True) + + self.assertTrue(np.allclose(res, self.out_np)) + + def test_dygraph(self): + for use_cuda in ([False, True] + if paddle.device.is_compiled_with_cuda() else [False]): + place = paddle.CUDAPlace(0) if use_cuda else paddle.CPUPlace() + paddle.disable_static(place=place) + result = paddle.heaviside( + paddle.to_tensor(self.x_np), paddle.to_tensor(self.y_np)) + + self.assertTrue(np.allclose(result.numpy(), self.out_np)) + + +class TestHeavisideAPI_float32(TestHeavisideAPI_float64): + def setUp(self): + self.x_np = np.random.random((13, 17)).astype("float32") + self.y_np = np.random.random((13, 17)).astype("float32") + self.out_np = np.heaviside(self.x_np, self.y_np) + self.dtype = "float32" + + +class TestHeavisideAPI_int64(TestHeavisideAPI_float64): + def setUp(self): + self.x_np = np.random.random((13, 17)).astype("int64") + self.y_np = np.random.random((13, 17)).astype("int64") + self.out_np = np.heaviside(self.x_np, self.y_np) + self.dtype = "int64" + + +class TestHeavisideAPI_int32(TestHeavisideAPI_float64): + def setUp(self): + self.x_np = np.random.random((13, 17)).astype("int32") + self.y_np = np.random.random((13, 17)).astype("int32") + self.out_np = np.heaviside(self.x_np, self.y_np) + self.dtype = "int32" + + +class TestHeavisideError(unittest.TestCase): + def test_input(self): + paddle.disable_static() + + def test_input_x(): + paddle.heaviside(1, paddle.randn([100])) + + self.assertRaises(ValueError, test_input_x) + + def test_input_y(): + paddle.heaviside(paddle.randn([100]), 1) + + self.assertRaises(ValueError, test_input_y) + + def test_input_xy(): + paddle.heaviside( + paddle.randn([100], 'float32'), paddle.randn([100], 'float64')) + + self.assertRaises(ValueError, test_input_xy) + + +if __name__ == '__main__': + unittest.main() From bc18bf2dcf3f56f443ae2ac11a6248c6cb671a69 Mon Sep 17 00:00:00 2001 From: BrilliantYuKaimin <91609464+BrilliantYuKaimin@users.noreply.github.com> Date: Sat, 16 Apr 2022 10:32:25 +0800 Subject: [PATCH 04/16] =?UTF-8?q?=E5=A2=9E=E5=8A=A0heaviside=E7=9A=84pytho?= =?UTF-8?q?n=E6=8E=A5=E5=8F=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- python/paddle/__init__.py | 1 + python/paddle/tensor/__init__.py | 2 ++ python/paddle/tensor/math.py | 45 ++++++++++++++++++++++++++++++++ 3 files changed, 48 insertions(+) diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index 3578b9a1aaeea..434f19bbce066 100755 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -268,6 +268,7 @@ from .tensor.math import fmin # noqa: F401 from .tensor.math import inner # noqa: F401 from .tensor.math import outer # noqa: F401 +from .tensor.math import heaviside # noqa: F401 from .tensor.math import frac # noqa: F401 from .tensor.random import bernoulli # noqa: F401 diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index 5f0fb4336e014..11e0248a56900 100755 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -228,6 +228,7 @@ from .math import fmin # noqa: F401 from .math import inner # noqa: F401 from .math import outer # noqa: F401 +from .math import heaviside # noqa: F401 from .math import frac # noqa: F401 from .random import multinomial # noqa: F401 @@ -493,6 +494,7 @@ 'put_along_axis', 'put_along_axis_', 'exponential_', + 'heaviside', ] #this list used in math_op_patch.py for magic_method bind diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index cfc9abb86984d..61b2256c3e37e 100644 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -4386,6 +4386,51 @@ def angle(x, name=None): helper.append_op(type=op_type, inputs=inputs, outputs=outputs) return out +def heaviside(x, y, name=None): + """ + Computes the Heaviside step function determined by y for each element in x. The equation is: + .. math:: + heaviside(x, y)= + \left\{ + \begin{array}{lcl} + 0,& &\text{if } \ x < 0, \\ + y,& &\text{if } \ x = 0, \\ + 1,& &\text{if } \ x > 0. + \end{array} + \right. + + Notes: + ``paddle.heaviside`` supports broadcasting. If you want know more about broadcasting, please refer to :ref:`user_guide_broadcasting` . + + Args: + x (Tensor): the input tensor, it's data type should be float32, float64, int32, int64. + y (Tensor): the input tensor, it's data type should be float32, float64, int32, int64. + name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. + + Returns: + N-D Tensor. A location into which the result is stored. If x, y have different shapes and are "broadcastable", the resulting tensor shape is the shape of x and y after broadcasting. If x, y have the same shape, its shape is the same as x and y. + + Examples: + .. code-block:: python + import paddle + x = paddle.to_tensor([-0.5, 0, 0.5]) + y = paddle.to_tensor([0.1]) + paddle.heaviside(x, y) + # [0. , 0.10000000, 1. ] + x = paddle.to_tensor([[-0.5, 0, 0.5], [-0.5, 0.5, 0]]) + y = paddle.to_tensor([0.1, 0.2, 0.3]) + paddle.heaviside(x, y) + # [[0. , 0.20000000, 1. ], + # [0. , 1. , 0.30000001]] + """ + op_type = 'elementwise_heaviside' + axis = -1 + act = None + if _non_static_mode(): + return _elementwise_op_in_dygraph( + x, y, axis=axis, act=act, op_name=op_type) + return _elementwise_op(LayerHelper(op_type, **locals())) + def frac(x, name=None): """ This API is used to return the fractional portion of each element in input. From 3216ded178ebe6e26fa58b1e06cbfcd470b55cef Mon Sep 17 00:00:00 2001 From: BrilliantYuKaimin <91609464+BrilliantYuKaimin@users.noreply.github.com> Date: Sat, 16 Apr 2022 10:32:42 +0800 Subject: [PATCH 05/16] add heaviside in white list --- .../fluid/tests/unittests/white_list/no_grad_set_white_list.py | 1 + tools/static_mode_white_list.py | 1 + 2 files changed, 2 insertions(+) diff --git a/python/paddle/fluid/tests/unittests/white_list/no_grad_set_white_list.py b/python/paddle/fluid/tests/unittests/white_list/no_grad_set_white_list.py index d5f4cef5b8759..fb1cd35c45380 100644 --- a/python/paddle/fluid/tests/unittests/white_list/no_grad_set_white_list.py +++ b/python/paddle/fluid/tests/unittests/white_list/no_grad_set_white_list.py @@ -37,6 +37,7 @@ 'dot', 'elementwise_add', 'elementwise_div', + 'elementwise_heaviside', 'elementwise_max', 'elementwise_min', 'elementwise_mul', diff --git a/tools/static_mode_white_list.py b/tools/static_mode_white_list.py index f907d51e4d038..39cc9d684584c 100755 --- a/tools/static_mode_white_list.py +++ b/tools/static_mode_white_list.py @@ -169,6 +169,7 @@ 'test_elementwise_div_op', 'test_elementwise_floordiv_op', 'test_elementwise_gradient_op', + 'test_elementwise_heaviside_op', 'test_elementwise_max_op', 'test_elementwise_min_op', 'test_elementwise_mod_op', From 24c2558dd553e2b590cb96010f6f97be83b554e5 Mon Sep 17 00:00:00 2001 From: BrilliantYuKaimin <91609464+BrilliantYuKaimin@users.noreply.github.com> Date: Sat, 16 Apr 2022 10:33:55 +0800 Subject: [PATCH 06/16] =?UTF-8?q?=E5=A2=9E=E5=8A=A0heaviside=E7=9A=84?= =?UTF-8?q?=E7=AD=BE=E5=90=8D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- paddle/phi/ops/compat/elementwise_sig.cc | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/paddle/phi/ops/compat/elementwise_sig.cc b/paddle/phi/ops/compat/elementwise_sig.cc index 0a58d86b05b06..76ab2d1e6415f 100644 --- a/paddle/phi/ops/compat/elementwise_sig.cc +++ b/paddle/phi/ops/compat/elementwise_sig.cc @@ -95,6 +95,15 @@ KernelSignature ElementwiseFloorDivOpArgumentMapping( return KernelSignature("floor_divide_raw", {"X", "Y"}, {"axis"}, {"Out"}); } +KernelSignature ElementwiseHeavisideOpArgumentMapping( + const ArgumentMappingContext& ctx) { + int axis = paddle::any_cast(ctx.Attr("axis")); + if (axis == -1) { + return KernelSignature("elementwise_heaviside", {"X", "Y"}, {}, {"Out"}); + } + return KernelSignature("elementwise_heaviside_raw", {"X", "Y"}, {"axis"}, {"Out"}); +} + KernelSignature ElementwisePowOpArgumentMapping( const ArgumentMappingContext& ctx) { int axis = paddle::any_cast(ctx.Attr("axis")); @@ -222,6 +231,15 @@ KernelSignature ElementwiseMinGradOpArgumentMapping( {"axis"}, {GradVarName("X"), GradVarName("Y")}); } + +KernelSignature ElementwiseHeavisideGradOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature("elementwise_heaviside_grad", + {"X", "Y", GradVarName("Out")}, + {"axis"}, + {GradVarName("X"), GradVarName("Y")}); +} + KernelSignature ElementwisePowGradOpArgumentMapping( const ArgumentMappingContext& ctx) { return KernelSignature("elementwise_pow_grad", @@ -272,6 +290,8 @@ PD_REGISTER_ARG_MAPPING_FN(elementwise_mod, phi::ElementwiseModOpArgumentMapping); PD_REGISTER_ARG_MAPPING_FN(elementwise_floordiv, phi::ElementwiseFloorDivOpArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(elementwise_heaviside, + phi::ElementwiseHeavisideOpArgumentMapping); PD_REGISTER_ARG_MAPPING_FN(elementwise_pow, phi::ElementwisePowOpArgumentMapping); PD_REGISTER_ARG_MAPPING_FN(elementwise_add_grad, @@ -306,5 +326,7 @@ PD_REGISTER_ARG_MAPPING_FN(elementwise_max_grad, phi::ElementwiseMaxGradOpArgumentMapping); PD_REGISTER_ARG_MAPPING_FN(elementwise_min_grad, phi::ElementwiseMinGradOpArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(elementwise_heaviside_grad, + phi::ElementwiseHeavisideGradOpArgumentMapping); PD_REGISTER_ARG_MAPPING_FN(elementwise_pow_grad, phi::ElementwisePowGradOpArgumentMapping); From d9db011b5e9e9e46c42e0f32564dd9e4e07e7725 Mon Sep 17 00:00:00 2001 From: BrilliantYuKaimin <91609464+BrilliantYuKaimin@users.noreply.github.com> Date: Sat, 16 Apr 2022 10:39:10 +0800 Subject: [PATCH 07/16] =?UTF-8?q?=E5=A2=9E=E5=8A=A0heaviside=E7=9A=84?= =?UTF-8?q?=E6=A0=B8=E5=87=BD=E6=95=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- paddle/phi/kernels/cpu/elementwise_kernel.cc | 21 ++++++++++++++++ paddle/phi/kernels/elementwise_kernel.cc | 25 ++++++++++++++++++++ paddle/phi/kernels/elementwise_kernel.h | 24 +++++++++++++++++++ paddle/phi/kernels/kps/elementwise_kernel.cu | 10 ++++++++ 4 files changed, 80 insertions(+) diff --git a/paddle/phi/kernels/cpu/elementwise_kernel.cc b/paddle/phi/kernels/cpu/elementwise_kernel.cc index a91ca1ee3244b..0cd236c9a8f04 100644 --- a/paddle/phi/kernels/cpu/elementwise_kernel.cc +++ b/paddle/phi/kernels/cpu/elementwise_kernel.cc @@ -142,6 +142,19 @@ void ElementwisePowRawKernel(const Context& dev_ctx, funcs::ElementwiseCompute, T>( dev_ctx, x, y, axis, funcs::ElementwisePowFunctor(), out); } + +template +void ElementwiseHeavisideRawKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + int axis, + DenseTensor* out) { + // allocate memory for out + dev_ctx.template Alloc(out); + funcs::ElementwiseCompute, T>( + dev_ctx, x, y, axis, funcs::ElementwiseHeavisideFunctor(), out); +} + // Create the definition of Add DEFINE_CPU_ELEMENTWISE_OP(Add) @@ -250,3 +263,11 @@ PD_REGISTER_KERNEL(elementwise_pow_raw, double, int, int64_t) {} +PD_REGISTER_KERNEL(elementwise_heaviside_raw, + CPU, + ALL_LAYOUT, + phi::ElementwiseHeavisideRawKernel, + float, + double, + int, + int64_t) {} diff --git a/paddle/phi/kernels/elementwise_kernel.cc b/paddle/phi/kernels/elementwise_kernel.cc index 6cd602e47b8e6..8c7a8c88f4630 100644 --- a/paddle/phi/kernels/elementwise_kernel.cc +++ b/paddle/phi/kernels/elementwise_kernel.cc @@ -100,6 +100,15 @@ void ElementwisePowKernel(const Context& dev_ctx, ElementwisePowRawKernel(dev_ctx, x, y, axis, out); } +template +void ElementwiseHeavisideKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + DenseTensor* out) { + int axis = -1; + ElementwiseHeavisideRawKernel(dev_ctx, x, y, axis, out); +} + } // namespace phi using complex64 = ::phi::dtype::complex; @@ -172,6 +181,14 @@ PD_REGISTER_KERNEL( modulo, CPU, ALL_LAYOUT, phi::ModuloKernel, float, double, int, int64_t) {} PD_REGISTER_KERNEL( floor_divide, CPU, ALL_LAYOUT, phi::FloorDivideKernel, int, int64_t) {} +PD_REGISTER_KERNEL(elementwise_heaviside, + CPU, + ALL_LAYOUT, + phi::ElementwiseHeavisideKernel, + float, + double, + int, + int64_t) {} PD_REGISTER_KERNEL(elementwise_pow, CPU, ALL_LAYOUT, @@ -258,6 +275,14 @@ PD_REGISTER_KERNEL( modulo, GPU, ALL_LAYOUT, phi::ModuloKernel, float, double, int, int64_t) {} PD_REGISTER_KERNEL( floor_divide, GPU, ALL_LAYOUT, phi::FloorDivideKernel, int, int64_t) {} +PD_REGISTER_KERNEL(elementwise_heaviside, + GPU, + ALL_LAYOUT, + phi::ElementwiseHeavisideKernel, + float, + double, + int, + int64_t) {} PD_REGISTER_KERNEL(elementwise_pow, GPU, ALL_LAYOUT, diff --git a/paddle/phi/kernels/elementwise_kernel.h b/paddle/phi/kernels/elementwise_kernel.h index 09b6b02e37257..da888a68b137c 100644 --- a/paddle/phi/kernels/elementwise_kernel.h +++ b/paddle/phi/kernels/elementwise_kernel.h @@ -150,6 +150,19 @@ void ElementwisePowKernel(const Context& dev_ctx, const DenseTensor& y, DenseTensor* out); +template +void ElementwiseHeavisideRawKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + int axis, + DenseTensor* out); + +template +void ElementwiseHeavisideKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + DenseTensor* out); + template DenseTensor Add(const Context& dev_ctx, const DenseTensor& x, @@ -238,6 +251,17 @@ DenseTensor FloorDivide(const Context& dev_ctx, return dense_out; } +template +DenseTensor ElementwiseHeaviside(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y) { + DenseTensor dense_out; + MetaTensor meta_out(&dense_out); + ElementwiseInferMeta(x, y, &meta_out); + ElementwiseHeavisideKernel(dev_ctx, x, y, &dense_out); + return dense_out; +} + template DenseTensor ElementwisePow(const Context& dev_ctx, const DenseTensor& x, diff --git a/paddle/phi/kernels/kps/elementwise_kernel.cu b/paddle/phi/kernels/kps/elementwise_kernel.cu index 01a34c0f85eda..b282cb6bf70ed 100644 --- a/paddle/phi/kernels/kps/elementwise_kernel.cu +++ b/paddle/phi/kernels/kps/elementwise_kernel.cu @@ -58,6 +58,8 @@ DEFINE_CUDA_ELEMENTWISE_OP(Minimum) DEFINE_CUDA_ELEMENTWISE_OP(Modulo) // Create the definition of FloorDivide DEFINE_CUDA_ELEMENTWISE_OP(FloorDivide) +// Create the definition of Heaviside +DEFINE_CUDA_ELEMENTWISE_OP(ElementwiseHeaviside) // Create the definition of Pow DEFINE_CUDA_ELEMENTWISE_OP(ElementwisePow) @@ -174,6 +176,14 @@ PD_REGISTER_KERNEL(floor_divide_raw, phi::FloorDivideRawKernel, int, int64_t) {} +PD_REGISTER_KERNEL(elementwise_heaviside_raw, + KPS, + ALL_LAYOUT, + phi::ElementwiseHeavisideRawKernel, + float, + double, + int, + int64_t) {} PD_REGISTER_KERNEL(elementwise_pow_raw, KPS, ALL_LAYOUT, From e3dcba16ce3dd7742f0f41c40b206df0b5641a21 Mon Sep 17 00:00:00 2001 From: BrilliantYuKaimin <91609464+BrilliantYuKaimin@users.noreply.github.com> Date: Sat, 16 Apr 2022 10:40:05 +0800 Subject: [PATCH 08/16] =?UTF-8?q?=E5=A2=9E=E5=8A=A0heaviside=E6=A2=AF?= =?UTF-8?q?=E5=BA=A6=E7=9A=84=E6=A0=B8=E5=87=BD=E6=95=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- paddle/phi/kernels/elementwise_grad_kernel.h | 9 +++++++ .../impl/elementwise_grad_kernel_impl.h | 27 +++++++++++++++++++ 2 files changed, 36 insertions(+) diff --git a/paddle/phi/kernels/elementwise_grad_kernel.h b/paddle/phi/kernels/elementwise_grad_kernel.h index 0e730fbfbfa4d..a1c806beb9a40 100644 --- a/paddle/phi/kernels/elementwise_grad_kernel.h +++ b/paddle/phi/kernels/elementwise_grad_kernel.h @@ -160,6 +160,15 @@ void MinimumGradKernel(const Context& dev_ctx, DenseTensor* dx, DenseTensor* dy); +template +void ElementwiseHeavisideGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& dout, + int axis, + DenseTensor* dx, + DenseTensor* dy); + template void ElementwisePowGradKernel(const Context& dev_ctx, const DenseTensor& x, diff --git a/paddle/phi/kernels/impl/elementwise_grad_kernel_impl.h b/paddle/phi/kernels/impl/elementwise_grad_kernel_impl.h index aba4a5f5fbd43..67813f754d33c 100644 --- a/paddle/phi/kernels/impl/elementwise_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/elementwise_grad_kernel_impl.h @@ -667,6 +667,33 @@ struct MinGradDy { } }; +template +struct HeavisideGradDx { + HOSTDEVICE T operator()(T x, T y, T out, T dout) const { + return dout * static_cast(0); + } +}; + +template +struct HeavisideGradDy { + HOSTDEVICE T operator()(T x, T y, T out, T dout) const { + return dout * static_cast(x == static_cast(0)); + } +}; + +template +void ElementwiseHeavisideGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& dout, + int axis, + DenseTensor* dx, + DenseTensor* dy) { + funcs::ElementwiseGradPreProcess(dout, dx); + phi::funcs::ElemwiseGradCompute, HeavisideGradDy>( + dev_ctx, x, y, dout, dout, axis, dx, dy, HeavisideGradDx(), HeavisideGradDy()); +} + template struct PowGradDX { HOSTDEVICE T operator()(T x, T y, T out, T dout) const { From 80c44e7bbdddccd49b24c0d7475307e94483dcad Mon Sep 17 00:00:00 2001 From: BrilliantYuKaimin <91609464+BrilliantYuKaimin@users.noreply.github.com> Date: Sat, 16 Apr 2022 10:41:14 +0800 Subject: [PATCH 09/16] =?UTF-8?q?=E5=A2=9E=E5=8A=A0heaviside=E6=A2=AF?= =?UTF-8?q?=E5=BA=A6=E7=9A=84=E6=B3=A8=E5=86=8C?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- paddle/phi/kernels/cpu/elementwise_grad_kernel.cc | 10 ++++++++++ paddle/phi/kernels/gpu/elementwise_grad_kernel.cu | 10 ++++++++++ 2 files changed, 20 insertions(+) diff --git a/paddle/phi/kernels/cpu/elementwise_grad_kernel.cc b/paddle/phi/kernels/cpu/elementwise_grad_kernel.cc index f452d9ffb7e89..804b6449876e5 100644 --- a/paddle/phi/kernels/cpu/elementwise_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/elementwise_grad_kernel.cc @@ -323,6 +323,16 @@ PD_REGISTER_KERNEL(minimum_grad, int, int64_t, phi::dtype::bfloat16) {} + +PD_REGISTER_KERNEL(elementwise_heaviside_grad, + CPU, + ALL_LAYOUT, + phi::ElementwiseHeavisideGradKernel, + float, + double, + int, + int64_t) {} + PD_REGISTER_KERNEL(elementwise_pow_grad, CPU, ALL_LAYOUT, diff --git a/paddle/phi/kernels/gpu/elementwise_grad_kernel.cu b/paddle/phi/kernels/gpu/elementwise_grad_kernel.cu index fae7978d3d2ea..a2b9f868578b8 100644 --- a/paddle/phi/kernels/gpu/elementwise_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/elementwise_grad_kernel.cu @@ -382,6 +382,16 @@ PD_REGISTER_KERNEL(minimum_grad, int64_t, phi::dtype::float16, phi::dtype::bfloat16) {} + +PD_REGISTER_KERNEL(elementwise_heaviside_grad, + GPU, + ALL_LAYOUT, + phi::ElementwiseHeavisideGradKernel, + float, + double, + int, + int64_t) {} + PD_REGISTER_KERNEL(elementwise_pow_grad, GPU, ALL_LAYOUT, From 33fd790466e32d7a1f03cbe6352c4f2b4ef46562 Mon Sep 17 00:00:00 2001 From: BrilliantYuKaimin <91609464+BrilliantYuKaimin@users.noreply.github.com> Date: Sun, 17 Apr 2022 18:23:27 +0800 Subject: [PATCH 10/16] =?UTF-8?q?=E8=B0=83=E6=95=B4=E4=BB=A3=E7=A0=81?= =?UTF-8?q?=E6=A0=BC=E5=BC=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../kernels/impl/elementwise_grad_kernel_impl.h | 14 ++++++++++++-- paddle/phi/ops/compat/elementwise_sig.cc | 9 +++++---- 2 files changed, 17 insertions(+), 6 deletions(-) diff --git a/paddle/phi/kernels/impl/elementwise_grad_kernel_impl.h b/paddle/phi/kernels/impl/elementwise_grad_kernel_impl.h index 67813f754d33c..6ec5f2fb0962d 100644 --- a/paddle/phi/kernels/impl/elementwise_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/elementwise_grad_kernel_impl.h @@ -690,8 +690,18 @@ void ElementwiseHeavisideGradKernel(const Context& dev_ctx, DenseTensor* dx, DenseTensor* dy) { funcs::ElementwiseGradPreProcess(dout, dx); - phi::funcs::ElemwiseGradCompute, HeavisideGradDy>( - dev_ctx, x, y, dout, dout, axis, dx, dy, HeavisideGradDx(), HeavisideGradDy()); + phi::funcs:: + ElemwiseGradCompute, HeavisideGradDy>( + dev_ctx, + x, + y, + dout, + dout, + axis, + dx, + dy, + HeavisideGradDx(), + HeavisideGradDy()); } template diff --git a/paddle/phi/ops/compat/elementwise_sig.cc b/paddle/phi/ops/compat/elementwise_sig.cc index 76ab2d1e6415f..4cf9cc014ef35 100644 --- a/paddle/phi/ops/compat/elementwise_sig.cc +++ b/paddle/phi/ops/compat/elementwise_sig.cc @@ -101,7 +101,8 @@ KernelSignature ElementwiseHeavisideOpArgumentMapping( if (axis == -1) { return KernelSignature("elementwise_heaviside", {"X", "Y"}, {}, {"Out"}); } - return KernelSignature("elementwise_heaviside_raw", {"X", "Y"}, {"axis"}, {"Out"}); + return KernelSignature( + "elementwise_heaviside_raw", {"X", "Y"}, {"axis"}, {"Out"}); } KernelSignature ElementwisePowOpArgumentMapping( @@ -235,9 +236,9 @@ KernelSignature ElementwiseMinGradOpArgumentMapping( KernelSignature ElementwiseHeavisideGradOpArgumentMapping( const ArgumentMappingContext& ctx) { return KernelSignature("elementwise_heaviside_grad", - {"X", "Y", GradVarName("Out")}, - {"axis"}, - {GradVarName("X"), GradVarName("Y")}); + {"X", "Y", GradVarName("Out")}, + {"axis"}, + {GradVarName("X"), GradVarName("Y")}); } KernelSignature ElementwisePowGradOpArgumentMapping( From 8d4cd9b06d210ad6555a6594ababbba10ce87a32 Mon Sep 17 00:00:00 2001 From: BrilliantYuKaimin <91609464+BrilliantYuKaimin@users.noreply.github.com> Date: Sun, 17 Apr 2022 21:07:20 +0800 Subject: [PATCH 11/16] Update elementwise_sig.cc --- paddle/phi/ops/compat/elementwise_sig.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/paddle/phi/ops/compat/elementwise_sig.cc b/paddle/phi/ops/compat/elementwise_sig.cc index a2a61a611837a..6859a76d73f1a 100644 --- a/paddle/phi/ops/compat/elementwise_sig.cc +++ b/paddle/phi/ops/compat/elementwise_sig.cc @@ -222,9 +222,9 @@ KernelSignature ElementwiseMinGradOpArgumentMapping( KernelSignature ElementwiseHeavisideGradOpArgumentMapping( const ArgumentMappingContext& ctx) { return KernelSignature("elementwise_heaviside_grad", - {"X", "Y", GradVarName("Out")}, + {"X", "Y", "Out@GRAD"}, {"axis"}, - {GradVarName("X"), GradVarName("Y")}); + {"X@GRAD", "Y@GRAD"}); } KernelSignature ElementwisePowGradOpArgumentMapping( From 65a0fdef084b914009174e2556aba87209834223 Mon Sep 17 00:00:00 2001 From: BrilliantYuKaimin <91609464+BrilliantYuKaimin@users.noreply.github.com> Date: Mon, 18 Apr 2022 11:48:43 +0800 Subject: [PATCH 12/16] add heaviside in __all__ --- python/paddle/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index 434f19bbce066..be94fd162bc01 100755 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -634,4 +634,5 @@ 'renorm', 'take_along_axis', 'put_along_axis', + 'heaviside', ] From 9aeb5d7cc26c510c115586b603c8504c61885a24 Mon Sep 17 00:00:00 2001 From: BrilliantYuKaimin <91609464+BrilliantYuKaimin@users.noreply.github.com> Date: Thu, 21 Apr 2022 16:36:52 +0800 Subject: [PATCH 13/16] Update heaviside docs --- python/paddle/tensor/math.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index d96377cccf97a..e8ea51b4d8d80 100644 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -4414,7 +4414,7 @@ def angle(x, name=None): def heaviside(x, y, name=None): """ - Computes the Heaviside step function determined by y for each element in x. The equation is: + Computes the Heaviside step function determined by corresponding element in y for each element in x. The equation is: .. math:: heaviside(x, y)= \left\{ @@ -4426,18 +4426,20 @@ def heaviside(x, y, name=None): \right. Notes: - ``paddle.heaviside`` supports broadcasting. If you want know more about broadcasting, please refer to :ref:`user_guide_broadcasting` . + ``paddle.heaviside`` supports broadcasting. If you want know more about broadcasting, please refer to :ref:`user_guide_broadcasting`. Args: - x (Tensor): the input tensor, it's data type should be float32, float64, int32, int64. - y (Tensor): the input tensor, it's data type should be float32, float64, int32, int64. - name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. + x (Tensor): The input tensor of Heaviside step function, it's data type should be float32, float64, int32 or int64. + y (Tensor): The tensor that determines a Heaviside step function, it's data type should be float32, float64, int32 or int64. + name (str, optional): Name for the operation (optional, default is None). Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`. Returns: - N-D Tensor. A location into which the result is stored. If x, y have different shapes and are "broadcastable", the resulting tensor shape is the shape of x and y after broadcasting. If x, y have the same shape, its shape is the same as x and y. + N-D Tensor. A location into which the result is stored. If x and y have different shapes and are broadcastable, the resulting tensor shape is the shape of x and y after broadcasting. If x, y have the same shape, its shape is the same as x and y. Examples: .. code-block:: python + :name: heaviside-example + import paddle x = paddle.to_tensor([-0.5, 0, 0.5]) y = paddle.to_tensor([0.1]) From c96fe8d466962a35941f8e92503865303966aad1 Mon Sep 17 00:00:00 2001 From: BrilliantYuKaimin <91609464+BrilliantYuKaimin@users.noreply.github.com> Date: Fri, 22 Apr 2022 15:27:32 +0800 Subject: [PATCH 14/16] Update math.py --- python/paddle/tensor/math.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index e8ea51b4d8d80..e362cf5b6a381 100644 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -4415,6 +4415,7 @@ def angle(x, name=None): def heaviside(x, y, name=None): """ Computes the Heaviside step function determined by corresponding element in y for each element in x. The equation is: + .. math:: heaviside(x, y)= \left\{ From 879a591ea5479fb1040b7231013253b8ef5d65d5 Mon Sep 17 00:00:00 2001 From: BrilliantYuKaimin <91609464+BrilliantYuKaimin@users.noreply.github.com> Date: Fri, 22 Apr 2022 17:45:25 +0800 Subject: [PATCH 15/16] Update math.py --- python/paddle/tensor/math.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index e362cf5b6a381..4a2fec712cf3c 100644 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -4414,17 +4414,17 @@ def angle(x, name=None): def heaviside(x, y, name=None): """ - Computes the Heaviside step function determined by corresponding element in y for each element in x. The equation is: + Computes the Heaviside step function determined by corresponding element in y for each element in x. The equation is .. math:: heaviside(x, y)= \left\{ - \begin{array}{lcl} - 0,& &\text{if } \ x < 0, \\ - y,& &\text{if } \ x = 0, \\ - 1,& &\text{if } \ x > 0. + \\begin{array}{lcl} + 0,& &\\text{if} \ x < 0, \\ + y,& &\\text{if} \ x = 0, \\ + 1,& &\\text{if} \ x > 0. \end{array} - \right. + \\right. Notes: ``paddle.heaviside`` supports broadcasting. If you want know more about broadcasting, please refer to :ref:`user_guide_broadcasting`. From 12ebfdbef1add14449c02ffe9b53983095f97574 Mon Sep 17 00:00:00 2001 From: BrilliantYuKaimin <91609464+BrilliantYuKaimin@users.noreply.github.com> Date: Fri, 22 Apr 2022 22:07:42 +0800 Subject: [PATCH 16/16] Update math.py --- python/paddle/tensor/math.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index 4a2fec712cf3c..527df4f5c8f62 100644 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -4420,14 +4420,14 @@ def heaviside(x, y, name=None): heaviside(x, y)= \left\{ \\begin{array}{lcl} - 0,& &\\text{if} \ x < 0, \\ - y,& &\\text{if} \ x = 0, \\ + 0,& &\\text{if} \ x < 0, \\\\ + y,& &\\text{if} \ x = 0, \\\\ 1,& &\\text{if} \ x > 0. \end{array} \\right. Notes: - ``paddle.heaviside`` supports broadcasting. If you want know more about broadcasting, please refer to :ref:`user_guide_broadcasting`. + ``paddle.heaviside`` supports broadcasting. If you want know more about broadcasting, please refer to :ref:`user_guide_broadcasting`. Args: x (Tensor): The input tensor of Heaviside step function, it's data type should be float32, float64, int32 or int64.