Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Batching rules matrix exp #116377

Closed
wants to merge 11 commits into from
Closed

Conversation

Xiao215
Copy link

@Xiao215 Xiao215 commented Dec 24, 2023

Fixes #115992

Result

The issue is fixed, where the warning no longer exist and speed improvement is observed. This is my first open source contribution so this might not be the best fix, let me know if there is anything I should change.

Cause of the issue

It seems that there is already a batching rule implemented for torch.linalg.matrix_exp in BatchRulesLinearAlgebra.cpp
, but somehow torch.matrix_exp is not using that batching rule. I believe that these two calls should be exactly same in terms of functionality since one is just alias of another, so I just simply assign the batching rule for torch.linalg.matrix_exp to troch.matrix_exp. The change does seem to solve the problem. Corresponding test for matrix_exp batching rule is added.

Validation of the change

On top of passing the test_vmap.py, I compare the run time of the old torch.matrix_exp to the new one. I try to compute the vmap(torch.linalg.matrix_exp)(t) and vmap(torch.matrix_exp)(t) for a tensor t = torch.randn(9000, 3, 3) and the following speed improvement was observed.

Implementation Time Before Change (seconds) Time After Change (seconds)
torch.linalg.matrix_exp 0.6841 0.5284
torch.matrix_exp 84.8330 0.1462

torch.matrix_exp has a very huge speed improvement, and no matter how many batches there are, the run times are all very fast. Note, no change is made for torch.linalg.matrix_exp, that should just be a reference.

Additional question

After implementing the batching rule for torch.matrix_exp, I observed that it consistently performs about 5 times faster than torch.linalg.matrix_exp, despite them being functionally identical. Could there be an underlying reason for this discrepancy?

THANK YOU! ❤️

cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @penguinwu @fegin @XilunWu @wanchaol @fduwjj @wz337 @tianyu-l @wconstab @yf225 @jgong5 @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10 @mcarilli @ptrblck @leslie-fang-intel @EikanWang @zou3519 @Chillee @samdow @kshitij12345 @janeyx99 @voznesenskym @Guobing-Chen @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @chenyang78 @kadeng @muchulee8 @aakhundov @ColinPeppler @LucasLLC

Copy link

pytorch-bot bot commented Dec 24, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/116377

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 62f2528 with merge base f316c35 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

Copy link

linux-foundation-easycla bot commented Dec 24, 2023

CLA Signed

The committers listed above are authorized under a signed CLA.

@Skylion007
Copy link
Collaborator

Okay, so I checked and it's confusing about where this should go. Matrix_Pow had the same issue and they put both batch rules in the top level BatchRulesDecomposition. Maybe we should start a specific area in the file for these aliases. Are there any other aliases in the list that you think are already implemented somewhere else? If so, we should open more PRs to fix them

@Skylion007
Copy link
Collaborator

torch.lu has the same problem, but it's deprecated.

@Skylion007
Copy link
Collaborator

Skylion007 commented Dec 24, 2023

I think torch.lu_solve has the same problem, but it's also deprecated

@Xiao215
Copy link
Author

Xiao215 commented Dec 24, 2023

I agree, the direction is kinda confusing. So the standard right now is to use torch.linalg.blahblah?

Also, why does torch.matrix_exp take less time than torch.linalg.matrix_exp, even though one is just alias for another? (I use a timeit to estimate the run time) Are they exactly same? I thought they are exactly same because here in BatchRulesLinearAlgebra.cpp

@Skylion007
Copy link
Collaborator

@Skylion007 Skylion007 added topic: bug fixes topic category module: functorch Pertaining to torch.func or pytorch/functorch release notes: functorch release notes category; Pertaining to torch.func or pytorch/functorch labels Dec 24, 2023
@Skylion007
Copy link
Collaborator

Hmm, looks this wasn't the right type of fix, sadly. @Xiao215
" FAILED [0.0011s] functorch/test_vmap_registrations.py::TestFunctorchDispatcher::test_register_a_batching_rule_for_composite_implicit_autograd_[aten::matrix_exp] - AssertionError: You've added a batching rule for a CompositeImplicitAutograd operator aten::matrix_exp. The correct way to add vmap support for it is to put it into BatchRulesDecomposition to reuse the CompositeImplicitAutograd decomposition

To execute this test, run the following from the base repo dir:
PYTORCH_TEST_WITH_DYNAMO=1 python test/functorch/test_vmap_registrations.py -k test_register_a_batching_rule_for_composite_implicit_autograd_[aten::matrix_exp]
"

@albanD albanD added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Dec 28, 2023
Copy link
Contributor

@zou3519 zou3519 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to add it to BatchRulesDecomposition (instead of the approach in this PR) as per test error messages.

Copy link

pytorch-bot bot commented Jan 19, 2024

Please seek CI approval before scheduling CIFlow labels

@pytorch-bot pytorch-bot bot added ciflow/mps Run MPS tests (subset of trunk) and removed ciflow/inductor labels Jan 19, 2024
Copy link

pytorch-bot bot commented Jan 19, 2024

Please seek CI approval before scheduling CIFlow labels

@pytorch-bot pytorch-bot bot removed the ciflow/mps Run MPS tests (subset of trunk) label Jan 19, 2024
@Xiao215 Xiao215 marked this pull request as draft January 19, 2024 17:36
@Xiao215
Copy link
Author

Xiao215 commented Jan 19, 2024

uh ohhh did I somehow added people as my reviewer, I'm so sorry if you get notified but I didn't mean to do it 😂

@Xiao215 Xiao215 marked this pull request as ready for review January 19, 2024 17:58
@Xiao215 Xiao215 requested a review from zou3519 January 20, 2024 00:42
@lezcano
Copy link
Collaborator

lezcano commented Jan 20, 2024

Too many people subscribed to this PR. Can you please open a new one?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: amp (automated mixed precision) autocast module: cpu CPU specific problem (e.g., perf, algorithm) module: distributed_checkpoint module: dynamo module: functorch Pertaining to torch.func or pytorch/functorch module: inductor NNC oncall: distributed Add this issue/PR to distributed oncall triage queue open source release notes: functorch release notes category; Pertaining to torch.func or pytorch/functorch release notes: quantization release notes category topic: bug fixes topic category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Please implement the batching rule for torch.matrix_exp.
6 participants