Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Please implement the batching rule for torch.matrix_exp. #115992

Open
LongZ-A opened this issue Dec 17, 2023 · 5 comments 路 May be fixed by #117934 or #125086
Open

Please implement the batching rule for torch.matrix_exp. #115992

LongZ-A opened this issue Dec 17, 2023 · 5 comments 路 May be fixed by #117934 or #125086
Labels
actionable good first issue module: functorch Pertaining to torch.func or pytorch/functorch triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@LongZ-A
Copy link

LongZ-A commented Dec 17, 2023

馃殌 The feature, motivation and pitch

Matrix exponentials are extremely expensive to compute, yet very important in many computing and ML problems . It would be great if such computational expensive functions can support batch implementations. Many thanks in advance for those who can help with this issue!

Alternatives

No response

Additional context

This below is the warning message from pytorch kernel: :3: UserWarning: There is a performance drop because we have not yet implemented the batching rule for aten::matrix_exp. Please file us an issue on GitHub so that we can prioritize its implementation. (Triggered internally at C:\cb\pytorch_1000000000000\work\aten\src\ATen\functorch\BatchedFallback.cpp:84.)

cc @zou3519 @Chillee @samdow @kshitij12345 @janeyx99

@zou3519 zou3519 added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module module: functorch Pertaining to torch.func or pytorch/functorch good first issue actionable labels Dec 20, 2023
@Xiao215
Copy link

Xiao215 commented Dec 21, 2023

Hi, can I give this issue a try? :)

@LongZ-A
Copy link
Author

LongZ-A commented Dec 22, 2023

Many thanks to everyone for the quick response! @zou3519 @Xiao215. Hopefully this will be successfully implemented soon 馃憤

@Xiao215
Copy link

Xiao215 commented Dec 24, 2023

BTW I believe this is a repro of the issue.

import torch
t = torch.randn(900, 3, 3)
o = torch.vmap(torch.matrix_exp)(t)

@amk16
Copy link

amk16 commented Jan 3, 2024

Hey could I attempt this if its still open?

@zou3519
Copy link
Contributor

zou3519 commented Jan 3, 2024

Issue still open, please feel free to attempt.

EDIT: nevermind, there already is a PR over at #116377

@Xiao215 Xiao215 linked a pull request Jan 20, 2024 that will close this issue
@samhu1 samhu1 linked a pull request Apr 27, 2024 that will close this issue
4 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
actionable good first issue module: functorch Pertaining to torch.func or pytorch/functorch triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
4 participants