diff --git a/paddle/fluid/operators/nanmedian_op.cc b/paddle/fluid/operators/nanmedian_op.cc new file mode 100644 index 0000000000000..23a497bdb1d3d --- /dev/null +++ b/paddle/fluid/operators/nanmedian_op.cc @@ -0,0 +1,125 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include "paddle/fluid/framework/infershape_utils.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/op_version_registry.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/backward.h" +#include "paddle/phi/infermeta/unary.h" + +namespace paddle { +namespace operators { + +class NanmedianOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); + } +}; + +class NanmedianOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", + "(Tensor), " + "the input feature data of NanmedianOp, dtype should be" + "int32, int64, float16, float32 or float64."); + AddOutput( + "MedianIndex", + "Store the index position of median values, The calculation differs " + "in the odd or even valid elements numbers." + "Along the axis, two elements contributed to the median value in " + "each row." + "If the amount of valid elements were even, both were the same.") + .AsIntermediate() + .AsExtra(); + AddOutput("Out", + "(Tensor)," + " the output of NanmedianOp, whose dtype is the same as X"); + AddAttr("keepdim", + "(bool, default true) " + "If true, retain the reduced axis with length 1.") + .SetDefault(true); + AddAttr>("axis", + "(std::vector). List of integers," + " indicating the dimensions to calculate medians") + .SetDefault({}); + AddComment(R"DOC( + Nanmedian operator + + This operator is considered as an extention of median operation, + which supports specifically the case of NaN values in the input. + + If all the elements in input are NaN it will also return NaN. + If no elements in input are Nan, this op is identical to thie median op. + + If the valid count of elements is a even number, the average value of + the elements in the middle is calculated as the median. + + This operator can also supports multiple axis. + )DOC"); + } +}; + +template +class NanmedianGradMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + void Apply(GradOpPtr op) const override { + op->SetType("nanmedian_grad"); + op->SetInput("X", this->Input("X")); + op->SetInput("MedianIndex", this->Output("MedianIndex")); + op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); + op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); + op->SetAttrMap(this->Attrs()); + } +}; + +class NanmedianGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.GetPlace()); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +DECLARE_INFER_SHAPE_FUNCTOR(nanmedian, NanmedianInferShapeFunctor, + PD_INFER_META(phi::NanmedianInferMeta)); + +REGISTER_OPERATOR(nanmedian, ops::NanmedianOp, ops::NanmedianOpMaker, + ops::NanmedianGradMaker, + ops::NanmedianGradMaker, + NanmedianInferShapeFunctor); + +DECLARE_INFER_SHAPE_FUNCTOR(nanmedian_grad, NanmedianGradInferShapeFunctor, + PD_INFER_META(phi::NanmedianGradInferMeta)); + +REGISTER_OPERATOR(nanmedian_grad, ops::NanmedianGradOp, + NanmedianGradInferShapeFunctor); diff --git a/paddle/phi/infermeta/backward.cc b/paddle/phi/infermeta/backward.cc index 78f8ff9e00ce5..521eb03fd770f 100644 --- a/paddle/phi/infermeta/backward.cc +++ b/paddle/phi/infermeta/backward.cc @@ -433,6 +433,17 @@ void MultiplexGradInferMeta(const MetaTensor& ids, } } +void NanmedianGradInferMeta(const MetaTensor& x, + const MetaTensor& median_index, + const MetaTensor& out_grad, + const IntArray& axes, + bool keep_dim, + MetaTensor* x_grad) { + auto x_dims = x.dims(); + x_grad->set_dims(x_dims); + x_grad->set_dtype(x.dtype()); +} + void NllLossGradInferMeta(const MetaTensor& x, const MetaTensor& label, const MetaTensor& weight, diff --git a/paddle/phi/infermeta/backward.h b/paddle/phi/infermeta/backward.h index b52734eb5b10c..93e2d4c43bc3f 100644 --- a/paddle/phi/infermeta/backward.h +++ b/paddle/phi/infermeta/backward.h @@ -191,6 +191,13 @@ void MultiplexGradInferMeta(const MetaTensor& ids, const MetaTensor& out_grad, std::vector ins_grad); +void NanmedianGradInferMeta(const MetaTensor& x, + const MetaTensor& median_index, + const MetaTensor& out_grad, + const IntArray& axes, + bool keep_dim, + MetaTensor* x_grad); + void NllLossGradInferMeta(const MetaTensor& input, const MetaTensor& label, const MetaTensor& weight, diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index 1ec804d1bf822..f736bf50162d8 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -1246,6 +1246,65 @@ void MultinomialInferMeta(const MetaTensor& x, out->set_dtype(DataType::INT64); } +void NanmedianInferMeta(const MetaTensor& x, + const IntArray& axes, + bool keep_dim, + MetaTensor* out, + MetaTensor* median_index) { + std::vector axis_list = axes.GetData(); + auto x_dim = x.dims(); + int64_t x_rank = x_dim.size(); + out->set_dtype(x.dtype()); + median_index->set_dtype(DataType::INT64); + median_index->set_dims(make_ddim({x.numel() * 2})); + + std::vector out_dim; + if (axis_list.empty()) { + if (keep_dim) { + for (int64_t i = 0; i < x_rank; i++) { + out_dim.push_back(1); + } + } else { + out_dim.push_back(1); + } + } else { + std::vector cleaned_axis; + for (auto& axis : axis_list) { + if (axis < 0) axis += x_rank; + + PADDLE_ENFORCE_LT( + axis, + x_rank, + errors::InvalidArgument( + "Attr(axis) value should be in range [-R, R-1], R is " + "the rank of Input(X). But received axis: %d, R: %d. " + "Current Input(X)'s shape is=[%s].", + axis, + x_rank, + x_dim)); + + PADDLE_ENFORCE_EQ( + std::find(cleaned_axis.begin(), cleaned_axis.end(), axis), + cleaned_axis.end(), + errors::InvalidArgument("Attr(axes) has duplicated elements: %d.", + static_cast(axis))); + + cleaned_axis.push_back(axis); + } + + for (int64_t i = 0; i < x_rank; i++) { + if (std::find(cleaned_axis.begin(), cleaned_axis.end(), i) == + cleaned_axis.end()) { + out_dim.push_back(x_dim[i]); + } else if (keep_dim) { + out_dim.push_back(1); + } + } + } + + out->set_dims(make_ddim(out_dim)); +} + void NormInferMeta(const MetaTensor& x, int axis, float epsilon, diff --git a/paddle/phi/infermeta/unary.h b/paddle/phi/infermeta/unary.h index 25ea003f58fd9..c21ef0e2d1103 100644 --- a/paddle/phi/infermeta/unary.h +++ b/paddle/phi/infermeta/unary.h @@ -178,6 +178,13 @@ void MultinomialInferMeta(const MetaTensor& x, int num_samples, bool replacement, MetaTensor* out); + +void NanmedianInferMeta(const MetaTensor& x, + const IntArray& axes, + bool keep_dim, + MetaTensor* out, + MetaTensor* median_index); + void NormInferMeta(const MetaTensor& x, int axis, float epsilon, diff --git a/paddle/phi/kernels/cpu/nanmedian_grad_kernel.cc b/paddle/phi/kernels/cpu/nanmedian_grad_kernel.cc new file mode 100644 index 0000000000000..156124c214895 --- /dev/null +++ b/paddle/phi/kernels/cpu/nanmedian_grad_kernel.cc @@ -0,0 +1,99 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/nanmedian_grad_kernel.h" +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/math_function.h" + +namespace phi { + +template +void CalcMedianGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& median_index, + const DenseTensor& out_grad, + const IntArray& axes, + DenseTensor* x_grad, + T* x_grad_ptr) { + phi::funcs::SetConstant set_zero; + set_zero(dev_ctx, x_grad, static_cast(0)); + if (!x_grad_ptr) return; + + const int64_t* m_ptr = median_index.data(); + const T* out_grad_ptr = out_grad.data(); + int64_t numel = x.numel(); + auto x_dim = x.dims(); + int64_t rank = x_dim.size(); + int64_t stride = x_dim[rank - 1]; + + int64_t pre_dim = numel / stride; + int64_t i = 0; + int64_t offset = 0; + T div_factor = static_cast(2.0); + for (i = 0; i < pre_dim; i++) { + if (m_ptr[2 * i] >= 0) { + if (m_ptr[2 * i] == m_ptr[2 * i + 1]) { + x_grad_ptr[offset + m_ptr[2 * i]] = out_grad_ptr[i]; + } else { + x_grad_ptr[offset + m_ptr[2 * i]] = out_grad_ptr[i] / div_factor; + x_grad_ptr[offset + m_ptr[2 * i + 1]] = out_grad_ptr[i] / div_factor; + } + } + offset += stride; + } +} + +template +void BaseMedianGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& median_index, + const DenseTensor& out_grad, + const IntArray& axes, + DenseTensor* x_grad) { + auto rank = x.dims().size(); + T* x_grad_ptr = dev_ctx.template Alloc(x_grad); + if (axes.size() && (rank > 1)) { + DenseTensor tmp_x_grad(*x_grad); + CalcMedianGradKernel( + dev_ctx, x, median_index, out_grad, axes, &tmp_x_grad, x_grad_ptr); + PostprocessMedianGradKernel(dev_ctx, &tmp_x_grad, axes, x_grad); + } else { + CalcMedianGradKernel( + dev_ctx, x, median_index, out_grad, axes, x_grad, x_grad_ptr); + } +} + +template +void NanmedianGradKernel(const Context& dev_ctx, + const DenseTensor& input, + const DenseTensor& median_index, + const DenseTensor& out_grad, + const IntArray& axes, + bool keep_dim, + DenseTensor* x_grad) { + BaseMedianGradKernel( + dev_ctx, input, median_index, out_grad, axes, x_grad); +} + +} // namespace phi + +PD_REGISTER_KERNEL(nanmedian_grad, + CPU, + ALL_LAYOUT, + phi::NanmedianGradKernel, + float, + double, + int, + int64_t) {} diff --git a/paddle/phi/kernels/cpu/nanmedian_kernel.cc b/paddle/phi/kernels/cpu/nanmedian_kernel.cc new file mode 100644 index 0000000000000..ed38405c9179f --- /dev/null +++ b/paddle/phi/kernels/cpu/nanmedian_kernel.cc @@ -0,0 +1,208 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/nanmedian_kernel.h" +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/top_k_kernel.h" + +namespace phi { + +template +void CalcMedianFunc(const Context& dev_ctx, + const DenseTensor& x, + const std::vector& nan_counts, + bool ignore_nan, + int64_t sort_k, + int64_t stride, + int64_t pre_dim, + T* o_ptr, + int64_t* m_ptr) { + bool should_ignore_nan = ignore_nan; + DenseTensor sort_out; + DenseTensor sort_indices; + auto sort_dim = x.dims(); + int64_t rank = sort_dim.size(); + sort_dim[rank - 1] = sort_k; + sort_out.Resize(sort_dim); + sort_indices.Resize(sort_dim); + + dev_ctx.template Alloc(&sort_out); + T* sort_out_ptr = sort_out.data(); + dev_ctx.template Alloc(&sort_indices); + int64_t* sort_indices_ptr = sort_indices.data(); + + TopkKernel( + dev_ctx, x, Scalar(sort_k), -1, false, true, &sort_out, &sort_indices); + + T div_factor = static_cast(2.0); + int64_t offset = 0; + int64_t i = 0; + bool is_ori_odd = stride & 1; + if (should_ignore_nan) { + for (i = 0; i < pre_dim; i++) { + offset = i * sort_k; + if (nan_counts[i] == stride) { + m_ptr[i * 2] = -1; + m_ptr[i * 2 + 1] = -1; + o_ptr[i] = sort_out_ptr[offset]; + } else { + int64_t nan_k = nan_counts[i] > 0 + ? static_cast(stride - nan_counts[i]) + : sort_k; + int64_t row_pos = static_cast(nan_k >> 1); + int64_t pos = offset + row_pos; + if (nan_k & 1) { + m_ptr[2 * i] = sort_indices_ptr[pos]; + m_ptr[2 * i + 1] = sort_indices_ptr[pos]; + o_ptr[i] = sort_out_ptr[pos]; + } else { + m_ptr[2 * i] = + row_pos > 0 ? sort_indices_ptr[pos - 1] : sort_indices_ptr[pos]; + m_ptr[2 * i + 1] = sort_indices_ptr[pos]; + T m_val_left = + row_pos > 0 ? sort_out_ptr[pos - 1] : sort_out_ptr[pos]; + T m_val_right = sort_out_ptr[pos]; + o_ptr[i] = (m_val_left + m_val_right) / div_factor; + } + } + } + } else { + if (is_ori_odd) { + for (i = 0; i < pre_dim; i++) { + offset = i * sort_k; + int64_t pos = offset + sort_k - 1; + o_ptr[i] = sort_out_ptr[pos]; + m_ptr[2 * i] = sort_indices_ptr[pos]; + m_ptr[2 * i + 1] = sort_indices_ptr[pos]; + } + } else { + for (i = 0; i < pre_dim; i++) { + offset = i * sort_k; + int64_t pos = offset + sort_k - 1; + m_ptr[2 * i] = + sort_k > 1 ? sort_indices_ptr[pos - 1] : sort_indices_ptr[pos]; + m_ptr[2 * i + 1] = sort_indices_ptr[pos]; + T m_val_left = sort_k > 1 ? sort_out_ptr[pos - 1] : sort_out_ptr[pos]; + T m_val_right = sort_out_ptr[pos]; + o_ptr[i] = (m_val_left + m_val_right) / div_factor; + } + } + } +} + +template +void ProcessMedianKernel(const Context& dev_ctx, + const DenseTensor& x, + T* o_ptr, + int64_t* m_ptr, + bool ignore_nan) { + bool should_ignore_nan = ignore_nan; + const T* x_ptr = x.data(); + + int64_t numel = x.numel(); + auto x_dim = x.dims(); + int64_t x_rank = x_dim.size(); + int64_t stride = x_dim[x_rank - 1]; + int64_t pre_dim = numel / stride; + int64_t i = 0; + + int64_t max_valid_num = 0; + std::vector nan_counts; + if (should_ignore_nan) { + int64_t total_nan_num = 0; + std::vector col_vec; + col_vec.reserve(stride); + col_vec.resize(stride); + nan_counts.clear(); + nan_counts.reserve(pre_dim); + nan_counts.resize(pre_dim); + for (int64_t i = 0; i < pre_dim; i++) { + col_vec.clear(); + col_vec.insert( + col_vec.begin(), x_ptr + i * stride, x_ptr + (i + 1) * stride); + nan_counts[i] = + std::count_if(col_vec.begin(), col_vec.end(), [&](const T& val) { + return std::isnan(static_cast(val)); + }); + total_nan_num += nan_counts[i]; + if (stride - nan_counts[i] > max_valid_num) + max_valid_num = stride - nan_counts[i]; + } + // all elems are nan + if (total_nan_num == numel) { + for (i = 0; i < pre_dim; i++) { + o_ptr[i] = x_ptr[0]; + m_ptr[2 * i] = -1; + m_ptr[2 * i + 1] = -1; + } + return; + } + should_ignore_nan = total_nan_num > 0; + } + + int64_t sort_k = should_ignore_nan ? max_valid_num : ((stride >> 1) + 1); + CalcMedianFunc(dev_ctx, + x, + nan_counts, + should_ignore_nan, + sort_k, + stride, + pre_dim, + o_ptr, + m_ptr); +} + +template +void BaseMedianKernel(const Context& dev_ctx, + const DenseTensor& input, + const IntArray& axes, + DenseTensor* out, + DenseTensor* median_index, + bool ignore_nan) { + DenseTensor x; + auto rank = input.dims().size(); + if ((axes.size() == 0) || rank <= 1) { + x = input; + x.Resize({input.numel()}); + } else { + PreprocessMedianKernel(dev_ctx, input, axes, &x); + } + + T* o_ptr = dev_ctx.template Alloc(out); + int64_t* m_ptr = dev_ctx.template Alloc(median_index); + ProcessMedianKernel(dev_ctx, x, o_ptr, m_ptr, ignore_nan); + out->Resize(out->dims()); +} + +template +void NanmedianKernel(const Context& dev_ctx, + const DenseTensor& x, + const IntArray& axes, + bool keepdim, + DenseTensor* out, + DenseTensor* median_index) { + BaseMedianKernel(dev_ctx, x, axes, out, median_index, true); +} + +} // namespace phi + +PD_REGISTER_KERNEL(nanmedian, + CPU, + ALL_LAYOUT, + phi::NanmedianKernel, + float, + double, + int, + int64_t) {} diff --git a/paddle/phi/kernels/gpu/nanmedian_grad_kernel.cu b/paddle/phi/kernels/gpu/nanmedian_grad_kernel.cu new file mode 100644 index 0000000000000..a7cd49c0e53f3 --- /dev/null +++ b/paddle/phi/kernels/gpu/nanmedian_grad_kernel.cu @@ -0,0 +1,122 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h" +#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/core/tensor_meta.h" +#include "paddle/phi/kernels/funcs/math_function.h" +#include "paddle/phi/kernels/nanmedian_grad_kernel.h" + +namespace phi { + +using paddle::platform::PADDLE_CUDA_NUM_THREADS; +inline int GET_BLOCKS(const int N) { + return (N + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS; +} + +template +__global__ void KernelNanmedianGrad(const T* x_ptr, + const int64_t* medians_ptr, + const T* out_grad_ptr, + T* x_grad_ptr, + int64_t stride, + int64_t pre_dim, + T div_factor) { + CUDA_KERNEL_LOOP(index, pre_dim) { + int64_t offset = index * stride; + if (medians_ptr[2 * index] >= 0) { + if (medians_ptr[2 * index] == medians_ptr[2 * index + 1]) { + x_grad_ptr[offset + medians_ptr[2 * index]] = out_grad_ptr[index]; + } else { + x_grad_ptr[offset + medians_ptr[2 * index]] = + out_grad_ptr[index] / div_factor; + x_grad_ptr[offset + medians_ptr[2 * index + 1]] = + out_grad_ptr[index] / div_factor; + } + } + } +} + +template +void CalcMedianGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& median_index, + const DenseTensor& out_grad, + DenseTensor* x_grad, + T* x_grad_ptr) { + phi::funcs::SetConstant set_zero; + set_zero(dev_ctx, x_grad, static_cast(0)); + + auto stream = dev_ctx.stream(); + const T* x_ptr = x.data(); + const int64_t* m_ptr = median_index.data(); + const T* out_grad_ptr = out_grad.data(); + + int64_t numel = x.numel(); + auto x_dim = x.dims(); + int64_t x_rank = x_dim.size(); + int64_t stride = x_dim[x_rank - 1]; + int64_t pre_dim = numel / stride; + + T div_factor = static_cast(2.0); + KernelNanmedianGrad< + T><<>>( + x_ptr, m_ptr, out_grad_ptr, x_grad_ptr, stride, pre_dim, div_factor); +} + +template +void BaseMedianGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& median_index, + const DenseTensor& out_grad, + const IntArray& axes, + DenseTensor* x_grad) { + auto rank = x.dims().size(); + T* x_grad_ptr = dev_ctx.template Alloc(x_grad); + if (axes.size() && (rank > 1)) { + DenseTensor tmp_x_grad(*x_grad); + CalcMedianGradKernel( + dev_ctx, x, median_index, out_grad, &tmp_x_grad, x_grad_ptr); + PostprocessMedianGradKernel(dev_ctx, &tmp_x_grad, axes, x_grad); + } else { + CalcMedianGradKernel( + dev_ctx, x, median_index, out_grad, x_grad, x_grad_ptr); + } +} + +template +void NanmedianGradKernel(const Context& dev_ctx, + const DenseTensor& input, + const DenseTensor& median_index, + const DenseTensor& out_grad, + const IntArray& axes, + bool keep_dim, + DenseTensor* x_grad) { + BaseMedianGradKernel( + dev_ctx, input, median_index, out_grad, axes, x_grad); +} + +} // namespace phi + +PD_REGISTER_KERNEL(nanmedian_grad, + GPU, + ALL_LAYOUT, + phi::NanmedianGradKernel, + float, + double, + int, + int64_t, + phi::dtype::float16) {} diff --git a/paddle/phi/kernels/gpu/nanmedian_kernel.cu b/paddle/phi/kernels/gpu/nanmedian_kernel.cu new file mode 100644 index 0000000000000..5975e2748997e --- /dev/null +++ b/paddle/phi/kernels/gpu/nanmedian_kernel.cu @@ -0,0 +1,287 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/memory/memcpy.h" +#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h" +#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/full_kernel.h" +#include "paddle/phi/kernels/nanmedian_kernel.h" +#include "paddle/phi/kernels/top_k_kernel.h" + +namespace phi { + +using paddle::platform::PADDLE_CUDA_NUM_THREADS; + +inline int GET_BLOCKS(const int N) { + return (N + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS; +} + +template +__global__ void KernelNanCounts(const T* input, + const int numel, + const int64_t pre_dim, + const int64_t stride, + T min_val, + int64_t* nan_total, + int64_t* nan_counts) { + extern __shared__ int64_t buf[]; + for (int i = threadIdx.x; i < pre_dim; i += blockDim.x) { + buf[i] = 0; + nan_counts[i] = 0; + } + + if (threadIdx.x == 0) { + nan_total[0] = 0; + nan_total[1] = 0; + } + + __syncthreads(); + + CUDA_KERNEL_LOOP(index, numel) { + const T x = input[index]; + if (isnan(static_cast(x))) { + auto bin = static_cast(index / stride); + paddle::platform::CudaAtomicAdd(&buf[bin], 1); + } + } + __syncthreads(); + + for (int i = threadIdx.x; i < pre_dim; i += blockDim.x) { + paddle::platform::CudaAtomicAdd(&nan_counts[i], buf[i]); + paddle::platform::CudaAtomicAdd(&nan_total[0], buf[i]); + paddle::platform::CudaAtomicMax(&nan_total[1], stride - buf[i]); + } +} + +template +__global__ void CalcMedianKernel(const T* sort_out_ptr, + const int64_t* sort_indices_ptr, + int64_t* median_val, + T* output, + T div_factor, + const bool is_odd, + const int64_t pre_dim, + const int64_t stride) { + CUDA_KERNEL_LOOP(index, pre_dim) { + int64_t pos = static_cast((index + 1) * stride) - 1; + if (is_odd) { + median_val[index * 2] = sort_indices_ptr[pos]; + median_val[index * 2 + 1] = sort_indices_ptr[pos]; + output[index] = sort_out_ptr[pos]; + } else { + median_val[index * 2] = + pos > 0 ? sort_indices_ptr[pos - 1] : sort_indices_ptr[pos]; + median_val[index * 2 + 1] = sort_indices_ptr[pos]; + T median_val_left = pos > 0 ? sort_out_ptr[pos - 1] : sort_out_ptr[pos]; + T median_val_right = sort_out_ptr[pos]; + output[index] = (median_val_left + median_val_right) / div_factor; + } + } +} + +template +__global__ void CalcNanmedianKernel(const T* sort_out_ptr, + const int64_t* sort_indices_ptr, + int64_t* nan_counts, + int64_t* median_val, + T* output, + const bool is_odd, + const int64_t pre_dim, + const int64_t max_valid_num, + const int64_t stride, + const T div_factor, + const T nan_val) { + CUDA_KERNEL_LOOP(index, pre_dim) { + int64_t pos = static_cast(index * max_valid_num); + int64_t nan_cnt = nan_counts[index]; + if (nan_cnt == stride) { + median_val[index * 2] = -1; + median_val[index * 2 + 1] = -1; + output[index] = nan_val; + } else { + int64_t nan_k = + nan_cnt > 0 ? static_cast(stride - nan_cnt) : max_valid_num; + int64_t row_pos = static_cast(nan_k >> 1); + pos += row_pos; + + if (nan_k & 1) { + median_val[index * 2] = sort_indices_ptr[pos]; + median_val[index * 2 + 1] = sort_indices_ptr[pos]; + output[index] = sort_out_ptr[pos]; + } else { + median_val[index * 2] = + pos > 0 ? sort_indices_ptr[pos - 1] : sort_indices_ptr[pos]; + median_val[index * 2 + 1] = sort_indices_ptr[pos]; + T median_val_left = pos > 0 ? sort_out_ptr[pos - 1] : sort_out_ptr[pos]; + T median_val_right = sort_out_ptr[pos]; + output[index] = (median_val_left + median_val_right) / div_factor; + } + } + } +} + +template +void ProcessMedianKernel(const Context& dev_ctx, + const DenseTensor& x, + bool ignore_nan, + DenseTensor* out, + int64_t* m_ptr) { + bool should_ignore_nan = ignore_nan; + auto stream = dev_ctx.stream(); + + const T* x_ptr = x.data(); + T* o_ptr = dev_ctx.template Alloc(out); + int64_t numel = x.numel(); + auto x_dim = x.dims(); + int64_t x_rank = x_dim.size(); + int64_t stride = x_dim[x_rank - 1]; + int64_t pre_dim = numel / stride; + int64_t i = 0; + + DenseTensor nan_counts, nan_stat; + int64_t* nan_counts_ptr; + int64_t max_valid_num = 0; + if (should_ignore_nan) { + nan_counts.Resize(phi::make_ddim({pre_dim})); + dev_ctx.template Alloc(&nan_counts); + nan_counts_ptr = nan_counts.data(); + nan_stat.Resize(phi::make_ddim({2})); + int64_t* nan_stat_mem = dev_ctx.template Alloc(&nan_stat); + int64_t* nan_stat_ptr = nan_stat.data(); + + KernelNanCounts<<>>(x_ptr, + numel, + pre_dim, + stride, + std::numeric_limits::min(), + nan_stat_ptr, + nan_counts_ptr); + + auto nan_stat_mem_cpu = + paddle::memory::Alloc(phi::CPUPlace(), sizeof(int64_t) * 2); + int64_t* nan_stat_cpu_ptr = + reinterpret_cast(nan_stat_mem_cpu->ptr()); + paddle::memory::Copy(phi::CPUPlace(), + nan_stat_cpu_ptr, + dev_ctx.GetPlace(), + nan_stat_mem, + sizeof(int64_t) * 2, + stream); + + // all elements are nan values + T nan_val = std::numeric_limits::quiet_NaN(); + if (nan_stat_cpu_ptr[0] == numel) { + FullLikeKernel(dev_ctx, x, nan_val, x.dtype(), out); + return; + } + + should_ignore_nan = nan_stat_cpu_ptr[0] > 0; + max_valid_num = nan_stat_cpu_ptr[1]; + } + + int64_t sort_k = should_ignore_nan ? max_valid_num : ((stride >> 1) + 1); + bool is_ori_odd = stride & 1; + + DenseTensor sort_out, sort_indices; + auto sort_dim = x.dims(); + int64_t rank = sort_dim.size(); + sort_dim[rank - 1] = sort_k; + sort_out.Resize(sort_dim); + sort_indices.Resize(sort_dim); + + dev_ctx.template Alloc(&sort_out); + T* sort_out_ptr = sort_out.data(); + dev_ctx.template Alloc(&sort_indices); + int64_t* sort_indices_ptr = sort_indices.data(); + + TopkKernel( + dev_ctx, x, Scalar(sort_k), -1, false, true, &sort_out, &sort_indices); + + T div_factor = static_cast(2.0); + T nan_val = std::numeric_limits::quiet_NaN(); + if (should_ignore_nan) { + CalcNanmedianKernel< + T><<>>( + sort_out_ptr, + sort_indices_ptr, + nan_counts_ptr, + m_ptr, + o_ptr, + is_ori_odd, + pre_dim, + max_valid_num, + stride, + div_factor, + nan_val); + } else { + CalcMedianKernel< + T><<>>( + sort_out_ptr, + sort_indices_ptr, + m_ptr, + o_ptr, + div_factor, + is_ori_odd, + pre_dim, + sort_k); + } +} + +template +void BaseMedianKernel(const Context& dev_ctx, + const DenseTensor& input, + const IntArray& axes, + bool ignore_nan, + DenseTensor* out, + DenseTensor* median_index) { + DenseTensor x; + auto rank = input.dims().size(); + if ((axes.size() == 0) || rank <= 1) { + x = input; + x.Resize({input.numel()}); + } else { + PreprocessMedianKernel(dev_ctx, input, axes, &x); + } + + int64_t* m_ptr = dev_ctx.template Alloc(median_index); + ProcessMedianKernel(dev_ctx, x, ignore_nan, out, m_ptr); + out->Resize(out->dims()); +} + +template +void NanmedianKernel(const Context& dev_ctx, + const DenseTensor& x, + const IntArray& axes, + bool keepdim, + DenseTensor* out, + DenseTensor* median_index) { + BaseMedianKernel(dev_ctx, x, axes, true, out, median_index); +} + +} // namespace phi + +PD_REGISTER_KERNEL(nanmedian, + GPU, + ALL_LAYOUT, + phi::NanmedianKernel, + float, + double, + int, + int64_t, + phi::dtype::float16) {} diff --git a/paddle/phi/kernels/nanmedian_grad_kernel.h b/paddle/phi/kernels/nanmedian_grad_kernel.h new file mode 100644 index 0000000000000..dc7321c1aa751 --- /dev/null +++ b/paddle/phi/kernels/nanmedian_grad_kernel.h @@ -0,0 +1,73 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "paddle/phi/common/int_array.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/kernels/funcs/math_function.h" + +namespace phi { + +template +void PostprocessMedianGradKernel(const Context& dev_ctx, + DenseTensor* input, + const IntArray& raw_axes, + DenseTensor* x) { + auto input_dim = input->dims(); + auto rank = input_dim.size(); + + std::vector axes = raw_axes.GetData(); + int64_t axes_size = static_cast(axes.size()); + for (int64_t i = 0; i < axes_size; i++) { + if (axes[i] < 0) { + axes[i] += rank; + } + } + + std::vector trans_back; + std::vector reshape_back; + trans_back.reserve(rank); + trans_back.resize(rank); + + int offset = 0; + for (int64_t i = 0; i < rank; i++) { + if (std::find(axes.begin(), axes.end(), i) == axes.end()) { + reshape_back.push_back(input_dim[i]); + trans_back[i] = offset; + offset += 1; + } + } + + for (int64_t i = 0; i < rank; i++) { + if (std::find(axes.begin(), axes.end(), i) != axes.end()) { + trans_back[i] = offset; + reshape_back.push_back(input_dim[i]); + offset += 1; + } + } + + input->Resize(make_ddim(reshape_back)); + funcs::TransCompute( + static_cast(trans_back.size()), dev_ctx, *input, x, trans_back); +} + +template +void NanmedianGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& median_index, + const DenseTensor& out_grad, + const IntArray& axes, + bool keep_dim, + DenseTensor* x_grad); +} // namespace phi diff --git a/paddle/phi/kernels/nanmedian_kernel.h b/paddle/phi/kernels/nanmedian_kernel.h new file mode 100644 index 0000000000000..374f420381bdc --- /dev/null +++ b/paddle/phi/kernels/nanmedian_kernel.h @@ -0,0 +1,75 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "paddle/phi/common/int_array.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/kernels/funcs/math_function.h" + +namespace phi { + +template +void PreprocessMedianKernel(const Context& dev_ctx, + const DenseTensor& input, + const IntArray& raw_axes, + DenseTensor* x) { + auto input_dim = input.dims(); + auto rank = input_dim.size(); + std::vector perm; + std::vector reshape; + + std::vector axes = raw_axes.GetData(); + int64_t axes_size = static_cast(axes.size()); + for (int64_t i = 0; i < axes_size; i++) { + if (axes[i] < 0) { + axes[i] += rank; + } + } + + for (int64_t i = 0; i < rank; i++) { + if (std::find(axes.begin(), axes.end(), i) == axes.end()) { + perm.push_back(i); + reshape.push_back(input_dim[i]); + } + } + + int64_t post_numel = 1; + for (int64_t i = 0; i < rank; i++) { + if (std::find(axes.begin(), axes.end(), i) != axes.end()) { + perm.push_back(i); + post_numel *= input_dim[i]; + } + } + reshape.push_back(post_numel); + + DDim trans_dim(input_dim); + int ndims = perm.size(); + for (int i = 0; i < ndims; i++) { + trans_dim[i] = input_dim[perm[i]]; + } + x->Resize(trans_dim); + dev_ctx.template Alloc(x); + funcs::TransCompute(ndims, dev_ctx, input, x, perm); + + x->Resize(make_ddim(reshape)); +} + +template +void NanmedianKernel(const Context& dev_ctx, + const DenseTensor& x, + const IntArray& axes, + bool keep_dim, + DenseTensor* out, + DenseTensor* medians); +} // namespace phi diff --git a/paddle/phi/ops/compat/nanmedian_sig.cc b/paddle/phi/ops/compat/nanmedian_sig.cc new file mode 100644 index 0000000000000..5ca0d450e3b41 --- /dev/null +++ b/paddle/phi/ops/compat/nanmedian_sig.cc @@ -0,0 +1,35 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +KernelSignature NanmedianOpArgumentMapping(const ArgumentMappingContext& ctx) { + return KernelSignature( + "nanmedian", {"X"}, {"axis", "keepdim"}, {"Out", "MedianIndex"}); +} + +KernelSignature NanmedianGradOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature("nanmedian_grad", + {"X", "MedianIndex", "Out@GRAD"}, + {"axis", "keepdim"}, + {"X@GRAD"}); +} + +} // namespace phi + +PD_REGISTER_ARG_MAPPING_FN(nanmedian, phi::NanmedianOpArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(nanmedian_grad, phi::NanmedianGradOpArgumentMapping); diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index 132105fb2b689..930918e967eed 100755 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -331,6 +331,7 @@ from .tensor.stat import var # noqa: F401 from .tensor.stat import numel # noqa: F401 from .tensor.stat import median # noqa: F401 +from .tensor.stat import nanmedian # noqa: F401 from .tensor.stat import quantile # noqa: F401 from .tensor.stat import nanquantile # noqa: F401 from .device import get_cudnn_version # noqa: F401 @@ -498,6 +499,7 @@ 'load', 'numel', 'median', + 'nanmedian', 'quantile', 'nanquantile', 'no_grad', diff --git a/python/paddle/fluid/tests/unittests/test_nanmedian.py b/python/paddle/fluid/tests/unittests/test_nanmedian.py new file mode 100644 index 0000000000000..2e1f13a8c7d9f --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_nanmedian.py @@ -0,0 +1,196 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +import paddle +import paddle.fluid.core as core + +np.random.seed(102) + + +class TestNanmedian(unittest.TestCase): + def setUp(self): + single_axis_shape = (120) + multi_axis_shape = (2, 3, 4, 5) + + self.fake_data = { + "single_axis_normal": + np.random.uniform(-1, 1, single_axis_shape).astype(np.float32), + "multi_axis_normal": + np.random.uniform(-1, 1, multi_axis_shape).astype(np.float32), + "single_axis_all_nan": np.full(single_axis_shape, np.nan), + "multi_axis_all_nan": np.full(multi_axis_shape, np.nan), + } + + single_partial_nan = self.fake_data["single_axis_normal"].copy() + single_partial_nan[single_partial_nan > 0] = np.nan + multi_partial_nan = self.fake_data["multi_axis_normal"].copy() + multi_partial_nan[multi_partial_nan > 0] = np.nan + self.fake_data["single_axis_partial_nan"] = single_partial_nan + self.fake_data["multi_axis_partial_nan"] = multi_partial_nan + + row_data = np.random.uniform(-1, 1, multi_axis_shape).astype(np.float32) + row_data[:, :, :, 0] = np.nan + row_data[:, :, :2, 1] = np.nan + row_data[:, :, 2:, 2] = np.nan + self.fake_data["row_nan_even"] = row_data + self.fake_data["row_nan_float64"] = row_data.astype(np.float64) + self.fake_data["row_nan_int64"] = row_data.astype(np.int64) + self.fake_data["row_nan_int32"] = row_data.astype(np.int32) + + col_data = np.random.uniform(-1, 1, multi_axis_shape).astype(np.float32) + col_data[:, :, 0, :] = np.nan + col_data[:, :, 1, :3] = np.nan + col_data[:, :, 2, 3:] = np.nan + self.fake_data["col_nan_odd"] = col_data + + self.place = paddle.CUDAPlace(0) if core.is_compiled_with_cuda() \ + else paddle.CPUPlace() + self.axis_candiate_list = [ + None, 0, 2, -1, -2, (1, 2), [0, -1], [0, 1, 3], (1, 2, 3), + [0, 2, 1, 3] + ] + + def test_api_static(self): + data = self.fake_data["col_nan_odd"] + paddle.enable_static() + np_res = np.nanmedian(data, keepdims=True) + with paddle.static.program_guard(paddle.static.Program()): + x = paddle.fluid.data('X', data.shape) + out1 = paddle.nanmedian(x, keepdim=True) + out2 = paddle.tensor.nanmedian(x, keepdim=True) + out3 = paddle.tensor.stat.nanmedian(x, keepdim=True) + axis = np.arange(len(data.shape)).tolist() + out4 = paddle.nanmedian(x, axis=axis, keepdim=True) + out5 = paddle.nanmedian(x, axis=tuple(axis), keepdim=True) + exe = paddle.static.Executor(self.place) + res = exe.run(feed={'X': data}, + fetch_list=[out1, out2, out3, out4, out5]) + + for out in res: + self.assertTrue(np.allclose(np_res, out, equal_nan=True)) + + def test_api_dygraph(self): + paddle.disable_static(self.place) + + def clean_axis_numpy(axis, shape_len): + if isinstance(axis, tuple): + axis = list(axis) + if isinstance(axis, list): + for k in range(len(axis)): + if axis[k] < 0: + axis[k] += shape_len + axis = set(axis) + return axis + + def test_data_case(data): + for keep_dim in [False, True]: + if np.isnan(data).all() and keep_dim: + np_ver = np.version.version.split('.') + if int(np_ver[0]) < 1 or int(np_ver[1]) <= 20: + print( + "This numpy version does not support all nan elements when keepdim is True" + ) + continue + + np_res = np.nanmedian(data, keepdims=keep_dim) + pd_res = paddle.nanmedian( + paddle.to_tensor(data), keepdim=keep_dim) + self.assertTrue( + np.allclose( + np_res, pd_res.numpy(), equal_nan=True)) + + def test_axis_case(data, axis): + pd_res = paddle.nanmedian( + paddle.to_tensor(data), axis=axis, keepdim=False) + axis = clean_axis_numpy(axis, len(data.shape)) + np_res = np.nanmedian(data, axis=axis, keepdims=False) + self.assertTrue(np.allclose(np_res, pd_res.numpy(), equal_nan=True)) + + for name, data in self.fake_data.items(): + test_data_case(data) + + for axis in self.axis_candiate_list: + test_axis_case(self.fake_data["row_nan_even"], axis) + test_axis_case(self.fake_data["col_nan_odd"], axis) + + paddle.enable_static() + + def test_errors(self): + paddle.enable_static() + with paddle.static.program_guard(paddle.static.Program()): + x = paddle.fluid.data("X", [10, 12]) + + def test_dtype(): + x2 = paddle.fluid.data('X2', [10, 12], 'bool') + paddle.nanmedian(x2) + + def test_empty_axis(): + paddle.nanmedian(x, axis=[], keepdim=True) + + def test_axis_not_in_range(): + paddle.nanmedian(x, axis=3, keepdim=True) + + def test_duplicated_axis(): + paddle.nanmedian(x, axis=[1, -1], keepdim=True) + + self.assertRaises(TypeError, test_dtype) + self.assertRaises(ValueError, test_empty_axis) + self.assertRaises(ValueError, test_axis_not_in_range) + self.assertRaises(ValueError, test_duplicated_axis) + + def test_dygraph(self): + paddle.disable_static(place=self.place) + with paddle.fluid.dygraph.guard(): + data = self.fake_data["col_nan_odd"] + out = paddle.nanmedian(paddle.to_tensor(data), keepdim=True) + np_res = np.nanmedian(data, keepdims=True) + self.assertTrue(np.allclose(np_res, out, equal_nan=True)) + paddle.enable_static() + + def test_check_grad(self): + paddle.disable_static(place=self.place) + shape = (4, 5) + x_np = np.random.uniform(-1, 1, shape).astype(np.float64) + x_np[0, :] = np.nan + x_np[1, :3] = np.nan + x_np[2, 3:] = np.nan + x_np_sorted = np.sort(x_np) + nan_counts = np.count_nonzero(np.isnan(x_np).astype(np.int32), axis=1) + np_grad = np.zeros((shape)) + for i in range(shape[0]): + valid_cnts = shape[1] - nan_counts[i] + if valid_cnts == 0: + continue + + mid = int(valid_cnts / 2) + targets = [x_np_sorted[i, mid]] + is_odd = valid_cnts % 2 + if not is_odd and mid > 0: + targets.append(x_np_sorted[i, mid - 1]) + for j in range(shape[1]): + if x_np[i, j] in targets: + np_grad[i, j] = 1 if is_odd else 0.5 + + x_tensor = paddle.to_tensor(x_np, stop_gradient=False) + y = paddle.nanmedian(x_tensor, axis=1, keepdim=True) + dx = paddle.grad(y, x_tensor)[0].numpy() + self.assertTrue(np.allclose(np_grad, dx, equal_nan=True)) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index 283bce1cc817f..478f4b6351fbf 100755 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -263,6 +263,7 @@ from .stat import var # noqa: F401 from .stat import numel # noqa: F401 from .stat import median # noqa: F401 +from .stat import nanmedian # noqa: F401 from .stat import quantile # noqa: F401 from .stat import nanquantile # noqa: F401 @@ -448,6 +449,7 @@ 'var', 'numel', 'median', + 'nanmedian', 'quantile', 'nanquantile', 'is_complex', diff --git a/python/paddle/tensor/stat.py b/python/paddle/tensor/stat.py index 52ccc60100996..372454b97a6be 100644 --- a/python/paddle/tensor/stat.py +++ b/python/paddle/tensor/stat.py @@ -241,6 +241,103 @@ def numel(x, name=None): return out +def nanmedian(x, axis=None, keepdim=True, name=None): + r""" + Compute the median along the specified axis, while ignoring NaNs. + + If the valid count of elements is a even number, + the average value of both elements in the middle is calculated as the median. + + Args: + x (Tensor): The input Tensor, it's data type can be int32, int64, float16, float32, float64. + axis (None|int|list|tuple, optional): + The axis along which to perform median calculations ``axis`` should be int or list of int. + ``axis`` should be in range [-D, D), where D is the dimensions of ``x`` . + If ``axis`` is less than 0, it works the same way as :math:`axis + D`. + If ``axis`` is None, median is calculated over all elements of ``x``. Default is None. + keepdim (bool, optional): Whether to reserve the reduced dimension(s) + in the output Tensor. If ``keepdim`` is True, the dimensions of + the output Tensor is the same as ``x`` except in the reduced + dimensions(it is of size 1 in this case). Otherwise, the shape of + the output Tensor is squeezed in ``axis`` . Default is True. + name (str, optional): Name for the operation (optional, default is None). + For more information, please refer to :ref:`api_guide_Name`. + + Returns: + Tensor, results of median along ``axis`` of ``x``. The output dtype is the same as `x`. + + Examples: + .. code-block:: python + :name: nanmedian-example + + import paddle + x = paddle.to_tensor([[float('nan'), 2. , 3. ], [0. , 1. , 2. ]]) + + y1 = x.nanmedian() + # y1 is [[2.]] + + y2 = x.nanmedian(0) + # y2 is [[0., 1.5, 2.5]] + + y3 = x.nanmedian(0, keepdim=False) + # y3 is [0., 1.5, 2.5] + + y4 = x.nanmedian((0, 1)) + # y4 is [[2.]] + """ + if not isinstance(x, Variable): + raise TypeError("In median, the input x should be a Tensor.") + + if isinstance(axis, (list, tuple)) and len(axis) == 0: + raise ValueError("Axis list should not be empty.") + + dims = len(x.shape) + if axis is None: + axis = [] + elif isinstance(axis, tuple): + axis = list(axis) + elif isinstance(axis, int): + axis = [axis] + + if not isinstance(axis, list): + raise ValueError( + "Axis should be None, int, or a list, element should in range [-rank(x), rank(x))." + ) + + for i in range(len(axis)): + if not isinstance(axis[i], int) or not (axis[i] < dims and + axis[i] >= -dims): + raise ValueError( + "Axis should be None, int, or a list, element should in range [-rank(x), rank(x))." + ) + if axis[i] < 0: + axis[i] += dims + + if len(axis) != len(set(axis)): + raise ValueError("Axis has duplicated elements.") + + if _in_legacy_dygraph(): + median_index, out = _C_ops.nanmedian(x, 'axis', axis, 'keepdim', + keepdim) + return out + + check_variable_and_dtype( + x, 'X', ['int32', 'int64', 'float16', 'float32', 'float64'], + 'nanmedian') + + helper = LayerHelper('nanmedian', **locals()) + attrs = {'axis': axis, 'keepdim': keepdim} + out = helper.create_variable_for_type_inference(x.dtype) + medians = helper.create_variable_for_type_inference(x.dtype) + helper.append_op( + type='nanmedian', + inputs={'X': x}, + outputs={'Out': out, + 'MedianIndex': medians}, + attrs=attrs) + return out + + def median(x, axis=None, keepdim=False, name=None): """ Compute the median along the specified axis. diff --git a/tools/parallel_UT_rule.py b/tools/parallel_UT_rule.py index 5088ad3457fb9..7702e8be9c958 100755 --- a/tools/parallel_UT_rule.py +++ b/tools/parallel_UT_rule.py @@ -824,7 +824,7 @@ 'test_mean_op', 'test_is_tensor', 'test_run_program_op', 'test_cuda_random_seed', 'test_linear_interp_op', 'test_fuse_all_reduce_pass', 'tensor_util_test', 'test_median', - 'test_linear', 'test_imperative_qat_amp', + 'test_nanmedian', 'test_linear', 'test_imperative_qat_amp', 'test_truncated_gaussian_random_op', 'test_lstm_cudnn_op', 'copy_same_tensor_test', 'test_squeeze2_op', 'naive_best_fit_allocator_test', 'test_model', 'test_py_reader_combination',