Skip to content

Commit

Permalink
【PaddlePaddle Hackathon 2】15 新增 API Nanmedian (PaddlePaddle#42385)
Browse files Browse the repository at this point in the history
* nanmedian op

* 修改cuda kernel的bug

* 修复count_if在其他硬件平台不兼容

* 修复某些cpu硬件不兼容

* 修复某些cpu硬件不兼容

* 修复isnan判断

* 兼容numpy低版本不支持全部nan的情况

* 兼容numpy低版本不支持全部nan的情况

* fix code example

* fix api comment error

* 修改反向传播逻辑以及c++处理逻辑

* 完成修改建议

* typo pre_dim

* update en docs, test=document_fix

* remove numpy in en doc, test=document_fix

* add r,test=document_fix

* 添加api到all

* follow advice from chenwhql
  • Loading branch information
thunder95 authored and fuyou765 committed Jun 7, 2022
1 parent 0ff9c50 commit 9f65623
Show file tree
Hide file tree
Showing 17 changed files with 1,406 additions and 1 deletion.
125 changes: 125 additions & 0 deletions paddle/fluid/operators/nanmedian_op.cc
@@ -0,0 +1,125 @@
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include <memory>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/backward.h"
#include "paddle/phi/infermeta/unary.h"

namespace paddle {
namespace operators {

class NanmedianOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
}
};

class NanmedianOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X",
"(Tensor), "
"the input feature data of NanmedianOp, dtype should be"
"int32, int64, float16, float32 or float64.");
AddOutput(
"MedianIndex",
"Store the index position of median values, The calculation differs "
"in the odd or even valid elements numbers."
"Along the axis, two elements contributed to the median value in "
"each row."
"If the amount of valid elements were even, both were the same.")
.AsIntermediate()
.AsExtra();
AddOutput("Out",
"(Tensor),"
" the output of NanmedianOp, whose dtype is the same as X");
AddAttr<bool>("keepdim",
"(bool, default true) "
"If true, retain the reduced axis with length 1.")
.SetDefault(true);
AddAttr<std::vector<int>>("axis",
"(std::vector<int>). List of integers,"
" indicating the dimensions to calculate medians")
.SetDefault({});
AddComment(R"DOC(
Nanmedian operator
This operator is considered as an extention of median operation,
which supports specifically the case of NaN values in the input.
If all the elements in input are NaN it will also return NaN.
If no elements in input are Nan, this op is identical to thie median op.
If the valid count of elements is a even number, the average value of
the elements in the middle is calculated as the median.
This operator can also supports multiple axis.
)DOC");
}
};

template <typename T>
class NanmedianGradMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;

void Apply(GradOpPtr<T> op) const override {
op->SetType("nanmedian_grad");
op->SetInput("X", this->Input("X"));
op->SetInput("MedianIndex", this->Output("MedianIndex"));
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetAttrMap(this->Attrs());
}
};

class NanmedianGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out")),
ctx.GetPlace());
}
};

} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(nanmedian, NanmedianInferShapeFunctor,
PD_INFER_META(phi::NanmedianInferMeta));

REGISTER_OPERATOR(nanmedian, ops::NanmedianOp, ops::NanmedianOpMaker,
ops::NanmedianGradMaker<paddle::framework::OpDesc>,
ops::NanmedianGradMaker<paddle::imperative::OpBase>,
NanmedianInferShapeFunctor);

DECLARE_INFER_SHAPE_FUNCTOR(nanmedian_grad, NanmedianGradInferShapeFunctor,
PD_INFER_META(phi::NanmedianGradInferMeta));

REGISTER_OPERATOR(nanmedian_grad, ops::NanmedianGradOp,
NanmedianGradInferShapeFunctor);
11 changes: 11 additions & 0 deletions paddle/phi/infermeta/backward.cc
Expand Up @@ -433,6 +433,17 @@ void MultiplexGradInferMeta(const MetaTensor& ids,
}
}

void NanmedianGradInferMeta(const MetaTensor& x,
const MetaTensor& median_index,
const MetaTensor& out_grad,
const IntArray& axes,
bool keep_dim,
MetaTensor* x_grad) {
auto x_dims = x.dims();
x_grad->set_dims(x_dims);
x_grad->set_dtype(x.dtype());
}

