New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
【PaddlePaddle Hackathon 2】15 新增 API Nanmedian #42385
Changes from 13 commits
9fe3ee7
00f901c
7eae9c2
1f2a6e6
7fb02ab
1cce8df
adaa2a1
d5c35d8
4ec331b
24424a7
a0e6c3c
0dac2bd
2a944f6
d7bdc21
06af183
39f5eb9
eb8cb64
bcfb015
718fcdb
8c158b5
b0c9471
5f14183
46dc918
21e131e
117e102
dc6b654
4e8cbc1
a3b23f6
be473ea
6744d0a
8021de0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,120 @@ | ||
/*Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#include <memory> | ||
#include "paddle/fluid/framework/infershape_utils.h" | ||
#include "paddle/fluid/framework/op_registry.h" | ||
#include "paddle/fluid/framework/op_version_registry.h" | ||
#include "paddle/phi/core/infermeta_utils.h" | ||
#include "paddle/phi/infermeta/unary.h" | ||
|
||
namespace paddle { | ||
namespace operators { | ||
|
||
class NanmedianOp : public framework::OperatorWithKernel { | ||
public: | ||
using framework::OperatorWithKernel::OperatorWithKernel; | ||
|
||
framework::OpKernelType GetExpectedKernelType( | ||
const framework::ExecutionContext& ctx) const override { | ||
return framework::OpKernelType( | ||
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); | ||
} | ||
}; | ||
|
||
class NanmedianOpMaker : public framework::OpProtoAndCheckerMaker { | ||
public: | ||
void Make() override { | ||
AddInput("X", | ||
"(Tensor, default Tensor<float>), " | ||
"the input feature data of NanmedianOp, dtype should be" | ||
"int32, int64, float16, float32, float64."); | ||
AddAttr<bool>( | ||
"ignore_nan", | ||
"(bool, default true) Set to true if nan values should be ignored. " | ||
"Set to false when no nan value in x were considered. ") | ||
.SetDefault(true); | ||
AddOutput("Medians", | ||
"The calculation differs in the odd or even of the valid " | ||
"elements amount." | ||
"Along the axis, two elements contributed to the median value in " | ||
"each row." | ||
"If the amount of valid elements were even, both were the same.") | ||
.AsIntermediate() | ||
.AsExtra(); | ||
AddOutput("Out", | ||
"(Tensor, default Tensor<float>)," | ||
" the output of NanmedianOp, whose dtype is the same as X"); | ||
AddComment(R"DOC( | ||
Nanmedian operator | ||
|
||
This operator is considered as an extention of median operation, | ||
which supports specifically the case of nan values in the input. | ||
|
||
If all the elements in input are NaN it will also return NaN. | ||
If no elements in input are Nan, this op is identical to thie median op. | ||
|
||
This operator can also supports multiple axis, | ||
and could be switched to median operator when `ignore_nan` were set to False. | ||
)DOC"); | ||
} | ||
}; | ||
|
||
template <typename T> | ||
class NanmedianGradMaker : public framework::SingleGradOpMaker<T> { | ||
public: | ||
using framework::SingleGradOpMaker<T>::SingleGradOpMaker; | ||
|
||
void Apply(GradOpPtr<T> op) const override { | ||
op->SetType("nanmedian_grad"); | ||
op->SetInput("X", this->Input("X")); | ||
op->SetInput("Medians", this->Output("Medians")); | ||
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); | ||
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); | ||
} | ||
}; | ||
|
||
class NanmedianGradOp : public framework::OperatorWithKernel { | ||
public: | ||
using framework::OperatorWithKernel::OperatorWithKernel; | ||
|
||
void InferShape(framework::InferShapeContext* ctx) const override { | ||
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "nanmedian"); | ||
OP_INOUT_CHECK(ctx->HasInput("Medians"), "Input", "Medians", "nanmedian"); | ||
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input", | ||
framework::GradVarName("Out"), "nanmedian"); | ||
|
||
auto x_dims = ctx->GetInputDim("X"); | ||
ctx->SetOutputDim(framework::GradVarName("X"), x_dims); | ||
} | ||
|
||
protected: | ||
framework::OpKernelType GetExpectedKernelType( | ||
const framework::ExecutionContext& ctx) const override { | ||
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( | ||
ctx, framework::GradVarName("Out")), | ||
ctx.GetPlace()); | ||
} | ||
}; | ||
|
||
} // namespace operators | ||
} // namespace paddle | ||
|
||
namespace ops = paddle::operators; | ||
DECLARE_INFER_SHAPE_FUNCTOR(nanmedian, NanmedianInferShapeFunctor, | ||
PD_INFER_META(phi::NanmedianInferMeta)); | ||
|
||
REGISTER_OPERATOR(nanmedian, ops::NanmedianOp, ops::NanmedianOpMaker, | ||
ops::NanmedianGradMaker<paddle::framework::OpDesc>, | ||
ops::NanmedianGradMaker<paddle::imperative::OpBase>, | ||
NanmedianInferShapeFunctor); | ||
|
||
REGISTER_OPERATOR(nanmedian_grad, ops::NanmedianGradOp); |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#include "paddle/phi/kernels/nanmedian_grad_kernel.h" | ||
|
||
#include "paddle/phi/backends/cpu/cpu_context.h" | ||
#include "paddle/phi/core/kernel_registry.h" | ||
|
||
namespace phi { | ||
|
||
template <typename T, typename Context> | ||
void NanmedianGradKernel(const Context& dev_ctx, | ||
const DenseTensor& x, | ||
const DenseTensor& medians, | ||
const DenseTensor& out_grad, | ||
DenseTensor* x_grad) { | ||
const T* x_ptr = x.data<T>(); | ||
const T* m_ptr = medians.data<T>(); | ||
const T* out_grad_ptr = out_grad.data<T>(); | ||
|
||
int64_t numel = x.numel(); | ||
auto x_dim = x.dims(); | ||
int64_t x_rank = x_dim.size(); | ||
int64_t stride = x_dim[x_rank - 1]; | ||
auto zero = static_cast<T>(0); | ||
|
||
if (x_grad) { | ||
T* x_grad_ptr = dev_ctx.template Alloc<T>(x_grad); | ||
int64_t i = 0; | ||
for (i = 0; i < numel; i++) { | ||
if (std::isnan(static_cast<float>(x_ptr[i]))) { | ||
x_grad_ptr[i] = zero; | ||
continue; | ||
} | ||
|
||
int64_t row = static_cast<int64_t>(i / stride); | ||
int64_t m_row = 2 * row; | ||
if (std::isnan(static_cast<float>(m_ptr[m_row])) || | ||
(x_ptr[i] != m_ptr[m_row] && x_ptr[i] != m_ptr[m_row + 1])) { | ||
x_grad_ptr[i] = zero; | ||
continue; | ||
} | ||
|
||
x_grad_ptr[i] = out_grad_ptr[row]; | ||
} | ||
} | ||
} | ||
|
||
} // namespace phi | ||
|
||
PD_REGISTER_KERNEL(nanmedian_grad, | ||
CPU, | ||
ALL_LAYOUT, | ||
phi::NanmedianGradKernel, | ||
float, | ||
double, | ||
int, | ||
int64_t) {} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#include "paddle/phi/kernels/nanmedian_kernel.h" | ||
|
||
#include "paddle/phi/backends/cpu/cpu_context.h" | ||
#include "paddle/phi/core/kernel_registry.h" | ||
|
||
namespace phi { | ||
|
||
template <typename T, typename Context> | ||
void NanmedianKernel(const Context& dev_ctx, | ||
const DenseTensor& x, | ||
bool ignore_nan, | ||
DenseTensor* out, | ||
DenseTensor* medians) { | ||
const T* x_ptr = x.data<T>(); | ||
T* o_ptr = dev_ctx.template Alloc<T>(out); | ||
T* m_ptr = dev_ctx.template Alloc<T>(medians); | ||
|
||
int64_t numel = x.numel(); | ||
auto x_dim = x.dims(); | ||
int64_t x_rank = x_dim.size(); | ||
int64_t stride = x_dim[x_rank - 1]; | ||
int64_t pre_dim = numel / stride; | ||
int64_t i = 0; | ||
|
||
bool all_nan = true; | ||
for (i = 0; i < numel; i++) { | ||
if (!std::isnan(static_cast<float>(*(x_ptr + i)))) { | ||
all_nan = false; | ||
break; | ||
} | ||
} | ||
|
||
if (all_nan) { | ||
for (i = 0; i < pre_dim; i++) { | ||
o_ptr[i] = x_ptr[0]; | ||
m_ptr[2 * i] = x_ptr[0]; | ||
m_ptr[2 * i + 1] = x_ptr[0]; | ||
} | ||
return; | ||
} | ||
|
||
std::vector<T> col_vec; | ||
col_vec.reserve(stride); | ||
col_vec.resize(stride); | ||
for (i = 0; i < pre_dim; i++) { | ||
col_vec.clear(); | ||
col_vec.insert( | ||
col_vec.begin(), x_ptr + i * stride, x_ptr + (i + 1) * stride); | ||
|
||
int64_t num_nan = | ||
std::count_if(col_vec.begin(), col_vec.end(), [&](const T& val) { | ||
return std::isnan(static_cast<float>(val)); | ||
}); | ||
|
||
int64_t pos = (stride - num_nan - 1) / 2; | ||
std::nth_element(col_vec.begin(), | ||
col_vec.begin() + pos, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. use col_vec.begin() + pos + 1 is better? no need to use std::nth_element again in if statement below? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Learned from https://en.cppreference.com/w/cpp/algorithm/nth_element, the order of elements before the nth element seems to be somehow unreliable, so I fetched the n-1th element again. If the index of elements are needed, this std:nth_element would not be the good choice. |
||
col_vec.end(), | ||
[](const T& l, const T& r) { | ||
return (!std::isnan(static_cast<float>(l)) && | ||
std::isnan(static_cast<float>(r))) || | ||
(l < r); | ||
}); | ||
|
||
m_ptr[2 * i] = col_vec[pos]; | ||
m_ptr[2 * i + 1] = col_vec[pos]; | ||
if ((stride - num_nan) % 2 == 0) { | ||
std::nth_element(col_vec.begin(), | ||
col_vec.begin() + pos + 1, | ||
col_vec.end(), | ||
[](const T& l, const T& r) { | ||
return (!std::isnan(static_cast<float>(l)) && | ||
std::isnan(static_cast<float>(r))) || | ||
(l < r); | ||
}); | ||
m_ptr[2 * i + 1] = col_vec[pos + 1]; | ||
} | ||
o_ptr[i] = static_cast<T>((m_ptr[2 * i] + m_ptr[2 * i + 1]) / 2.0); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. remind, how to calculate median in even numbers should be clearly written in the document. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
} | ||
} | ||
|
||
} // namespace phi | ||
|
||
PD_REGISTER_KERNEL(nanmedian, | ||
CPU, | ||
ALL_LAYOUT, | ||
phi::NanmedianKernel, | ||
float, | ||
double, | ||
int, | ||
int64_t) {} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this logic to get the gradient of nanmedian is not correct, consider this data: x = [float('nan'), 2., 3., 2., 3., float('nan')], after sort, only
2
in offset 3(start from 0) of x and3
in offset 2 of x are used to compute median, so only this two will have gradient.2
in offset 1 of x and3
in offset 4 of x do not have gradient.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done