Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【PaddlePaddle Hackathon 2】24、为 Paddle 新增 nn.ChannelShuffle 组网 API #40743

Merged
merged 34 commits into from Apr 25, 2022

Conversation

BrilliantYuKaimin
Copy link
Contributor

@BrilliantYuKaimin BrilliantYuKaimin commented Mar 20, 2022

PR types

New features

PR changes

APIs

Describe

解决了Issue: #40309
增加了nn.ChannelShuffle和nn.functional.channel_shuffle,其行为是把其行为是将形为[N,C,H,W]重塑成[N,g,C/g,H,W]的形状,再转置成[N,C/g,g,H,W]的形状,最后重塑回[N,C,H,W]的形状。
设计文档: PaddlePaddle/community#40
中文文档: PaddlePaddle/docs#4501

@paddle-bot-old paddle-bot-old bot added contributor External developers status: proposed labels Mar 20, 2022
@paddle-bot-old
Copy link

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@paddle-bot-old
Copy link

PR格式检查通过,你的PR将接受Paddle专家以及开源社区的review,请及时关注PR动态。
The format inspection passed. Your PR will be reviewed by experts of Paddle and developers from the open-source community. Stay tuned.

@paddle-bot-old
Copy link

PR格式检查通过,你的PR将接受Paddle专家以及开源社区的review,请及时关注PR动态。
The format inspection passed. Your PR will be reviewed by experts of Paddle and developers from the open-source community. Stay tuned.

Copy link
Contributor

@shiyutang shiyutang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

API的撰写已经很完善,部分细节需要修改

Comment on lines 43 to 44
PADDLE_ENFORCE_GE(groups, 1, platform::errors::InvalidArgument(
"groups should be larger than 0."));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

在infermeta中判断

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

完成

Comment on lines 71 to 83
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(
ctx->HasInput(framework::GradVarName("Out")), true,
platform::errors::NotFound("Input(Out@Grad) should not be null"));
PADDLE_ENFORCE_EQ(
ctx->HasOutput(framework::GradVarName("X")), true,
platform::errors::NotFound("Output(X@Grad) should not be null"));

auto do_dims = ctx->GetInputDim(framework::GradVarName("Out"));
PADDLE_ENFORCE_EQ(do_dims.size(), 4,
platform::errors::InvalidArgument(
"Input should be a 4-D tensor of format [N, C, "
"H, W] or [N, H, W, C], but got %u.",
do_dims.size()));

auto dx_dims = do_dims;
ctx->SetOutputDim(framework::GradVarName("X"), dx_dims);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

放入infermeta

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我看其他算子也都是只把前向的形状推断放在infermeta中,而反向的形状推断放在xxx_op.cc中。

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可以放到infermeta里,参考下这个PR的review
#40728

See the License for the specific language governing permissions and
limitations under the License. */

#include <memory>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

没有看到对这个库函数的使用

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已删去

"The number of groups to divide channels in[%u] "
"should divide the number of channel[%u]",
groups,
input_dims[3]));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

报错信息可以增加空格方便报错阅读

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是指把in[%u]channel[%u]改成in [%u]channel [%u]吗?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

嗯嗯

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

完成。

