Skip to content

Commit

Permalink
Add conditional to ensure only CUDA goes through SVD code path as fal…
Browse files Browse the repository at this point in the history
…lback.
  • Loading branch information
ZelboK committed Apr 28, 2024
1 parent 6e8b3fd commit c71e504
Show file tree
Hide file tree
Showing 2 changed files with 1 addition and 2 deletions.
2 changes: 1 addition & 1 deletion aten/src/ATen/native/BatchLinearAlgebra.cpp
Expand Up @@ -3429,7 +3429,7 @@ static void linalg_lstsq_out_info(
auto input_working_copy = copyBatchedColumnMajor(input);

// now the actual call that computes the result in-place (apply_lstsq)
if (driver == "gelss") {
if (driver == "gelss" input.device() != at::kCPU) {
auto [U, S, Vh] = at::_linalg_svd(input, false, true, "gesvd");
auto S_pinv = S.reciprocal();
auto s1 = at::narrow(S, /*dim=*/-1, /*start=*/0, /*length=*/1); // singular values are sorted in descending order
Expand Down
1 change: 0 additions & 1 deletion test/test_linalg.py
Expand Up @@ -344,7 +344,6 @@ def check_correctness(a, b):
for m, batch in itertools.product(ms, batches):
a = random_well_conditioned_matrix(m, m, dtype=dtype, device=device).view(*([1] * len(batch)), m, m)
b = torch.rand(*(batch + (m, m)), dtype=dtype, device=device)

check_correctness(a, b)

# cases with broadcastable shapes
Expand Down

0 comments on commit c71e504

Please sign in to comment.