From 995bdad7e835e610aafaf2b4c8430b4ef0fb9526 Mon Sep 17 00:00:00 2001 From: BrilliantYuKaimin <91609464+BrilliantYuKaimin@users.noreply.github.com> Date: Sat, 19 Mar 2022 15:47:58 +0800 Subject: [PATCH 01/33] =?UTF-8?q?=E5=A2=9E=E5=8A=A0PixelUnshuffle=E7=9A=84?= =?UTF-8?q?=E5=BD=A2=E7=8A=B6=E6=8E=A8=E6=96=AD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- paddle/phi/infermeta/unary.cc | 50 +++++++++++++++++++++++++++++++++++ paddle/phi/infermeta/unary.h | 5 ++++ 2 files changed, 55 insertions(+) diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index 03029550c2afa..6c3bbf1b1afe0 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -874,6 +874,56 @@ void PixelShuffleInferMeta(const MetaTensor& x, out->set_dims(output_dims); } +void PixelUnshuffleInferMeta(const MetaTensor& x, + int downscale_factor, + const std::string& data_format, + MetaTensor* out) { + auto input_dims = x.dims(); + PADDLE_ENFORCE_EQ(input_dims.size(), + 4, + phi::errors::InvalidArgument( + "Input should be a 4-D tensor of format [N, C, H, W] " + "or [N, H, W, C], but got %u.", + input_dims.size())); + + const bool channel_last = (data_format == "NHWC"); + + if (!channel_last) { + PADDLE_ENFORCE_EQ((input_dims[2] % downscale_factor) == 0 && + (input_dims[3] % downscale_factor) == 0, + true, + phi::errors::InvalidArgument( + "Downscale factor[%u] should divide both " + "height[%u] and width[%u]", + downscale_factor, + input_dims[2], + input_dims[3])); + } else { + PADDLE_ENFORCE_EQ((input_dims[1] % downscale_factor) == 0 && + (input_dims[2] % downscale_factor) == 0, + true, + phi::errors::InvalidArgument( + "Downscale factor[%u] should divide both " + "height[%u] and width[%u]", + downscale_factor, + input_dims[1], + input_dims[2])); + } + auto output_dims = input_dims; + output_dims[0] = input_dims[0]; + if (!channel_last) { + output_dims[1] = input_dims[1] * (downscale_factor * downscale_factor); + output_dims[2] = input_dims[2] / downscale_factor; + output_dims[3] = input_dims[3] / downscale_factor; + } else { + output_dims[1] = input_dims[1] / downscale_factor; + output_dims[2] = input_dims[2] / downscale_factor; + output_dims[3] = input_dims[3] * (downscale_factor * downscale_factor); + } + out->set_dtype(x.dtype()); + out->set_dims(output_dims); +} + void PoolInferMeta(const MetaTensor& x, const std::vector& kernel_size, const std::vector& strides, diff --git a/paddle/phi/infermeta/unary.h b/paddle/phi/infermeta/unary.h index 00026f8598b07..ff439d549fd97 100644 --- a/paddle/phi/infermeta/unary.h +++ b/paddle/phi/infermeta/unary.h @@ -142,6 +142,11 @@ void PixelShuffleInferMeta(const MetaTensor& x, const std::string& data_format, MetaTensor* out); +void PixelUnshuffleInferMeta(const MetaTensor& x, + int downscale_factor, + const std::string& data_format, + MetaTensor* out); + void PoolInferMeta(const MetaTensor& x, const std::vector& kernel_size, const std::vector& strides, From 8c401fb817d367ac88db179838c916a97f2feb00 Mon Sep 17 00:00:00 2001 From: BrilliantYuKaimin <91609464+BrilliantYuKaimin@users.noreply.github.com> Date: Sat, 19 Mar 2022 15:50:28 +0800 Subject: [PATCH 02/33] =?UTF-8?q?=E5=A2=9E=E5=8A=A0PixelUnshuffle=E7=9A=84?= =?UTF-8?q?=E7=AE=97=E5=AD=90=E6=B3=A8=E5=86=8C?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../cpu/pixel_unshuffle_grad_kernel.cc | 26 +++++++++++++++++++ .../phi/kernels/cpu/pixel_unshuffle_kernel.cc | 26 +++++++++++++++++++ .../gpu/pixel_unshuffle_grad_kernel.cu | 26 +++++++++++++++++++ .../phi/kernels/gpu/pixel_unshuffle_kernel.cu | 26 +++++++++++++++++++ 4 files changed, 104 insertions(+) create mode 100644 paddle/phi/kernels/cpu/pixel_unshuffle_grad_kernel.cc create mode 100644 paddle/phi/kernels/cpu/pixel_unshuffle_kernel.cc create mode 100644 paddle/phi/kernels/gpu/pixel_unshuffle_grad_kernel.cu create mode 100644 paddle/phi/kernels/gpu/pixel_unshuffle_kernel.cu diff --git a/paddle/phi/kernels/cpu/pixel_unshuffle_grad_kernel.cc b/paddle/phi/kernels/cpu/pixel_unshuffle_grad_kernel.cc new file mode 100644 index 0000000000000..ef61fca35957e --- /dev/null +++ b/paddle/phi/kernels/cpu/pixel_unshuffle_grad_kernel.cc @@ -0,0 +1,26 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/pixel_unshuffle_grad_kernel.h" +#include "paddle/phi/kernels/impl/pixel_unshuffle_grad_kernel_impl.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" + +PD_REGISTER_KERNEL(pixel_unshuffle_grad, + CPU, + ALL_LAYOUT, + phi::PixelUnshuffleGradKernel, + float, + double) {} diff --git a/paddle/phi/kernels/cpu/pixel_unshuffle_kernel.cc b/paddle/phi/kernels/cpu/pixel_unshuffle_kernel.cc new file mode 100644 index 0000000000000..9f4bc747f3209 --- /dev/null +++ b/paddle/phi/kernels/cpu/pixel_unshuffle_kernel.cc @@ -0,0 +1,26 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/pixel_unshuffle_kernel.h" +#include "paddle/phi/kernels/impl/pixel_unshuffle_kernel_impl.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" + +PD_REGISTER_KERNEL(pixel_unshuffle, + CPU, + ALL_LAYOUT, + phi::PixelUnshuffleKernel, + float, + double) {} diff --git a/paddle/phi/kernels/gpu/pixel_unshuffle_grad_kernel.cu b/paddle/phi/kernels/gpu/pixel_unshuffle_grad_kernel.cu new file mode 100644 index 0000000000000..9cbbc5072aa25 --- /dev/null +++ b/paddle/phi/kernels/gpu/pixel_unshuffle_grad_kernel.cu @@ -0,0 +1,26 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/impl/pixel_unshuffle_grad_kernel_impl.h" +#include "paddle/phi/kernels/pixel_unshuffle_grad_kernel.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" + +PD_REGISTER_KERNEL(pixel_unshuffle_grad, + GPU, + ALL_LAYOUT, + phi::PixelUnshuffleGradKernel, + float, + double) {} diff --git a/paddle/phi/kernels/gpu/pixel_unshuffle_kernel.cu b/paddle/phi/kernels/gpu/pixel_unshuffle_kernel.cu new file mode 100644 index 0000000000000..ca2e520ffde10 --- /dev/null +++ b/paddle/phi/kernels/gpu/pixel_unshuffle_kernel.cu @@ -0,0 +1,26 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/impl/pixel_unshuffle_kernel_impl.h" +#include "paddle/phi/kernels/pixel_unshuffle_kernel.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" + +PD_REGISTER_KERNEL(pixel_unshuffle, + GPU, + ALL_LAYOUT, + phi::PixelUnshuffleKernel, + float, + double) {} From 7a62b6e48aa967da871ba2ff1778803fd0680061 Mon Sep 17 00:00:00 2001 From: BrilliantYuKaimin <91609464+BrilliantYuKaimin@users.noreply.github.com> Date: Sat, 19 Mar 2022 15:52:59 +0800 Subject: [PATCH 03/33] =?UTF-8?q?=E5=A2=9E=E5=8A=A0PixelUnshuffle=E5=8F=8A?= =?UTF-8?q?=E5=85=B6=E6=A2=AF=E5=BA=A6=E7=9A=84=E6=A0=B8=E5=87=BD=E6=95=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../impl/pixel_unshuffle_grad_kernel_impl.h | 56 +++++++++++++++++++ .../impl/pixel_unshuffle_kernel_impl.h | 55 ++++++++++++++++++ .../phi/kernels/pixel_unshuffle_grad_kernel.h | 29 ++++++++++ paddle/phi/kernels/pixel_unshuffle_kernel.h | 29 ++++++++++ 4 files changed, 169 insertions(+) create mode 100644 paddle/phi/kernels/impl/pixel_unshuffle_grad_kernel_impl.h create mode 100644 paddle/phi/kernels/impl/pixel_unshuffle_kernel_impl.h create mode 100644 paddle/phi/kernels/pixel_unshuffle_grad_kernel.h create mode 100644 paddle/phi/kernels/pixel_unshuffle_kernel.h diff --git a/paddle/phi/kernels/impl/pixel_unshuffle_grad_kernel_impl.h b/paddle/phi/kernels/impl/pixel_unshuffle_grad_kernel_impl.h new file mode 100644 index 0000000000000..a418f184b7826 --- /dev/null +++ b/paddle/phi/kernels/impl/pixel_unshuffle_grad_kernel_impl.h @@ -0,0 +1,56 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/kernels/funcs/math_function.h" + +namespace phi { + +template +void PixelUnshuffleGradKernel(const Context& ctx, + const DenseTensor& out_grad, + int downscale_factor, + const std::string& data_format, + DenseTensor* x_grad) { + auto* dout = &out_grad; + auto* dx = x_grad; + ctx.template Alloc(dx); + int factor = downscale_factor; + bool channel_last = (data_format == "NHWC"); + auto do_dims = dout->dims(); + auto dx_dims = dx->dims(); + + DenseTensor t(*dout); + if (!channel_last) { + t.Resize({do_dims[0], dx_dims[1], factor, factor, do_dims[2], do_dims[3]}); + } else { + t.Resize({do_dims[0], do_dims[1], do_dims[2], dx_dims[3], factor, factor}); + } + std::vector axis = {0, 1, 4, 2, 5, 3}; + + DenseTensor o(*dx); + if (!channel_last) { + o.Resize({do_dims[0], dx_dims[1], do_dims[2], factor, do_dims[3], factor}); + } else { + o.Resize({do_dims[0], do_dims[1], factor, do_dims[2], factor, dx_dims[3]}); + } + phi::funcs::Transpose trans; + trans(ctx, t, &o, axis); + dx->Resize(dx_dims); +} + +} // namespace phi diff --git a/paddle/phi/kernels/impl/pixel_unshuffle_kernel_impl.h b/paddle/phi/kernels/impl/pixel_unshuffle_kernel_impl.h new file mode 100644 index 0000000000000..ac93e00a56576 --- /dev/null +++ b/paddle/phi/kernels/impl/pixel_unshuffle_kernel_impl.h @@ -0,0 +1,55 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/kernels/funcs/math_function.h" + +namespace phi { + +template +void PixelUnshuffleKernel(const Context& ctx, + const DenseTensor& x, + int downscale_factor, + const std::string& data_format, + DenseTensor* out) { + auto* in = &x; + ctx.template Alloc(out); + int factor = downscale_factor; + bool channel_last = (data_format == "NHWC"); + auto in_dims = in->dims(); + auto o_dims = out->dims(); + + DenseTensor t(*in); + if (!channel_last) { + t.Resize({in_dims[0], in_dims[1], o_dims[2], factor, o_dims[3], factor}); + } else { + t.Resize({in_dims[0], o_dims[1], factor, o_dims[2], factor, in_dims[3]}); + } + std::vector axis = {0, 1, 3, 5, 2, 4}; + + DenseTensor o(*out); + if (!channel_last) { + o.Resize({in_dims[0], in_dims[1], factor, factor, o_dims[2], o_dims[3]}); + } else { + o.Resize({in_dims[0], o_dims[1], o_dims[2], in_dims[3], factor, factor}); + } + phi::funcs::Transpose trans; + trans(ctx, t, &o, axis); + out->Resize(o_dims); +} + +} // namespace phi diff --git a/paddle/phi/kernels/pixel_unshuffle_grad_kernel.h b/paddle/phi/kernels/pixel_unshuffle_grad_kernel.h new file mode 100644 index 0000000000000..f62f1f5b4c7b7 --- /dev/null +++ b/paddle/phi/kernels/pixel_unshuffle_grad_kernel.h @@ -0,0 +1,29 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void PixelUnshuffleGradKernel(const Context& ctx, + const DenseTensor& out_grad, + int downscale_factor, + const std::string& data_format, + DenseTensor* x_grad); + +} // namespace phi diff --git a/paddle/phi/kernels/pixel_unshuffle_kernel.h b/paddle/phi/kernels/pixel_unshuffle_kernel.h new file mode 100644 index 0000000000000..a631223034e96 --- /dev/null +++ b/paddle/phi/kernels/pixel_unshuffle_kernel.h @@ -0,0 +1,29 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void PixelUnshuffleKernel(const Context& ctx, + const DenseTensor& x, + int downscale_factor, + const std::string& data_format, + DenseTensor* out); + +} // namespace phi From 9591a487cbe33bbac13bf8c2baf49a726ee167cb Mon Sep 17 00:00:00 2001 From: BrilliantYuKaimin <91609464+BrilliantYuKaimin@users.noreply.github.com> Date: Sat, 19 Mar 2022 15:54:28 +0800 Subject: [PATCH 04/33] =?UTF-8?q?=E5=A2=9E=E5=8A=A0PixelUnshuffle=E7=AE=97?= =?UTF-8?q?=E5=AD=90=E7=9A=84=E6=8F=8F=E8=BF=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- paddle/fluid/operators/pixel_unshuffle_op.cc | 134 +++++++++++++++++++ 1 file changed, 134 insertions(+) create mode 100644 paddle/fluid/operators/pixel_unshuffle_op.cc diff --git a/paddle/fluid/operators/pixel_unshuffle_op.cc b/paddle/fluid/operators/pixel_unshuffle_op.cc new file mode 100644 index 0000000000000..ca883b4e6d97f --- /dev/null +++ b/paddle/fluid/operators/pixel_unshuffle_op.cc @@ -0,0 +1,134 @@ +/*Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include "paddle/fluid/framework/infershape_utils.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/op_version_registry.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/unary.h" + +namespace paddle { +namespace operators { + +class PixelUnshuffleOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; +}; + +class PixelUnshuffleOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", + "(Tensor, default Tensor), " + "the input feature data of PixelUnshuffleOp, the layout is " + "[N, C, H, W] or [N, H, W, C]."); + AddOutput("Out", + "(Tensor, default Tensor), the output of " + "PixelUnshuffleOp. The layout is [N, C*factor^2, H/factor, " + "W/factor] or [N, H/factor, W/factor, C*factor^2]."); + AddAttr("downscale_factor", + "the factor to decrease spatial resolution by.") + .SetDefault(1) + .AddCustomChecker([](const int& downscale_factor) { + PADDLE_ENFORCE_GE(downscale_factor, 1, + platform::errors::InvalidArgument( + "downscale_factor should be larger than 0.")); + }); + AddAttr( + "data_format", + "An optional string from: \"NHWC\", \"NCHW\". " + "Defaults to \"NHWC\", Specify the data format of the input data.") + .SetDefault("NCHW"); + + AddComment(R"DOC( + Pixel Unshuffle operator + This operator rearranges elements in a tensor of shape :math:`(*, C, H, W)` + to a tensor of shape :math:`(*, C\times r^2, H / r, W / r)`. + + This operation is the reversion of PixelShuffle operation. + + Please refer to the paper: + `Real-Time Single Image and Video Super-Resolution Using an Efficient + Sub-Pixel Convolutional Neural Network `_ + by Shi et. al (2016) for more details. + + )DOC"); + } +}; + +template +class PixelUnshuffleGradOpMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + void Apply(GradOpPtr op) const override { + op->SetType("pixel_unshuffle_grad"); + op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); + op->SetAttrMap(this->Attrs()); + op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); + } +}; + +class PixelUnshuffleGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE_EQ( + ctx->HasInput(framework::GradVarName("Out")), true, + platform::errors::NotFound("Input(Out@Grad) should not be null")); + PADDLE_ENFORCE_EQ( + ctx->HasOutput(framework::GradVarName("X")), true, + platform::errors::NotFound("Output(X@Grad) should not be null")); + + auto do_dims = ctx->GetInputDim(framework::GradVarName("Out")); + PADDLE_ENFORCE_EQ(do_dims.size(), 4, + platform::errors::InvalidArgument( + "Input should be a 4-D tensor of format [N, C, H, W] " + "or [N, H, W, C], but got %u.", + do_dims.size())); + + auto downscale_factor = ctx->Attrs().Get("downscale_factor"); + + const std::string data_format = + ctx->Attrs().Get("data_format"); + const bool channel_last = (data_format == "NHWC"); + + auto dx_dims = do_dims; + dx_dims[0] = do_dims[0]; + + if (!channel_last) { + dx_dims[1] = do_dims[1] / (downscale_factor * downscale_factor); + dx_dims[2] = do_dims[2] * downscale_factor; + dx_dims[3] = do_dims[3] * downscale_factor; + } else { + dx_dims[1] = do_dims[1] * downscale_factor; + dx_dims[2] = do_dims[2] * downscale_factor; + dx_dims[3] = do_dims[3] / (downscale_factor * downscale_factor); + } + ctx->SetOutputDim(framework::GradVarName("X"), dx_dims); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +DECLARE_INFER_SHAPE_FUNCTOR(pixel_unshuffle, PixelUnshuffleInferShapeFunctor, + PD_INFER_META(phi::PixelUnshuffleInferMeta)); + +REGISTER_OPERATOR(pixel_unshuffle, ops::PixelUnshuffleOp, ops::PixelUnshuffleOpMaker, + ops::PixelUnshuffleGradOpMaker, + ops::PixelUnshuffleGradOpMaker, + PixelUnshuffleInferShapeFunctor); + +REGISTER_OPERATOR(pixel_unshuffle_grad, ops::PixelUnshuffleGradOp); From f6ad3657a16da9d4ce2783a5965782e4c0b8027c Mon Sep 17 00:00:00 2001 From: BrilliantYuKaimin <91609464+BrilliantYuKaimin@users.noreply.github.com> Date: Sat, 19 Mar 2022 15:54:58 +0800 Subject: [PATCH 05/33] =?UTF-8?q?=E5=A2=9E=E5=8A=A0PixelUnshuffle=E7=AE=97?= =?UTF-8?q?=E5=AD=90=E7=9A=84=E7=AD=BE=E5=90=8D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- paddle/phi/ops/compat/pixel_unshuffle_sig.cc | 38 ++++++++++++++++++++ 1 file changed, 38 insertions(+) create mode 100644 paddle/phi/ops/compat/pixel_unshuffle_sig.cc diff --git a/paddle/phi/ops/compat/pixel_unshuffle_sig.cc b/paddle/phi/ops/compat/pixel_unshuffle_sig.cc new file mode 100644 index 0000000000000..dba5023a2ff91 --- /dev/null +++ b/paddle/phi/ops/compat/pixel_unshuffle_sig.cc @@ -0,0 +1,38 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +KernelSignature PixelUnshuffleOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature( + "pixel_unshuffle", {"X"}, {"downscale_factor", "data_format"}, {"Out"}); +} + +KernelSignature PixelUnshuffleGradOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature("pixel_unshuffle_grad", + {GradVarName("Out")}, + {"downscale_factor", "data_format"}, + {GradVarName("X")}); +} + +} // namespace phi + +PD_REGISTER_ARG_MAPPING_FN(pixel_unshuffle, + phi::PixelUnshuffleOpArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(pixel_unshuffle_grad, + phi::PixelUnshuffleGradOpArgumentMapping); From 73aed0263f8b2c1198c43285dc52798d6def6048 Mon Sep 17 00:00:00 2001 From: BrilliantYuKaimin <91609464+BrilliantYuKaimin@users.noreply.github.com> Date: Sat, 19 Mar 2022 16:00:26 +0800 Subject: [PATCH 06/33] =?UTF-8?q?=E5=9C=A8Python=E5=B1=82=E9=9D=A2?= =?UTF-8?q?=E5=A2=9E=E5=8A=A0PixelUnshuffle?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- python/paddle/nn/__init__.py | 2 + python/paddle/nn/functional/__init__.py | 2 + python/paddle/nn/functional/vision.py | 45 +++++++++++++++++ python/paddle/nn/layer/__init__.py | 1 + python/paddle/nn/layer/vision.py | 64 +++++++++++++++++++++++++ 5 files changed, 114 insertions(+) diff --git a/python/paddle/nn/__init__.py b/python/paddle/nn/__init__.py index c0820e140268b..3c63145290db1 100644 --- a/python/paddle/nn/__init__.py +++ b/python/paddle/nn/__init__.py @@ -137,6 +137,7 @@ from .layer.distance import PairwiseDistance # noqa: F401 from .layer.vision import PixelShuffle # noqa: F401 +from .layer.vision import PixelUnshuffle # noqa: F401 from .layer.container import LayerDict # noqa: F401 from .utils.spectral_norm_hook import spectral_norm @@ -298,6 +299,7 @@ def weight_norm(*args): 'Swish', 'Mish', 'PixelShuffle', + 'PixelUnshuffle', 'ELU', 'ReLU6', 'LayerDict', diff --git a/python/paddle/nn/functional/__init__.py b/python/paddle/nn/functional/__init__.py index a24afc45a5995..cab133b92fa6c 100644 --- a/python/paddle/nn/functional/__init__.py +++ b/python/paddle/nn/functional/__init__.py @@ -114,6 +114,7 @@ from .vision import affine_grid # noqa: F401 from .vision import grid_sample # noqa: F401 from .vision import pixel_shuffle # noqa: F401 +from .vision import pixel_unshuffle # noqa: F401 from .input import one_hot # noqa: F401 from .input import embedding # noqa: F401 from ...fluid.layers import gather_tree # noqa: F401 @@ -213,6 +214,7 @@ 'grid_sample', 'local_response_norm', 'pixel_shuffle', + 'pixel_unshuffle', 'embedding', 'gather_tree', 'one_hot', diff --git a/python/paddle/nn/functional/vision.py b/python/paddle/nn/functional/vision.py index 43c7757a8777b..3c10602198562 100644 --- a/python/paddle/nn/functional/vision.py +++ b/python/paddle/nn/functional/vision.py @@ -344,3 +344,48 @@ def pixel_shuffle(x, upscale_factor, data_format="NCHW", name=None): attrs={"upscale_factor": upscale_factor, "data_format": data_format}) return out + +def pixel_unshuffle(x, downscale_factor, data_format="NCHW", name=None): + """ + This API implements pixel unshuffle operation. + See more details in :ref:`api_nn_vision_PixelUnshuffle` . + Parameters: + x(Tensor): 4-D tensor, the data type should be float32 or float64. + downscale_factor(int): factor to decrease spatial resolution. + data_format (str): The data format of the input and output data. An optional string from: "NCHW", "NHWC". The default is "NCHW". When it is "NCHW", the data is stored in the order of: [batch_size, input_channels, input_height, input_width]. + name (str, optional): The default value is None. Normally there is no need for user to set this property. + Returns: + Out(tensor): Reshaped tensor according to the new dimension. + Raises: + ValueError: If downscale_factor cannot divide both the height and width of input. + Examples: + .. code-block:: python + + import paddle + import paddle.nn.functional as F + x = paddle.randn([2, 1, 12, 12]) + out = F.pixel_unshuffle(x, 3) + # out.shape = [2, 9, 4, 4] + """ + if not isinstance(downscale_factor, int): + raise TypeError("downscale factor must be int type") + + if data_format not in ["NCHW", "NHWC"]: + raise ValueError("Attr(data_format) should be 'NCHW' or 'NHWC'." + "But recevie Attr(data_format): {} ".format( + data_format)) + + if in_dynamic_mode(): + return _C_ops.pixel_unshuffle(x, "downscale_factor", downscale_factor, + "data_format", data_format) + + helper = LayerHelper("pixel_unshuffle", **locals()) + check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'pixel_unshuffle') + out = helper.create_variable_for_type_inference(dtype=x.dtype) + helper.append_op( + type="pixel_unshuffle", + inputs={"X": x}, + outputs={"Out": out}, + attrs={"downscale_factor": downscale_factor, + "data_format": data_format}) + return out diff --git a/python/paddle/nn/layer/__init__.py b/python/paddle/nn/layer/__init__.py index 2b50508065605..b0f6f7690bce2 100644 --- a/python/paddle/nn/layer/__init__.py +++ b/python/paddle/nn/layer/__init__.py @@ -87,6 +87,7 @@ from .norm import LocalResponseNorm # noqa: F401 from .vision import PixelShuffle # noqa: F401 +from .vision import PixelUnshuffle # noqa: F401 from .distance import PairwiseDistance # noqa: F401 from .container import LayerDict # noqa: F401 diff --git a/python/paddle/nn/layer/vision.py b/python/paddle/nn/layer/vision.py index 0531afb4eeeeb..e13e825cef60a 100644 --- a/python/paddle/nn/layer/vision.py +++ b/python/paddle/nn/layer/vision.py @@ -87,3 +87,67 @@ def extra_repr(self): if self._name is not None: main_str += ', name={}'.format(self._name) return main_str + +class PixelUnshuffle(Layer): + """ + + PixelUnshuffle Layer + + This operator rearranges elements in a tensor of shape [N, C, H, W] + to a tensor of shape [N, C*downscale_factor**2, H/downscale_factor, W/downscale_factor], + or from shape [N, H, W, C] to [N, H/downscale_factor, W/downscale_factor, C*downscale_factor**2]. + This operation is the reversion of PixelShuffle operation. + Please refer to the paper: `Real-Time Single Image and Video Super-Resolution + Using an Efficient Sub-Pixel Convolutional Neural Network `_ . + by Shi et. al (2016) for more details. + + Parameters: + + downscale_factor(int): factor to decrease spatial resolution. + data_format (str): The data format of the input and output data. An optional string from: "NCHW", "NHWC". The default is "NCHW". When it is "NCHW", the data is stored in the order of: [batch_size, input_channels, input_height, input_width]. + name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. + + Shape: + - x: 4-D tensor with shape: (N, C, H, W) or (N, H, W, C). + - out: 4-D tensor with shape: (N, C*downscale_factor**2, H/downscale_factor, W/downscale_factor) or (N, H/downscale_factor, W/downscale_factor, C*downscale_factor**2). + + + Examples: + .. code-block:: python + + import paddle + import paddle.nn as nn + + x = paddle.randn([2, 1, 12, 12]) + pixel_unshuffle = nn.PixelUnshuffle(3) + out = pixel_unshuffle(x) + # out.shape = [2, 9, 4, 4] + + """ + + + def __init__(self, downscale_factor, data_format="NCHW", name=None): + super(PixelUnshuffle, self).__init__() + + if not isinstance(downscale_factor, int): + raise TypeError("downscale factor must be int type") + + if data_format not in ["NCHW", "NHWC"]: + raise ValueError("Data format should be 'NCHW' or 'NHWC'." + "But recevie data format: {}".format(data_format)) + + self._downscale_factor = downscale_factor + self._data_format = data_format + self._name = name + + def forward(self, x): + return functional.pixel_unshuffle(x, self._downscale_factor, + self._data_format, self._name) + + def extra_repr(self): + main_str = 'downscale_factor={}'.format(self._downscale_factor) + if self._data_format != 'NCHW': + main_str += ', data_format={}'.format(self._data_format) + if self._name is not None: + main_str += ', name={}'.format(self._name) + return main_str From 8a259c035d02534cfe3bfc4a2e1c64ca6032e2fe Mon Sep 17 00:00:00 2001 From: BrilliantYuKaimin <91609464+BrilliantYuKaimin@users.noreply.github.com> Date: Sat, 19 Mar 2022 16:14:59 +0800 Subject: [PATCH 07/33] =?UTF-8?q?=E5=A2=9E=E5=8A=A0PixelUnshuffle=E7=9A=84?= =?UTF-8?q?=E5=8D=95=E6=B5=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../tests/unittests/test_pixel_unshuffle.py | 222 ++++++++++++++++++ 1 file changed, 222 insertions(+) create mode 100644 python/paddle/fluid/tests/unittests/test_pixel_unshuffle.py diff --git a/python/paddle/fluid/tests/unittests/test_pixel_unshuffle.py b/python/paddle/fluid/tests/unittests/test_pixel_unshuffle.py new file mode 100644 index 0000000000000..dd1ba296f3300 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_pixel_unshuffle.py @@ -0,0 +1,222 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np + +from op_test import OpTest +import paddle +import paddle.nn.functional as F +import paddle.fluid.core as core +import paddle.fluid as fluid + + +def pixel_unshuffle_np(x, down_factor, data_format="NCHW"): + if data_format == "NCHW": + n, c, h, w = x.shape + new_shape = (n, c, h / down_factor, down_factor, + w / down_factor, down_factor) + npresult = np.reshape(x, new_shape) + npresult = npresult.transpose(0, 1, 3, 5, 2, 4) + oshape = [n, c * down_factor * down_factor, h / down_factor, + w / down_factor] + npresult = np.reshape(npresult, oshape) + return npresult + else: + n, h, w, c = x.shape + new_shape = (n, h / down_factor, down_factor, + w / down_factor, down_factor, c) + npresult = np.reshape(x, new_shape) + npresult = npresult.transpose(0, 1, 3, 5, 2, 4) + oshape = [n, h / down_factor, + w / down_factor, c * down_factor * down_factor] + npresult = np.reshape(npresult, oshape) + return npresult + + +class TestPixelUnshuffleOp(OpTest): + def setUp(self): + self.op_type = "pixel_unshuffle" + self.init_data_format() + n, c, h, w = 2, 1, 12, 12 + + if self.format == "NCHW": + shape = [n, c, h, w] + if self.format == "NHWC": + shape = [n, h, w, c] + + down_factor = 3 + + x = np.random.random(shape).astype("float64") + npresult = pixel_unshuffle_np(x, down_factor, self.format) + + self.inputs = {"X": x} + self.outputs = {"Out": npresult} + self.attrs = {"downscale_factor": down_factor, + "data_format": self.format} + + def init_data_format(self): + self.format = "NCHW" + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(["X"], "Out") + + +class TestChannelLast(TestPixelUnshuffleOp): + def init_data_format(self): + self.format = "NHWC" + + +class TestPixelUnshuffleAPI(unittest.TestCase): + def setUp(self): + self.x_1_np = np.random.random([2, 1, 12, 12]).astype("float64") + self.x_2_np = np.random.random([2, 12, 12, 1]).astype("float64") + self.out_1_np = pixel_unshuffle_np(self.x_1_np, 3) + self.out_2_np = pixel_unshuffle_np(self.x_2_np, 3, "NHWC") + + def test_static_graph_functional(self): + for use_cuda in ([False, True] + if core.is_compiled_with_cuda() else [False]): + place = paddle.CUDAPlace(0) if use_cuda else paddle.CPUPlace() + + paddle.enable_static() + x_1 = paddle.fluid.data( + name="x", shape=[2, 1, 12, 12], dtype="float64") + x_2 = paddle.fluid.data( + name="x2", shape=[2, 12, 12, 1], dtype="float64") + out_1 = F.pixel_unshuffle(x_1, 3) + out_2 = F.pixel_unshuffle(x_2, 3, "NHWC") + + exe = paddle.static.Executor(place=place) + res_1 = exe.run(fluid.default_main_program(), + feed={"x": self.x_1_np}, + fetch_list=out_1, + use_prune=True) + + res_2 = exe.run(fluid.default_main_program(), + feed={"x2": self.x_2_np}, + fetch_list=out_2, + use_prune=True) + + assert np.allclose(res_1, self.out_1_np) + assert np.allclose(res_2, self.out_2_np) + + # same test between layer and functional in this op. + def test_static_graph_layer(self): + for use_cuda in ([False, True] + if core.is_compiled_with_cuda() else [False]): + place = paddle.CUDAPlace(0) if use_cuda else paddle.CPUPlace() + + paddle.enable_static() + x_1 = paddle.fluid.data( + name="x", shape=[2, 1, 12, 12], dtype="float64") + x_2 = paddle.fluid.data( + name="x2", shape=[2, 12, 12, 1], dtype="float64") + # init instance + ps_1 = paddle.nn.PixelUnshuffle(3) + ps_2 = paddle.nn.PixelUnshuffle(3, "NHWC") + out_1 = ps_1(x_1) + out_2 = ps_2(x_2) + out_1_np = pixel_unshuffle_np(self.x_1_np, 3) + out_2_np = pixel_unshuffle_np(self.x_2_np, 3, "NHWC") + + exe = paddle.static.Executor(place=place) + res_1 = exe.run(fluid.default_main_program(), + feed={"x": self.x_1_np}, + fetch_list=out_1, + use_prune=True) + + res_2 = exe.run(fluid.default_main_program(), + feed={"x2": self.x_2_np}, + fetch_list=out_2, + use_prune=True) + + assert np.allclose(res_1, out_1_np) + assert np.allclose(res_2, out_2_np) + + def run_dygraph(self, down_factor, data_format): + + n, c, h, w = 2, 1, 12, 12 + + if data_format == "NCHW": + shape = [n, c, h, w] + if data_format == "NHWC": + shape = [n, h, w, c] + + x = np.random.random(shape).astype("float64") + + npresult = pixel_unshuffle_np(x, down_factor, data_format) + + for use_cuda in ([False, True] + if core.is_compiled_with_cuda() else [False]): + place = paddle.CUDAPlace(0) if use_cuda else paddle.CPUPlace() + + paddle.disable_static(place=place) + + pixel_unshuffle = paddle.nn.PixelUnshuffle( + down_factor, data_format=data_format) + result = pixel_unshuffle(paddle.to_tensor(x)) + + self.assertTrue(np.allclose(result.numpy(), npresult)) + + result_functional = F.pixel_unshuffle( + paddle.to_tensor(x), 3, data_format) + self.assertTrue(np.allclose(result_functional.numpy(), npresult)) + + def test_dygraph1(self): + self.run_dygraph(3, "NCHW") + + def test_dygraph2(self): + self.run_dygraph(3, "NHWC") + + +class TestPixelUnshuffleError(unittest.TestCase): + def test_error_functional(self): + def error_downscale_factor(): + with paddle.fluid.dygraph.guard(): + x = np.random.random([2, 1, 12, 12]).astype("float64") + pixel_unshuffle = F.pixel_unshuffle(paddle.to_tensor(x), 3.33) + + self.assertRaises(TypeError, error_downscale_factor) + + def error_data_format(): + with paddle.fluid.dygraph.guard(): + x = np.random.random([2, 1, 12, 12]).astype("float64") + pixel_unshuffle = F.pixel_unshuffle(paddle.to_tensor(x), 3, "WOW") + + self.assertRaises(ValueError, error_data_format) + + def test_error_layer(self): + def error_downscale_factor_layer(): + with paddle.fluid.dygraph.guard(): + x = np.random.random([2, 1, 12, 12]).astype("float64") + ps = paddle.nn.PixelUnshuffle(3.33) + + self.assertRaises(TypeError, error_downscale_factor_layer) + + def error_data_format_layer(): + with paddle.fluid.dygraph.guard(): + x = np.random.random([2, 1, 12, 12]).astype("float64") + ps = paddle.nn.PixelUnshuffle(3, "MEOW") + + self.assertRaises(ValueError, error_data_format_layer) + + +if __name__ == "__main__": + unittest.main() From b28157ee06ecce6aabab29dd9b82956e3d8a7092 Mon Sep 17 00:00:00 2001 From: BrilliantYuKaimin <91609464+BrilliantYuKaimin@users.noreply.github.com> Date: Sat, 19 Mar 2022 17:39:38 +0800 Subject: [PATCH 08/33] Update test_pixel_unshuffle.py --- .../tests/unittests/test_pixel_unshuffle.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_pixel_unshuffle.py b/python/paddle/fluid/tests/unittests/test_pixel_unshuffle.py index dd1ba296f3300..928754151578e 100644 --- a/python/paddle/fluid/tests/unittests/test_pixel_unshuffle.py +++ b/python/paddle/fluid/tests/unittests/test_pixel_unshuffle.py @@ -23,26 +23,27 @@ import paddle.fluid.core as core import paddle.fluid as fluid +paddle.enable_static() def pixel_unshuffle_np(x, down_factor, data_format="NCHW"): if data_format == "NCHW": n, c, h, w = x.shape - new_shape = (n, c, h / down_factor, down_factor, - w / down_factor, down_factor) + new_shape = (n, c, h // down_factor, down_factor, + w // down_factor, down_factor) npresult = np.reshape(x, new_shape) npresult = npresult.transpose(0, 1, 3, 5, 2, 4) - oshape = [n, c * down_factor * down_factor, h / down_factor, - w / down_factor] + oshape = [n, c * down_factor * down_factor, h // down_factor, + w // down_factor] npresult = np.reshape(npresult, oshape) return npresult else: n, h, w, c = x.shape - new_shape = (n, h / down_factor, down_factor, - w / down_factor, down_factor, c) + new_shape = (n, h // down_factor, down_factor, + w // down_factor, down_factor, c) npresult = np.reshape(x, new_shape) npresult = npresult.transpose(0, 1, 3, 5, 2, 4) - oshape = [n, h / down_factor, - w / down_factor, c * down_factor * down_factor] + oshape = [n, h // down_factor, + w // down_factor, c * down_factor * down_factor] npresult = np.reshape(npresult, oshape) return npresult From d16b545b40af953b0d5b83de3d2446c9ff7bd8cb Mon Sep 17 00:00:00 2001 From: BrilliantYuKaimin <91609464+BrilliantYuKaimin@users.noreply.github.com> Date: Sun, 20 Mar 2022 11:58:27 +0800 Subject: [PATCH 09/33] test=document_fix --- paddle/fluid/operators/pixel_unshuffle_op.cc | 3 +- paddle/phi/infermeta/unary.cc | 36 +++++------ paddle/phi/ops/compat/pixel_unshuffle_sig.cc | 2 +- .../tests/unittests/test_pixel_unshuffle.py | 62 +++++++++++++++---- python/paddle/nn/functional/vision.py | 6 +- python/paddle/nn/layer/vision.py | 1 - 6 files changed, 76 insertions(+), 34 deletions(-) diff --git a/paddle/fluid/operators/pixel_unshuffle_op.cc b/paddle/fluid/operators/pixel_unshuffle_op.cc index ca883b4e6d97f..1f46e815ced1b 100644 --- a/paddle/fluid/operators/pixel_unshuffle_op.cc +++ b/paddle/fluid/operators/pixel_unshuffle_op.cc @@ -126,7 +126,8 @@ namespace ops = paddle::operators; DECLARE_INFER_SHAPE_FUNCTOR(pixel_unshuffle, PixelUnshuffleInferShapeFunctor, PD_INFER_META(phi::PixelUnshuffleInferMeta)); -REGISTER_OPERATOR(pixel_unshuffle, ops::PixelUnshuffleOp, ops::PixelUnshuffleOpMaker, +REGISTER_OPERATOR(pixel_unshuffle, ops::PixelUnshuffleOp, + ops::PixelUnshuffleOpMaker, ops::PixelUnshuffleGradOpMaker, ops::PixelUnshuffleGradOpMaker, PixelUnshuffleInferShapeFunctor); diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index 6c3bbf1b1afe0..f93b0279adfef 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -889,25 +889,25 @@ void PixelUnshuffleInferMeta(const MetaTensor& x, const bool channel_last = (data_format == "NHWC"); if (!channel_last) { - PADDLE_ENFORCE_EQ((input_dims[2] % downscale_factor) == 0 && - (input_dims[3] % downscale_factor) == 0, - true, - phi::errors::InvalidArgument( - "Downscale factor[%u] should divide both " - "height[%u] and width[%u]", - downscale_factor, - input_dims[2], - input_dims[3])); + PADDLE_ENFORCE_EQ( + (input_dims[2] % downscale_factor) == 0 && + (input_dims[3] % downscale_factor) == 0, + true, + phi::errors::InvalidArgument("Downscale factor[%u] should divide both " + "height[%u] and width[%u]", + downscale_factor, + input_dims[2], + input_dims[3])); } else { - PADDLE_ENFORCE_EQ((input_dims[1] % downscale_factor) == 0 && - (input_dims[2] % downscale_factor) == 0, - true, - phi::errors::InvalidArgument( - "Downscale factor[%u] should divide both " - "height[%u] and width[%u]", - downscale_factor, - input_dims[1], - input_dims[2])); + PADDLE_ENFORCE_EQ( + (input_dims[1] % downscale_factor) == 0 && + (input_dims[2] % downscale_factor) == 0, + true, + phi::errors::InvalidArgument("Downscale factor[%u] should divide both " + "height[%u] and width[%u]", + downscale_factor, + input_dims[1], + input_dims[2])); } auto output_dims = input_dims; output_dims[0] = input_dims[0]; diff --git a/paddle/phi/ops/compat/pixel_unshuffle_sig.cc b/paddle/phi/ops/compat/pixel_unshuffle_sig.cc index dba5023a2ff91..e78b676dd629e 100644 --- a/paddle/phi/ops/compat/pixel_unshuffle_sig.cc +++ b/paddle/phi/ops/compat/pixel_unshuffle_sig.cc @@ -32,7 +32,7 @@ KernelSignature PixelUnshuffleGradOpArgumentMapping( } // namespace phi -PD_REGISTER_ARG_MAPPING_FN(pixel_unshuffle, +PD_REGISTER_ARG_MAPPING_FN(pixel_unshuffle, phi::PixelUnshuffleOpArgumentMapping); PD_REGISTER_ARG_MAPPING_FN(pixel_unshuffle_grad, phi::PixelUnshuffleGradOpArgumentMapping); diff --git a/python/paddle/fluid/tests/unittests/test_pixel_unshuffle.py b/python/paddle/fluid/tests/unittests/test_pixel_unshuffle.py index 928754151578e..53eec4f7f08ce 100644 --- a/python/paddle/fluid/tests/unittests/test_pixel_unshuffle.py +++ b/python/paddle/fluid/tests/unittests/test_pixel_unshuffle.py @@ -26,30 +26,38 @@ paddle.enable_static() def pixel_unshuffle_np(x, down_factor, data_format="NCHW"): + '''Numpy implementation of pixel unshuffle''' + if data_format == "NCHW": n, c, h, w = x.shape - new_shape = (n, c, h // down_factor, down_factor, - w // down_factor, down_factor) + new_shape = (n, c, h // down_factor, down_factor, w // down_factor, + down_factor) npresult = np.reshape(x, new_shape) npresult = npresult.transpose(0, 1, 3, 5, 2, 4) - oshape = [n, c * down_factor * down_factor, h // down_factor, - w // down_factor] + oshape = [ + n, c * down_factor * down_factor, h // down_factor, w // down_factor + ] npresult = np.reshape(npresult, oshape) return npresult else: n, h, w, c = x.shape - new_shape = (n, h // down_factor, down_factor, - w // down_factor, down_factor, c) + new_shape = (n, h // down_factor, down_factor, w // down_factor, + down_factor, c) npresult = np.reshape(x, new_shape) npresult = npresult.transpose(0, 1, 3, 5, 2, 4) - oshape = [n, h // down_factor, - w // down_factor, c * down_factor * down_factor] + oshape = [ + n, h // down_factor, w // down_factor, c * down_factor * down_factor + ] npresult = np.reshape(npresult, oshape) return npresult class TestPixelUnshuffleOp(OpTest): + '''TestPixelUnshuffleOp''' + def setUp(self): + '''setUp''' + self.op_type = "pixel_unshuffle" self.init_data_format() n, c, h, w = 2, 1, 12, 12 @@ -66,32 +74,50 @@ def setUp(self): self.inputs = {"X": x} self.outputs = {"Out": npresult} - self.attrs = {"downscale_factor": down_factor, - "data_format": self.format} + self.attrs = { + "downscale_factor": down_factor, + "data_format": self.format + } def init_data_format(self): + '''init_data_format''' + self.format = "NCHW" def test_check_output(self): + '''test_check_output''' + self.check_output() def test_check_grad(self): + '''test_check_grad''' + self.check_grad(["X"], "Out") class TestChannelLast(TestPixelUnshuffleOp): + '''TestChannelLast''' + def init_data_format(self): + '''init_data_format''' + self.format = "NHWC" class TestPixelUnshuffleAPI(unittest.TestCase): + '''TestPixelUnshuffleAPI''' + def setUp(self): + '''setUp''' + self.x_1_np = np.random.random([2, 1, 12, 12]).astype("float64") self.x_2_np = np.random.random([2, 12, 12, 1]).astype("float64") self.out_1_np = pixel_unshuffle_np(self.x_1_np, 3) self.out_2_np = pixel_unshuffle_np(self.x_2_np, 3, "NHWC") def test_static_graph_functional(self): + '''test_static_graph_functional''' + for use_cuda in ([False, True] if core.is_compiled_with_cuda() else [False]): place = paddle.CUDAPlace(0) if use_cuda else paddle.CPUPlace() @@ -120,6 +146,8 @@ def test_static_graph_functional(self): # same test between layer and functional in this op. def test_static_graph_layer(self): + '''test_static_graph_layer''' + for use_cuda in ([False, True] if core.is_compiled_with_cuda() else [False]): place = paddle.CUDAPlace(0) if use_cuda else paddle.CPUPlace() @@ -152,6 +180,7 @@ def test_static_graph_layer(self): assert np.allclose(res_2, out_2_np) def run_dygraph(self, down_factor, data_format): + '''run_dygraph''' n, c, h, w = 2, 1, 12, 12 @@ -181,14 +210,22 @@ def run_dygraph(self, down_factor, data_format): self.assertTrue(np.allclose(result_functional.numpy(), npresult)) def test_dygraph1(self): + '''test_dygraph1''' + self.run_dygraph(3, "NCHW") def test_dygraph2(self): + '''test_dygraph2''' + self.run_dygraph(3, "NHWC") class TestPixelUnshuffleError(unittest.TestCase): + '''TestPixelUnshuffleError''' + def test_error_functional(self): + '''test_error_functional''' + def error_downscale_factor(): with paddle.fluid.dygraph.guard(): x = np.random.random([2, 1, 12, 12]).astype("float64") @@ -199,11 +236,14 @@ def error_downscale_factor(): def error_data_format(): with paddle.fluid.dygraph.guard(): x = np.random.random([2, 1, 12, 12]).astype("float64") - pixel_unshuffle = F.pixel_unshuffle(paddle.to_tensor(x), 3, "WOW") + pixel_unshuffle = F.pixel_unshuffle( + paddle.to_tensor(x), 3, "WOW") self.assertRaises(ValueError, error_data_format) def test_error_layer(self): + '''test_error_layer''' + def error_downscale_factor_layer(): with paddle.fluid.dygraph.guard(): x = np.random.random([2, 1, 12, 12]).astype("float64") diff --git a/python/paddle/nn/functional/vision.py b/python/paddle/nn/functional/vision.py index 3c10602198562..7574f9b3c0ab7 100644 --- a/python/paddle/nn/functional/vision.py +++ b/python/paddle/nn/functional/vision.py @@ -386,6 +386,8 @@ def pixel_unshuffle(x, downscale_factor, data_format="NCHW", name=None): type="pixel_unshuffle", inputs={"X": x}, outputs={"Out": out}, - attrs={"downscale_factor": downscale_factor, - "data_format": data_format}) + attrs={ + "downscale_factor": downscale_factor, + "data_format": data_format + }) return out diff --git a/python/paddle/nn/layer/vision.py b/python/paddle/nn/layer/vision.py index e13e825cef60a..10ecf2bb0a853 100644 --- a/python/paddle/nn/layer/vision.py +++ b/python/paddle/nn/layer/vision.py @@ -125,7 +125,6 @@ class PixelUnshuffle(Layer): """ - def __init__(self, downscale_factor, data_format="NCHW", name=None): super(PixelUnshuffle, self).__init__() From 89e36a0bc8e9094fb8734f441104d7c11c524d8b Mon Sep 17 00:00:00 2001 From: BrilliantYuKaimin <91609464+BrilliantYuKaimin@users.noreply.github.com> Date: Sun, 20 Mar 2022 12:40:42 +0800 Subject: [PATCH 10/33] Update test_pixel_unshuffle.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 增加对extra_repr的测试 --- python/paddle/fluid/tests/unittests/test_pixel_unshuffle.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/python/paddle/fluid/tests/unittests/test_pixel_unshuffle.py b/python/paddle/fluid/tests/unittests/test_pixel_unshuffle.py index 53eec4f7f08ce..ed342de589188 100644 --- a/python/paddle/fluid/tests/unittests/test_pixel_unshuffle.py +++ b/python/paddle/fluid/tests/unittests/test_pixel_unshuffle.py @@ -209,6 +209,11 @@ def run_dygraph(self, down_factor, data_format): paddle.to_tensor(x), 3, data_format) self.assertTrue(np.allclose(result_functional.numpy(), npresult)) + pixel_unshuffle_str = 'downscale_factor={}, data_format={}'.format( + pixel_unshuffle._downscale_factor, + pixel_unshuffle._data_format) + self.assertTrue(pixel_unshuffle.extra_repr(), pixel_unshuffle_str) + def test_dygraph1(self): '''test_dygraph1''' From 37925736916b26ff6d3d8609a1470e217cd111ee Mon Sep 17 00:00:00 2001 From: BrilliantYuKaimin <91609464+BrilliantYuKaimin@users.noreply.github.com> Date: Sun, 20 Mar 2022 17:00:54 +0800 Subject: [PATCH 11/33] =?UTF-8?q?=E4=BF=AE=E6=AD=A3=E4=BB=A3=E7=A0=81?= =?UTF-8?q?=E6=A0=BC=E5=BC=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- paddle/phi/infermeta/unary.cc | 4 ++-- python/paddle/fluid/tests/unittests/test_pixel_unshuffle.py | 6 +++--- python/paddle/nn/functional/vision.py | 1 + python/paddle/nn/layer/vision.py | 1 + 4 files changed, 7 insertions(+), 5 deletions(-) diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index f93b0279adfef..128a7d0182f32 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -890,7 +890,7 @@ void PixelUnshuffleInferMeta(const MetaTensor& x, if (!channel_last) { PADDLE_ENFORCE_EQ( - (input_dims[2] % downscale_factor) == 0 && + (input_dims[2] % downscale_factor) == 0 && (input_dims[3] % downscale_factor) == 0, true, phi::errors::InvalidArgument("Downscale factor[%u] should divide both " @@ -900,7 +900,7 @@ void PixelUnshuffleInferMeta(const MetaTensor& x, input_dims[3])); } else { PADDLE_ENFORCE_EQ( - (input_dims[1] % downscale_factor) == 0 && + (input_dims[1] % downscale_factor) == 0 && (input_dims[2] % downscale_factor) == 0, true, phi::errors::InvalidArgument("Downscale factor[%u] should divide both " diff --git a/python/paddle/fluid/tests/unittests/test_pixel_unshuffle.py b/python/paddle/fluid/tests/unittests/test_pixel_unshuffle.py index ed342de589188..56aa7858c129b 100644 --- a/python/paddle/fluid/tests/unittests/test_pixel_unshuffle.py +++ b/python/paddle/fluid/tests/unittests/test_pixel_unshuffle.py @@ -25,6 +25,7 @@ paddle.enable_static() + def pixel_unshuffle_np(x, down_factor, data_format="NCHW"): '''Numpy implementation of pixel unshuffle''' @@ -75,7 +76,7 @@ def setUp(self): self.inputs = {"X": x} self.outputs = {"Out": npresult} self.attrs = { - "downscale_factor": down_factor, + "downscale_factor": down_factor, "data_format": self.format } @@ -210,8 +211,7 @@ def run_dygraph(self, down_factor, data_format): self.assertTrue(np.allclose(result_functional.numpy(), npresult)) pixel_unshuffle_str = 'downscale_factor={}, data_format={}'.format( - pixel_unshuffle._downscale_factor, - pixel_unshuffle._data_format) + pixel_unshuffle._downscale_factor, pixel_unshuffle._data_format) self.assertTrue(pixel_unshuffle.extra_repr(), pixel_unshuffle_str) def test_dygraph1(self): diff --git a/python/paddle/nn/functional/vision.py b/python/paddle/nn/functional/vision.py index 7574f9b3c0ab7..c9d946b17b638 100644 --- a/python/paddle/nn/functional/vision.py +++ b/python/paddle/nn/functional/vision.py @@ -345,6 +345,7 @@ def pixel_shuffle(x, upscale_factor, data_format="NCHW", name=None): "data_format": data_format}) return out + def pixel_unshuffle(x, downscale_factor, data_format="NCHW", name=None): """ This API implements pixel unshuffle operation. diff --git a/python/paddle/nn/layer/vision.py b/python/paddle/nn/layer/vision.py index 10ecf2bb0a853..a6f498a37048e 100644 --- a/python/paddle/nn/layer/vision.py +++ b/python/paddle/nn/layer/vision.py @@ -88,6 +88,7 @@ def extra_repr(self): main_str += ', name={}'.format(self._name) return main_str + class PixelUnshuffle(Layer): """ From 3388c97acbf1798ad3b00ce3c6e20bc8cbc890e2 Mon Sep 17 00:00:00 2001 From: BrilliantYuKaimin <91609464+BrilliantYuKaimin@users.noreply.github.com> Date: Sun, 20 Mar 2022 22:47:56 +0800 Subject: [PATCH 12/33] Update test_pixel_unshuffle.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 修正对extra_repr的测试 --- .../paddle/fluid/tests/unittests/test_pixel_unshuffle.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_pixel_unshuffle.py b/python/paddle/fluid/tests/unittests/test_pixel_unshuffle.py index 56aa7858c129b..c6dab4d816b38 100644 --- a/python/paddle/fluid/tests/unittests/test_pixel_unshuffle.py +++ b/python/paddle/fluid/tests/unittests/test_pixel_unshuffle.py @@ -210,9 +210,10 @@ def run_dygraph(self, down_factor, data_format): paddle.to_tensor(x), 3, data_format) self.assertTrue(np.allclose(result_functional.numpy(), npresult)) - pixel_unshuffle_str = 'downscale_factor={}, data_format={}'.format( - pixel_unshuffle._downscale_factor, pixel_unshuffle._data_format) - self.assertTrue(pixel_unshuffle.extra_repr(), pixel_unshuffle_str) + pixel_unshuffle_str = 'downscale_factor={}'.format(down_factor) + if data_format != 'NCHW': + pixel_unshuffle_str += ', data_format={}'.format(data_format) + self.assertEqual(pixel_unshuffle.extra_repr(), pixel_unshuffle_str) def test_dygraph1(self): '''test_dygraph1''' From 9e28fef4b5bc2764ed9927cfa9f5332e08655349 Mon Sep 17 00:00:00 2001 From: BrilliantYuKaimin <91609464+BrilliantYuKaimin@users.noreply.github.com> Date: Tue, 22 Mar 2022 21:05:05 +0800 Subject: [PATCH 13/33] =?UTF-8?q?=E4=BF=AE=E6=94=B9pixel=5Funshuffle?= =?UTF-8?q?=E6=A0=B8=E5=87=BD=E6=95=B0=E7=9A=84=E5=AE=9E=E7=8E=B0=E4=BD=8D?= =?UTF-8?q?=E7=BD=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../cpu/pixel_unshuffle_grad_kernel.cc | 26 ------------------- .../phi/kernels/cpu/pixel_unshuffle_kernel.cc | 26 ------------------- .../gpu/pixel_unshuffle_grad_kernel.cu | 26 ------------------- .../phi/kernels/gpu/pixel_unshuffle_kernel.cu | 26 ------------------- ..._impl.h => pixel_unshuffle_grad_kernel.cc} | 20 +++++++++++++- ...ernel_impl.h => pixel_unshuffle_kernel.cc} | 20 +++++++++++++- 6 files changed, 38 insertions(+), 106 deletions(-) delete mode 100644 paddle/phi/kernels/cpu/pixel_unshuffle_grad_kernel.cc delete mode 100644 paddle/phi/kernels/cpu/pixel_unshuffle_kernel.cc delete mode 100644 paddle/phi/kernels/gpu/pixel_unshuffle_grad_kernel.cu delete mode 100644 paddle/phi/kernels/gpu/pixel_unshuffle_kernel.cu rename paddle/phi/kernels/{impl/pixel_unshuffle_grad_kernel_impl.h => pixel_unshuffle_grad_kernel.cc} (75%) rename paddle/phi/kernels/{impl/pixel_unshuffle_kernel_impl.h => pixel_unshuffle_kernel.cc} (75%) diff --git a/paddle/phi/kernels/cpu/pixel_unshuffle_grad_kernel.cc b/paddle/phi/kernels/cpu/pixel_unshuffle_grad_kernel.cc deleted file mode 100644 index ef61fca35957e..0000000000000 --- a/paddle/phi/kernels/cpu/pixel_unshuffle_grad_kernel.cc +++ /dev/null @@ -1,26 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/phi/kernels/pixel_unshuffle_grad_kernel.h" -#include "paddle/phi/kernels/impl/pixel_unshuffle_grad_kernel_impl.h" - -#include "paddle/phi/backends/cpu/cpu_context.h" -#include "paddle/phi/core/kernel_registry.h" - -PD_REGISTER_KERNEL(pixel_unshuffle_grad, - CPU, - ALL_LAYOUT, - phi::PixelUnshuffleGradKernel, - float, - double) {} diff --git a/paddle/phi/kernels/cpu/pixel_unshuffle_kernel.cc b/paddle/phi/kernels/cpu/pixel_unshuffle_kernel.cc deleted file mode 100644 index 9f4bc747f3209..0000000000000 --- a/paddle/phi/kernels/cpu/pixel_unshuffle_kernel.cc +++ /dev/null @@ -1,26 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/phi/kernels/pixel_unshuffle_kernel.h" -#include "paddle/phi/kernels/impl/pixel_unshuffle_kernel_impl.h" - -#include "paddle/phi/backends/cpu/cpu_context.h" -#include "paddle/phi/core/kernel_registry.h" - -PD_REGISTER_KERNEL(pixel_unshuffle, - CPU, - ALL_LAYOUT, - phi::PixelUnshuffleKernel, - float, - double) {} diff --git a/paddle/phi/kernels/gpu/pixel_unshuffle_grad_kernel.cu b/paddle/phi/kernels/gpu/pixel_unshuffle_grad_kernel.cu deleted file mode 100644 index 9cbbc5072aa25..0000000000000 --- a/paddle/phi/kernels/gpu/pixel_unshuffle_grad_kernel.cu +++ /dev/null @@ -1,26 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/phi/kernels/impl/pixel_unshuffle_grad_kernel_impl.h" -#include "paddle/phi/kernels/pixel_unshuffle_grad_kernel.h" - -#include "paddle/phi/backends/gpu/gpu_context.h" -#include "paddle/phi/core/kernel_registry.h" - -PD_REGISTER_KERNEL(pixel_unshuffle_grad, - GPU, - ALL_LAYOUT, - phi::PixelUnshuffleGradKernel, - float, - double) {} diff --git a/paddle/phi/kernels/gpu/pixel_unshuffle_kernel.cu b/paddle/phi/kernels/gpu/pixel_unshuffle_kernel.cu deleted file mode 100644 index ca2e520ffde10..0000000000000 --- a/paddle/phi/kernels/gpu/pixel_unshuffle_kernel.cu +++ /dev/null @@ -1,26 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/phi/kernels/impl/pixel_unshuffle_kernel_impl.h" -#include "paddle/phi/kernels/pixel_unshuffle_kernel.h" - -#include "paddle/phi/backends/gpu/gpu_context.h" -#include "paddle/phi/core/kernel_registry.h" - -PD_REGISTER_KERNEL(pixel_unshuffle, - GPU, - ALL_LAYOUT, - phi::PixelUnshuffleKernel, - float, - double) {} diff --git a/paddle/phi/kernels/impl/pixel_unshuffle_grad_kernel_impl.h b/paddle/phi/kernels/pixel_unshuffle_grad_kernel.cc similarity index 75% rename from paddle/phi/kernels/impl/pixel_unshuffle_grad_kernel_impl.h rename to paddle/phi/kernels/pixel_unshuffle_grad_kernel.cc index a418f184b7826..1d6787c6d85d4 100644 --- a/paddle/phi/kernels/impl/pixel_unshuffle_grad_kernel_impl.h +++ b/paddle/phi/kernels/pixel_unshuffle_grad_kernel.cc @@ -12,11 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. -#pragma once #include #include #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/kernels/funcs/math_function.h" +#include "paddle/phi/backends/all_context.h" +#include "paddle/phi/kernels/pixel_unshuffle_grad_kernel.h" +#include "paddle/phi/core/kernel_registry.h" namespace phi { @@ -54,3 +56,19 @@ void PixelUnshuffleGradKernel(const Context& ctx, } } // namespace phi + +PD_REGISTER_KERNEL(pixel_unshuffle_grad, + CPU, + ALL_LAYOUT, + phi::PixelUnshuffleGradKernel, + float, + double) {} + +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) +PD_REGISTER_KERNEL(pixel_unshuffle_grad, + GPU, + ALL_LAYOUT, + phi::PixelUnshuffleGradKernel, + float, + double) {} +#endif diff --git a/paddle/phi/kernels/impl/pixel_unshuffle_kernel_impl.h b/paddle/phi/kernels/pixel_unshuffle_kernel.cc similarity index 75% rename from paddle/phi/kernels/impl/pixel_unshuffle_kernel_impl.h rename to paddle/phi/kernels/pixel_unshuffle_kernel.cc index ac93e00a56576..d0c3c4f715f12 100644 --- a/paddle/phi/kernels/impl/pixel_unshuffle_kernel_impl.h +++ b/paddle/phi/kernels/pixel_unshuffle_kernel.cc @@ -12,11 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. -#pragma once #include #include #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/kernels/funcs/math_function.h" +#include "paddle/phi/backends/all_context.h" +#include "paddle/phi/kernels/pixel_unshuffle_kernel.h" +#include "paddle/phi/core/kernel_registry.h" namespace phi { @@ -53,3 +55,19 @@ void PixelUnshuffleKernel(const Context& ctx, } } // namespace phi + +PD_REGISTER_KERNEL(pixel_unshuffle, + CPU, + ALL_LAYOUT, + phi::PixelUnshuffleKernel, + float, + double) {} + +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) +PD_REGISTER_KERNEL(pixel_unshuffle, + GPU, + ALL_LAYOUT, + phi::PixelUnshuffleKernel, + float, + double) {} +#endif From ef6f8ea7fd177417874b60f3eba206d296d3a559 Mon Sep 17 00:00:00 2001 From: BrilliantYuKaimin <91609464+BrilliantYuKaimin@users.noreply.github.com> Date: Wed, 23 Mar 2022 11:23:46 +0800 Subject: [PATCH 14/33] =?UTF-8?q?=E4=BF=AE=E6=AD=A3=E4=BB=A3=E7=A0=81?= =?UTF-8?q?=E6=A0=BC=E5=BC=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../kernels/pixel_unshuffle_grad_kernel.cc | 14 ++++++------ paddle/phi/kernels/pixel_unshuffle_kernel.cc | 22 +++++++++---------- 2 files changed, 18 insertions(+), 18 deletions(-) diff --git a/paddle/phi/kernels/pixel_unshuffle_grad_kernel.cc b/paddle/phi/kernels/pixel_unshuffle_grad_kernel.cc index 1d6787c6d85d4..fe6db6620cb6e 100644 --- a/paddle/phi/kernels/pixel_unshuffle_grad_kernel.cc +++ b/paddle/phi/kernels/pixel_unshuffle_grad_kernel.cc @@ -12,13 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "paddle/phi/kernels/pixel_unshuffle_grad_kernel.h" #include #include -#include "paddle/phi/core/dense_tensor.h" -#include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/backends/all_context.h" -#include "paddle/phi/kernels/pixel_unshuffle_grad_kernel.h" +#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/math_function.h" namespace phi { @@ -65,10 +65,10 @@ PD_REGISTER_KERNEL(pixel_unshuffle_grad, double) {} #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) -PD_REGISTER_KERNEL(pixel_unshuffle_grad, +PD_REGISTER_KERNEL(pixel_unshuffle_grad, GPU, - ALL_LAYOUT, - phi::PixelUnshuffleGradKernel, - float, + ALL_LAYOUT, + phi::PixelUnshuffleGradKernel, + float, double) {} #endif diff --git a/paddle/phi/kernels/pixel_unshuffle_kernel.cc b/paddle/phi/kernels/pixel_unshuffle_kernel.cc index d0c3c4f715f12..df52bec681de9 100644 --- a/paddle/phi/kernels/pixel_unshuffle_kernel.cc +++ b/paddle/phi/kernels/pixel_unshuffle_kernel.cc @@ -12,13 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "paddle/phi/kernels/pixel_unshuffle_kernel.h" #include #include -#include "paddle/phi/core/dense_tensor.h" -#include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/backends/all_context.h" -#include "paddle/phi/kernels/pixel_unshuffle_kernel.h" +#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/math_function.h" namespace phi { @@ -56,18 +56,18 @@ void PixelUnshuffleKernel(const Context& ctx, } // namespace phi -PD_REGISTER_KERNEL(pixel_unshuffle, +PD_REGISTER_KERNEL(pixel_unshuffle, CPU, - ALL_LAYOUT, - phi::PixelUnshuffleKernel, - float, + ALL_LAYOUT, + phi::PixelUnshuffleKernel, + float, double) {} #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) -PD_REGISTER_KERNEL(pixel_unshuffle, +PD_REGISTER_KERNEL(pixel_unshuffle, GPU, - ALL_LAYOUT, - phi::PixelUnshuffleKernel, - float, + ALL_LAYOUT, + phi::PixelUnshuffleKernel, + float, double) {} #endif From 51bb6f8eb0549f6f27e5b6baa76ac28bebfadbdd Mon Sep 17 00:00:00 2001 From: BrilliantYuKaimin <91609464+BrilliantYuKaimin@users.noreply.github.com> Date: Tue, 29 Mar 2022 22:16:55 +0800 Subject: [PATCH 15/33] =?UTF-8?q?=E5=AE=8C=E5=96=84=E5=AF=B9=E8=BE=93?= =?UTF-8?q?=E5=85=A5=E7=9A=84=E6=A3=80=E6=9F=A5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- python/paddle/nn/functional/vision.py | 10 +++++++++- python/paddle/nn/layer/vision.py | 5 ++++- 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/python/paddle/nn/functional/vision.py b/python/paddle/nn/functional/vision.py index c9d946b17b638..5cc449587b508 100644 --- a/python/paddle/nn/functional/vision.py +++ b/python/paddle/nn/functional/vision.py @@ -368,8 +368,16 @@ def pixel_unshuffle(x, downscale_factor, data_format="NCHW", name=None): out = F.pixel_unshuffle(x, 3) # out.shape = [2, 9, 4, 4] """ + if len(x.shape) != 4: + raise ValueError( + "Input x should be 4D tensor, but received x with the shape of {}". + format(x.shape)) + if not isinstance(downscale_factor, int): - raise TypeError("downscale factor must be int type") + raise TypeError("Downscale factor must be int type") + + if downscale_factor <= 0: + raise ValueError("Downscale factor must be positive") if data_format not in ["NCHW", "NHWC"]: raise ValueError("Attr(data_format) should be 'NCHW' or 'NHWC'." diff --git a/python/paddle/nn/layer/vision.py b/python/paddle/nn/layer/vision.py index a6f498a37048e..4ab288a0c56ec 100644 --- a/python/paddle/nn/layer/vision.py +++ b/python/paddle/nn/layer/vision.py @@ -130,7 +130,10 @@ def __init__(self, downscale_factor, data_format="NCHW", name=None): super(PixelUnshuffle, self).__init__() if not isinstance(downscale_factor, int): - raise TypeError("downscale factor must be int type") + raise TypeError("Downscale factor must be int type") + + if downscale_factor <= 0: + raise ValueError("Downscale factor must be positive") if data_format not in ["NCHW", "NHWC"]: raise ValueError("Data format should be 'NCHW' or 'NHWC'." From cf80ace19ab1a24759d7332d489b8e26acb4386d Mon Sep 17 00:00:00 2001 From: BrilliantYuKaimin <91609464+BrilliantYuKaimin@users.noreply.github.com> Date: Tue, 29 Mar 2022 22:17:06 +0800 Subject: [PATCH 16/33] Update test_pixel_unshuffle.py --- .../tests/unittests/test_pixel_unshuffle.py | 45 +++++++++++++++---- 1 file changed, 37 insertions(+), 8 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_pixel_unshuffle.py b/python/paddle/fluid/tests/unittests/test_pixel_unshuffle.py index c6dab4d816b38..19ef33b2e2959 100644 --- a/python/paddle/fluid/tests/unittests/test_pixel_unshuffle.py +++ b/python/paddle/fluid/tests/unittests/test_pixel_unshuffle.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import print_function - import unittest import numpy as np @@ -23,8 +21,6 @@ import paddle.fluid.core as core import paddle.fluid as fluid -paddle.enable_static() - def pixel_unshuffle_np(x, down_factor, data_format="NCHW"): '''Numpy implementation of pixel unshuffle''' @@ -88,12 +84,16 @@ def init_data_format(self): def test_check_output(self): '''test_check_output''' + paddle.enable_static() self.check_output() + paddle.disable_static() def test_check_grad(self): '''test_check_grad''' + paddle.enable_static() self.check_grad(["X"], "Out") + paddle.disable_static() class TestChannelLast(TestPixelUnshuffleOp): @@ -232,12 +232,26 @@ class TestPixelUnshuffleError(unittest.TestCase): def test_error_functional(self): '''test_error_functional''' - def error_downscale_factor(): + def error_input(): + with paddle.fluid.dygraph.guard(): + x = np.random.random([4, 12, 12]).astype("float64") + pixel_unshuffle = F.pixel_unshuffle(paddle.to_tensor(x), 2) + + self.assertRaises(ValueError, error_input) + + def error_downscale_factor_1(): with paddle.fluid.dygraph.guard(): x = np.random.random([2, 1, 12, 12]).astype("float64") pixel_unshuffle = F.pixel_unshuffle(paddle.to_tensor(x), 3.33) - self.assertRaises(TypeError, error_downscale_factor) + self.assertRaises(TypeError, error_downscale_factor_1) + + def error_downscale_factor_2(): + with paddle.fluid.dygraph.guard(): + x = np.random.random([2, 1, 12, 12]).astype("float64") + pixel_unshuffle = F.pixel_unshuffle(paddle.to_tensor(x), -1) + + self.assertRaises(ValueError, error_downscale_factor_2) def error_data_format(): with paddle.fluid.dygraph.guard(): @@ -250,12 +264,27 @@ def error_data_format(): def test_error_layer(self): '''test_error_layer''' - def error_downscale_factor_layer(): + def error_input_layer(): + with paddle.fluid.dygraph.guard(): + x = np.random.random([4, 12, 12]).astype("float64") + ps = paddle.nn.PixelUnshuffle(2) + ps(paddle.to_tensor(x)) + + self.assertRaises(ValueError, error_input_layer) + + def error_downscale_factor_layer_1(): with paddle.fluid.dygraph.guard(): x = np.random.random([2, 1, 12, 12]).astype("float64") ps = paddle.nn.PixelUnshuffle(3.33) - self.assertRaises(TypeError, error_downscale_factor_layer) + self.assertRaises(TypeError, error_downscale_factor_layer_1) + + def error_downscale_factor_layer_2(): + with paddle.fluid.dygraph.guard(): + x = np.random.random([2, 1, 12, 12]).astype("float64") + ps = paddle.nn.PixelUnshuffle(-1) + + self.assertRaises(ValueError, error_downscale_factor_layer_2) def error_data_format_layer(): with paddle.fluid.dygraph.guard(): From 4ca1ab4dc57621ebc58adb632ae59c0e770e53dd Mon Sep 17 00:00:00 2001 From: BrilliantYuKaimin <91609464+BrilliantYuKaimin@users.noreply.github.com> Date: Tue, 29 Mar 2022 22:18:01 +0800 Subject: [PATCH 17/33] =?UTF-8?q?=E5=AE=8C=E5=96=84pixel=5Funshuffle?= =?UTF-8?q?=E7=9A=84=E8=BE=93=E5=85=A5=E6=A3=80=E6=9F=A5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- paddle/phi/infermeta/unary.cc | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index 128a7d0182f32..7ed9aef1bcada 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -885,6 +885,15 @@ void PixelUnshuffleInferMeta(const MetaTensor& x, "Input should be a 4-D tensor of format [N, C, H, W] " "or [N, H, W, C], but got %u.", input_dims.size())); + PADDLE_ENFORCE_GE(downscale_factor, 1, + platform::errors::InvalidArgument( + "downscale_factor should be larger than 0.")) + PADDLE_ENFORCE_EQ(data_format == "NCHW" || data_format == "NHWC", + true, + phi::errors::InvalidArgument( + "data_format must be one of " + "NCHW and NHWC. But recevied data_format: %s", + data_format)); const bool channel_last = (data_format == "NHWC"); From ea07a1726c888c24840a732ef2842f3631e4f167 Mon Sep 17 00:00:00 2001 From: BrilliantYuKaimin <91609464+BrilliantYuKaimin@users.noreply.github.com> Date: Tue, 29 Mar 2022 22:18:12 +0800 Subject: [PATCH 18/33] Update pixel_unshuffle_op.cc --- paddle/fluid/operators/pixel_unshuffle_op.cc | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/paddle/fluid/operators/pixel_unshuffle_op.cc b/paddle/fluid/operators/pixel_unshuffle_op.cc index 1f46e815ced1b..452a9235f8974 100644 --- a/paddle/fluid/operators/pixel_unshuffle_op.cc +++ b/paddle/fluid/operators/pixel_unshuffle_op.cc @@ -9,7 +9,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include #include "paddle/fluid/framework/infershape_utils.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_version_registry.h" @@ -37,12 +36,7 @@ class PixelUnshuffleOpMaker : public framework::OpProtoAndCheckerMaker { "W/factor] or [N, H/factor, W/factor, C*factor^2]."); AddAttr("downscale_factor", "the factor to decrease spatial resolution by.") - .SetDefault(1) - .AddCustomChecker([](const int& downscale_factor) { - PADDLE_ENFORCE_GE(downscale_factor, 1, - platform::errors::InvalidArgument( - "downscale_factor should be larger than 0.")); - }); + .SetDefault(1); AddAttr( "data_format", "An optional string from: \"NHWC\", \"NCHW\". " @@ -70,6 +64,7 @@ class PixelUnshuffleGradOpMaker : public framework::SingleGradOpMaker { public: using framework::SingleGradOpMaker::SingleGradOpMaker; + protected: void Apply(GradOpPtr op) const override { op->SetType("pixel_unshuffle_grad"); op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); From b0cc19aac7cf814126c18c12e8b3e8bd0c7d64a6 Mon Sep 17 00:00:00 2001 From: BrilliantYuKaimin <91609464+BrilliantYuKaimin@users.noreply.github.com> Date: Wed, 30 Mar 2022 09:52:16 +0800 Subject: [PATCH 19/33] Update unary.cc --- paddle/phi/infermeta/unary.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index d30a52fc08c8b..46e8149c20ce8 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -1216,8 +1216,8 @@ void PixelUnshuffleInferMeta(const MetaTensor& x, "or [N, H, W, C], but got %u.", input_dims.size())); PADDLE_ENFORCE_GE(downscale_factor, 1, - platform::errors::InvalidArgument( - "downscale_factor should be larger than 0.")) + phi::errors::InvalidArgument( + "downscale_factor should be larger than 0.")); PADDLE_ENFORCE_EQ(data_format == "NCHW" || data_format == "NHWC", true, phi::errors::InvalidArgument( From e96d98a3c8eefff354ea317bec4a499095e92066 Mon Sep 17 00:00:00 2001 From: BrilliantYuKaimin <91609464+BrilliantYuKaimin@users.noreply.github.com> Date: Fri, 1 Apr 2022 10:41:20 +0800 Subject: [PATCH 20/33] add pixel_unshuffle --- tools/static_mode_white_list.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tools/static_mode_white_list.py b/tools/static_mode_white_list.py index 365047f7e8382..e9ab186cac918 100755 --- a/tools/static_mode_white_list.py +++ b/tools/static_mode_white_list.py @@ -384,6 +384,7 @@ 'test_partial_sum_op', 'test_pass_builder', 'test_pixel_shuffle', + 'test_pixel_unshuffle', 'test_polygon_box_transform', 'test_pool1d_api', 'test_pool2d_api', From fc1ff5389f0318976c88205f0689625897e82b14 Mon Sep 17 00:00:00 2001 From: BrilliantYuKaimin <91609464+BrilliantYuKaimin@users.noreply.github.com> Date: Fri, 1 Apr 2022 10:41:27 +0800 Subject: [PATCH 21/33] Update test_pixel_unshuffle.py --- python/paddle/fluid/tests/unittests/test_pixel_unshuffle.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_pixel_unshuffle.py b/python/paddle/fluid/tests/unittests/test_pixel_unshuffle.py index 19ef33b2e2959..768a9e307c91e 100644 --- a/python/paddle/fluid/tests/unittests/test_pixel_unshuffle.py +++ b/python/paddle/fluid/tests/unittests/test_pixel_unshuffle.py @@ -84,16 +84,12 @@ def init_data_format(self): def test_check_output(self): '''test_check_output''' - paddle.enable_static() self.check_output() - paddle.disable_static() def test_check_grad(self): '''test_check_grad''' - paddle.enable_static() self.check_grad(["X"], "Out") - paddle.disable_static() class TestChannelLast(TestPixelUnshuffleOp): From b3c084be32a5d23acdd446f37798e06769ca6e40 Mon Sep 17 00:00:00 2001 From: BrilliantYuKaimin <91609464+BrilliantYuKaimin@users.noreply.github.com> Date: Fri, 1 Apr 2022 11:23:16 +0800 Subject: [PATCH 22/33] Update vision.py --- python/paddle/nn/functional/vision.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/paddle/nn/functional/vision.py b/python/paddle/nn/functional/vision.py index 5cc449587b508..26504c6784779 100644 --- a/python/paddle/nn/functional/vision.py +++ b/python/paddle/nn/functional/vision.py @@ -17,6 +17,7 @@ from ...fluid.layer_helper import LayerHelper from ...fluid.data_feeder import check_variable_and_dtype from ...fluid import dygraph_utils +from ...fluid.framework import _non_static_mode import numpy as np from paddle import _C_ops from ...device import is_compiled_with_rocm @@ -384,7 +385,7 @@ def pixel_unshuffle(x, downscale_factor, data_format="NCHW", name=None): "But recevie Attr(data_format): {} ".format( data_format)) - if in_dynamic_mode(): + if _non_static_mode(): return _C_ops.pixel_unshuffle(x, "downscale_factor", downscale_factor, "data_format", data_format) From bcb06dd0cd2cfede7ad9f62ca24bbd1b2cd98b69 Mon Sep 17 00:00:00 2001 From: BrilliantYuKaimin <91609464+BrilliantYuKaimin@users.noreply.github.com> Date: Fri, 1 Apr 2022 16:22:09 +0800 Subject: [PATCH 23/33] =?UTF-8?q?=E8=B0=83=E6=95=B4=E4=BB=A3=E7=A0=81?= =?UTF-8?q?=E6=A0=BC=E5=BC=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- paddle/phi/infermeta/unary.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index 46e8149c20ce8..a5b5b1068a67d 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -1215,7 +1215,8 @@ void PixelUnshuffleInferMeta(const MetaTensor& x, "Input should be a 4-D tensor of format [N, C, H, W] " "or [N, H, W, C], but got %u.", input_dims.size())); - PADDLE_ENFORCE_GE(downscale_factor, 1, + PADDLE_ENFORCE_GE(downscale_factor, + 1, phi::errors::InvalidArgument( "downscale_factor should be larger than 0.")); PADDLE_ENFORCE_EQ(data_format == "NCHW" || data_format == "NHWC", From b3126ede16ad0a2c134e1d01ada68ea70cc04861 Mon Sep 17 00:00:00 2001 From: BrilliantYuKaimin <91609464+BrilliantYuKaimin@users.noreply.github.com> Date: Thu, 7 Apr 2022 12:57:53 +0800 Subject: [PATCH 24/33] Update vision.py --- python/paddle/nn/functional/vision.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/nn/functional/vision.py b/python/paddle/nn/functional/vision.py index 26504c6784779..617fddd57c88f 100644 --- a/python/paddle/nn/functional/vision.py +++ b/python/paddle/nn/functional/vision.py @@ -17,11 +17,11 @@ from ...fluid.layer_helper import LayerHelper from ...fluid.data_feeder import check_variable_and_dtype from ...fluid import dygraph_utils -from ...fluid.framework import _non_static_mode import numpy as np from paddle import _C_ops from ...device import is_compiled_with_rocm from paddle import in_dynamic_mode +from paddle.framework import _non_static_mode __all__ = [] From 86535198d093e0b8ad79d658bc447a023cd057cf Mon Sep 17 00:00:00 2001 From: BrilliantYuKaimin <91609464+BrilliantYuKaimin@users.noreply.github.com> Date: Thu, 14 Apr 2022 15:32:10 +0800 Subject: [PATCH 25/33] Delete extra spaces --- paddle/phi/infermeta/unary.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/phi/infermeta/unary.h b/paddle/phi/infermeta/unary.h index b67261600d193..38e5144b8c821 100644 --- a/paddle/phi/infermeta/unary.h +++ b/paddle/phi/infermeta/unary.h @@ -209,7 +209,7 @@ void PixelUnshuffleInferMeta(const MetaTensor& x, int downscale_factor, const std::string& data_format, MetaTensor* out); - + void PNormInferMeta(const MetaTensor& x, float porder, int axis, From c3fbce61008f326d2d3043acf289e7310ef93034 Mon Sep 17 00:00:00 2001 From: BrilliantYuKaimin <91609464+BrilliantYuKaimin@users.noreply.github.com> Date: Mon, 18 Apr 2022 14:34:18 +0800 Subject: [PATCH 26/33] Update pixel_unshuffle_sig.cc --- paddle/phi/ops/compat/pixel_unshuffle_sig.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/paddle/phi/ops/compat/pixel_unshuffle_sig.cc b/paddle/phi/ops/compat/pixel_unshuffle_sig.cc index e78b676dd629e..ce2939d13a042 100644 --- a/paddle/phi/ops/compat/pixel_unshuffle_sig.cc +++ b/paddle/phi/ops/compat/pixel_unshuffle_sig.cc @@ -25,9 +25,9 @@ KernelSignature PixelUnshuffleOpArgumentMapping( KernelSignature PixelUnshuffleGradOpArgumentMapping( const ArgumentMappingContext& ctx) { return KernelSignature("pixel_unshuffle_grad", - {GradVarName("Out")}, + {"Out@GRAD"}, {"downscale_factor", "data_format"}, - {GradVarName("X")}); + {"X@GRAD"}); } } // namespace phi From 2310fc8ce4ce0c3591b832012796e37082544517 Mon Sep 17 00:00:00 2001 From: BrilliantYuKaimin <91609464+BrilliantYuKaimin@users.noreply.github.com> Date: Thu, 21 Apr 2022 09:51:51 +0800 Subject: [PATCH 27/33] Update vision.py --- python/paddle/nn/functional/vision.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/python/paddle/nn/functional/vision.py b/python/paddle/nn/functional/vision.py index 617fddd57c88f..58ced222aa040 100644 --- a/python/paddle/nn/functional/vision.py +++ b/python/paddle/nn/functional/vision.py @@ -351,17 +351,19 @@ def pixel_unshuffle(x, downscale_factor, data_format="NCHW", name=None): """ This API implements pixel unshuffle operation. See more details in :ref:`api_nn_vision_PixelUnshuffle` . + Parameters: - x(Tensor): 4-D tensor, the data type should be float32 or float64. - downscale_factor(int): factor to decrease spatial resolution. - data_format (str): The data format of the input and output data. An optional string from: "NCHW", "NHWC". The default is "NCHW". When it is "NCHW", the data is stored in the order of: [batch_size, input_channels, input_height, input_width]. - name (str, optional): The default value is None. Normally there is no need for user to set this property. + x (Tensor): 4-D tensor, the data type should be float32 or float64. + downscale_factor (int): Factor to decrease spatial resolution. + data_format (str): The data format of the input and output data. An optional string of NCHW or NHWC. The default is NCHW. When it is NCHW, the data is stored in the order of [batch_size, input_channels, input_height, input_width]. + name (str, optional): Name for the operation (optional, default is None). Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`. + Returns: - Out(tensor): Reshaped tensor according to the new dimension. - Raises: - ValueError: If downscale_factor cannot divide both the height and width of input. + Out (Tensor): Reshaped tensor according to the new dimension. + Examples: .. code-block:: python + :name: pixel_unshuffle-example import paddle import paddle.nn.functional as F From 777c9f87cc6dbe69b9f5e15f0d0ab0a57ad819b7 Mon Sep 17 00:00:00 2001 From: BrilliantYuKaimin <91609464+BrilliantYuKaimin@users.noreply.github.com> Date: Thu, 21 Apr 2022 09:51:54 +0800 Subject: [PATCH 28/33] Update vision.py --- python/paddle/nn/layer/vision.py | 26 +++++++++++--------------- 1 file changed, 11 insertions(+), 15 deletions(-) diff --git a/python/paddle/nn/layer/vision.py b/python/paddle/nn/layer/vision.py index 4ab288a0c56ec..f8385c1eefa7c 100644 --- a/python/paddle/nn/layer/vision.py +++ b/python/paddle/nn/layer/vision.py @@ -91,31 +91,27 @@ def extra_repr(self): class PixelUnshuffle(Layer): """ - - PixelUnshuffle Layer - - This operator rearranges elements in a tensor of shape [N, C, H, W] - to a tensor of shape [N, C*downscale_factor**2, H/downscale_factor, W/downscale_factor], - or from shape [N, H, W, C] to [N, H/downscale_factor, W/downscale_factor, C*downscale_factor**2]. - This operation is the reversion of PixelShuffle operation. + This operator rearranges elements in a tensor of shape :math:`[N, C, H, W]` + to a tensor of shape :math:`[N, r^2C, H/r, W/r]`, or from shape + :math:`[N, H, W, C]` to :math:`[N, H/r, W/r, r^2C]`, where :math:`r` is the + downscale factor. This operation is the reversion of PixelShuffle operation. Please refer to the paper: `Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network `_ . by Shi et. al (2016) for more details. Parameters: - - downscale_factor(int): factor to decrease spatial resolution. - data_format (str): The data format of the input and output data. An optional string from: "NCHW", "NHWC". The default is "NCHW". When it is "NCHW", the data is stored in the order of: [batch_size, input_channels, input_height, input_width]. - name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. + downscale_factor (int): Factor to decrease spatial resolution. + data_format (str): The data format of the input and output data. An optional string of NCHW or NHWC. The default is NCHW. When it is NCHW, the data is stored in the order of [batch_size, input_channels, input_height, input_width]. + name (str, optional): Name for the operation (optional, default is None). Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`. Shape: - - x: 4-D tensor with shape: (N, C, H, W) or (N, H, W, C). - - out: 4-D tensor with shape: (N, C*downscale_factor**2, H/downscale_factor, W/downscale_factor) or (N, H/downscale_factor, W/downscale_factor, C*downscale_factor**2). - + - **x**: 4-D tensor with shape of :math:`[N, C, H, W]` or :math:`[N, C, H, W]`. + - **out**: 4-D tensor with shape of :math:`[N, r^2C, H/r, W/r]` or :math:`[N, H/r, W/r, r^2C]`, where :math:`r` is :attr:`downscale_factor`. Examples: .. code-block:: python - + :name: PixelUnshuffle-example + import paddle import paddle.nn as nn From b937d0e0d7a1b2fb82bac8887d0db492fe69ab35 Mon Sep 17 00:00:00 2001 From: BrilliantYuKaimin <91609464+BrilliantYuKaimin@users.noreply.github.com> Date: Thu, 21 Apr 2022 09:52:19 +0800 Subject: [PATCH 29/33] add PixelUnshuffleGradInferMeta --- paddle/phi/infermeta/backward.cc | 30 ++++++++++++++++++++++++++++++ paddle/phi/infermeta/backward.h | 5 +++++ 2 files changed, 35 insertions(+) diff --git a/paddle/phi/infermeta/backward.cc b/paddle/phi/infermeta/backward.cc index 567f39a915c02..e9bc08a5e39a9 100644 --- a/paddle/phi/infermeta/backward.cc +++ b/paddle/phi/infermeta/backward.cc @@ -427,6 +427,36 @@ void NllLossGradInferMeta(const MetaTensor& x, } } +void PixelUnshuffleGradInferMeta(const MetaTensor& out_grad, + int downscale_factor, + const std::string& data_format, + MetaTensor* x_grad) { + auto do_dims = out_grad.dims(); + PADDLE_ENFORCE_EQ(do_dims.size(), + 4, + phi::errors::InvalidArgument( + "Input should be a 4-D tensor of format [N, C, H, W] " + "or [N, H, W, C], but got %u.", + do_dims.size())); + + const bool channel_last = (data_format == "NHWC"); + + auto dx_dims = do_dims; + dx_dims[0] = do_dims[0]; + + if (!channel_last) { + dx_dims[1] = do_dims[1] / (downscale_factor * downscale_factor); + dx_dims[2] = do_dims[2] * downscale_factor; + dx_dims[3] = do_dims[3] * downscale_factor; + } else { + dx_dims[1] = do_dims[1] * downscale_factor; + dx_dims[2] = do_dims[2] * downscale_factor; + dx_dims[3] = do_dims[3] / (downscale_factor * downscale_factor); + } + x_grad->set_dims(dx_dims); + x_grad->set_dtype(out_grad.dtype()); +} + void PoolGradInferMeta(const MetaTensor& x, const MetaTensor& out, const MetaTensor& dout, diff --git a/paddle/phi/infermeta/backward.h b/paddle/phi/infermeta/backward.h index 6807438ebbb75..ca178edc52415 100644 --- a/paddle/phi/infermeta/backward.h +++ b/paddle/phi/infermeta/backward.h @@ -173,6 +173,11 @@ void NllLossGradInferMeta(const MetaTensor& input, MetaTensor* intput_grad, MetaConfig config = MetaConfig()); +void PixelUnshuffleGradInferMeta(const MetaTensor& out_grad, + int downscale_factor, + const std::string& data_format, + MetaTensor* x_grad); + void PsroiPoolGradInferMeta(const MetaTensor& x, const MetaTensor& rois, paddle::optional rois_num, From 73fabcb0f2861d8ed3b6d038228829f02829c039 Mon Sep 17 00:00:00 2001 From: BrilliantYuKaimin <91609464+BrilliantYuKaimin@users.noreply.github.com> Date: Thu, 21 Apr 2022 09:52:49 +0800 Subject: [PATCH 30/33] remove PixelUnshuffleOpArgumentMapping --- paddle/phi/ops/compat/pixel_unshuffle_sig.cc | 8 -------- 1 file changed, 8 deletions(-) diff --git a/paddle/phi/ops/compat/pixel_unshuffle_sig.cc b/paddle/phi/ops/compat/pixel_unshuffle_sig.cc index ce2939d13a042..817dc1a228877 100644 --- a/paddle/phi/ops/compat/pixel_unshuffle_sig.cc +++ b/paddle/phi/ops/compat/pixel_unshuffle_sig.cc @@ -16,12 +16,6 @@ namespace phi { -KernelSignature PixelUnshuffleOpArgumentMapping( - const ArgumentMappingContext& ctx) { - return KernelSignature( - "pixel_unshuffle", {"X"}, {"downscale_factor", "data_format"}, {"Out"}); -} - KernelSignature PixelUnshuffleGradOpArgumentMapping( const ArgumentMappingContext& ctx) { return KernelSignature("pixel_unshuffle_grad", @@ -32,7 +26,5 @@ KernelSignature PixelUnshuffleGradOpArgumentMapping( } // namespace phi -PD_REGISTER_ARG_MAPPING_FN(pixel_unshuffle, - phi::PixelUnshuffleOpArgumentMapping); PD_REGISTER_ARG_MAPPING_FN(pixel_unshuffle_grad, phi::PixelUnshuffleGradOpArgumentMapping); From d5f68748c4a2828ed5d5f0fea83b5425dc6f6513 Mon Sep 17 00:00:00 2001 From: BrilliantYuKaimin <91609464+BrilliantYuKaimin@users.noreply.github.com> Date: Thu, 21 Apr 2022 09:53:06 +0800 Subject: [PATCH 31/33] Update pixel_unshuffle_op.cc --- paddle/fluid/operators/pixel_unshuffle_op.cc | 68 ++++++-------------- 1 file changed, 21 insertions(+), 47 deletions(-) diff --git a/paddle/fluid/operators/pixel_unshuffle_op.cc b/paddle/fluid/operators/pixel_unshuffle_op.cc index 452a9235f8974..ae5c1db50296e 100644 --- a/paddle/fluid/operators/pixel_unshuffle_op.cc +++ b/paddle/fluid/operators/pixel_unshuffle_op.cc @@ -1,19 +1,23 @@ -/*Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include "paddle/fluid/framework/infershape_utils.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_version_registry.h" #include "paddle/phi/core/infermeta_utils.h" #include "paddle/phi/infermeta/unary.h" +#include "paddle/phi/infermeta/backward.h" namespace paddle { namespace operators { @@ -76,42 +80,6 @@ class PixelUnshuffleGradOpMaker : public framework::SingleGradOpMaker { class PixelUnshuffleGradOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext* ctx) const override { - PADDLE_ENFORCE_EQ( - ctx->HasInput(framework::GradVarName("Out")), true, - platform::errors::NotFound("Input(Out@Grad) should not be null")); - PADDLE_ENFORCE_EQ( - ctx->HasOutput(framework::GradVarName("X")), true, - platform::errors::NotFound("Output(X@Grad) should not be null")); - - auto do_dims = ctx->GetInputDim(framework::GradVarName("Out")); - PADDLE_ENFORCE_EQ(do_dims.size(), 4, - platform::errors::InvalidArgument( - "Input should be a 4-D tensor of format [N, C, H, W] " - "or [N, H, W, C], but got %u.", - do_dims.size())); - - auto downscale_factor = ctx->Attrs().Get("downscale_factor"); - - const std::string data_format = - ctx->Attrs().Get("data_format"); - const bool channel_last = (data_format == "NHWC"); - - auto dx_dims = do_dims; - dx_dims[0] = do_dims[0]; - - if (!channel_last) { - dx_dims[1] = do_dims[1] / (downscale_factor * downscale_factor); - dx_dims[2] = do_dims[2] * downscale_factor; - dx_dims[3] = do_dims[3] * downscale_factor; - } else { - dx_dims[1] = do_dims[1] * downscale_factor; - dx_dims[2] = do_dims[2] * downscale_factor; - dx_dims[3] = do_dims[3] / (downscale_factor * downscale_factor); - } - ctx->SetOutputDim(framework::GradVarName("X"), dx_dims); - } }; } // namespace operators @@ -127,4 +95,10 @@ REGISTER_OPERATOR(pixel_unshuffle, ops::PixelUnshuffleOp, ops::PixelUnshuffleGradOpMaker, PixelUnshuffleInferShapeFunctor); -REGISTER_OPERATOR(pixel_unshuffle_grad, ops::PixelUnshuffleGradOp); +DECLARE_INFER_SHAPE_FUNCTOR(pixel_unshuffle_grad, + PixelUnshuffleGradInferShapeFunctor, + PD_INFER_META(phi::PixelUnshuffleGradInferMeta)); + +REGISTER_OPERATOR(pixel_unshuffle_grad, + ops::PixelUnshuffleGradOp, + PixelUnshuffleGradInferShapeFunctor); From e871227e7059b0cce18f587fdc04349da08edb6b Mon Sep 17 00:00:00 2001 From: BrilliantYuKaimin <91609464+BrilliantYuKaimin@users.noreply.github.com> Date: Thu, 21 Apr 2022 09:54:39 +0800 Subject: [PATCH 32/33] =?UTF-8?q?=E8=B0=83=E6=95=B4pixel=5Funshuffle?= =?UTF-8?q?=E5=8F=8A=E5=85=B6=E6=A2=AF=E5=BA=A6=E7=9A=84=E6=A0=B8=E5=87=BD?= =?UTF-8?q?=E6=95=B0=E7=9A=84=E5=AE=9E=E7=8E=B0=E4=BD=8D=E7=BD=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../cpu/pixel_unshuffle_grad_kernel.cc | 26 +++++++++++++++++ .../phi/kernels/cpu/pixel_unshuffle_kernel.cc | 26 +++++++++++++++++ .../gpu/pixel_unshuffle_grad_kernel.cu | 26 +++++++++++++++++ .../phi/kernels/gpu/pixel_unshuffle_kernel.cu | 26 +++++++++++++++++ .../pixel_unshuffle_grad_kernel_impl.h} | 28 ++++--------------- .../pixel_unshuffle_kernel_impl.h} | 28 ++++--------------- .../phi/kernels/pixel_unshuffle_grad_kernel.h | 2 +- paddle/phi/kernels/pixel_unshuffle_kernel.h | 2 +- 8 files changed, 118 insertions(+), 46 deletions(-) create mode 100644 paddle/phi/kernels/cpu/pixel_unshuffle_grad_kernel.cc create mode 100644 paddle/phi/kernels/cpu/pixel_unshuffle_kernel.cc create mode 100644 paddle/phi/kernels/gpu/pixel_unshuffle_grad_kernel.cu create mode 100644 paddle/phi/kernels/gpu/pixel_unshuffle_kernel.cu rename paddle/phi/kernels/{pixel_unshuffle_grad_kernel.cc => impl/pixel_unshuffle_grad_kernel_impl.h} (71%) rename paddle/phi/kernels/{pixel_unshuffle_kernel.cc => impl/pixel_unshuffle_kernel_impl.h} (71%) diff --git a/paddle/phi/kernels/cpu/pixel_unshuffle_grad_kernel.cc b/paddle/phi/kernels/cpu/pixel_unshuffle_grad_kernel.cc new file mode 100644 index 0000000000000..ef61fca35957e --- /dev/null +++ b/paddle/phi/kernels/cpu/pixel_unshuffle_grad_kernel.cc @@ -0,0 +1,26 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/pixel_unshuffle_grad_kernel.h" +#include "paddle/phi/kernels/impl/pixel_unshuffle_grad_kernel_impl.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" + +PD_REGISTER_KERNEL(pixel_unshuffle_grad, + CPU, + ALL_LAYOUT, + phi::PixelUnshuffleGradKernel, + float, + double) {} diff --git a/paddle/phi/kernels/cpu/pixel_unshuffle_kernel.cc b/paddle/phi/kernels/cpu/pixel_unshuffle_kernel.cc new file mode 100644 index 0000000000000..9f4bc747f3209 --- /dev/null +++ b/paddle/phi/kernels/cpu/pixel_unshuffle_kernel.cc @@ -0,0 +1,26 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/pixel_unshuffle_kernel.h" +#include "paddle/phi/kernels/impl/pixel_unshuffle_kernel_impl.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" + +PD_REGISTER_KERNEL(pixel_unshuffle, + CPU, + ALL_LAYOUT, + phi::PixelUnshuffleKernel, + float, + double) {} diff --git a/paddle/phi/kernels/gpu/pixel_unshuffle_grad_kernel.cu b/paddle/phi/kernels/gpu/pixel_unshuffle_grad_kernel.cu new file mode 100644 index 0000000000000..9cbbc5072aa25 --- /dev/null +++ b/paddle/phi/kernels/gpu/pixel_unshuffle_grad_kernel.cu @@ -0,0 +1,26 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/impl/pixel_unshuffle_grad_kernel_impl.h" +#include "paddle/phi/kernels/pixel_unshuffle_grad_kernel.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" + +PD_REGISTER_KERNEL(pixel_unshuffle_grad, + GPU, + ALL_LAYOUT, + phi::PixelUnshuffleGradKernel, + float, + double) {} diff --git a/paddle/phi/kernels/gpu/pixel_unshuffle_kernel.cu b/paddle/phi/kernels/gpu/pixel_unshuffle_kernel.cu new file mode 100644 index 0000000000000..ca2e520ffde10 --- /dev/null +++ b/paddle/phi/kernels/gpu/pixel_unshuffle_kernel.cu @@ -0,0 +1,26 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/impl/pixel_unshuffle_kernel_impl.h" +#include "paddle/phi/kernels/pixel_unshuffle_kernel.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" + +PD_REGISTER_KERNEL(pixel_unshuffle, + GPU, + ALL_LAYOUT, + phi::PixelUnshuffleKernel, + float, + double) {} diff --git a/paddle/phi/kernels/pixel_unshuffle_grad_kernel.cc b/paddle/phi/kernels/impl/pixel_unshuffle_grad_kernel_impl.h similarity index 71% rename from paddle/phi/kernels/pixel_unshuffle_grad_kernel.cc rename to paddle/phi/kernels/impl/pixel_unshuffle_grad_kernel_impl.h index fe6db6620cb6e..cb02539f2e890 100644 --- a/paddle/phi/kernels/pixel_unshuffle_grad_kernel.cc +++ b/paddle/phi/kernels/impl/pixel_unshuffle_grad_kernel_impl.h @@ -12,25 +12,25 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/phi/kernels/pixel_unshuffle_grad_kernel.h" +#pragma once + #include #include -#include "paddle/phi/backends/all_context.h" + #include "paddle/phi/core/dense_tensor.h" -#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/funcs/math_function.h" namespace phi { template -void PixelUnshuffleGradKernel(const Context& ctx, +void PixelUnshuffleGradKernel(const Context& dev_ctx, const DenseTensor& out_grad, int downscale_factor, const std::string& data_format, DenseTensor* x_grad) { auto* dout = &out_grad; auto* dx = x_grad; - ctx.template Alloc(dx); + dev_ctx.template Alloc(dx); int factor = downscale_factor; bool channel_last = (data_format == "NHWC"); auto do_dims = dout->dims(); @@ -51,24 +51,8 @@ void PixelUnshuffleGradKernel(const Context& ctx, o.Resize({do_dims[0], do_dims[1], factor, do_dims[2], factor, dx_dims[3]}); } phi::funcs::Transpose trans; - trans(ctx, t, &o, axis); + trans(dev_ctx, t, &o, axis); dx->Resize(dx_dims); } } // namespace phi - -PD_REGISTER_KERNEL(pixel_unshuffle_grad, - CPU, - ALL_LAYOUT, - phi::PixelUnshuffleGradKernel, - float, - double) {} - -#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) -PD_REGISTER_KERNEL(pixel_unshuffle_grad, - GPU, - ALL_LAYOUT, - phi::PixelUnshuffleGradKernel, - float, - double) {} -#endif diff --git a/paddle/phi/kernels/pixel_unshuffle_kernel.cc b/paddle/phi/kernels/impl/pixel_unshuffle_kernel_impl.h similarity index 71% rename from paddle/phi/kernels/pixel_unshuffle_kernel.cc rename to paddle/phi/kernels/impl/pixel_unshuffle_kernel_impl.h index df52bec681de9..0a140b270ba1b 100644 --- a/paddle/phi/kernels/pixel_unshuffle_kernel.cc +++ b/paddle/phi/kernels/impl/pixel_unshuffle_kernel_impl.h @@ -12,24 +12,24 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/phi/kernels/pixel_unshuffle_kernel.h" +#pragma once + #include #include -#include "paddle/phi/backends/all_context.h" + #include "paddle/phi/core/dense_tensor.h" -#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/funcs/math_function.h" namespace phi { template -void PixelUnshuffleKernel(const Context& ctx, +void PixelUnshuffleKernel(const Context& dev_ctx, const DenseTensor& x, int downscale_factor, const std::string& data_format, DenseTensor* out) { auto* in = &x; - ctx.template Alloc(out); + dev_ctx.template Alloc(out); int factor = downscale_factor; bool channel_last = (data_format == "NHWC"); auto in_dims = in->dims(); @@ -50,24 +50,8 @@ void PixelUnshuffleKernel(const Context& ctx, o.Resize({in_dims[0], o_dims[1], o_dims[2], in_dims[3], factor, factor}); } phi::funcs::Transpose trans; - trans(ctx, t, &o, axis); + trans(dev_ctx, t, &o, axis); out->Resize(o_dims); } } // namespace phi - -PD_REGISTER_KERNEL(pixel_unshuffle, - CPU, - ALL_LAYOUT, - phi::PixelUnshuffleKernel, - float, - double) {} - -#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) -PD_REGISTER_KERNEL(pixel_unshuffle, - GPU, - ALL_LAYOUT, - phi::PixelUnshuffleKernel, - float, - double) {} -#endif diff --git a/paddle/phi/kernels/pixel_unshuffle_grad_kernel.h b/paddle/phi/kernels/pixel_unshuffle_grad_kernel.h index f62f1f5b4c7b7..868633e56be50 100644 --- a/paddle/phi/kernels/pixel_unshuffle_grad_kernel.h +++ b/paddle/phi/kernels/pixel_unshuffle_grad_kernel.h @@ -20,7 +20,7 @@ namespace phi { template -void PixelUnshuffleGradKernel(const Context& ctx, +void PixelUnshuffleGradKernel(const Context& dev_ctx, const DenseTensor& out_grad, int downscale_factor, const std::string& data_format, diff --git a/paddle/phi/kernels/pixel_unshuffle_kernel.h b/paddle/phi/kernels/pixel_unshuffle_kernel.h index a631223034e96..179e2b6639f9e 100644 --- a/paddle/phi/kernels/pixel_unshuffle_kernel.h +++ b/paddle/phi/kernels/pixel_unshuffle_kernel.h @@ -20,7 +20,7 @@ namespace phi { template -void PixelUnshuffleKernel(const Context& ctx, +void PixelUnshuffleKernel(const Context& dev_ctx, const DenseTensor& x, int downscale_factor, const std::string& data_format, From 948f32b658dab17ea01f825e1484bf7152774357 Mon Sep 17 00:00:00 2001 From: BrilliantYuKaimin <91609464+BrilliantYuKaimin@users.noreply.github.com> Date: Thu, 21 Apr 2022 12:03:03 +0800 Subject: [PATCH 33/33] Update pixel_unshuffle_op.cc --- paddle/fluid/operators/pixel_unshuffle_op.cc | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/operators/pixel_unshuffle_op.cc b/paddle/fluid/operators/pixel_unshuffle_op.cc index ae5c1db50296e..8d16e02c04c83 100644 --- a/paddle/fluid/operators/pixel_unshuffle_op.cc +++ b/paddle/fluid/operators/pixel_unshuffle_op.cc @@ -16,8 +16,8 @@ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_version_registry.h" #include "paddle/phi/core/infermeta_utils.h" -#include "paddle/phi/infermeta/unary.h" #include "paddle/phi/infermeta/backward.h" +#include "paddle/phi/infermeta/unary.h" namespace paddle { namespace operators { @@ -99,6 +99,5 @@ DECLARE_INFER_SHAPE_FUNCTOR(pixel_unshuffle_grad, PixelUnshuffleGradInferShapeFunctor, PD_INFER_META(phi::PixelUnshuffleGradInferMeta)); -REGISTER_OPERATOR(pixel_unshuffle_grad, - ops::PixelUnshuffleGradOp, +REGISTER_OPERATOR(pixel_unshuffle_grad, ops::PixelUnshuffleGradOp, PixelUnshuffleGradInferShapeFunctor);