void NllLossGradInferMeta(const MetaTensor& x,
const MetaTensor& label,
const MetaTensor& weight,
Expand Down
7 changes: 7 additions & 0 deletions paddle/phi/infermeta/backward.h
Expand Up @@ -191,6 +191,13 @@ void MultiplexGradInferMeta(const MetaTensor& ids,
const MetaTensor& out_grad,
std::vector<MetaTensor*> ins_grad);

void NanmedianGradInferMeta(const MetaTensor& x,
const MetaTensor& median_index,
const MetaTensor& out_grad,
const IntArray& axes,
bool keep_dim,
MetaTensor* x_grad);

void NllLossGradInferMeta(const MetaTensor& input,
const MetaTensor& label,
const MetaTensor& weight,
Expand Down
59 changes: 59 additions & 0 deletions paddle/phi/infermeta/unary.cc
Expand Up @@ -1246,6 +1246,65 @@ void MultinomialInferMeta(const MetaTensor& x,
out->set_dtype(DataType::INT64);
}

void NanmedianInferMeta(const MetaTensor& x,
const IntArray& axes,
bool keep_dim,
MetaTensor* out,
MetaTensor* median_index) {
std::vector<int64_t> axis_list = axes.GetData();
auto x_dim = x.dims();
int64_t x_rank = x_dim.size();
out->set_dtype(x.dtype());
median_index->set_dtype(DataType::INT64);
median_index->set_dims(make_ddim({x.numel() * 2}));

std::vector<int32_t> out_dim;
if (axis_list.empty()) {
if (keep_dim) {
for (int64_t i = 0; i < x_rank; i++) {
out_dim.push_back(1);
}
} else {
out_dim.push_back(1);
}
} else {
std::vector<int64_t> cleaned_axis;
for (auto& axis : axis_list) {
if (axis < 0) axis += x_rank;

PADDLE_ENFORCE_LT(
axis,
x_rank,
errors::InvalidArgument(
"Attr(axis) value should be in range [-R, R-1], R is "
"the rank of Input(X). But received axis: %d, R: %d. "
"Current Input(X)'s shape is=[%s].",
axis,
x_rank,
x_dim));

PADDLE_ENFORCE_EQ(
std::find(cleaned_axis.begin(), cleaned_axis.end(), axis),
cleaned_axis.end(),
errors::InvalidArgument("Attr(axes) has duplicated elements: %d.",
static_cast<int>(axis)));

cleaned_axis.push_back(axis);
}

for (int64_t i = 0; i < x_rank; i++) {
if (std::find(cleaned_axis.begin(), cleaned_axis.end(), i) ==
cleaned_axis.end()) {
out_dim.push_back(x_dim[i]);
} else if (keep_dim) {
out_dim.push_back(1);
}
}
}

out->set_dims(make_ddim(out_dim));
}

void NormInferMeta(const MetaTensor& x,
int axis,
float epsilon,
Expand Down
7 changes: 7 additions & 0 deletions paddle/phi/infermeta/unary.h
Expand Up @@ -178,6 +178,13 @@ void MultinomialInferMeta(const MetaTensor& x,
int num_samples,
bool replacement,
MetaTensor* out);

void NanmedianInferMeta(const MetaTensor& x,
const IntArray& axes,
bool keep_dim,
MetaTensor* out,
MetaTensor* median_index);

void NormInferMeta(const MetaTensor& x,
int axis,
float epsilon,
Expand Down
99 changes: 99 additions & 0 deletions paddle/phi/kernels/cpu/nanmedian_grad_kernel.cc
@@ -0,0 +1,99 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/nanmedian_grad_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/math_function.h"

namespace phi {

template <typename T, typename Context>
void CalcMedianGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& median_index,
const DenseTensor& out_grad,
const IntArray& axes,
DenseTensor* x_grad,
T* x_grad_ptr) {
phi::funcs::SetConstant<Context, T> set_zero;
set_zero(dev_ctx, x_grad, static_cast<T>(0));
if (!x_grad_ptr) return;

const int64_t* m_ptr = median_index.data<int64_t>();
const T* out_grad_ptr = out_grad.data<T>();
int64_t numel = x.numel();
auto x_dim = x.dims();
int64_t rank = x_dim.size();
int64_t stride = x_dim[rank - 1];

int64_t pre_dim = numel / stride;
int64_t i = 0;
int64_t offset = 0;
T div_factor = static_cast<T>(2.0);
for (i = 0; i < pre_dim; i++) {
if (m_ptr[2 * i] >= 0) {
if (m_ptr[2 * i] == m_ptr[2 * i + 1]) {
x_grad_ptr[offset + m_ptr[2 * i]] = out_grad_ptr[i];
} else {
x_grad_ptr[offset + m_ptr[2 * i]] = out_grad_ptr[i] / div_factor;
x_grad_ptr[offset + m_ptr[2 * i + 1]] = out_grad_ptr[i] / div_factor;
}
}
offset += stride;
}
}

template <typename T, typename Context>
void BaseMedianGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& median_index,
const DenseTensor& out_grad,
const IntArray& axes,
DenseTensor* x_grad) {
auto rank = x.dims().size();
T* x_grad_ptr = dev_ctx.template Alloc<T>(x_grad);
if (axes.size() && (rank > 1)) {
DenseTensor tmp_x_grad(*x_grad);
CalcMedianGradKernel<T, Context>(
dev_ctx, x, median_index, out_grad, axes, &tmp_x_grad, x_grad_ptr);
PostprocessMedianGradKernel<T, Context>(dev_ctx, &tmp_x_grad, axes, x_grad);
} else {
CalcMedianGradKernel<T, Context>(
dev_ctx, x, median_index, out_grad, axes, x_grad, x_grad_ptr);
}
}

template <typename T, typename Context>
void NanmedianGradKernel(const Context& dev_ctx,
const DenseTensor& input,
const DenseTensor& median_index,
const DenseTensor& out_grad,
const IntArray& axes,
bool keep_dim,
DenseTensor* x_grad) {
BaseMedianGradKernel<T, Context>(
dev_ctx, input, median_index, out_grad, axes, x_grad);
}

} // namespace phi

PD_REGISTER_KERNEL(nanmedian_grad,
CPU,
ALL_LAYOUT,
phi::NanmedianGradKernel,
float,
double,
int,
int64_t) {}

0 comments on commit 9f65623

Please sign in to comment.