Skip to content

Commit

Permalink
Add conditional to ensure only CUDA goes through SVD code path as fal…
Browse files Browse the repository at this point in the history
…lback.
  • Loading branch information
ZelboK committed Apr 28, 2024
1 parent 6e8b3fd commit da93358
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion aten/src/ATen/native/BatchLinearAlgebra.cpp
Expand Up @@ -3429,7 +3429,7 @@ static void linalg_lstsq_out_info(
auto input_working_copy = copyBatchedColumnMajor(input);

// now the actual call that computes the result in-place (apply_lstsq)
if (driver == "gelss") {
if (driver == "gelss" input.device() != at::kCPU) {
auto [U, S, Vh] = at::_linalg_svd(input, false, true, "gesvd");
auto S_pinv = S.reciprocal();
auto s1 = at::narrow(S, /*dim=*/-1, /*start=*/0, /*length=*/1); // singular values are sorted in descending order
Expand Down

0 comments on commit da93358

Please sign in to comment.