int groups,
const std::string& data_format,
MetaTensor* out) {
auto input_dims = x.dims();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

增加对groups,dataformat的检查

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

完成

# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可删除

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

完成

import paddle.fluid.core as core
import paddle.fluid as fluid

paddle.enable_static()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不需要一开始指定

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

完成


self.assertRaises(ValueError, error_data_format)

def test_error_layer(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

增加group 不是正数、tensor不是4d的报错信息

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

完成

@dingjiaweiww
Copy link
Contributor

请先通过CI噢~

@@ -69,10 +65,14 @@ def init_data_format(self):
self.format = "NCHW"

def test_check_output(self):
paddle.enable_static()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不需要单独在这里增加静态图的约束

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已删去,原先这么写是因为不设置为静态图就会报错。后来发现tools/static_mode_white_list.py中也有test_pixel_shuffle。经验证,在tools/static_mode_white_list.py删去test_pixel_shufflepixel_shuffle的单测也会报错。于是我在tools/static_mode_white_list.py中也增加了test_channel_shuffle


def test_check_grad(self):
paddle.enable_static()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不需要单独在这里增加静态图的约束

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

回复同上。

@shiyutang
Copy link
Contributor

  1. 需要增加设计文档链接,如图所示:
    06b8a2904adf5ec3c229031619c25f57

  2. 需要替换 in_dygraph_mode 如下图所示:
    4ac43d3d67cdea50a08659a006f90a19

@BrilliantYuKaimin
Copy link
Contributor Author

  • 需要增加设计文档链接,如图所示:
    06b8a2904adf5ec3c229031619c25f57
  • 需要替换 in_dygraph_mode 如下图所示:
    4ac43d3d67cdea50a08659a006f90a19

完成

@shiyutang
Copy link
Contributor

根据最新说明,_non_static_mode的调用路径需要更改为paddle.framework._non_static_mode,不再和fluid耦合。

@BrilliantYuKaimin
Copy link
Contributor Author

CI显示报错:
image

已经解决

XiaoguangHu01
XiaoguangHu01 previously approved these changes Apr 18, 2022
Copy link
Contributor

@XiaoguangHu01 XiaoguangHu01 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LG API

"""
This API implements channel shuffle operation.
See more details in :ref:`api_nn_vision_ChannelShuffle` .
Parameters:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

image

  • 需要在每个关键字段前后增加空行,以保证官网正确解析文档~参考文档
  • 中英文文档应保持一致

name (str, optional): The default value is None. Normally there is no need for user to set this property.
Returns:
Out(tensor): Rearranged tensor keeping the original tensor shape.
Raises:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不需要增加raises,这个可以删除~


Shape:
- x: 4-D tensor with shape: (N, C, H, W) or (N, H, W, C).
- out: 4-D tensor with shape same as x.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

image

  • 「-」--> 「:」
  • 参数需要加粗,最好不要在参数描述中出现其他「:」,容易引起误解
  • 另外,其他文档中有shape这个关键字段吗?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PixelShuffle有shape这个关键字段。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

源码里用的是:,好像是生成文档时自动变成-的。

Ligoml
Ligoml previously approved these changes Apr 20, 2022
Copy link
Contributor

@Ligoml Ligoml left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM for docs

jeff41404
jeff41404 previously approved these changes Apr 20, 2022
Copy link
Contributor

@shiyutang shiyutang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

类似问题可参照修改:#40728 (review)

@@ -0,0 +1,74 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

参考:#40728 (comment) 进行同理修改。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

完成

@@ -0,0 +1,73 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

参考:#40728 (comment) 进行同理修改。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

完成

@BrilliantYuKaimin BrilliantYuKaimin dismissed stale reviews from jeff41404 and Ligoml via e29fb30 April 21, 2022 04:02
@BrilliantYuKaimin
Copy link
Contributor Author

类似问题可参照修改:#40728 (review)

完成

@BrilliantYuKaimin
Copy link
Contributor Author

BrilliantYuKaimin commented Apr 21, 2022

研发老师可以帮忙看一下 PR-CI-Kunlun-KP-Build 不通过的原因吗?

WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

按照这个修改下吧~
image

Copy link
Contributor

@chenwhql chenwhql left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@XiaoguangHu01 XiaoguangHu01 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@shiyutang shiyutang requested a review from Ligoml April 24, 2022 07:38
Copy link
Contributor

@Ligoml Ligoml left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM for docs

@jeff41404 jeff41404 merged commit bbaaf21 into PaddlePaddle:develop Apr 25, 2022
QingshuChen pushed a commit that referenced this pull request May 6, 2022
* bind elementwise_mod_op_xpu *test=kunlun

* add more supported dtypes and UTs *test=kunlun

* fix datatype error

* add op to in xpu1_op_list

* Update Mac cmake version >=3.15 (#41456)

* Update Mac cmake version >=3.15

* notest;read test1

notest;read test2

notest;read test3

* fix inference link error

* fix inference link error

* fix windows link error

* fix cmake_policy

* fix build big size

* Add paddle::variant and replace paddle::any (#42139)

* add variant and replace any

* split attribute

* disable unittest failed in eager CI in temporary (#42101)

* test=py3-eager

* test=py3-eager

* test=py3-eager

* combine graph_table and feature_table in graph_engine (#42134)

* extract sub-graph

* graph-engine merging

* fix

* fix

* fix heter-ps config

* test performance

* test performance

* test performance

* test

* test

* update bfs

* change cmake

* test

* test gpu speed

* gpu_graph_engine optimization

* add dsm sample method

* add graph_neighbor_sample_v2

* Add graph_neighbor_sample_v2

* fix for loop

* add cpu sample interface

* fix kernel judgement

* add ssd layer to graph_engine

* fix allocation

* fix syntax error

* fix syntax error

* fix pscore class

* fix

* change index settings

* recover test

* recover test

* fix spelling

* recover

* fix

* move cudamemcpy after cuda stream sync

* fix linking problem

* remove comment

* add cpu test

* test

* add cpu test

* change comment

* combine feature table and graph table

* test

* test

* pybind

* test

* test

* test

* test

* pybind

* pybind

* fix cmake

* pybind

* fix

* fix

* add pybind

* add pybind

Co-authored-by: DesmonDay <908660116@qq.com>

* [CustomDevice] add eager mode support (#42034)

* fix FlattenContiguousRangeOpConverter out dim error (#42087)

* fix FlattenContiguousRangeOpConverter out dim error

* update code

* fix python3.10 compile bug on windows (#42140)

* Optimize dygraph GetExpectedKernelType perf (#42154)

* opt dygraph scheduling

* revert part impl

* fix incorrect usages of std::move and other compile errors (#41045)

* fix bug of std::move and others

* fix an compile error in debug mode

* fix wrong copy assignment operator

Signed-off-by: tiancaishaonvjituizi <452565578@qq.com>

* reformat

Signed-off-by: tiancaishaonvjituizi <452565578@qq.com>

* reformat

Signed-off-by: tiancaishaonvjituizi <452565578@qq.com>

* fix ArrayRef constructor following llvm

* fix format

* fix conflict with master

* fix variant compile error (#42203)

* [Eager] Support numpy.ndarry in CastNumpy2Scalar (#42136)

* [Eager] Remove redundancy code, fix fp16 case (#42169)

* [Eager] Support div(scalar) in eager mode (#42148)

* [Eager] Support div scalar in eager mode

* Updated and remove debug logs

* Remove list, use 'or' directly

* Remove useless statement

* fix recompute (#42128)

* fix recompute

* modify return

* add LICENSE in wheel dist-info package (#42187)

* replace any by variant in infermeta (#42181)

* 【PaddlePaddle Hackathon 2】24、为 Paddle 新增 nn.ChannelShuffle 组网 API (#40743)

* Add infermeta for ChannelShuffle

* Create channel_shuffle_grad_kernel.h

* Create channel_shuffle_kernel.h

* Create channel_shuffle_sig.cc

* Create channel_shuffle_op.cc

ChannelShuffle算子的描述

* Create channel_shuffle_kernel_impl.h

ChannelShuffle核函数的实现

* Create channel_shuffle_grad_kernel_impl.h

ChannelShuffle反向核函数的实现

* Add kernel register of channel shuffle and grad

注册ChannelShuffle及其反向的核函数

* add nn.functional.channel_shuffle

* add nn.ChannelShuffle

* Create test_channel_shuffle.py

* Update example of ChannelShuffle in vision.py

* Update test_channel_shuffle.py

* 修改channel_shuffle核函数的实现位置

* 修正代码格式

* 删除多余空格

* 完善channel_shuffle的错误检查

* Update unary.cc

* Update channel_shuffle_op.cc

* Update test_channel_shuffle.py

* Update unary.cc

* add channel_shuffle

* Update test_channel_shuffle.py

* Update vision.py

* 调整代码格式

* Update channel_shuffle_sig.cc

* 更新ChannelShuffle的文档

* 更新channel_shuffle的文档

* remove ChannelShuffleOpArgumentMapping

* add ChannelShuffleGradInferMeta

* Update channel_shuffle_op.cc

* 调整channel_shuffle及其梯度的核函数的位置

* Do not reset default stream for StreamSafeCUDAAllocator (#42149)

* remove redundant computation in Categorical.probs (#42114)

* Downloading data for test_analyzer_vit_ocr (#42041)

* Change server URL

* update config

* add test to parallel UT rule

* add checksum to ensure files are downloaded

* change downloading target

* reuse existing variable

* change target directory

* fix en docs of some Apis (gradients, scope_guard, cuda_places, name_scope, device_guard, load_program_state, scale, ParamAttr and WeightNormParamAttr) (#41604)

* Update scope_guard; test=document_fix

* gradients; test=document_fix

* gradients; test=document_fix

* name_scope; test=document_fix

* cpu_places; test=document_fix

* WeightNormParamAttr; test=document_fix

* cuda_places; test=document_fix

* load_program_state; test=document_fix

* device_guard; test=document_fix

* device_guard; test=document_fix

* ParamAttr; test=document_fix

* scale; test=document_fix

* scale; test=document_fix

* update code example;test=document_fix

Co-authored-by: Chen Long <1300851984@qq.com>

* fix datatype error

add op to in xpu1_op_list

*test=kunlun

* fix elementwise_mod op path error  *test=kunlun

* fix elementwise_mod UT error  *test=kunlun

* fix datatype error

add op to in xpu1_op_list

*test=kunlun

add op to in xpu1_op_list

fix elementwise_mod op path error  *test=kunlun

fix elementwise_mod UT error  *test=kunlun

Co-authored-by: tianshuo78520a <707759223@qq.com>
Co-authored-by: Chen Weihang <chenweihang@baidu.com>
Co-authored-by: pangyoki <pangyoki@126.com>
Co-authored-by: seemingwang <seemingwang@users.noreply.github.com>
Co-authored-by: DesmonDay <908660116@qq.com>
Co-authored-by: ronnywang <524019753@qq.com>
Co-authored-by: baoachun <962571062@qq.com>
Co-authored-by: Zhou Wei <1183042833@qq.com>
Co-authored-by: tiancaishaonvjituizi <452565578@qq.com>
Co-authored-by: Weilong Wu <veyron_wu@163.com>
Co-authored-by: Roc <30228238+sljlp@users.noreply.github.com>
Co-authored-by: BrilliantYuKaimin <91609464+BrilliantYuKaimin@users.noreply.github.com>
Co-authored-by: Ruibiao Chen <chenruibiao@baidu.com>
Co-authored-by: Feiyu Chan <chenfeiyu@baidu.com>
Co-authored-by: Sławomir Siwek <slawomir.siwek@intel.com>
Co-authored-by: Yilingyelu <103369238+Yilingyelu@users.noreply.github.com>
Co-authored-by: Chen Long <1300851984@qq.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
contributor External developers
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet