Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
【PaddlePaddle Hackathon 2】29、为 Paddle 新增 PixelUnshuffle 组网 API (#40728)
* 增加PixelUnshuffle的形状推断 * 增加PixelUnshuffle的算子注册 * 增加PixelUnshuffle及其梯度的核函数 * 增加PixelUnshuffle算子的描述 * 增加PixelUnshuffle算子的签名 * 在Python层面增加PixelUnshuffle * 增加PixelUnshuffle的单测 * Update test_pixel_unshuffle.py * test=document_fix * Update test_pixel_unshuffle.py 增加对extra_repr的测试 * 修正代码格式 * Update test_pixel_unshuffle.py 修正对extra_repr的测试 * 修改pixel_unshuffle核函数的实现位置 * 修正代码格式 * 完善对输入的检查 * Update test_pixel_unshuffle.py * 完善pixel_unshuffle的输入检查 * Update pixel_unshuffle_op.cc * Update unary.cc * add pixel_unshuffle * Update test_pixel_unshuffle.py * Update vision.py * 调整代码格式 * Update vision.py * Delete extra spaces * Update pixel_unshuffle_sig.cc * Update vision.py * Update vision.py * add PixelUnshuffleGradInferMeta * remove PixelUnshuffleOpArgumentMapping * Update pixel_unshuffle_op.cc * 调整pixel_unshuffle及其梯度的核函数的实现位置 * Update pixel_unshuffle_op.cc
- Loading branch information
1 parent
3cdc7a0
commit 5be9b82
Showing
21 changed files
with
931 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#include "paddle/fluid/framework/infershape_utils.h" | ||
#include "paddle/fluid/framework/op_registry.h" | ||
#include "paddle/fluid/framework/op_version_registry.h" | ||
#include "paddle/phi/core/infermeta_utils.h" | ||
#include "paddle/phi/infermeta/backward.h" | ||
#include "paddle/phi/infermeta/unary.h" | ||
|
||
namespace paddle { | ||
namespace operators { | ||
|
||
class PixelUnshuffleOp : public framework::OperatorWithKernel { | ||
public: | ||
using framework::OperatorWithKernel::OperatorWithKernel; | ||
}; | ||
|
||
class PixelUnshuffleOpMaker : public framework::OpProtoAndCheckerMaker { | ||
public: | ||
void Make() override { | ||
AddInput("X", | ||
"(Tensor, default Tensor<float>), " | ||
"the input feature data of PixelUnshuffleOp, the layout is " | ||
"[N, C, H, W] or [N, H, W, C]."); | ||
AddOutput("Out", | ||
"(Tensor, default Tensor<float>), the output of " | ||
"PixelUnshuffleOp. The layout is [N, C*factor^2, H/factor, " | ||
"W/factor] or [N, H/factor, W/factor, C*factor^2]."); | ||
AddAttr<int>("downscale_factor", | ||
"the factor to decrease spatial resolution by.") | ||
.SetDefault(1); | ||
AddAttr<std::string>( | ||
"data_format", | ||
"An optional string from: \"NHWC\", \"NCHW\". " | ||
"Defaults to \"NHWC\", Specify the data format of the input data.") | ||
.SetDefault("NCHW"); | ||
|
||
AddComment(R"DOC( | ||
Pixel Unshuffle operator | ||
This operator rearranges elements in a tensor of shape :math:`(*, C, H, W)` | ||
to a tensor of shape :math:`(*, C\times r^2, H / r, W / r)`. | ||
This operation is the reversion of PixelShuffle operation. | ||
Please refer to the paper: | ||
`Real-Time Single Image and Video Super-Resolution Using an Efficient | ||
Sub-Pixel Convolutional Neural Network <https://arxiv.org/abs/1609.05158v2>`_ | ||
by Shi et. al (2016) for more details. | ||
)DOC"); | ||
} | ||
}; | ||
|
||
template <typename T> | ||
class PixelUnshuffleGradOpMaker : public framework::SingleGradOpMaker<T> { | ||
public: | ||
using framework::SingleGradOpMaker<T>::SingleGradOpMaker; | ||
|
||
protected: | ||
void Apply(GradOpPtr<T> op) const override { | ||
op->SetType("pixel_unshuffle_grad"); | ||
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); | ||
op->SetAttrMap(this->Attrs()); | ||
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); | ||
} | ||
}; | ||
|
||
class PixelUnshuffleGradOp : public framework::OperatorWithKernel { | ||
public: | ||
using framework::OperatorWithKernel::OperatorWithKernel; | ||
}; | ||
|
||
} // namespace operators | ||
} // namespace paddle | ||
|
||
namespace ops = paddle::operators; | ||
DECLARE_INFER_SHAPE_FUNCTOR(pixel_unshuffle, PixelUnshuffleInferShapeFunctor, | ||
PD_INFER_META(phi::PixelUnshuffleInferMeta)); | ||
|
||
REGISTER_OPERATOR(pixel_unshuffle, ops::PixelUnshuffleOp, | ||
ops::PixelUnshuffleOpMaker, | ||
ops::PixelUnshuffleGradOpMaker<paddle::framework::OpDesc>, | ||
ops::PixelUnshuffleGradOpMaker<paddle::imperative::OpBase>, | ||
PixelUnshuffleInferShapeFunctor); | ||
|
||
DECLARE_INFER_SHAPE_FUNCTOR(pixel_unshuffle_grad, | ||
PixelUnshuffleGradInferShapeFunctor, | ||
PD_INFER_META(phi::PixelUnshuffleGradInferMeta)); | ||
|
||
REGISTER_OPERATOR(pixel_unshuffle_grad, ops::PixelUnshuffleGradOp, | ||
PixelUnshuffleGradInferShapeFunctor); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#include "paddle/phi/kernels/pixel_unshuffle_grad_kernel.h" | ||
#include "paddle/phi/kernels/impl/pixel_unshuffle_grad_kernel_impl.h" | ||
|
||
#include "paddle/phi/backends/cpu/cpu_context.h" | ||
#include "paddle/phi/core/kernel_registry.h" | ||
|
||
PD_REGISTER_KERNEL(pixel_unshuffle_grad, | ||
CPU, | ||
ALL_LAYOUT, | ||
phi::PixelUnshuffleGradKernel, | ||
float, | ||
double) {} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#include "paddle/phi/kernels/pixel_unshuffle_kernel.h" | ||
#include "paddle/phi/kernels/impl/pixel_unshuffle_kernel_impl.h" | ||
|
||
#include "paddle/phi/backends/cpu/cpu_context.h" | ||
#include "paddle/phi/core/kernel_registry.h" | ||
|
||
PD_REGISTER_KERNEL(pixel_unshuffle, | ||
CPU, | ||
ALL_LAYOUT, | ||
phi::PixelUnshuffleKernel, | ||
float, | ||
double) {} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#include "paddle/phi/kernels/impl/pixel_unshuffle_grad_kernel_impl.h" | ||
#include "paddle/phi/kernels/pixel_unshuffle_grad_kernel.h" | ||
|
||
#include "paddle/phi/backends/gpu/gpu_context.h" | ||
#include "paddle/phi/core/kernel_registry.h" | ||
|
||
PD_REGISTER_KERNEL(pixel_unshuffle_grad, | ||
GPU, | ||
ALL_LAYOUT, | ||
phi::PixelUnshuffleGradKernel, | ||
float, | ||
double) {} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#include "paddle/phi/kernels/impl/pixel_unshuffle_kernel_impl.h" | ||
#include "paddle/phi/kernels/pixel_unshuffle_kernel.h" | ||
|
||
#include "paddle/phi/backends/gpu/gpu_context.h" | ||
#include "paddle/phi/core/kernel_registry.h" | ||
|
||
PD_REGISTER_KERNEL(pixel_unshuffle, | ||
GPU, | ||
ALL_LAYOUT, | ||
phi::PixelUnshuffleKernel, | ||
float, | ||
double) {} |
58 changes: 58 additions & 0 deletions
58
paddle/phi/kernels/impl/pixel_unshuffle_grad_kernel_impl.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#pragma once | ||
|
||
#include <string> | ||
#include <vector> | ||
|
||
#include "paddle/phi/core/dense_tensor.h" | ||
#include "paddle/phi/kernels/funcs/math_function.h" | ||
|
||
namespace phi { | ||
|
||
template <typename T, typename Context> | ||
void PixelUnshuffleGradKernel(const Context& dev_ctx, | ||
const DenseTensor& out_grad, | ||
int downscale_factor, | ||
const std::string& data_format, | ||
DenseTensor* x_grad) { | ||
auto* dout = &out_grad; | ||
auto* dx = x_grad; | ||
dev_ctx.template Alloc<T>(dx); | ||
int factor = downscale_factor; | ||
bool channel_last = (data_format == "NHWC"); | ||
auto do_dims = dout->dims(); | ||
auto dx_dims = dx->dims(); | ||
|
||
DenseTensor t(*dout); | ||
if (!channel_last) { | ||
t.Resize({do_dims[0], dx_dims[1], factor, factor, do_dims[2], do_dims[3]}); | ||
} else { | ||
t.Resize({do_dims[0], do_dims[1], do_dims[2], dx_dims[3], factor, factor}); | ||
} | ||
std::vector<int> axis = {0, 1, 4, 2, 5, 3}; | ||
|
||
DenseTensor o(*dx); | ||
if (!channel_last) { | ||
o.Resize({do_dims[0], dx_dims[1], do_dims[2], factor, do_dims[3], factor}); | ||
} else { | ||
o.Resize({do_dims[0], do_dims[1], factor, do_dims[2], factor, dx_dims[3]}); | ||
} | ||
phi::funcs::Transpose<Context, T, 6> trans; | ||
trans(dev_ctx, t, &o, axis); | ||
dx->Resize(dx_dims); | ||
} | ||
|
||
} // namespace phi |
Oops, something went wrong.