Skip to content

Commit

Permalink
Add logic for lstsq to be able to use the SVD driver as a backend for…
Browse files Browse the repository at this point in the history
… when matrices are rank deficient.
  • Loading branch information
ZelboK committed Apr 28, 2024
1 parent 6bef5e9 commit 7372645
Showing 1 changed file with 33 additions and 12 deletions.
45 changes: 33 additions & 12 deletions aten/src/ATen/native/BatchLinearAlgebra.cpp
Expand Up @@ -6,7 +6,6 @@
#include <ATen/TensorMeta.h>
#include <ATen/TensorOperators.h>
#include <ATen/TensorSubclassLikeUtils.h>

#include <ATen/native/BatchLinearAlgebra.h>
#include <ATen/native/LinearAlgebraUtils.h>
#include <ATen/native/Resize.h>
Expand Down Expand Up @@ -117,6 +116,8 @@
#include <ATen/ops/triu.h>
#include <ATen/ops/vdot.h>
#include <ATen/ops/zeros.h>
#include <ATen/ops/zeros_like.h>
#include <ATen/ops/matmul.h>
#endif

// First the required LAPACK implementations are registered here.
Expand Down Expand Up @@ -1556,7 +1557,7 @@ void _linalg_check_errors(
": The algorithm failed to converge because the input matrix is ill-conditioned or has too many repeated eigenvalues (error code: ", info, ").");
} else if (api_name.find("lstsq") != api_name.npos) {
TORCH_CHECK_LINALG(false, api_name, batch_str,
": The least squares solution could not be computed because the input matrix does not have full rank (error code: ", info, ").");
": The least squares solution could not be computed because the input matrix does not have full rank (error code: ", info, "). Specify SVD in the driver if you would like to do this.");
} else if (api_name.find("lu_factor") != api_name.npos) {
TORCH_CHECK(false, api_name, batch_str,
": U[", info, ",", info, "] is zero and using it on lu_solve would result in a division by zero. "
Expand Down Expand Up @@ -3422,16 +3423,40 @@ static void linalg_lstsq_out_info(
TORCH_INTERNAL_ASSERT(singular_values.sizes().equals(singular_values_shape));
TORCH_INTERNAL_ASSERT(singular_values.is_contiguous());
}

// 'input' is modified in-place so we need a column-major copy
auto input_working_copy = copyBatchedColumnMajor(input);

// now the actual call that computes the result in-place (apply_lstsq)
lstsq_stub(input.device().type(), input_working_copy, solution, rank, singular_values, infos, rcond, driver);

if(driver == "gelsd" || driver == "gelss") {
auto k = std::min(m, n);
auto U = at::zeros({input.size(0), m, k}, input.options());
auto Vh = at::zeros({input.size(0), k, n}, input.options());

svd_stub(input.device().type(),
input,
false, // you don't want the full SVD for least squares
true, // we need U, S, Vh for least squares
"gesvd",
U, singular_values, Vh, infos);

auto tol = 1e-5; // what should this be? can rcond be used?
auto mask = singular_values > tol;
auto pseudo_sv = at::zeros_like(singular_values);

pseudo_sv.masked_scatter_(mask, singular_values.masked_select(mask).reciprocal());
auto uhOther = at::matmul(U.adjoint(), other);
if(pseudo_sv.dim() !=uhOther.dim()) {
pseudo_sv = pseudo_sv.unsqueeze(-1);
}
auto pseudo_sv_other = pseudo_sv * uhOther;
solution = at::matmul(Vh.adjoint(), pseudo_sv_other);
}
else {
lstsq_stub(input.device().type(), input_working_copy, solution, rank, singular_values, infos, rcond, driver);
}
// residuals are available only if m > n and drivers other than gelsy used
if (m > n && driver != "gelsy") {
// if the driver is gelss or gelsd then the residuals are available only if rank == n

bool compute_residuals = true;
if (driver == "gelss" || driver == "gelsd") {
if (input.dim() == 2) {
Expand Down Expand Up @@ -3490,8 +3515,8 @@ static std::string get_default_lstsq_driver(c10::optional<c10::string_view> driv
);
} else { // else if (input.is_cuda())
TORCH_CHECK(
driver_str == "gels",
"torch.linalg.lstsq: `driver` other than `gels` is not supported on CUDA"
(driver_str == "gels" || driver_str == "gelsd"),
"torch.linalg.lstsq: `driver` other than `gels` or `gelsd` is not supported on CUDA"
);
}
} else {
Expand Down Expand Up @@ -3548,7 +3573,6 @@ std::tuple<Tensor&, Tensor&, Tensor&, Tensor&> linalg_lstsq_out(
// to be consistent with torch.linalg.matrix_rank output dtype
ScalarType rank_expected_type = ScalarType::Long;
checkLinalgCompatibleDtype("torch.linalg.lstsq", rank.scalar_type(), rank_expected_type, "rank");

// 'singular_values' is expected to have real float dtype
checkLinalgCompatibleDtype("torch.linalg.lstsq", singular_values.scalar_type(), real_dtype, "singular_values");

Expand All @@ -3560,7 +3584,6 @@ std::tuple<Tensor&, Tensor&, Tensor&, Tensor&> linalg_lstsq_out(
: _get_epsilon(c10::toRealValueType(input.scalar_type())) * std::max<int64_t>(input.size(-2), input.size(-1));

auto infos = at::zeros({std::max<int64_t>(1, batchCount(input))}, input.options().dtype(kInt));

// now check whether the provided output tensors can be used directly

// Two types of 'other' tensors are supported:
Expand All @@ -3576,7 +3599,6 @@ std::tuple<Tensor&, Tensor&, Tensor&, Tensor&> linalg_lstsq_out(
// 1. the shape matches the expected shape
// 2. the dtype matches the expected dtype
// 3. the tensor is contiguous

// Checks for the 'solution' tensor
std::vector<int64_t> expected_solution_shape = broadcast_batch_size(input, other_2d, input.dim() - 2);
// the actual shape of the shape of the solution returned in (*, n,) or (*, n, nrhs)
Expand Down Expand Up @@ -3661,7 +3683,6 @@ std::tuple<Tensor&, Tensor&, Tensor&, Tensor&> linalg_lstsq_out(
// else use the provided output storage directly
linalg_lstsq_out_info(solution, residuals, rank, singular_values, infos, input, other, rcond_value, driver_name);
}

at::_linalg_check_errors(infos, "torch.linalg.lstsq", infos.numel() <= 1);
return std::tuple<Tensor&, Tensor&, Tensor&, Tensor&>(solution, residuals, rank, singular_values);
}
Expand Down

0 comments on commit 7372645

Please sign in to comment.