Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow linalg.lstsq to use svd to compute the result for rank deficient matrices. #125110

Open
wants to merge 13 commits into
base: main
Choose a base branch
from

Conversation

ZelboK
Copy link
Contributor

@ZelboK ZelboK commented Apr 28, 2024

Fixes #117122

This PR adds the logic so that in the case of rank deficient matrices, it can fallback to an SVD backend for batched mode. A big thank you to @tvercaut for the well written issue and suggestion on how to approach the problem.

Summary:

  1. At the time of writing this I haven't touched non-batched yet. I am hoping to get some feedback before proceeding.
  2. I believe there should be eyes on how specifically we want to fallback to SVD as the implementation when we run into rank deficient matrices.

Please keep in mind this is my 2nd PR to pytorch, and I've never really used pytorch. I'm learning independently through digging deep in the internals so I may make some obvious mistakes. Please forgive!

Copy link

pytorch-bot bot commented Apr 28, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/125110

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 3006f30 with merge base e5e623a (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the release notes: sparse release notes category label Apr 28, 2024
@ZelboK ZelboK marked this pull request as draft April 28, 2024 00:57
@ZelboK ZelboK marked this pull request as ready for review April 28, 2024 01:22
@ZelboK
Copy link
Contributor Author

ZelboK commented Apr 28, 2024

@janeyx99 @ptrblck

2nd PR! 🎉 Sorry it took me some days. I'm learning Pytorch internals independently so I'm still learning the codebase.

Also I'm curious to know if there's a community like slack or discord for Pytorch?

@lezcano
Copy link
Collaborator

lezcano commented Apr 28, 2024

mind cleaning up all the spurious new lines and the PR in general?

@ZelboK ZelboK force-pushed the feat-improve-driver-linalg-lstq branch from 06e42e8 to 7372645 Compare April 28, 2024 09:46
@ZelboK
Copy link
Contributor Author

ZelboK commented Apr 28, 2024

@lezcano

My apologies! I've cleaned it up. I missed some new lines from when I was cleaning up my debugging/experimenting code so I could understand the codebase.

Copy link
Collaborator

@lezcano lezcano left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks mostly good.

Needs tests in test_linalg.py and updating the docs noting that this gelss mode is also supported.

aten/src/ATen/native/BatchLinearAlgebra.cpp Outdated Show resolved Hide resolved
aten/src/ATen/native/BatchLinearAlgebra.cpp Outdated Show resolved Hide resolved
aten/src/ATen/native/BatchLinearAlgebra.cpp Outdated Show resolved Hide resolved
aten/src/ATen/native/BatchLinearAlgebra.cpp Outdated Show resolved Hide resolved
aten/src/ATen/native/BatchLinearAlgebra.cpp Outdated Show resolved Hide resolved
aten/src/ATen/native/BatchLinearAlgebra.cpp Outdated Show resolved Hide resolved
aten/src/ATen/native/BatchLinearAlgebra.cpp Outdated Show resolved Hide resolved
ZelboK and others added 2 commits April 28, 2024 09:24
Co-authored-by: Mario Lezcano Casado <3291265+lezcano@users.noreply.github.com>
torch/linalg/__init__.py Outdated Show resolved Hide resolved
aten/src/ATen/native/BatchLinearAlgebra.cpp Outdated Show resolved Hide resolved
aten/src/ATen/native/BatchLinearAlgebra.cpp Outdated Show resolved Hide resolved
aten/src/ATen/native/BatchLinearAlgebra.cpp Outdated Show resolved Hide resolved
@ZelboK
Copy link
Contributor Author

ZelboK commented Apr 28, 2024

@lezcano

So when it comes to the tests, what kind of test did you think would be appropriate, aside from checking that it no longer throws? I can add gelss as a driver to be used in test_linalg_lstsq_batch_broadcasting for example and assure its results are as expected. I'm not too familiar with the test suites yet so hoping for guidance here.

Edit: Workflow runs exposed two failing tests for CPU and complex lstsq computations. I didn't notice I didn't build with LAPACK, so these tests were skipped. Will look into it now.

@ZelboK ZelboK force-pushed the feat-improve-driver-linalg-lstq branch from da93358 to c71e504 Compare April 28, 2024 23:43
Copy link
Collaborator

@lezcano lezcano left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For testing, just add a path that tests this driver in the relevant tests that tests the other drivers. We may even already have a test that tests this driver for CPU.

test/test_linalg.py Outdated Show resolved Hide resolved
test/test_linalg.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@lezcano lezcano left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please try to keep the changes to the bare minimum

Comment on lines 3433 to 3436
if (input.numel() == 0 || input.size(-2) == 0 || input.size(-1) == 0) {
auto output_shape = input.sizes().vec();
output_shape.back() = other.size(-1);
solution = at::zeros(output_shape, input.options());
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this necessary?

Copy link
Contributor Author

@ZelboK ZelboK Apr 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I should've communicated these changes.

The tests will actually fail without this check because it'll generate tensors similar to this torch.empty((0, 1)). The narrow code will lead to

  File "/home/ksm/pytorch/test/test_linalg.py", line 316, in test_linalg_lstsq
    res = torch.linalg.lstsq(a, b, rcond=rcond, driver=driver)
RuntimeError: start (0) + length (1) exceeds dimension size (0).

Do we want to remove the edge case handling to simplify the logic, and communicate in the docs that this will occur?

If so, I'll also have to look at the tests again.

aten/src/ATen/native/BatchLinearAlgebra.cpp Show resolved Hide resolved
// LAPACK stores residuals data for postprocessing in rows n:(m-n)
if (compute_residuals) {
// LAPACK stores residuals data for postprocessing in rows n:(m-n)
if (solution.size(-2) >= n + (m - n)) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is all this necessary?

Copy link
Contributor Author

@ZelboK ZelboK Apr 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The tests were exposing some problems and they mainly originated from use of narrow in specific tensor cases. So these new conditionals were added to handle those situations.

  File "/home/ksm/pytorch/test/test_linalg.py", line 316, in test_linalg_lstsq
    res = torch.linalg.lstsq(a, b, rcond=rcond, driver=driver)
RuntimeError: start (1) + length (1) exceeds dimension size (1)

Would you prefer to revert these conditionals, run the test workflows, see what fails and go from there?
Going forward I will keep code changes to a minimal and document/comment out why some changes are necessary. Didn't mean to make yor job harder, my bad 😅

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But this path was already working before. I don't understand why should we touch it at all?

Copy link
Contributor Author

@ZelboK ZelboK Apr 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not needed. When adding tests and trying to go through exceptions, I found narrow was the main cause. So I added guards against all the narrows in the code.
I added this because if the second last dimension was less than n + m - n (which I now realize is just m) then this will throw.

I tried to produce a scenario where this would throw an exception, but it's always caught earlier on. This check is redundant and can be removed.

Edit: Also to clarify, the exception I pasted in the above comment was not from this line of code.

Copy link
Contributor Author

@ZelboK ZelboK Apr 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wait. I actually just produced an exception from this. But I really don't understand how.

auto raw_residuals = solution.narrow(/*dim=*/-2, /*start=*/n, /*length*/m - n);

this actually raises an exception, if I leave the line
rank.fill_(0) from earlier in the code. But when I remove it, the above line of code no longer raises an exception...? Why...? I understand that line is redundant, but I am still really really curious about this.
https://pastebin.com/BfysJQv4
^ backtrace(quite long) in case its useful

aten/src/ATen/native/BatchLinearAlgebra.cpp Outdated Show resolved Hide resolved
aten/src/ATen/native/BatchLinearAlgebra.cpp Show resolved Hide resolved
// LAPACK stores residuals data for postprocessing in rows n:(m-n)
if (compute_residuals) {
// LAPACK stores residuals data for postprocessing in rows n:(m-n)
if (solution.size(-2) >= n + (m - n)) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But this path was already working before. I don't understand why should we touch it at all?

test/test_linalg.py Show resolved Hide resolved
aten/src/ATen/native/BatchLinearAlgebra.cpp Outdated Show resolved Hide resolved
Co-authored-by: Mario Lezcano Casado <3291265+lezcano@users.noreply.github.com>
solution.set_(solution.storage(), solution_view.storage_offset(), solution_view.sizes(), solution_view.strides());
}
if (solution.size(-2) >= n) {
auto solution_view = solution.narrow(/*dim=*/-2, /*start=*/0, /*length*/n);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@lezcano

I added this, the tests for test_linalg_lstsq was failing for :
python test_linalg.py -k test_linalg_lstsq_cuda_float32

with tensors like:

torch.Size([2, 1]) is a shape
a contents: tensor([[0.1540],
        [0.9887]], device='cuda:0')

@cpuhrsch cpuhrsch added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Apr 30, 2024
@ZelboK
Copy link
Contributor Author

ZelboK commented May 1, 2024

@cpuhrsch @lezcano Could one of you please run the pipeline / tests workflow to see if perhaps these tests are flaky on my environment for whatever reason?

@@ -1080,7 +1080,7 @@

Keyword args:
driver (str, optional): name of the LAPACK/MAGMA method to be used.
If `None`, `'gelsy'` is used for CPU inputs and `'gels'` for CUDA inputs.
If `None`, `'gelsy'` is used for CPU inputs, `'gels'` and `'gelss'` for CUDA inputs.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

revert

// residuals are available only if m > n and drivers other than gelsy used
if (m > n && driver != "gelsy") {
// if the driver is gelss or gelsd then the residuals are available only if rank == n
bool compute_residuals = true;
if (driver == "gelss" || driver == "gelsd") {
if (input.dim() == 2) {
compute_residuals = (rank.item().toInt() == n);
compute_residuals = (rank.item().toDouble() == n);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this change necessary?

} else {
auto [U, S, Vh] = at::_linalg_svd(input, false, true, "gesvd");
rank = at::zeros({1}, at::kLong);
rank[0] = (S > rcond).sum();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not correct. Compute the rank by looking at the zeros of S_pinv.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
open source release notes: sparse release notes category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Improve behaviour of torch.linalg.lstsq on CUDA GPU for rank defficient matrices
5 participants