Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[traced-graph][sparse] propagate sparsity metadata into traced graph #117907

Closed
wants to merge 4 commits into from

Conversation

aartbik
Copy link
Contributor

@aartbik aartbik commented Jan 20, 2024

Propagate sparsity metadata from sparse tensors of torch.sparse into the traced graph representation (with would be useful for a JIT backend that supports a "sparse compiler"). This is a first careful attempt, since the actual "meta" feature seem still incomplete for coo and completely lacking for csr/csc/bsr/bsc.

For background see forum postings (with examples):
https://discuss.pytorch.org/t/connecting-pytorch-sparse-tensors-with-mlir/195145
https://dev-discuss.pytorch.org/t/connecting-pytorch-sparse-tensors-with-mlir/1803

And feature request:
#117188

cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @penguinwu @fegin @XilunWu @wanchaol @fduwjj @wz337 @tianyu-l @wconstab @yf225 @chauhang @d4l3k @jgong5 @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10 @voznesenskym @EikanWang @Guobing-Chen @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @rohan-varma @aakhundov

Copy link

pytorch-bot bot commented Jan 20, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/117907

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure, 1 Unrelated Failure

As of commit b313230 with merge base 1b29c16 (image):

NEW FAILURE - The following job has failed:

UNSTABLE - The following job failed but was likely due to flakiness present on trunk and has been marked as unstable:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

Copy link

linux-foundation-easycla bot commented Jan 20, 2024

CLA Signed

The committers listed above are authorized under a signed CLA.

Copy link

pytorch-bot bot commented Jan 20, 2024

Please seek CI approval before scheduling CIFlow labels

Copy link

pytorch-bot bot commented Jan 24, 2024

Please seek CI approval before scheduling CIFlow labels

@aartbik
Copy link
Contributor Author

aartbik commented Jan 24, 2024

@ezyang

@ezyang
Copy link
Contributor

ezyang commented Jan 25, 2024

I am in principle in favor of make fake tensor work on sparse tensors, and adjusting our infra so that this works. I didn't carefully review the implementation details in this PR. What should next steps be?

@ezyang
Copy link
Contributor

ezyang commented Jan 25, 2024

cc @eellison @pearu

@aartbik
Copy link
Contributor Author

aartbik commented Jan 25, 2024

As for next steps, I can definitely work (and with my team) on bringing this to a higher quality, but we would really like to have a core PyTorch developer assist and look over our shoulders, since we are new to this repo and don't want to miss essential design principles when making first changes. Of course, if a core developer is willing and interested in helping out with actual code, that would be welcome too!

Copy link
Collaborator

@pearu pearu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a couple of comments. While the aim here is to support sparse tensor layouts, I think the corresponding changes ought address the layout-specific issues (such as the lack of strides, for instance) in a more sparse-layout-agnostic way whenever possible.

torch/_subclasses/fake_tensor.py Outdated Show resolved Hide resolved
torch/__init__.py Outdated Show resolved Hide resolved
torch/_dynamo/utils.py Outdated Show resolved Hide resolved
torch/_subclasses/fake_tensor.py Outdated Show resolved Hide resolved
torch/fx/graph.py Outdated Show resolved Hide resolved
Copy link

pytorch-bot bot commented Feb 2, 2024

Please seek CI approval before scheduling CIFlow labels

Copy link

pytorch-bot bot commented Feb 2, 2024

Please seek CI approval before scheduling CIFlow labels

@pearu
Copy link
Collaborator

pearu commented Feb 2, 2024

A related issue is #99404 that suggests looking at the problem from the perspective of creating FakeTensors using the fake tensor converter class.

@ezyang
Copy link
Contributor

ezyang commented Feb 3, 2024

@pearu are you willing to shepherd this PR to completion?

@aartbik
Copy link
Contributor Author

aartbik commented Feb 9, 2024

Thanks all!

Note that I am not stalling this PR ;-) but I am wrapping-up the required torch-mlir work so that we are pretty certain of all the metadata that should propagated into the traced graph. Also, we plan to do some "sanity" performance runs, just to see if this new path really has potential.

So please stay tuned, I will get back to this PR really soon....

@pytorch-bot pytorch-bot bot added module: cpu CPU specific problem (e.g., perf, algorithm) release notes: quantization release notes category labels May 22, 2024
@pearu
Copy link
Collaborator

pearu commented May 23, 2024

@pytorchbot merge -f "with two unrelated CI failures"

@pytorch pytorch deleted a comment from pytorch-bot bot May 23, 2024
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use -f as last resort and instead consider -i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@aartbik
Copy link
Contributor Author

aartbik commented May 24, 2024

Thanks again, @pearu , for you truly exemplary mentorship to me getting this in!
I am enthusiastically starting on the next tasks for this feature!

@aartbik aartbik deleted the bik branch May 24, 2024 15:52
@pearu
Copy link
Collaborator

pearu commented May 24, 2024

@aartbik , congratulations to getting this PR merged! And thanks for your patience, landing a PR in 5 months requires plenty of it :) Looking forward to your next contributions!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/inductor ciflow/trunk Trigger trunk jobs on your pull request Merged module: cpu CPU specific problem (e.g., perf, algorithm) module: dynamo module: inductor oncall: distributed Add this issue/PR to distributed oncall triage queue open source release notes: fx release notes category release notes: quantization release notes category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

9 participants