Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
【PaddlePaddle Hackathon 2】15 新增 API Nanmedian (#42385)
* nanmedian op * 修改cuda kernel的bug * 修复count_if在其他硬件平台不兼容 * 修复某些cpu硬件不兼容 * 修复某些cpu硬件不兼容 * 修复isnan判断 * 兼容numpy低版本不支持全部nan的情况 * 兼容numpy低版本不支持全部nan的情况 * fix code example * fix api comment error * 修改反向传播逻辑以及c++处理逻辑 * 完成修改建议 * typo pre_dim * update en docs, test=document_fix * remove numpy in en doc, test=document_fix * add r,test=document_fix * 添加api到all * follow advice from chenwhql
- Loading branch information
Showing
17 changed files
with
1,406 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#include <memory> | ||
#include "paddle/fluid/framework/infershape_utils.h" | ||
#include "paddle/fluid/framework/op_registry.h" | ||
#include "paddle/fluid/framework/op_version_registry.h" | ||
#include "paddle/phi/core/infermeta_utils.h" | ||
#include "paddle/phi/infermeta/backward.h" | ||
#include "paddle/phi/infermeta/unary.h" | ||
|
||
namespace paddle { | ||
namespace operators { | ||
|
||
class NanmedianOp : public framework::OperatorWithKernel { | ||
public: | ||
using framework::OperatorWithKernel::OperatorWithKernel; | ||
|
||
framework::OpKernelType GetExpectedKernelType( | ||
const framework::ExecutionContext& ctx) const override { | ||
return framework::OpKernelType( | ||
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); | ||
} | ||
}; | ||
|
||
class NanmedianOpMaker : public framework::OpProtoAndCheckerMaker { | ||
public: | ||
void Make() override { | ||
AddInput("X", | ||
"(Tensor), " | ||
"the input feature data of NanmedianOp, dtype should be" | ||
"int32, int64, float16, float32 or float64."); | ||
AddOutput( | ||
"MedianIndex", | ||
"Store the index position of median values, The calculation differs " | ||
"in the odd or even valid elements numbers." | ||
"Along the axis, two elements contributed to the median value in " | ||
"each row." | ||
"If the amount of valid elements were even, both were the same.") | ||
.AsIntermediate() | ||
.AsExtra(); | ||
AddOutput("Out", | ||
"(Tensor)," | ||
" the output of NanmedianOp, whose dtype is the same as X"); | ||
AddAttr<bool>("keepdim", | ||
"(bool, default true) " | ||
"If true, retain the reduced axis with length 1.") | ||
.SetDefault(true); | ||
AddAttr<std::vector<int>>("axis", | ||
"(std::vector<int>). List of integers," | ||
" indicating the dimensions to calculate medians") | ||
.SetDefault({}); | ||
AddComment(R"DOC( | ||
Nanmedian operator | ||
This operator is considered as an extention of median operation, | ||
which supports specifically the case of NaN values in the input. | ||
If all the elements in input are NaN it will also return NaN. | ||
If no elements in input are Nan, this op is identical to thie median op. | ||
If the valid count of elements is a even number, the average value of | ||
the elements in the middle is calculated as the median. | ||
This operator can also supports multiple axis. | ||
)DOC"); | ||
} | ||
}; | ||
|
||
template <typename T> | ||
class NanmedianGradMaker : public framework::SingleGradOpMaker<T> { | ||
public: | ||
using framework::SingleGradOpMaker<T>::SingleGradOpMaker; | ||
|
||
void Apply(GradOpPtr<T> op) const override { | ||
op->SetType("nanmedian_grad"); | ||
op->SetInput("X", this->Input("X")); | ||
op->SetInput("MedianIndex", this->Output("MedianIndex")); | ||
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); | ||
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); | ||
op->SetAttrMap(this->Attrs()); | ||
} | ||
}; | ||
|
||
class NanmedianGradOp : public framework::OperatorWithKernel { | ||
public: | ||
using framework::OperatorWithKernel::OperatorWithKernel; | ||
|
||
protected: | ||
framework::OpKernelType GetExpectedKernelType( | ||
const framework::ExecutionContext& ctx) const override { | ||
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( | ||
ctx, framework::GradVarName("Out")), | ||
ctx.GetPlace()); | ||
} | ||
}; | ||
|
||
} // namespace operators | ||
} // namespace paddle | ||
|
||
namespace ops = paddle::operators; | ||
DECLARE_INFER_SHAPE_FUNCTOR(nanmedian, NanmedianInferShapeFunctor, | ||
PD_INFER_META(phi::NanmedianInferMeta)); | ||
|
||
REGISTER_OPERATOR(nanmedian, ops::NanmedianOp, ops::NanmedianOpMaker, | ||
ops::NanmedianGradMaker<paddle::framework::OpDesc>, | ||
ops::NanmedianGradMaker<paddle::imperative::OpBase>, | ||
NanmedianInferShapeFunctor); | ||
|
||
DECLARE_INFER_SHAPE_FUNCTOR(nanmedian_grad, NanmedianGradInferShapeFunctor, | ||
PD_INFER_META(phi::NanmedianGradInferMeta)); | ||
|
||
REGISTER_OPERATOR(nanmedian_grad, ops::NanmedianGradOp, | ||
NanmedianGradInferShapeFunctor); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#include "paddle/phi/kernels/nanmedian_grad_kernel.h" | ||
#include "paddle/phi/backends/cpu/cpu_context.h" | ||
#include "paddle/phi/core/kernel_registry.h" | ||
#include "paddle/phi/kernels/funcs/math_function.h" | ||
|
||
namespace phi { | ||
|
||
template <typename T, typename Context> | ||
void CalcMedianGradKernel(const Context& dev_ctx, | ||
const DenseTensor& x, | ||
const DenseTensor& median_index, | ||
const DenseTensor& out_grad, | ||
const IntArray& axes, | ||
DenseTensor* x_grad, | ||
T* x_grad_ptr) { | ||
phi::funcs::SetConstant<Context, T> set_zero; | ||
set_zero(dev_ctx, x_grad, static_cast<T>(0)); | ||
if (!x_grad_ptr) return; | ||
|
||
const int64_t* m_ptr = median_index.data<int64_t>(); | ||
const T* out_grad_ptr = out_grad.data<T>(); | ||
int64_t numel = x.numel(); | ||
auto x_dim = x.dims(); | ||
int64_t rank = x_dim.size(); | ||
int64_t stride = x_dim[rank - 1]; | ||
|
||
int64_t pre_dim = numel / stride; | ||
int64_t i = 0; | ||
int64_t offset = 0; | ||
T div_factor = static_cast<T>(2.0); | ||
for (i = 0; i < pre_dim; i++) { | ||
if (m_ptr[2 * i] >= 0) { | ||
if (m_ptr[2 * i] == m_ptr[2 * i + 1]) { | ||
x_grad_ptr[offset + m_ptr[2 * i]] = out_grad_ptr[i]; | ||
} else { | ||
x_grad_ptr[offset + m_ptr[2 * i]] = out_grad_ptr[i] / div_factor; | ||
x_grad_ptr[offset + m_ptr[2 * i + 1]] = out_grad_ptr[i] / div_factor; | ||
} | ||
} | ||
offset += stride; | ||
} | ||
} | ||
|
||
template <typename T, typename Context> | ||
void BaseMedianGradKernel(const Context& dev_ctx, | ||
const DenseTensor& x, | ||
const DenseTensor& median_index, | ||
const DenseTensor& out_grad, | ||
const IntArray& axes, | ||
DenseTensor* x_grad) { | ||
auto rank = x.dims().size(); | ||
T* x_grad_ptr = dev_ctx.template Alloc<T>(x_grad); | ||
if (axes.size() && (rank > 1)) { | ||
DenseTensor tmp_x_grad(*x_grad); | ||
CalcMedianGradKernel<T, Context>( | ||
dev_ctx, x, median_index, out_grad, axes, &tmp_x_grad, x_grad_ptr); | ||
PostprocessMedianGradKernel<T, Context>(dev_ctx, &tmp_x_grad, axes, x_grad); | ||
} else { | ||
CalcMedianGradKernel<T, Context>( | ||
dev_ctx, x, median_index, out_grad, axes, x_grad, x_grad_ptr); | ||
} | ||
} | ||
|
||
template <typename T, typename Context> | ||
void NanmedianGradKernel(const Context& dev_ctx, | ||
const DenseTensor& input, | ||
const DenseTensor& median_index, | ||
const DenseTensor& out_grad, | ||
const IntArray& axes, | ||
bool keep_dim, | ||
DenseTensor* x_grad) { | ||
BaseMedianGradKernel<T, Context>( | ||
dev_ctx, input, median_index, out_grad, axes, x_grad); | ||
} | ||
|
||
} // namespace phi | ||
|
||
PD_REGISTER_KERNEL(nanmedian_grad, | ||
CPU, | ||
ALL_LAYOUT, | ||
phi::NanmedianGradKernel, | ||
float, | ||
double, | ||
int, | ||
int64_t) {} |
Oops, something went wrong.