Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【PaddlePaddle Hackathon 2】15 新增 API Nanmedian #42385

Merged
merged 31 commits into from May 30, 2022
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
9fe3ee7
nanmedian op
thunder95 Apr 28, 2022
00f901c
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
thunder95 Apr 28, 2022
7eae9c2
修改cuda kernel的bug
thunder95 Apr 29, 2022
1f2a6e6
修复count_if在其他硬件平台不兼容
thunder95 Apr 29, 2022
7fb02ab
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
thunder95 Apr 29, 2022
1cce8df
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
thunder95 Apr 29, 2022
adaa2a1
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
thunder95 Apr 30, 2022
d5c35d8
修复某些cpu硬件不兼容
thunder95 Apr 30, 2022
4ec331b
修复某些cpu硬件不兼容
thunder95 Apr 30, 2022
24424a7
修复isnan判断
thunder95 Apr 30, 2022
a0e6c3c
兼容numpy低版本不支持全部nan的情况
thunder95 Apr 30, 2022
0dac2bd
兼容numpy低版本不支持全部nan的情况
thunder95 Apr 30, 2022
2a944f6
fix code example
thunder95 May 1, 2022
d7bdc21
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
thunder95 May 5, 2022
06af183
fix api comment error
thunder95 May 5, 2022
39f5eb9
修改反向传播逻辑以及c++处理逻辑
thunder95 May 10, 2022
eb8cb64
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
thunder95 May 10, 2022
bcfb015
完成修改建议
thunder95 May 11, 2022
718fcdb
typo pre_dim
thunder95 May 12, 2022
8c158b5
update en docs, test=document_fix
thunder95 May 20, 2022
b0c9471
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
thunder95 May 20, 2022
5f14183
Merge branch 'nanmedian' of https://github.com/thunder95/Paddle into …
thunder95 May 20, 2022
46dc918
remove numpy in en doc, test=document_fix
thunder95 May 23, 2022
21e131e
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
thunder95 May 23, 2022
117e102
add r,test=document_fix
thunder95 May 25, 2022
dc6b654
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
thunder95 May 25, 2022
4e8cbc1
Merge branch 'nanmedian' of https://github.com/thunder95/Paddle into …
thunder95 May 25, 2022
a3b23f6
添加api到all
thunder95 May 26, 2022
be473ea
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
thunder95 May 26, 2022
6744d0a
follow advice from chenwhql
thunder95 May 26, 2022
8021de0
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
thunder95 May 26, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
120 changes: 120 additions & 0 deletions paddle/fluid/operators/nanmedian_op.cc
@@ -0,0 +1,120 @@
/*Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include <memory>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/unary.h"

namespace paddle {
namespace operators {

class NanmedianOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
}
};

class NanmedianOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X",
"(Tensor, default Tensor<float>), "
"the input feature data of NanmedianOp, dtype should be"
"int32, int64, float16, float32, float64.");
AddAttr<bool>(
"ignore_nan",
"(bool, default true) Set to true if nan values should be ignored. "
"Set to false when no nan value in x were considered. ")
.SetDefault(true);
AddOutput("Medians",
"The calculation differs in the odd or even of the valid "
"elements amount."
"Along the axis, two elements contributed to the median value in "
"each row."
"If the amount of valid elements were even, both were the same.")
.AsIntermediate()
.AsExtra();
AddOutput("Out",
"(Tensor, default Tensor<float>),"
" the output of NanmedianOp, whose dtype is the same as X");
AddComment(R"DOC(
Nanmedian operator

This operator is considered as an extention of median operation,
which supports specifically the case of nan values in the input.

If all the elements in input are NaN it will also return NaN.
If no elements in input are Nan, this op is identical to thie median op.

This operator can also supports multiple axis,
and could be switched to median operator when `ignore_nan` were set to False.
)DOC");
}
};

template <typename T>
class NanmedianGradMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;

void Apply(GradOpPtr<T> op) const override {
op->SetType("nanmedian_grad");
op->SetInput("X", this->Input("X"));
op->SetInput("Medians", this->Output("Medians"));
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
}
};

class NanmedianGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "nanmedian");
OP_INOUT_CHECK(ctx->HasInput("Medians"), "Input", "Medians", "nanmedian");
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input",
framework::GradVarName("Out"), "nanmedian");

auto x_dims = ctx->GetInputDim("X");
ctx->SetOutputDim(framework::GradVarName("X"), x_dims);
}

protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out")),
ctx.GetPlace());
}
};

} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(nanmedian, NanmedianInferShapeFunctor,
PD_INFER_META(phi::NanmedianInferMeta));

REGISTER_OPERATOR(nanmedian, ops::NanmedianOp, ops::NanmedianOpMaker,
ops::NanmedianGradMaker<paddle::framework::OpDesc>,
ops::NanmedianGradMaker<paddle::imperative::OpBase>,
NanmedianInferShapeFunctor);

REGISTER_OPERATOR(nanmedian_grad, ops::NanmedianGradOp);
24 changes: 24 additions & 0 deletions paddle/phi/infermeta/unary.cc
Expand Up @@ -1243,6 +1243,30 @@ void MultinomialInferMeta(const MetaTensor& x,
out->set_dtype(DataType::INT64);
}

void NanmedianInferMeta(const MetaTensor& x,
bool ignore_nan,
MetaTensor* out,
MetaTensor* medians) {
auto x_dim = x.dims();
int64_t x_rank = x_dim.size();

std::vector<int64_t> out_dims(x_rank);
std::vector<int64_t> median_dims(x_rank);
for (int64_t i = 0; i < x_rank - 1; i++) {
out_dims[i] = x_dim[i];
median_dims[i] = x_dim[i];
}

out_dims[x_rank - 1] = 1;
median_dims[x_rank - 1] = 2;

out->set_dims(make_ddim(out_dims));
out->set_dtype(x.dtype());

medians->set_dims(make_ddim(median_dims));
medians->set_dtype(x.dtype());
}

void NormInferMeta(const MetaTensor& x,
int axis,
float epsilon,
Expand Down
6 changes: 6 additions & 0 deletions paddle/phi/infermeta/unary.h
Expand Up @@ -177,6 +177,12 @@ void MultinomialInferMeta(const MetaTensor& x,
int num_samples,
bool replacement,
MetaTensor* out);

void NanmedianInferMeta(const MetaTensor& x,
bool ignore_nan,
MetaTensor* out,
MetaTensor* medians);

void NormInferMeta(const MetaTensor& x,
int axis,
float epsilon,
Expand Down
69 changes: 69 additions & 0 deletions paddle/phi/kernels/cpu/nanmedian_grad_kernel.cc
@@ -0,0 +1,69 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/nanmedian_grad_kernel.h"

#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"

namespace phi {

template <typename T, typename Context>
void NanmedianGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& medians,
const DenseTensor& out_grad,
DenseTensor* x_grad) {
const T* x_ptr = x.data<T>();
const T* m_ptr = medians.data<T>();
const T* out_grad_ptr = out_grad.data<T>();

int64_t numel = x.numel();
auto x_dim = x.dims();
int64_t x_rank = x_dim.size();
int64_t stride = x_dim[x_rank - 1];
auto zero = static_cast<T>(0);

if (x_grad) {
T* x_grad_ptr = dev_ctx.template Alloc<T>(x_grad);
int64_t i = 0;
for (i = 0; i < numel; i++) {
if (std::isnan(static_cast<float>(x_ptr[i]))) {
x_grad_ptr[i] = zero;
continue;
}

int64_t row = static_cast<int64_t>(i / stride);
int64_t m_row = 2 * row;
if (std::isnan(static_cast<float>(m_ptr[m_row])) ||
(x_ptr[i] != m_ptr[m_row] && x_ptr[i] != m_ptr[m_row + 1])) {
x_grad_ptr[i] = zero;
continue;
}

x_grad_ptr[i] = out_grad_ptr[row];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this logic to get the gradient of nanmedian is not correct, consider this data: x = [float('nan'), 2., 3., 2., 3., float('nan')], after sort, only 2 in offset 3(start from 0) of x and 3 in offset 2 of x are used to compute median, so only this two will have gradient. 2 in offset 1 of x and 3 in offset 4 of x do not have gradient.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

}
}
}

} // namespace phi

PD_REGISTER_KERNEL(nanmedian_grad,
CPU,
ALL_LAYOUT,
phi::NanmedianGradKernel,
float,
double,
int,
int64_t) {}
105 changes: 105 additions & 0 deletions paddle/phi/kernels/cpu/nanmedian_kernel.cc
@@ -0,0 +1,105 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/nanmedian_kernel.h"

#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"

namespace phi {

template <typename T, typename Context>
void NanmedianKernel(const Context& dev_ctx,
const DenseTensor& x,
bool ignore_nan,
DenseTensor* out,
DenseTensor* medians) {
const T* x_ptr = x.data<T>();
T* o_ptr = dev_ctx.template Alloc<T>(out);
T* m_ptr = dev_ctx.template Alloc<T>(medians);

int64_t numel = x.numel();
auto x_dim = x.dims();
int64_t x_rank = x_dim.size();
int64_t stride = x_dim[x_rank - 1];
int64_t pre_dim = numel / stride;
int64_t i = 0;

bool all_nan = true;
for (i = 0; i < numel; i++) {
if (!std::isnan(static_cast<float>(*(x_ptr + i)))) {
all_nan = false;
break;
}
}

if (all_nan) {
for (i = 0; i < pre_dim; i++) {
o_ptr[i] = x_ptr[0];
m_ptr[2 * i] = x_ptr[0];
m_ptr[2 * i + 1] = x_ptr[0];
}
return;
}

std::vector<T> col_vec;
col_vec.reserve(stride);
col_vec.resize(stride);
for (i = 0; i < pre_dim; i++) {
col_vec.clear();
col_vec.insert(
col_vec.begin(), x_ptr + i * stride, x_ptr + (i + 1) * stride);

int64_t num_nan =
std::count_if(col_vec.begin(), col_vec.end(), [&](const T& val) {
return std::isnan(static_cast<float>(val));
});

int64_t pos = (stride - num_nan - 1) / 2;
std::nth_element(col_vec.begin(),
col_vec.begin() + pos,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use col_vec.begin() + pos + 1 is better? no need to use std::nth_element again in if statement below?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Learned from https://en.cppreference.com/w/cpp/algorithm/nth_element, the order of elements before the nth element seems to be somehow unreliable, so I fetched the n-1th element again. If the index of elements are needed, this std:nth_element would not be the good choice.

col_vec.end(),
[](const T& l, const T& r) {
return (!std::isnan(static_cast<float>(l)) &&
std::isnan(static_cast<float>(r))) ||
(l < r);
});

m_ptr[2 * i] = col_vec[pos];
m_ptr[2 * i + 1] = col_vec[pos];
if ((stride - num_nan) % 2 == 0) {
std::nth_element(col_vec.begin(),
col_vec.begin() + pos + 1,
col_vec.end(),
[](const T& l, const T& r) {
return (!std::isnan(static_cast<float>(l)) &&
std::isnan(static_cast<float>(r))) ||
(l < r);
});
m_ptr[2 * i + 1] = col_vec[pos + 1];
}
o_ptr[i] = static_cast<T>((m_ptr[2 * i] + m_ptr[2 * i + 1]) / 2.0);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remind, how to calculate median in even numbers should be clearly written in the document.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

}
}

} // namespace phi

PD_REGISTER_KERNEL(nanmedian,
CPU,
ALL_LAYOUT,
phi::NanmedianKernel,
float,
double,
int,
int64_t) {}