Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Hackathon 3rd No.22 ] add paddle.incubate.sparse.reshape #46694

Conversation

OccupyMars2025
Copy link
Contributor

@OccupyMars2025 OccupyMars2025 commented Oct 1, 2022

PR types

New features

PR changes

OPs

Describe

add paddle.incubate.sparse.reshape
rfc:PaddlePaddle/community#246
中文文档:PaddlePaddle/docs#5360

csr复用coo算法,反向复用前向,coo算法先展平计算出location位置,再通过stride计算出新的index

本PR的前身是 #46333 , 因为之前的PR中我的comment太多,所以开了这个新的PR。本PR是在 #45849 基础上修改得到 ( sparse reshape 和 sparse transpose 有相似性,我基本上只修改了kernel中的计算逻辑 ),在这里感谢 https://github.com/zrr1999

…-version002' of https://github.com/OccupyMars2025/Paddle into hackathon-3rd-task22-add-paddle.incubate.sparse.reshape-version002
@OccupyMars2025
Copy link
Contributor Author

OccupyMars2025 commented Oct 14, 2022

@zhangting2020 飞桨专家你好,能否解释一下为什么PR-CI-OP-benchmark 会报错吗?而且这个CI报错内容一直都没变。因为我自己看不懂报错内容,所以单纯想了解一下报错原因,谢谢。

luotao1
luotao1 previously approved these changes Oct 14, 2022
@luotao1
Copy link
Contributor

luotao1 commented Oct 14, 2022

能否解释一下为什么PR-CI-OP-benchmark 会报错吗

因为缺少sparse_utils的op-benchmark脚本所以会报错。这个脚本是历史遗留问题,和新增API无关,所以下豁免了。

@OccupyMars2025
Copy link
Contributor Author

OccupyMars2025 commented Oct 14, 2022

能否解释一下为什么PR-CI-OP-benchmark 会报错吗

因为缺少sparse_utils的op-benchmark脚本所以会报错。这个脚本是历史遗留问题,和新增API无关,所以下豁免了。

明白了,谢谢。

The failure of PR-CI-OP-benchmark is attributed to the following highlighted warning. The failure has nothing to do with your code. You can just ignore this failure.

image



@dygraph_only
def reshape(x, shape, name=None):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the design of reshape API needs to be added to rfc like other APIs.

func : ReshapeInferMeta
kernel :
func : reshape_coo{sparse_coo -> sparse_coo},
reshape_csr{sparse_csr -> sparse_csr}
Copy link
Contributor

@jeff41404 jeff41404 Oct 14, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the design of reshape_csr needs to be added to rfc, also ReshapeCsrKernel and ReshapeCsrGradKernel below.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean I need to modify the RFC to make the content of the RFC agree with my actual implementation ? Ok, I will do it.

@OccupyMars2025
Copy link
Contributor Author

Now I have modified the RFC to make the content of RFC agree with my actual implementation:
PaddlePaddle/community#301

@OccupyMars2025
Copy link
Contributor Author

OccupyMars2025 commented Oct 14, 2022

PR-CI-Coverage 失败,显示我所有的 cpu kernel 都没有被 cover 到

image

我能想到的可能的解决方法就是:

使用 numpy array, 在 paddle.to_tensor中明确指明 place,代码示例如下:

image
image

在使用CUDAPlace 之前,增加paddle.device.is_compiled_with_cuda() 进行环境判断

但是我使用的 单测代码和 #45849 是一样的,也许不是单测代码的问题

@zhwesky2010
Copy link
Contributor

zhwesky2010 commented Oct 14, 2022

@OccupyMars2025 这个应该是coverage流水线没有跑到CPU版本的问题,但是其他的CPU流水线(比如py3、windows-openblas)会跑这个单测的CPU版本,所以不会有问题,我建议可以豁免。

@OccupyMars2025
Copy link
Contributor Author

