Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【Hackathon No.21】为 Paddle 新增 paddle.incubate.sparse.transpose 稀疏 API #45849

Merged
merged 40 commits into from Sep 30, 2022

Conversation

zrr1999
Copy link
Member

@zrr1999 zrr1999 commented Sep 7, 2022

@paddle-bot
Copy link

paddle-bot bot commented Sep 7, 2022

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

# Conflicts:
#	paddle/phi/api/yaml/sparse_bw_api.yaml
@zrr1999 zrr1999 marked this pull request as ready for review September 9, 2022 10:20
@zrr1999 zrr1999 force-pushed the sparse_transpose branch 4 times, most recently from 10161e9 to a0d092c Compare September 24, 2022 16:11
const DenseTensor& x_values = x.values();
// return a copy of x
if (perm[0] == 0 && perm[1] == 1 && (n_dim == 2 || perm[2] == 2)) {
out->SetMember(x_crows, x_cols, x_values, x.dims());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

用DenseTensor a=b; 触发左值构造,底层会共享Allocation内存,不要直接就完全用同一个DenseTensor,这个可能会引入风险

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的,已经修改

Copy link
Contributor

@zhwesky2010 zhwesky2010 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

"""
Changes the perm order of ``x`` without changing its data, requiring x to be a SparseCooTensor or SparseCsrTensor.

.. math::
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

中英文API文档需要保持一致哈,例如API描述和公式,最好不要一边有一边没有

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的,这里已经在中文文档中修改

Copy link
Contributor

@Ligoml Ligoml left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM for docs

Copy link
Contributor

@jeff41404 jeff41404 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@chenwhql chenwhql left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM for fluid header using

Copy link
Contributor

@XiaoguangHu01 XiaoguangHu01 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

args : (Tensor x, int[] perm)
output : Tensor(out)
infer_meta :
func : TransposeInferMeta
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the used " func : TransposeInferMeta" is actually the one for dense tensor. But the TransposeInferMeta for dense tensor is also applicable to sparse tensor. "func : TransposeGradInferMeta" has the same situation. So I submit a PR to delete maybe unused code in paddle\phi\infermeta\sparse\unary.h

#46844

def check_result(self, x_shape, dims, format):
with _test_eager_guard():
mask = paddle.randint(0, 2, x_shape).astype("float32")
origin_x = paddle.rand(x_shape, dtype='float32') * mask
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You need to make sure that all zero elements in "origin_x" is caused by multiplying by "mask",
or the following backward checks may fail:

       np.testing.assert_allclose(sp_x.grad.to_dense().numpy(),
                                   (dense_x.grad * mask).numpy(),
                                   rtol=1e-05)

Here is an example:
origin_x = [2, 0, 3, 4, 5] * [1, 1, 0, 1, 1]
so origin_x is [2, 0, 0, 4, 5]
sp_x.grad.to_dense().numpy() will be [1, 0, 0, 1, 1] ,
while (dense_x.grad * mask).numpy() will be [1, 1, 0, 1, 1],
then the backward check fails.

I have stated a PR to modify it.
#47017

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

感谢修改

@zhwesky2010
Copy link
Contributor

@zrr1999 你好,还需要在 https://github.com/PaddlePaddle/docs 中增加transpose的中文文档

@zrr1999
Copy link
Member Author

zrr1999 commented Oct 17, 2022

@zrr1999 你好,还需要在 https://github.com/PaddlePaddle/docs 中增加transpose的中文文档

您好,这部分之前已经提交pr,,但是一直还没有merge,已经解决了最新的冲突

Copy link
Contributor

@zkh2016 zkh2016 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#53008 单测出现随机现象可能是使用错误的cudaMalloc和cudaMemcpy导致的,麻烦改下后再测下。 @zrr1999

hipMemcpy(
d_perm, perm.data(), sizeof(int) * perm.size(), hipMemcpyHostToDevice);
#else
cudaMalloc(reinterpret_cast<void **>(&d_perm), sizeof(int) * perm.size());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里使用的有点问题,如果用cudaMalloc应该手动调用cudaFree释放。 但是建议使用paddle::memory::Alloc

d_perm, perm.data(), sizeof(int) * perm.size(), hipMemcpyHostToDevice);
#else
cudaMalloc(reinterpret_cast<void **>(&d_perm), sizeof(int) * perm.size());
cudaMemcpy(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里需要使用异步的方式,否则这里和kernel执行不在一个stream上。可以用 GpuMemcpyAsync

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#53008 单测出现随机现象可能是使用错误的cudaMalloc和cudaMemcpy导致的,麻烦改下后再测下。 @zrr1999

我添加了一个新的pr用来修复这个bug,看起来reshape算子好像也存在类似的问题 https://github.com/PaddlePaddle/Paddle/pull/53038/files

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的,辛苦一块改下呗

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
contributor External developers
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

9 participants