Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inductor can not fuse cat with a pointwise #125075

Open
shunting314 opened this issue Apr 26, 2024 · 2 comments
Open

Inductor can not fuse cat with a pointwise #125075

shunting314 opened this issue Apr 26, 2024 · 2 comments
Labels
enhancement Not as big of a feature, but technically not a bug. Should be easy to fix module: inductor oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@shunting314
Copy link
Contributor

shunting314 commented Apr 26, 2024

馃悰 Describe the bug

Check this example

    def test_concat_and_downcast(self):
        M = 30522
        N = 768
        PAD = 6

        @torch.compile
        def f(x):
            z = torch.cat([x, torch.zeros([PAD, N])], dim=0)
            return z, x.to(torch.float16)

        x = torch.randn(M, N)
        f(x)

The cat and the downcast can not be fused right now. But in principal we should be able to fuse them and save one whole load of tensor 'x'.

cc @ezyang @msaroufim @bdhirsh @anijain2305 @chauhang @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @jansel @Chillee @eellison as a FYI.

Error logs

No response

Minified repro

No response

Versions

.

cc @ezyang @msaroufim @bdhirsh @anijain2305 @chauhang

@shunting314 shunting314 changed the title Inductor can not fuse cat with a followed pointwise Inductor can not fuse cat with a pointwise Apr 26, 2024
@Chillee
Copy link
Contributor

Chillee commented Apr 27, 2024

We'd fuse it with the pointwise cat codegen?

@jbschlosser jbschlosser added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module enhancement Not as big of a feature, but technically not a bug. Should be easy to fix module: inductor labels Apr 29, 2024
@shunting314
Copy link
Contributor Author

We'd fuse it with the pointwise cat codegen?

That requires the awareness of the other pointwise node that can be fused when we lower cat as a piontwise? I think it's a bit hard since the other pointwise node that can be fused can happen somewhere else.

Or an alternative is to do this in scheduler since the scheduler does global search and can try to fuse the pointwise node with the cat node. But the tricky thing here is these 2 nodes has different numels.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement Not as big of a feature, but technically not a bug. Should be easy to fix module: inductor oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

3 participants