New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Batching rules matrix exp #116377
Batching rules matrix exp #116377
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/116377
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 62f2528 with merge base f316c35 (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Okay, so I checked and it's confusing about where this should go. Matrix_Pow had the same issue and they put both batch rules in the top level BatchRulesDecomposition. Maybe we should start a specific area in the file for these aliases. Are there any other aliases in the list that you think are already implemented somewhere else? If so, we should open more PRs to fix them |
torch.lu has the same problem, but it's deprecated. |
I think |
I agree, the direction is kinda confusing. So the standard right now is to use Also, why does |
They are the same: https://pytorch.org/docs/stable/generated/torch.matrix_exp.html |
Hmm, looks this wasn't the right type of fix, sadly. @Xiao215 To execute this test, run the following from the base repo dir: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Need to add it to BatchRulesDecomposition (instead of the approach in this PR) as per test error messages.
Please seek CI approval before scheduling CIFlow labels |
Please seek CI approval before scheduling CIFlow labels |
uh ohhh did I somehow added people as my reviewer, I'm so sorry if you get notified but I didn't mean to do it 😂 |
Too many people subscribed to this PR. Can you please open a new one? |
Fixes #115992
Result
The issue is fixed, where the warning no longer exist and speed improvement is observed. This is my first open source contribution so this might not be the best fix, let me know if there is anything I should change.
Cause of the issue
It seems that there is already a batching rule implemented for
torch.linalg.matrix_exp
in BatchRulesLinearAlgebra.cpp, but somehow
torch.matrix_exp
is not using that batching rule. I believe that these two calls should be exactly same in terms of functionality since one is just alias of another, so I just simply assign the batching rule fortorch.linalg.matrix_exp
totroch.matrix_exp
. The change does seem to solve the problem. Corresponding test for matrix_exp batching rule is added.Validation of the change
On top of passing the
test_vmap.py
, I compare the run time of the oldtorch.matrix_exp
to the new one. I try to compute thevmap(torch.linalg.matrix_exp)(t)
andvmap(torch.matrix_exp)(t)
for a tensort = torch.randn(9000, 3, 3)
and the following speed improvement was observed.torch.linalg.matrix_exp
torch.matrix_exp
torch.matrix_exp
has a very huge speed improvement, and no matter how many batches there are, the run times are all very fast. Note, no change is made fortorch.linalg.matrix_exp
, that should just be a reference.Additional question
After implementing the batching rule for
torch.matrix_exp
, I observed that it consistently performs about 5 times faster thantorch.linalg.matrix_exp
, despite them being functionally identical. Could there be an underlying reason for this discrepancy?THANK YOU! ❤️
cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @penguinwu @fegin @XilunWu @wanchaol @fduwjj @wz337 @tianyu-l @wconstab @yf225 @jgong5 @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10 @mcarilli @ptrblck @leslie-fang-intel @EikanWang @zou3519 @Chillee @samdow @kshitij12345 @janeyx99 @voznesenskym @Guobing-Chen @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @chenyang78 @kadeng @muchulee8 @aakhundov @ColinPeppler @LucasLLC