diff --git a/paddle/fluid/operators/pixel_unshuffle_op.cc b/paddle/fluid/operators/pixel_unshuffle_op.cc new file mode 100644 index 0000000000000..8d16e02c04c83 --- /dev/null +++ b/paddle/fluid/operators/pixel_unshuffle_op.cc @@ -0,0 +1,103 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/infershape_utils.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/op_version_registry.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/backward.h" +#include "paddle/phi/infermeta/unary.h" + +namespace paddle { +namespace operators { + +class PixelUnshuffleOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; +}; + +class PixelUnshuffleOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", + "(Tensor, default Tensor), " + "the input feature data of PixelUnshuffleOp, the layout is " + "[N, C, H, W] or [N, H, W, C]."); + AddOutput("Out", + "(Tensor, default Tensor), the output of " + "PixelUnshuffleOp. The layout is [N, C*factor^2, H/factor, " + "W/factor] or [N, H/factor, W/factor, C*factor^2]."); + AddAttr("downscale_factor", + "the factor to decrease spatial resolution by.") + .SetDefault(1); + AddAttr( + "data_format", + "An optional string from: \"NHWC\", \"NCHW\". " + "Defaults to \"NHWC\", Specify the data format of the input data.") + .SetDefault("NCHW"); + + AddComment(R"DOC( + Pixel Unshuffle operator + This operator rearranges elements in a tensor of shape :math:`(*, C, H, W)` + to a tensor of shape :math:`(*, C\times r^2, H / r, W / r)`. + + This operation is the reversion of PixelShuffle operation. + + Please refer to the paper: + `Real-Time Single Image and Video Super-Resolution Using an Efficient + Sub-Pixel Convolutional Neural Network `_ + by Shi et. al (2016) for more details. + + )DOC"); + } +}; + +template +class PixelUnshuffleGradOpMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + protected: + void Apply(GradOpPtr op) const override { + op->SetType("pixel_unshuffle_grad"); + op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); + op->SetAttrMap(this->Attrs()); + op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); + } +}; + +class PixelUnshuffleGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +DECLARE_INFER_SHAPE_FUNCTOR(pixel_unshuffle, PixelUnshuffleInferShapeFunctor, + PD_INFER_META(phi::PixelUnshuffleInferMeta)); + +REGISTER_OPERATOR(pixel_unshuffle, ops::PixelUnshuffleOp, + ops::PixelUnshuffleOpMaker, + ops::PixelUnshuffleGradOpMaker, + ops::PixelUnshuffleGradOpMaker, + PixelUnshuffleInferShapeFunctor); + +DECLARE_INFER_SHAPE_FUNCTOR(pixel_unshuffle_grad, + PixelUnshuffleGradInferShapeFunctor, + PD_INFER_META(phi::PixelUnshuffleGradInferMeta)); + +REGISTER_OPERATOR(pixel_unshuffle_grad, ops::PixelUnshuffleGradOp, + PixelUnshuffleGradInferShapeFunctor); diff --git a/paddle/phi/infermeta/backward.cc b/paddle/phi/infermeta/backward.cc index 4a4585e00eed6..602942abf4d34 100644 --- a/paddle/phi/infermeta/backward.cc +++ b/paddle/phi/infermeta/backward.cc @@ -443,6 +443,36 @@ void NllLossGradInferMeta(const MetaTensor& x, } } +void PixelUnshuffleGradInferMeta(const MetaTensor& out_grad, + int downscale_factor, + const std::string& data_format, + MetaTensor* x_grad) { + auto do_dims = out_grad.dims(); + PADDLE_ENFORCE_EQ(do_dims.size(), + 4, + phi::errors::InvalidArgument( + "Input should be a 4-D tensor of format [N, C, H, W] " + "or [N, H, W, C], but got %u.", + do_dims.size())); + + const bool channel_last = (data_format == "NHWC"); + + auto dx_dims = do_dims; + dx_dims[0] = do_dims[0]; + + if (!channel_last) { + dx_dims[1] = do_dims[1] / (downscale_factor * downscale_factor); + dx_dims[2] = do_dims[2] * downscale_factor; + dx_dims[3] = do_dims[3] * downscale_factor; + } else { + dx_dims[1] = do_dims[1] * downscale_factor; + dx_dims[2] = do_dims[2] * downscale_factor; + dx_dims[3] = do_dims[3] / (downscale_factor * downscale_factor); + } + x_grad->set_dims(dx_dims); + x_grad->set_dtype(out_grad.dtype()); +} + void PoolGradInferMeta(const MetaTensor& x, const MetaTensor& out, const MetaTensor& dout, diff --git a/paddle/phi/infermeta/backward.h b/paddle/phi/infermeta/backward.h index 9db958778d597..c35b58d0f56e4 100644 --- a/paddle/phi/infermeta/backward.h +++ b/paddle/phi/infermeta/backward.h @@ -178,6 +178,11 @@ void NllLossGradInferMeta(const MetaTensor& input, MetaTensor* intput_grad, MetaConfig config = MetaConfig()); +void PixelUnshuffleGradInferMeta(const MetaTensor& out_grad, + int downscale_factor, + const std::string& data_format, + MetaTensor* x_grad); + void PsroiPoolGradInferMeta(const MetaTensor& x, const MetaTensor& rois, paddle::optional rois_num, diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index 400c56db3efc2..cff14308c7fe9 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -1416,6 +1416,66 @@ void PixelShuffleGradInferMeta(const MetaTensor& out_grad, x_grad->set_dtype(out_grad.dtype()); } +void PixelUnshuffleInferMeta(const MetaTensor& x, + int downscale_factor, + const std::string& data_format, + MetaTensor* out) { + auto input_dims = x.dims(); + PADDLE_ENFORCE_EQ(input_dims.size(), + 4, + phi::errors::InvalidArgument( + "Input should be a 4-D tensor of format [N, C, H, W] " + "or [N, H, W, C], but got %u.", + input_dims.size())); + PADDLE_ENFORCE_GE(downscale_factor, + 1, + phi::errors::InvalidArgument( + "downscale_factor should be larger than 0.")); + PADDLE_ENFORCE_EQ(data_format == "NCHW" || data_format == "NHWC", + true, + phi::errors::InvalidArgument( + "data_format must be one of " + "NCHW and NHWC. But recevied data_format: %s", + data_format)); + + const bool channel_last = (data_format == "NHWC"); + + if (!channel_last) { + PADDLE_ENFORCE_EQ( + (input_dims[2] % downscale_factor) == 0 && + (input_dims[3] % downscale_factor) == 0, + true, + phi::errors::InvalidArgument("Downscale factor[%u] should divide both " + "height[%u] and width[%u]", + downscale_factor, + input_dims[2], + input_dims[3])); + } else { + PADDLE_ENFORCE_EQ( + (input_dims[1] % downscale_factor) == 0 && + (input_dims[2] % downscale_factor) == 0, + true, + phi::errors::InvalidArgument("Downscale factor[%u] should divide both " + "height[%u] and width[%u]", + downscale_factor, + input_dims[1], + input_dims[2])); + } + auto output_dims = input_dims; + output_dims[0] = input_dims[0]; + if (!channel_last) { + output_dims[1] = input_dims[1] * (downscale_factor * downscale_factor); + output_dims[2] = input_dims[2] / downscale_factor; + output_dims[3] = input_dims[3] / downscale_factor; + } else { + output_dims[1] = input_dims[1] / downscale_factor; + output_dims[2] = input_dims[2] / downscale_factor; + output_dims[3] = input_dims[3] * (downscale_factor * downscale_factor); + } + out->set_dtype(x.dtype()); + out->set_dims(output_dims); +} + void PNormInferMeta(const MetaTensor& x, float porder, int axis, diff --git a/paddle/phi/infermeta/unary.h b/paddle/phi/infermeta/unary.h index c67eb2068d8bf..eef750b852f06 100644 --- a/paddle/phi/infermeta/unary.h +++ b/paddle/phi/infermeta/unary.h @@ -204,6 +204,11 @@ void PixelShuffleGradInferMeta(const MetaTensor& out_grad, const std::string& data_format, MetaTensor* x_grad); +void PixelUnshuffleInferMeta(const MetaTensor& x, + int downscale_factor, + const std::string& data_format, + MetaTensor* out); + void PNormInferMeta(const MetaTensor& x, float porder, int axis, diff --git a/paddle/phi/kernels/cpu/pixel_unshuffle_grad_kernel.cc b/paddle/phi/kernels/cpu/pixel_unshuffle_grad_kernel.cc new file mode 100644 index 0000000000000..ef61fca35957e --- /dev/null +++ b/paddle/phi/kernels/cpu/pixel_unshuffle_grad_kernel.cc @@ -0,0 +1,26 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/pixel_unshuffle_grad_kernel.h" +#include "paddle/phi/kernels/impl/pixel_unshuffle_grad_kernel_impl.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" + +PD_REGISTER_KERNEL(pixel_unshuffle_grad, + CPU, + ALL_LAYOUT, + phi::PixelUnshuffleGradKernel, + float, + double) {} diff --git a/paddle/phi/kernels/cpu/pixel_unshuffle_kernel.cc b/paddle/phi/kernels/cpu/pixel_unshuffle_kernel.cc new file mode 100644 index 0000000000000..9f4bc747f3209 --- /dev/null +++ b/paddle/phi/kernels/cpu/pixel_unshuffle_kernel.cc @@ -0,0 +1,26 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/pixel_unshuffle_kernel.h" +#include "paddle/phi/kernels/impl/pixel_unshuffle_kernel_impl.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" + +PD_REGISTER_KERNEL(pixel_unshuffle, + CPU, + ALL_LAYOUT, + phi::PixelUnshuffleKernel, + float, + double) {} diff --git a/paddle/phi/kernels/gpu/pixel_unshuffle_grad_kernel.cu b/paddle/phi/kernels/gpu/pixel_unshuffle_grad_kernel.cu new file mode 100644 index 0000000000000..9cbbc5072aa25 --- /dev/null +++ b/paddle/phi/kernels/gpu/pixel_unshuffle_grad_kernel.cu @@ -0,0 +1,26 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/impl/pixel_unshuffle_grad_kernel_impl.h" +#include "paddle/phi/kernels/pixel_unshuffle_grad_kernel.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" + +PD_REGISTER_KERNEL(pixel_unshuffle_grad, + GPU, + ALL_LAYOUT, + phi::PixelUnshuffleGradKernel, + float, + double) {} diff --git a/paddle/phi/kernels/gpu/pixel_unshuffle_kernel.cu b/paddle/phi/kernels/gpu/pixel_unshuffle_kernel.cu new file mode 100644 index 0000000000000..ca2e520ffde10 --- /dev/null +++ b/paddle/phi/kernels/gpu/pixel_unshuffle_kernel.cu @@ -0,0 +1,26 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/impl/pixel_unshuffle_kernel_impl.h" +#include "paddle/phi/kernels/pixel_unshuffle_kernel.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" + +PD_REGISTER_KERNEL(pixel_unshuffle, + GPU, + ALL_LAYOUT, + phi::PixelUnshuffleKernel, + float, + double) {} diff --git a/paddle/phi/kernels/impl/pixel_unshuffle_grad_kernel_impl.h b/paddle/phi/kernels/impl/pixel_unshuffle_grad_kernel_impl.h new file mode 100644 index 0000000000000..cb02539f2e890 --- /dev/null +++ b/paddle/phi/kernels/impl/pixel_unshuffle_grad_kernel_impl.h @@ -0,0 +1,58 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/kernels/funcs/math_function.h" + +namespace phi { + +template +void PixelUnshuffleGradKernel(const Context& dev_ctx, + const DenseTensor& out_grad, + int downscale_factor, + const std::string& data_format, + DenseTensor* x_grad) { + auto* dout = &out_grad; + auto* dx = x_grad; + dev_ctx.template Alloc(dx); + int factor = downscale_factor; + bool channel_last = (data_format == "NHWC"); + auto do_dims = dout->dims(); + auto dx_dims = dx->dims(); + + DenseTensor t(*dout); + if (!channel_last) { + t.Resize({do_dims[0], dx_dims[1], factor, factor, do_dims[2], do_dims[3]}); + } else { + t.Resize({do_dims[0], do_dims[1], do_dims[2], dx_dims[3], factor, factor}); + } + std::vector axis = {0, 1, 4, 2, 5, 3}; + + DenseTensor o(*dx); + if (!channel_last) { + o.Resize({do_dims[0], dx_dims[1], do_dims[2], factor, do_dims[3], factor}); + } else { + o.Resize({do_dims[0], do_dims[1], factor, do_dims[2], factor, dx_dims[3]}); + } + phi::funcs::Transpose trans; + trans(dev_ctx, t, &o, axis); + dx->Resize(dx_dims); +} + +} // namespace phi diff --git a/paddle/phi/kernels/impl/pixel_unshuffle_kernel_impl.h b/paddle/phi/kernels/impl/pixel_unshuffle_kernel_impl.h new file mode 100644 index 0000000000000..0a140b270ba1b --- /dev/null +++ b/paddle/phi/kernels/impl/pixel_unshuffle_kernel_impl.h @@ -0,0 +1,57 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/kernels/funcs/math_function.h" + +namespace phi { + +template +void PixelUnshuffleKernel(const Context& dev_ctx, + const DenseTensor& x, + int downscale_factor, + const std::string& data_format, + DenseTensor* out) { + auto* in = &x; + dev_ctx.template Alloc(out); + int factor = downscale_factor; + bool channel_last = (data_format == "NHWC"); + auto in_dims = in->dims(); + auto o_dims = out->dims(); + + DenseTensor t(*in); + if (!channel_last) { + t.Resize({in_dims[0], in_dims[1], o_dims[2], factor, o_dims[3], factor}); + } else { + t.Resize({in_dims[0], o_dims[1], factor, o_dims[2], factor, in_dims[3]}); + } + std::vector axis = {0, 1, 3, 5, 2, 4}; + + DenseTensor o(*out); + if (!channel_last) { + o.Resize({in_dims[0], in_dims[1], factor, factor, o_dims[2], o_dims[3]}); + } else { + o.Resize({in_dims[0], o_dims[1], o_dims[2], in_dims[3], factor, factor}); + } + phi::funcs::Transpose trans; + trans(dev_ctx, t, &o, axis); + out->Resize(o_dims); +} + +} // namespace phi diff --git a/paddle/phi/kernels/pixel_unshuffle_grad_kernel.h b/paddle/phi/kernels/pixel_unshuffle_grad_kernel.h new file mode 100644 index 0000000000000..868633e56be50 --- /dev/null +++ b/paddle/phi/kernels/pixel_unshuffle_grad_kernel.h @@ -0,0 +1,29 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void PixelUnshuffleGradKernel(const Context& dev_ctx, + const DenseTensor& out_grad, + int downscale_factor, + const std::string& data_format, + DenseTensor* x_grad); + +} // namespace phi diff --git a/paddle/phi/kernels/pixel_unshuffle_kernel.h b/paddle/phi/kernels/pixel_unshuffle_kernel.h new file mode 100644 index 0000000000000..179e2b6639f9e --- /dev/null +++ b/paddle/phi/kernels/pixel_unshuffle_kernel.h @@ -0,0 +1,29 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void PixelUnshuffleKernel(const Context& dev_ctx, + const DenseTensor& x, + int downscale_factor, + const std::string& data_format, + DenseTensor* out); + +} // namespace phi diff --git a/paddle/phi/ops/compat/pixel_unshuffle_sig.cc b/paddle/phi/ops/compat/pixel_unshuffle_sig.cc new file mode 100644 index 0000000000000..817dc1a228877 --- /dev/null +++ b/paddle/phi/ops/compat/pixel_unshuffle_sig.cc @@ -0,0 +1,30 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +KernelSignature PixelUnshuffleGradOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature("pixel_unshuffle_grad", + {"Out@GRAD"}, + {"downscale_factor", "data_format"}, + {"X@GRAD"}); +} + +} // namespace phi + +PD_REGISTER_ARG_MAPPING_FN(pixel_unshuffle_grad, + phi::PixelUnshuffleGradOpArgumentMapping); diff --git a/python/paddle/fluid/tests/unittests/test_pixel_unshuffle.py b/python/paddle/fluid/tests/unittests/test_pixel_unshuffle.py new file mode 100644 index 0000000000000..768a9e307c91e --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_pixel_unshuffle.py @@ -0,0 +1,294 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np + +from op_test import OpTest +import paddle +import paddle.nn.functional as F +import paddle.fluid.core as core +import paddle.fluid as fluid + + +def pixel_unshuffle_np(x, down_factor, data_format="NCHW"): + '''Numpy implementation of pixel unshuffle''' + + if data_format == "NCHW": + n, c, h, w = x.shape + new_shape = (n, c, h // down_factor, down_factor, w // down_factor, + down_factor) + npresult = np.reshape(x, new_shape) + npresult = npresult.transpose(0, 1, 3, 5, 2, 4) + oshape = [ + n, c * down_factor * down_factor, h // down_factor, w // down_factor + ] + npresult = np.reshape(npresult, oshape) + return npresult + else: + n, h, w, c = x.shape + new_shape = (n, h // down_factor, down_factor, w // down_factor, + down_factor, c) + npresult = np.reshape(x, new_shape) + npresult = npresult.transpose(0, 1, 3, 5, 2, 4) + oshape = [ + n, h // down_factor, w // down_factor, c * down_factor * down_factor + ] + npresult = np.reshape(npresult, oshape) + return npresult + + +class TestPixelUnshuffleOp(OpTest): + '''TestPixelUnshuffleOp''' + + def setUp(self): + '''setUp''' + + self.op_type = "pixel_unshuffle" + self.init_data_format() + n, c, h, w = 2, 1, 12, 12 + + if self.format == "NCHW": + shape = [n, c, h, w] + if self.format == "NHWC": + shape = [n, h, w, c] + + down_factor = 3 + + x = np.random.random(shape).astype("float64") + npresult = pixel_unshuffle_np(x, down_factor, self.format) + + self.inputs = {"X": x} + self.outputs = {"Out": npresult} + self.attrs = { + "downscale_factor": down_factor, + "data_format": self.format + } + + def init_data_format(self): + '''init_data_format''' + + self.format = "NCHW" + + def test_check_output(self): + '''test_check_output''' + + self.check_output() + + def test_check_grad(self): + '''test_check_grad''' + + self.check_grad(["X"], "Out") + + +class TestChannelLast(TestPixelUnshuffleOp): + '''TestChannelLast''' + + def init_data_format(self): + '''init_data_format''' + + self.format = "NHWC" + + +class TestPixelUnshuffleAPI(unittest.TestCase): + '''TestPixelUnshuffleAPI''' + + def setUp(self): + '''setUp''' + + self.x_1_np = np.random.random([2, 1, 12, 12]).astype("float64") + self.x_2_np = np.random.random([2, 12, 12, 1]).astype("float64") + self.out_1_np = pixel_unshuffle_np(self.x_1_np, 3) + self.out_2_np = pixel_unshuffle_np(self.x_2_np, 3, "NHWC") + + def test_static_graph_functional(self): + '''test_static_graph_functional''' + + for use_cuda in ([False, True] + if core.is_compiled_with_cuda() else [False]): + place = paddle.CUDAPlace(0) if use_cuda else paddle.CPUPlace() + + paddle.enable_static() + x_1 = paddle.fluid.data( + name="x", shape=[2, 1, 12, 12], dtype="float64") + x_2 = paddle.fluid.data( + name="x2", shape=[2, 12, 12, 1], dtype="float64") + out_1 = F.pixel_unshuffle(x_1, 3) + out_2 = F.pixel_unshuffle(x_2, 3, "NHWC") + + exe = paddle.static.Executor(place=place) + res_1 = exe.run(fluid.default_main_program(), + feed={"x": self.x_1_np}, + fetch_list=out_1, + use_prune=True) + + res_2 = exe.run(fluid.default_main_program(), + feed={"x2": self.x_2_np}, + fetch_list=out_2, + use_prune=True) + + assert np.allclose(res_1, self.out_1_np) + assert np.allclose(res_2, self.out_2_np) + + # same test between layer and functional in this op. + def test_static_graph_layer(self): + '''test_static_graph_layer''' + + for use_cuda in ([False, True] + if core.is_compiled_with_cuda() else [False]): + place = paddle.CUDAPlace(0) if use_cuda else paddle.CPUPlace() + + paddle.enable_static() + x_1 = paddle.fluid.data( + name="x", shape=[2, 1, 12, 12], dtype="float64") + x_2 = paddle.fluid.data( + name="x2", shape=[2, 12, 12, 1], dtype="float64") + # init instance + ps_1 = paddle.nn.PixelUnshuffle(3) + ps_2 = paddle.nn.PixelUnshuffle(3, "NHWC") + out_1 = ps_1(x_1) + out_2 = ps_2(x_2) + out_1_np = pixel_unshuffle_np(self.x_1_np, 3) + out_2_np = pixel_unshuffle_np(self.x_2_np, 3, "NHWC") + + exe = paddle.static.Executor(place=place) + res_1 = exe.run(fluid.default_main_program(), + feed={"x": self.x_1_np}, + fetch_list=out_1, + use_prune=True) + + res_2 = exe.run(fluid.default_main_program(), + feed={"x2": self.x_2_np}, + fetch_list=out_2, + use_prune=True) + + assert np.allclose(res_1, out_1_np) + assert np.allclose(res_2, out_2_np) + + def run_dygraph(self, down_factor, data_format): + '''run_dygraph''' + + n, c, h, w = 2, 1, 12, 12 + + if data_format == "NCHW": + shape = [n, c, h, w] + if data_format == "NHWC": + shape = [n, h, w, c] + + x = np.random.random(shape).astype("float64") + + npresult = pixel_unshuffle_np(x, down_factor, data_format) + + for use_cuda in ([False, True] + if core.is_compiled_with_cuda() else [False]): + place = paddle.CUDAPlace(0) if use_cuda else paddle.CPUPlace() + + paddle.disable_static(place=place) + + pixel_unshuffle = paddle.nn.PixelUnshuffle( + down_factor, data_format=data_format) + result = pixel_unshuffle(paddle.to_tensor(x)) + + self.assertTrue(np.allclose(result.numpy(), npresult)) + + result_functional = F.pixel_unshuffle( + paddle.to_tensor(x), 3, data_format) + self.assertTrue(np.allclose(result_functional.numpy(), npresult)) + + pixel_unshuffle_str = 'downscale_factor={}'.format(down_factor) + if data_format != 'NCHW': + pixel_unshuffle_str += ', data_format={}'.format(data_format) + self.assertEqual(pixel_unshuffle.extra_repr(), pixel_unshuffle_str) + + def test_dygraph1(self): + '''test_dygraph1''' + + self.run_dygraph(3, "NCHW") + + def test_dygraph2(self): + '''test_dygraph2''' + + self.run_dygraph(3, "NHWC") + + +class TestPixelUnshuffleError(unittest.TestCase): + '''TestPixelUnshuffleError''' + + def test_error_functional(self): + '''test_error_functional''' + + def error_input(): + with paddle.fluid.dygraph.guard(): + x = np.random.random([4, 12, 12]).astype("float64") + pixel_unshuffle = F.pixel_unshuffle(paddle.to_tensor(x), 2) + + self.assertRaises(ValueError, error_input) + + def error_downscale_factor_1(): + with paddle.fluid.dygraph.guard(): + x = np.random.random([2, 1, 12, 12]).astype("float64") + pixel_unshuffle = F.pixel_unshuffle(paddle.to_tensor(x), 3.33) + + self.assertRaises(TypeError, error_downscale_factor_1) + + def error_downscale_factor_2(): + with paddle.fluid.dygraph.guard(): + x = np.random.random([2, 1, 12, 12]).astype("float64") + pixel_unshuffle = F.pixel_unshuffle(paddle.to_tensor(x), -1) + + self.assertRaises(ValueError, error_downscale_factor_2) + + def error_data_format(): + with paddle.fluid.dygraph.guard(): + x = np.random.random([2, 1, 12, 12]).astype("float64") + pixel_unshuffle = F.pixel_unshuffle( + paddle.to_tensor(x), 3, "WOW") + + self.assertRaises(ValueError, error_data_format) + + def test_error_layer(self): + '''test_error_layer''' + + def error_input_layer(): + with paddle.fluid.dygraph.guard(): + x = np.random.random([4, 12, 12]).astype("float64") + ps = paddle.nn.PixelUnshuffle(2) + ps(paddle.to_tensor(x)) + + self.assertRaises(ValueError, error_input_layer) + + def error_downscale_factor_layer_1(): + with paddle.fluid.dygraph.guard(): + x = np.random.random([2, 1, 12, 12]).astype("float64") + ps = paddle.nn.PixelUnshuffle(3.33) + + self.assertRaises(TypeError, error_downscale_factor_layer_1) + + def error_downscale_factor_layer_2(): + with paddle.fluid.dygraph.guard(): + x = np.random.random([2, 1, 12, 12]).astype("float64") + ps = paddle.nn.PixelUnshuffle(-1) + + self.assertRaises(ValueError, error_downscale_factor_layer_2) + + def error_data_format_layer(): + with paddle.fluid.dygraph.guard(): + x = np.random.random([2, 1, 12, 12]).astype("float64") + ps = paddle.nn.PixelUnshuffle(3, "MEOW") + + self.assertRaises(ValueError, error_data_format_layer) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/nn/__init__.py b/python/paddle/nn/__init__.py index 70e3518a1af46..bceee4b964a33 100644 --- a/python/paddle/nn/__init__.py +++ b/python/paddle/nn/__init__.py @@ -138,6 +138,7 @@ from .layer.distance import PairwiseDistance # noqa: F401 from .layer.vision import PixelShuffle # noqa: F401 +from .layer.vision import PixelUnshuffle # noqa: F401 from .layer.vision import ChannelShuffle # noqa: F401 from .layer.container import LayerDict # noqa: F401 @@ -301,6 +302,7 @@ def weight_norm(*args): 'Swish', 'Mish', 'PixelShuffle', + 'PixelUnshuffle', 'ChannelShuffle', 'ELU', 'ReLU6', diff --git a/python/paddle/nn/functional/__init__.py b/python/paddle/nn/functional/__init__.py index 58251c2890430..68213d831c550 100644 --- a/python/paddle/nn/functional/__init__.py +++ b/python/paddle/nn/functional/__init__.py @@ -114,6 +114,7 @@ from .vision import affine_grid # noqa: F401 from .vision import grid_sample # noqa: F401 from .vision import pixel_shuffle # noqa: F401 +from .vision import pixel_unshuffle # noqa: F401 from .vision import channel_shuffle # noqa: F401 from .input import one_hot # noqa: F401 from .input import embedding # noqa: F401 @@ -214,6 +215,7 @@ 'grid_sample', 'local_response_norm', 'pixel_shuffle', + 'pixel_unshuffle', 'channel_shuffle', 'embedding', 'gather_tree', diff --git a/python/paddle/nn/functional/vision.py b/python/paddle/nn/functional/vision.py index 07e68d71dc1f1..9a9c2ee4cf7d1 100644 --- a/python/paddle/nn/functional/vision.py +++ b/python/paddle/nn/functional/vision.py @@ -347,6 +347,64 @@ def pixel_shuffle(x, upscale_factor, data_format="NCHW", name=None): return out +def pixel_unshuffle(x, downscale_factor, data_format="NCHW", name=None): + """ + This API implements pixel unshuffle operation. + See more details in :ref:`api_nn_vision_PixelUnshuffle` . + + Parameters: + x (Tensor): 4-D tensor, the data type should be float32 or float64. + downscale_factor (int): Factor to decrease spatial resolution. + data_format (str): The data format of the input and output data. An optional string of NCHW or NHWC. The default is NCHW. When it is NCHW, the data is stored in the order of [batch_size, input_channels, input_height, input_width]. + name (str, optional): Name for the operation (optional, default is None). Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`. + + Returns: + Out (Tensor): Reshaped tensor according to the new dimension. + + Examples: + .. code-block:: python + :name: pixel_unshuffle-example + + import paddle + import paddle.nn.functional as F + x = paddle.randn([2, 1, 12, 12]) + out = F.pixel_unshuffle(x, 3) + # out.shape = [2, 9, 4, 4] + """ + if len(x.shape) != 4: + raise ValueError( + "Input x should be 4D tensor, but received x with the shape of {}". + format(x.shape)) + + if not isinstance(downscale_factor, int): + raise TypeError("Downscale factor must be int type") + + if downscale_factor <= 0: + raise ValueError("Downscale factor must be positive") + + if data_format not in ["NCHW", "NHWC"]: + raise ValueError("Attr(data_format) should be 'NCHW' or 'NHWC'." + "But recevie Attr(data_format): {} ".format( + data_format)) + + if _non_static_mode(): + return _C_ops.pixel_unshuffle(x, "downscale_factor", downscale_factor, + "data_format", data_format) + + helper = LayerHelper("pixel_unshuffle", **locals()) + check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'pixel_unshuffle') + out = helper.create_variable_for_type_inference(dtype=x.dtype) + helper.append_op( + type="pixel_unshuffle", + inputs={"X": x}, + outputs={"Out": out}, + attrs={ + "downscale_factor": downscale_factor, + "data_format": data_format + }) + return out + + def channel_shuffle(x, groups, data_format="NCHW", name=None): """ This API implements channel shuffle operation. diff --git a/python/paddle/nn/layer/__init__.py b/python/paddle/nn/layer/__init__.py index 339feef8f32e6..31364f0281c8a 100644 --- a/python/paddle/nn/layer/__init__.py +++ b/python/paddle/nn/layer/__init__.py @@ -88,6 +88,7 @@ from .norm import LocalResponseNorm # noqa: F401 from .vision import PixelShuffle # noqa: F401 +from .vision import PixelUnshuffle # noqa: F401 from .vision import ChannelShuffle # noqa: F401 from .distance import PairwiseDistance # noqa: F401 from .container import LayerDict # noqa: F401 diff --git a/python/paddle/nn/layer/vision.py b/python/paddle/nn/layer/vision.py index e775d4fcf6dfb..6d5c112d75703 100644 --- a/python/paddle/nn/layer/vision.py +++ b/python/paddle/nn/layer/vision.py @@ -89,6 +89,69 @@ def extra_repr(self): return main_str +class PixelUnshuffle(Layer): + """ + This operator rearranges elements in a tensor of shape :math:`[N, C, H, W]` + to a tensor of shape :math:`[N, r^2C, H/r, W/r]`, or from shape + :math:`[N, H, W, C]` to :math:`[N, H/r, W/r, r^2C]`, where :math:`r` is the + downscale factor. This operation is the reversion of PixelShuffle operation. + Please refer to the paper: `Real-Time Single Image and Video Super-Resolution + Using an Efficient Sub-Pixel Convolutional Neural Network `_ . + by Shi et. al (2016) for more details. + + Parameters: + downscale_factor (int): Factor to decrease spatial resolution. + data_format (str): The data format of the input and output data. An optional string of NCHW or NHWC. The default is NCHW. When it is NCHW, the data is stored in the order of [batch_size, input_channels, input_height, input_width]. + name (str, optional): Name for the operation (optional, default is None). Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`. + + Shape: + - **x**: 4-D tensor with shape of :math:`[N, C, H, W]` or :math:`[N, C, H, W]`. + - **out**: 4-D tensor with shape of :math:`[N, r^2C, H/r, W/r]` or :math:`[N, H/r, W/r, r^2C]`, where :math:`r` is :attr:`downscale_factor`. + + Examples: + .. code-block:: python + :name: PixelUnshuffle-example + + import paddle + import paddle.nn as nn + + x = paddle.randn([2, 1, 12, 12]) + pixel_unshuffle = nn.PixelUnshuffle(3) + out = pixel_unshuffle(x) + # out.shape = [2, 9, 4, 4] + + """ + + def __init__(self, downscale_factor, data_format="NCHW", name=None): + super(PixelUnshuffle, self).__init__() + + if not isinstance(downscale_factor, int): + raise TypeError("Downscale factor must be int type") + + if downscale_factor <= 0: + raise ValueError("Downscale factor must be positive") + + if data_format not in ["NCHW", "NHWC"]: + raise ValueError("Data format should be 'NCHW' or 'NHWC'." + "But recevie data format: {}".format(data_format)) + + self._downscale_factor = downscale_factor + self._data_format = data_format + self._name = name + + def forward(self, x): + return functional.pixel_unshuffle(x, self._downscale_factor, + self._data_format, self._name) + + def extra_repr(self): + main_str = 'downscale_factor={}'.format(self._downscale_factor) + if self._data_format != 'NCHW': + main_str += ', data_format={}'.format(self._data_format) + if self._name is not None: + main_str += ', name={}'.format(self._name) + return main_str + + class ChannelShuffle(Layer): """ This operator divides channels in a tensor of shape [N, C, H, W] or [N, H, W, C] into g groups, diff --git a/tools/static_mode_white_list.py b/tools/static_mode_white_list.py index 5dcff12c2c87e..aaa667595f94c 100755 --- a/tools/static_mode_white_list.py +++ b/tools/static_mode_white_list.py @@ -387,6 +387,7 @@ 'test_partial_sum_op', 'test_pass_builder', 'test_pixel_shuffle', + 'test_pixel_unshuffle', 'test_polygon_box_transform', 'test_pool1d_api', 'test_pool2d_api',