OccupyMars2025 commented Oct 14, 2022

@zhouwei25 不好意思,没及时看到你的回复。我已经按照我评论中的想法修改了,即明确指定tensor 的place,这样就算coverage流水线跑的是CUDA版本的PaddlePaddle,一样会用到 cpu kernel。这是我的理解
#46694 (comment)

按照我说的方法,PR-CI-Coverage 报错的问题解决了。(2022/10/14 22:00)

zhwesky2010
zhwesky2010 previously approved these changes Oct 14, 2022
Xreki
Xreki previously approved these changes Oct 14, 2022
Copy link
Contributor

@Xreki Xreki left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM for op benchmark ci

@OccupyMars2025
Copy link
Contributor Author

OccupyMars2025 commented Oct 14, 2022

感谢所有飞桨专家的审核和解答,一些报错我怎么想都想不明白,但是飞桨专家一两句话就解释清楚了,让我豁然开朗。本PR是在 #45849 基础上修改得到 ( sparse reshape 和 sparse transpose 有相似性,我基本上只修改了kernel中的计算逻辑 ),在这里感谢 https://github.com/zrr1999

jeff41404
jeff41404 previously approved these changes Oct 17, 2022
Copy link
Contributor

@jeff41404 jeff41404 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Ligoml
Ligoml previously approved these changes Oct 17, 2022
Copy link
Contributor

@Ligoml Ligoml left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM for docs

Copy link
Contributor

@Ligoml Ligoml left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM for docs

@OccupyMars2025
Copy link
Contributor Author

抱歉,因为我操作不当,修改了docstring后,commit重新触发了CI,可能要麻烦飞桨专家再approve一遍,以后我会注意。

@OccupyMars2025
Copy link
Contributor Author

@zhangting2020 @Xreki 麻烦再approve一下 PR-CI-OP-benchmark ,谢谢

@Ligoml Ligoml merged commit abb3813 into PaddlePaddle:develop Oct 17, 2022
zhwesky2010 pushed a commit to zhwesky2010/Paddle that referenced this pull request Oct 17, 2022
…le#46694)

* add sparse reshape

* change the dtype in all test cases to int64

* just one test case

* modify comments

* Update test_sparse_reshape_op.py

* chang the type of "shape"  from  vector<int64_t>  to  IntArray

* check whether sp_out.to_dense() is the cause  of error

* print sp_out

* Update reshape_kernel.cc

* use numpy to generate the equal paddle tensor

* just check dense_tensor.numpy()

* check cpu and cuda versions

* Update test_sparse_reshape_op.py

* supply all test cases for cpu forward coo kernel

* test forward coo cuda kernel

* change configuration of cuda kernel

* keep only one test case

* test coo cpu kernel (forward and backward)

* row major or column major ???

* test cuda coo forward kernel

* complete declaration and registration

* Update __init__.py

* rebuild

* retrigger CI

* add cudaMalloc and cudaMemcpy  in  ReshapeCooKernel  and change back to row major order in a cuda dense tensor

* midify minor error

* test only cpu coo forward kernel

* add all test cases for coo forward kernel  (both cpu and gpu)

* test all forward kernels (coo, csr; cpu, gpu)

* add all test cases for all kinds of kernels

* just retrigger CI

* Update sparse_ops.yaml

* Update sparse_ops.yaml

* Update sparse_ops.yaml

* resolve conflicts

* Update sparse_ops.yaml

* don't specify tensor place

* new shape has -1 or 0 in it

* Update unary_grad_kernel.h

* correct lvalue error

* code style

* Update sparse_backward.yaml

* Update sparse_ops.yaml

* Update unary_kernel.h

* Update unary.py

* Update sparse_backward.yaml

* Update unary.py

* code style

* code style

* code style

* Update unary.py

* specify tensor place explicitly

* do not use numpy array

* use numpy array in unit test again

* modify example code in docstring
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
contributor External developers
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

7 participants