diff --git a/lax/src/least_squares.rs b/lax/src/least_squares.rs index 532d1e6b..fa730150 100644 --- a/lax/src/least_squares.rs +++ b/lax/src/least_squares.rs @@ -12,6 +12,14 @@ pub struct LeastSquaresOwned { pub rank: i32, } +/// Result of LeastSquares +pub struct LeastSquaresRef<'work, A: Scalar> { + /// singular values + pub singular_values: &'work [A::Real], + /// The rank of the input matrix A + pub rank: i32, +} + #[cfg_attr(doc, katexit::katexit)] /// Solve least square problem pub trait LeastSquaresSvdDivideConquer_: Scalar { @@ -29,7 +37,89 @@ pub trait LeastSquaresSvdDivideConquer_: Scalar { a: &mut [Self], b_layout: MatrixLayout, b: &mut [Self], - ) -> Result>; + ) -> Result>; +} + +pub struct LeastSquaresWork { + pub a_layout: MatrixLayout, + pub b_layout: MatrixLayout, + pub singular_values: Vec>, + pub work: Vec>, + pub iwork: Vec>, + pub rwork: Option>>, +} + +pub trait LeastSquaresWorkImpl: Sized { + type Elem: Scalar; + fn new(a_layout: MatrixLayout, b_layout: MatrixLayout) -> Result; + fn calc(&mut self, a: &mut [Self], b: &mut [Self]) -> Result>; + fn eval(self, a: &mut [Self], b: &mut [Self]) -> Result>; +} + +impl LeastSquaresWorkImpl for LeastSquaresWork { + type Elem = c64; + + fn new(a_layout: MatrixLayout, b_layout: MatrixLayout) -> Result { + let (m, n) = a_layout.size(); + let (m_, nrhs) = b_layout.size(); + let k = m.min(n); + assert!(m_ >= m); + + let rcond = -1.; + let mut singular_values = vec_uninit(k as usize); + let mut rank: i32 = 0; + + // eval work size + let mut info = 0; + let mut work_size = [Self::Elem::zero()]; + let mut iwork_size = [0]; + let mut rwork = [::Real::zero()]; + unsafe { + lapack_sys::zgelsd_( + &m, + &n, + &nrhs, + std::ptr::null_mut(), + &a_layout.lda(), + std::ptr::null_mut(), + &b_layout.lda(), + AsPtr::as_mut_ptr(&mut singular_values), + &rcond, + &mut rank, + AsPtr::as_mut_ptr(&mut work_size), + &(-1), + AsPtr::as_mut_ptr(&mut rwork), + iwork_size.as_mut_ptr(), + &mut info, + ) + }; + info.as_lapack_result()?; + + let lwork = work_size[0].to_usize().unwrap(); + let liwork = iwork_size[0].to_usize().unwrap(); + let lrwork = rwork[0].to_usize().unwrap(); + + let work = vec_uninit(lwork); + let iwork = vec_uninit(liwork); + let rwork = vec_uninit(lrwork); + + Ok(LeastSquaresWork { + a_layout, + b_layout, + work, + iwork, + rwork: Some(rwork), + singular_values, + }) + } + + fn calc(&mut self, a: &mut [Self], b: &mut [Self]) -> Result> { + todo!() + } + + fn eval(self, a: &mut [Self], b: &mut [Self]) -> Result> { + todo!() + } } macro_rules! impl_least_squares {