New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
【PaddlePaddle Hackathon 2】18、为 Paddle 新增 paddle.heaviside 和 paddle.Tensor.heaviside API #40934
Closed
BrilliantYuKaimin
wants to merge
25
commits into
PaddlePaddle:develop
from
BrilliantYuKaimin:heaviside_2
Closed
Changes from 11 commits
Commits
Show all changes
25 commits
Select commit
Hold shift + click to select a range
299b726
增加heaviside的Python接口
BrilliantYuKaimin bfde6ff
增加heaviside算子的描述
BrilliantYuKaimin 4c91cfa
增加heaviside的相关函子
BrilliantYuKaimin ea60774
增加heaviside及其梯度的实现
BrilliantYuKaimin a439d57
增加heaviside及其梯度的核函数的注册
BrilliantYuKaimin d1322bc
增加heaviside的签名
BrilliantYuKaimin 99a8433
增加heaviside单测
BrilliantYuKaimin 2106d2c
Update elementwise_heaviside_op.cc
BrilliantYuKaimin 53cdf22
调整代码格式
BrilliantYuKaimin 1d07d18
Update elementwise_functor.h
BrilliantYuKaimin 36448fe
调整代码格式
BrilliantYuKaimin ea9ba8d
完善对输入的描述
BrilliantYuKaimin 5397eb8
Update test_elementwise_heaviside_op.py
BrilliantYuKaimin 0cf7567
Merge branch 'develop' into heaviside_2
BrilliantYuKaimin 0ecc4fc
Update test_elementwise_heaviside_op.py
BrilliantYuKaimin a3b84e1
add elementwise_heaviside
BrilliantYuKaimin b97fcb9
Update test_elementwise_heaviside_op.py
BrilliantYuKaimin 0739b3d
Merge branch 'PaddlePaddle:develop' into heaviside_2
BrilliantYuKaimin e3f0831
Update elementwise_sig.cc
BrilliantYuKaimin 570f4c4
add elementwise_heaviside
BrilliantYuKaimin d7609ee
Update test_elementwise_heaviside_op.py
BrilliantYuKaimin 9ca329e
Update math.py
BrilliantYuKaimin 5bc256d
Update test_elementwise_heaviside_op.py
BrilliantYuKaimin 07d5607
Merge branch 'develop' into heaviside_2
BrilliantYuKaimin 30cf8c6
Update test_elementwise_heaviside_op.py
BrilliantYuKaimin File filter
Filter by extension
Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
There are no files selected for viewing
63 changes: 63 additions & 0 deletions
63
paddle/fluid/operators/elementwise/elementwise_heaviside_op.cc
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#include <string> | ||
#include "paddle/fluid/operators/elementwise/elementwise_op.h" | ||
|
||
namespace paddle { | ||
namespace operators { | ||
|
||
class ElementwiseHeavisideOpMaker : public ElementwiseOpMaker { | ||
protected: | ||
std::string GetName() const override { return "Heaviside"; } | ||
std::string GetEquation() const override { return "Out = Heaviside(X, Y)"; } | ||
|
||
void AddInputX() override { | ||
AddInput("X", "The input tensor of Heaviside step function."); | ||
} | ||
|
||
void AddInputY() override { | ||
AddInput("Y", "The tensor determining a Heaviside step function."); | ||
} | ||
|
||
std::string GetOpFuntionality() const override { | ||
return "Computes the Heaviside step function determined by Y " | ||
"for each element in X."; | ||
} | ||
}; | ||
|
||
template <typename T> | ||
class ElementwiseHeavisideGradOpMaker : public framework::SingleGradOpMaker<T> { | ||
public: | ||
using framework::SingleGradOpMaker<T>::SingleGradOpMaker; | ||
|
||
protected: | ||
void Apply(GradOpPtr<T> op) const override { | ||
op->SetType("elementwise_heaviside_grad"); | ||
op->SetInput("X", this->Input("X")); | ||
op->SetInput("Y", this->Input("Y")); | ||
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); | ||
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); | ||
op->SetOutput(framework::GradVarName("Y"), this->InputGrad("Y")); | ||
op->SetAttrMap(this->Attrs()); | ||
} | ||
}; | ||
|
||
} // namespace operators | ||
} // namespace paddle | ||
|
||
namespace ops = paddle::operators; | ||
REGISTER_OPERATOR( | ||
elementwise_heaviside, ops::ElementwiseOp, ops::ElementwiseHeavisideOpMaker, | ||
ops::ElementwiseHeavisideGradOpMaker<paddle::framework::OpDesc>, | ||
ops::ElementwiseHeavisideGradOpMaker<paddle::imperative::OpBase>); | ||
|
||
REGISTER_OPERATOR(elementwise_heaviside_grad, ops::ElementwiseOpGrad); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
112 changes: 112 additions & 0 deletions
112
python/paddle/fluid/tests/unittests/test_elementwise_heaviside_op.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,112 @@ | ||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 单测中需要对照设计文档中补充测试case,目前仍然缺少不同dtype/设备/动态图、错误检查等测试内容。 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 已经补充了测试内容,OPTEST减少到了3个。 |
||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from __future__ import print_function | ||
|
||
import unittest | ||
import numpy as np | ||
from op_test import OpTest | ||
import paddle | ||
|
||
paddle.enable_static() | ||
|
||
|
||
class TestElementwiseOp(OpTest): | ||
def setUp(self): | ||
self.op_type = "elementwise_heaviside" | ||
x = np.random.random((13, 17)).astype("float64") | ||
y = np.random.random((13, 17)).astype("float64") | ||
self.inputs = {'X': x, 'Y': y} | ||
self.outputs = {'Out': np.heaviside(self.inputs['X'], self.inputs['Y'])} | ||
|
||
def test_check_output(self): | ||
self.check_output() | ||
|
||
def test_check_grad_normal(self): | ||
self.check_grad(['X', 'Y'], 'Out') | ||
|
||
|
||
class TestElementwiseHeavisideOp_Vector(TestElementwiseOp): | ||
def setUp(self): | ||
self.op_type = "elementwise_heaviside" | ||
x = np.random.random((100, )).astype("float64") | ||
y = np.random.random((100, )).astype("float64") | ||
self.inputs = {'X': x, 'Y': y} | ||
self.outputs = {'Out': np.heaviside(self.inputs['X'], self.inputs['Y'])} | ||
|
||
|
||
class TestElementwiseHeavisideOp_broadcast_0(TestElementwiseOp): | ||
def setUp(self): | ||
self.op_type = "elementwise_heaviside" | ||
x = np.random.random((100, 5, 2)).astype(np.float64) | ||
y = np.random.random((100, 1, 1)).astype(np.float64) | ||
self.inputs = {'X': x, 'Y': y} | ||
|
||
self.attrs = {'axis': 0} | ||
self.outputs = {'Out': np.heaviside(self.inputs['X'], self.inputs['Y'])} | ||
|
||
|
||
class TestElementwiseHeavisideOp_broadcast_1(TestElementwiseOp): | ||
def setUp(self): | ||
self.op_type = "elementwise_heaviside" | ||
x = np.random.random((2, 100, 3)).astype(np.float64) | ||
y = np.random.random((100, )).astype(np.float64) | ||
self.inputs = {'X': x, 'Y': y} | ||
|
||
self.attrs = {'axis': 1} | ||
self.outputs = { | ||
'Out': np.heaviside(self.inputs['X'], | ||
self.inputs['Y'].reshape(1, 100, 1)) | ||
} | ||
|
||
|
||
class TestElementwiseHeavisideOp_broadcast_2(TestElementwiseOp): | ||
def setUp(self): | ||
self.op_type = "elementwise_heaviside" | ||
x = np.random.random((1, 3, 100)).astype(np.float64) | ||
y = np.random.random((100, )).astype(np.float64) | ||
self.inputs = {'X': x, 'Y': y} | ||
|
||
self.outputs = { | ||
'Out': | ||
np.heaviside(self.inputs['X'], self.inputs['Y'].reshape(1, 1, 100)) | ||
} | ||
|
||
|
||
class TestElementwiseHeavisideOp_broadcast_3(TestElementwiseOp): | ||
def setUp(self): | ||
self.op_type = "elementwise_heaviside" | ||
x = np.random.random((2, 50, 2, 1)).astype(np.float64) | ||
y = np.random.random((50, 2)).astype(np.float64) | ||
self.inputs = {'X': x, 'Y': y} | ||
|
||
self.attrs = {'axis': 1} | ||
self.outputs = { | ||
'Out': np.heaviside(self.inputs['X'], | ||
self.inputs['Y'].reshape(1, 50, 2, 1)) | ||
} | ||
|
||
|
||
class TestElementwiseHeavisideOp_broadcast_4(TestElementwiseOp): | ||
def setUp(self): | ||
self.op_type = "elementwise_heaviside" | ||
x = np.random.random((2, 3, 4, 50)).astype(np.float64) | ||
y = np.random.random((2, 3, 1, 50)).astype(np.float64) | ||
self.inputs = {'X': x, 'Y': y} | ||
|
||
self.outputs = {'Out': np.heaviside(self.inputs['X'], self.inputs['Y'])} | ||
|
||
|
||
if __name__ == '__main__': | ||
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里对X,Y的说明还需要细化一下,可以参考https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/fluid/operators/elementwise/elementwise_floordiv_op.cc#L42
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
完成