diff --git a/lax/src/eig.rs b/lax/src/eig.rs index bf3b9f8c..f11f5287 100644 --- a/lax/src/eig.rs +++ b/lax/src/eig.rs @@ -41,12 +41,12 @@ macro_rules! impl_eig_complex { } else { (EigenVectorFlag::Not, EigenVectorFlag::Not) }; - let mut eigs = unsafe { vec_uninit(n as usize) }; - let mut rwork: Vec = unsafe { vec_uninit(2 * n as usize) }; + let mut eigs: Vec> = unsafe { vec_uninit(n as usize) }; + let mut rwork: Vec> = unsafe { vec_uninit(2 * n as usize) }; - let mut vl: Option> = + let mut vl: Option>> = jobvl.then(|| unsafe { vec_uninit((n * n) as usize) }); - let mut vr: Option> = + let mut vr: Option>> = jobvr.then(|| unsafe { vec_uninit((n * n) as usize) }); // calc work size @@ -74,7 +74,7 @@ macro_rules! impl_eig_complex { // actal ev let lwork = work_size[0].to_usize().unwrap(); - let mut work: Vec = unsafe { vec_uninit(lwork) }; + let mut work: Vec> = unsafe { vec_uninit(lwork) }; let lwork = lwork as i32; unsafe { $ev( @@ -96,10 +96,14 @@ macro_rules! impl_eig_complex { }; info.as_lapack_result()?; + let eigs = unsafe { eigs.assume_init() }; + let vr = unsafe { vr.map(|v| v.assume_init()) }; + let mut vl = unsafe { vl.map(|v| v.assume_init()) }; + // Hermite conjugate if jobvl.is_calc() { for c in vl.as_mut().unwrap().iter_mut() { - c.im = -c.im + c.im = -c.im; } } @@ -145,12 +149,12 @@ macro_rules! impl_eig_real { } else { (EigenVectorFlag::Not, EigenVectorFlag::Not) }; - let mut eig_re: Vec = unsafe { vec_uninit(n as usize) }; - let mut eig_im: Vec = unsafe { vec_uninit(n as usize) }; + let mut eig_re: Vec> = unsafe { vec_uninit(n as usize) }; + let mut eig_im: Vec> = unsafe { vec_uninit(n as usize) }; - let mut vl: Option> = + let mut vl: Option>> = jobvl.then(|| unsafe { vec_uninit((n * n) as usize) }); - let mut vr: Option> = + let mut vr: Option>> = jobvr.then(|| unsafe { vec_uninit((n * n) as usize) }); // calc work size @@ -178,7 +182,7 @@ macro_rules! impl_eig_real { // actual ev let lwork = work_size[0].to_usize().unwrap(); - let mut work: Vec = unsafe { vec_uninit(lwork) }; + let mut work: Vec> = unsafe { vec_uninit(lwork) }; let lwork = lwork as i32; unsafe { $ev( @@ -200,6 +204,11 @@ macro_rules! impl_eig_real { }; info.as_lapack_result()?; + let eig_re = unsafe { eig_re.assume_init() }; + let eig_im = unsafe { eig_im.assume_init() }; + let vl = unsafe { vl.map(|v| v.assume_init()) }; + let vr = unsafe { vr.map(|v| v.assume_init()) }; + // reconstruct eigenvalues let eigs: Vec = eig_re .iter() @@ -228,14 +237,14 @@ macro_rules! impl_eig_real { let n = n as usize; let v = vr.or(vl).unwrap(); - let mut eigvecs = unsafe { vec_uninit(n * n) }; + let mut eigvecs: Vec> = unsafe { vec_uninit(n * n) }; let mut col = 0; while col < n { if eig_im[col] == 0. { // The corresponding eigenvalue is real. for row in 0..n { let re = v[row + col * n]; - eigvecs[row + col * n] = Self::complex(re, 0.); + eigvecs[row + col * n].write(Self::complex(re, 0.)); } col += 1; } else { @@ -247,12 +256,13 @@ macro_rules! impl_eig_real { if jobvl.is_calc() { im = -im; } - eigvecs[row + col * n] = Self::complex(re, im); - eigvecs[row + (col + 1) * n] = Self::complex(re, -im); + eigvecs[row + col * n].write(Self::complex(re, im)); + eigvecs[row + (col + 1) * n].write(Self::complex(re, -im)); } col += 2; } } + let eigvecs = unsafe { eigvecs.assume_init() }; Ok((eigs, eigvecs)) } diff --git a/lax/src/eigh.rs b/lax/src/eigh.rs index a8403e90..0692f921 100644 --- a/lax/src/eigh.rs +++ b/lax/src/eigh.rs @@ -42,10 +42,10 @@ macro_rules! impl_eigh { assert_eq!(layout.len(), layout.lda()); let n = layout.len(); let jobz = if calc_v { EigenVectorFlag::Calc } else { EigenVectorFlag::Not }; - let mut eigs = unsafe { vec_uninit(n as usize) }; + let mut eigs: Vec> = unsafe { vec_uninit(n as usize) }; $( - let mut $rwork_ident: Vec = unsafe { vec_uninit(3 * n as usize - 2 as usize) }; + let mut $rwork_ident: Vec> = unsafe { vec_uninit(3 * n as usize - 2 as usize) }; )* // calc work size @@ -69,7 +69,7 @@ macro_rules! impl_eigh { // actual ev let lwork = work_size[0].to_usize().unwrap(); - let mut work: Vec = unsafe { vec_uninit(lwork) }; + let mut work: Vec> = unsafe { vec_uninit(lwork) }; let lwork = lwork as i32; unsafe { $ev( @@ -86,6 +86,8 @@ macro_rules! impl_eigh { ); } info.as_lapack_result()?; + + let eigs = unsafe { eigs.assume_init() }; Ok(eigs) } @@ -99,10 +101,10 @@ macro_rules! impl_eigh { assert_eq!(layout.len(), layout.lda()); let n = layout.len(); let jobz = if calc_v { EigenVectorFlag::Calc } else { EigenVectorFlag::Not }; - let mut eigs = unsafe { vec_uninit(n as usize) }; + let mut eigs: Vec> = unsafe { vec_uninit(n as usize) }; $( - let mut $rwork_ident: Vec = unsafe { vec_uninit(3 * n as usize - 2) }; + let mut $rwork_ident: Vec> = unsafe { vec_uninit(3 * n as usize - 2) }; )* // calc work size @@ -129,7 +131,7 @@ macro_rules! impl_eigh { // actual evg let lwork = work_size[0].to_usize().unwrap(); - let mut work: Vec = unsafe { vec_uninit(lwork) }; + let mut work: Vec> = unsafe { vec_uninit(lwork) }; let lwork = lwork as i32; unsafe { $evg( @@ -149,6 +151,7 @@ macro_rules! impl_eigh { ); } info.as_lapack_result()?; + let eigs = unsafe { eigs.assume_init() }; Ok(eigs) } } diff --git a/lax/src/layout.rs b/lax/src/layout.rs index e7ab1da4..e695d8e7 100644 --- a/lax/src/layout.rs +++ b/lax/src/layout.rs @@ -37,7 +37,7 @@ //! This `S` for a matrix `A` is called "leading dimension of the array A" in LAPACK document, and denoted by `lda`. //! -use cauchy::Scalar; +use super::*; #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum MatrixLayout { @@ -153,7 +153,7 @@ impl MatrixLayout { /// ------ /// - If size of `a` and `layout` size mismatch /// -pub fn square_transpose(layout: MatrixLayout, a: &mut [T]) { +pub fn square_transpose(layout: MatrixLayout, a: &mut [T]) { let (m, n) = layout.size(); let n = n as usize; let m = m as usize; @@ -162,23 +162,78 @@ pub fn square_transpose(layout: MatrixLayout, a: &mut [T]) { for j in (i + 1)..n { let a_ij = a[i * n + j]; let a_ji = a[j * m + i]; - a[i * n + j] = a_ji.conj(); - a[j * m + i] = a_ij.conj(); + a[i * n + j] = a_ji; + a[j * m + i] = a_ij; } } } /// Out-place transpose for general matrix /// -/// Inplace transpose of non-square matrices is hard. -/// See also: https://en.wikipedia.org/wiki/In-place_matrix_transposition +/// Examples +/// --------- +/// +/// ```rust +/// # use lax::layout::*; +/// let layout = MatrixLayout::C { row: 2, lda: 3 }; +/// let a = vec![1., 2., 3., 4., 5., 6.]; +/// let (l, b) = transpose(layout, &a); +/// assert_eq!(l, MatrixLayout::F { col: 3, lda: 2 }); +/// assert_eq!(b, &[1., 4., 2., 5., 3., 6.]); +/// ``` +/// +/// ```rust +/// # use lax::layout::*; +/// let layout = MatrixLayout::F { col: 2, lda: 3 }; +/// let a = vec![1., 2., 3., 4., 5., 6.]; +/// let (l, b) = transpose(layout, &a); +/// assert_eq!(l, MatrixLayout::C { row: 3, lda: 2 }); +/// assert_eq!(b, &[1., 4., 2., 5., 3., 6.]); +/// ``` +/// +/// Panics +/// ------ +/// - If input array size and `layout` size mismatch +/// +pub fn transpose(layout: MatrixLayout, input: &[T]) -> (MatrixLayout, Vec) { + let (m, n) = layout.size(); + let transposed = layout.resized(n, m).t(); + let m = m as usize; + let n = n as usize; + assert_eq!(input.len(), m * n); + + let mut out: Vec> = unsafe { vec_uninit(m * n) }; + + match layout { + MatrixLayout::C { .. } => { + for i in 0..m { + for j in 0..n { + out[j * m + i].write(input[i * n + j]); + } + } + } + MatrixLayout::F { .. } => { + for i in 0..m { + for j in 0..n { + out[i * n + j].write(input[j * m + i]); + } + } + } + } + (transposed, unsafe { out.assume_init() }) +} + +/// Out-place transpose for general matrix +/// +/// Examples +/// --------- /// /// ```rust /// # use lax::layout::*; /// let layout = MatrixLayout::C { row: 2, lda: 3 }; /// let a = vec![1., 2., 3., 4., 5., 6.]; /// let mut b = vec![0.0; a.len()]; -/// let l = transpose(layout, &a, &mut b); +/// let l = transpose_over(layout, &a, &mut b); /// assert_eq!(l, MatrixLayout::F { col: 3, lda: 2 }); /// assert_eq!(b, &[1., 4., 2., 5., 3., 6.]); /// ``` @@ -188,16 +243,16 @@ pub fn square_transpose(layout: MatrixLayout, a: &mut [T]) { /// let layout = MatrixLayout::F { col: 2, lda: 3 }; /// let a = vec![1., 2., 3., 4., 5., 6.]; /// let mut b = vec![0.0; a.len()]; -/// let l = transpose(layout, &a, &mut b); +/// let l = transpose_over(layout, &a, &mut b); /// assert_eq!(l, MatrixLayout::C { row: 3, lda: 2 }); /// assert_eq!(b, &[1., 4., 2., 5., 3., 6.]); /// ``` /// /// Panics /// ------ -/// - If size of `a` and `layout` size mismatch +/// - If input array sizes and `layout` size mismatch /// -pub fn transpose(layout: MatrixLayout, from: &[T], to: &mut [T]) -> MatrixLayout { +pub fn transpose_over(layout: MatrixLayout, from: &[T], to: &mut [T]) -> MatrixLayout { let (m, n) = layout.size(); let transposed = layout.resized(n, m).t(); let m = m as usize; diff --git a/lax/src/least_squares.rs b/lax/src/least_squares.rs index 97f9a839..6be44f33 100644 --- a/lax/src/least_squares.rs +++ b/lax/src/least_squares.rs @@ -68,8 +68,9 @@ macro_rules! impl_least_squares { let mut a_t = None; let a_layout = match a_layout { MatrixLayout::C { .. } => { - a_t = Some(unsafe { vec_uninit( a.len()) }); - transpose(a_layout, a, a_t.as_mut().unwrap()) + let (layout, t) = transpose(a_layout, a); + a_t = Some(t); + layout } MatrixLayout::F { .. } => a_layout, }; @@ -78,14 +79,15 @@ macro_rules! impl_least_squares { let mut b_t = None; let b_layout = match b_layout { MatrixLayout::C { .. } => { - b_t = Some(unsafe { vec_uninit( b.len()) }); - transpose(b_layout, b, b_t.as_mut().unwrap()) + let (layout, t) = transpose(b_layout, b); + b_t = Some(t); + layout } MatrixLayout::F { .. } => b_layout, }; let rcond: Self::Real = -1.; - let mut singular_values: Vec = unsafe { vec_uninit( k as usize) }; + let mut singular_values: Vec> = unsafe { vec_uninit( k as usize) }; let mut rank: i32 = 0; // eval work size @@ -118,12 +120,12 @@ macro_rules! impl_least_squares { // calc let lwork = work_size[0].to_usize().unwrap(); - let mut work: Vec = unsafe { vec_uninit(lwork) }; + let mut work: Vec> = unsafe { vec_uninit(lwork) }; let liwork = iwork_size[0].to_usize().unwrap(); - let mut iwork = unsafe { vec_uninit(liwork) }; + let mut iwork: Vec> = unsafe { vec_uninit(liwork) }; $( let lrwork = $rwork[0].to_usize().unwrap(); - let mut $rwork: Vec = unsafe { vec_uninit(lrwork) }; + let mut $rwork: Vec> = unsafe { vec_uninit(lrwork) }; )* unsafe { $gelsd( @@ -140,16 +142,18 @@ macro_rules! impl_least_squares { AsPtr::as_mut_ptr(&mut work), &(lwork as i32), $(AsPtr::as_mut_ptr(&mut $rwork),)* - iwork.as_mut_ptr(), + AsPtr::as_mut_ptr(&mut iwork), &mut info, ); } info.as_lapack_result()?; + let singular_values = unsafe { singular_values.assume_init() }; + // Skip a_t -> a transpose because A has been destroyed // Re-transpose b if let Some(b_t) = b_t { - transpose(b_layout, &b_t, b); + transpose_over(b_layout, &b_t, b); } Ok(LeastSquaresOutput { diff --git a/lax/src/lib.rs b/lax/src/lib.rs index 26b740bb..c8d2264d 100644 --- a/lax/src/lib.rs +++ b/lax/src/lib.rs @@ -100,6 +100,7 @@ pub use self::triangular::*; pub use self::tridiagonal::*; use cauchy::*; +use std::mem::MaybeUninit; pub type Pivot = Vec; @@ -146,10 +147,31 @@ macro_rules! impl_as_ptr { } }; } +impl_as_ptr!(i32, i32); impl_as_ptr!(f32, f32); impl_as_ptr!(f64, f64); impl_as_ptr!(c32, lapack_sys::__BindgenComplex); impl_as_ptr!(c64, lapack_sys::__BindgenComplex); +impl_as_ptr!(MaybeUninit, i32); +impl_as_ptr!(MaybeUninit, f32); +impl_as_ptr!(MaybeUninit, f64); +impl_as_ptr!(MaybeUninit, lapack_sys::__BindgenComplex); +impl_as_ptr!(MaybeUninit, lapack_sys::__BindgenComplex); + +pub(crate) trait VecAssumeInit { + type Target; + unsafe fn assume_init(self) -> Self::Target; +} + +impl VecAssumeInit for Vec> { + type Target = Vec; + unsafe fn assume_init(self) -> Self::Target { + // FIXME use Vec::into_raw_parts instead after stablized + // https://doc.rust-lang.org/std/vec/struct.Vec.html#method.into_raw_parts + let mut me = std::mem::ManuallyDrop::new(self); + Vec::from_raw_parts(me.as_mut_ptr() as *mut T, me.len(), me.capacity()) + } +} /// Upper/Lower specification for seveal usages #[derive(Debug, Clone, Copy)] @@ -247,7 +269,7 @@ impl EigenVectorFlag { /// ------ /// - Memory is not initialized. Do not read the memory before write. /// -unsafe fn vec_uninit(n: usize) -> Vec { +unsafe fn vec_uninit(n: usize) -> Vec> { let mut v = Vec::with_capacity(n); v.set_len(n); v diff --git a/lax/src/opnorm.rs b/lax/src/opnorm.rs index ddcc2c85..fca7704c 100644 --- a/lax/src/opnorm.rs +++ b/lax/src/opnorm.rs @@ -18,7 +18,7 @@ macro_rules! impl_opnorm { MatrixLayout::F { .. } => t, MatrixLayout::C { .. } => t.transpose(), }; - let mut work: Vec = if matches!(t, NormType::Infinity) { + let mut work: Vec> = if matches!(t, NormType::Infinity) { unsafe { vec_uninit(m as usize) } } else { Vec::new() diff --git a/lax/src/qr.rs b/lax/src/qr.rs index 33de0372..553bb606 100644 --- a/lax/src/qr.rs +++ b/lax/src/qr.rs @@ -62,7 +62,7 @@ macro_rules! impl_qr { // calc let lwork = work_size[0].to_usize().unwrap(); - let mut work: Vec = unsafe { vec_uninit(lwork) }; + let mut work: Vec> = unsafe { vec_uninit(lwork) }; unsafe { match l { MatrixLayout::F { .. } => { @@ -93,6 +93,8 @@ macro_rules! impl_qr { } info.as_lapack_result()?; + let tau = unsafe { tau.assume_init() }; + Ok(tau) } @@ -134,7 +136,7 @@ macro_rules! impl_qr { // calc let lwork = work_size[0].to_usize().unwrap(); - let mut work: Vec = unsafe { vec_uninit(lwork) }; + let mut work: Vec> = unsafe { vec_uninit(lwork) }; unsafe { match l { MatrixLayout::F { .. } => $gqr( diff --git a/lax/src/rcond.rs b/lax/src/rcond.rs index fcd4211f..dfc8a941 100644 --- a/lax/src/rcond.rs +++ b/lax/src/rcond.rs @@ -17,8 +17,8 @@ macro_rules! impl_rcond_real { let mut rcond = Self::Real::zero(); let mut info = 0; - let mut work: Vec = unsafe { vec_uninit(4 * n as usize) }; - let mut iwork = unsafe { vec_uninit(n as usize) }; + let mut work: Vec> = unsafe { vec_uninit(4 * n as usize) }; + let mut iwork: Vec> = unsafe { vec_uninit(n as usize) }; let norm_type = match l { MatrixLayout::C { .. } => NormType::Infinity, MatrixLayout::F { .. } => NormType::One, @@ -32,7 +32,7 @@ macro_rules! impl_rcond_real { &anorm, &mut rcond, AsPtr::as_mut_ptr(&mut work), - iwork.as_mut_ptr(), + AsPtr::as_mut_ptr(&mut iwork), &mut info, ) }; @@ -54,8 +54,8 @@ macro_rules! impl_rcond_complex { let (n, _) = l.size(); let mut rcond = Self::Real::zero(); let mut info = 0; - let mut work: Vec = unsafe { vec_uninit(2 * n as usize) }; - let mut rwork: Vec = unsafe { vec_uninit(2 * n as usize) }; + let mut work: Vec> = unsafe { vec_uninit(2 * n as usize) }; + let mut rwork: Vec> = unsafe { vec_uninit(2 * n as usize) }; let norm_type = match l { MatrixLayout::C { .. } => NormType::Infinity, MatrixLayout::F { .. } => NormType::One, diff --git a/lax/src/solve.rs b/lax/src/solve.rs index 9c19c874..ae76f190 100644 --- a/lax/src/solve.rs +++ b/lax/src/solve.rs @@ -41,11 +41,12 @@ macro_rules! impl_solve { &l.len(), AsPtr::as_mut_ptr(a), &l.lda(), - ipiv.as_mut_ptr(), + AsPtr::as_mut_ptr(&mut ipiv), &mut info, ) }; info.as_lapack_result()?; + let ipiv = unsafe { ipiv.assume_init() }; Ok(ipiv) } @@ -74,7 +75,7 @@ macro_rules! impl_solve { // actual let lwork = work_size[0].to_usize().unwrap(); - let mut work: Vec = unsafe { vec_uninit(lwork) }; + let mut work: Vec> = unsafe { vec_uninit(lwork) }; unsafe { $getri( &l.len(), diff --git a/lax/src/solveh.rs b/lax/src/solveh.rs index c5259dda..9f65978d 100644 --- a/lax/src/solveh.rs +++ b/lax/src/solveh.rs @@ -34,7 +34,7 @@ macro_rules! impl_solveh { &n, AsPtr::as_mut_ptr(a), &l.lda(), - ipiv.as_mut_ptr(), + AsPtr::as_mut_ptr(&mut ipiv), AsPtr::as_mut_ptr(&mut work_size), &(-1), &mut info, @@ -44,27 +44,28 @@ macro_rules! impl_solveh { // actual let lwork = work_size[0].to_usize().unwrap(); - let mut work: Vec = unsafe { vec_uninit(lwork) }; + let mut work: Vec> = unsafe { vec_uninit(lwork) }; unsafe { $trf( uplo.as_ptr(), &n, AsPtr::as_mut_ptr(a), &l.lda(), - ipiv.as_mut_ptr(), + AsPtr::as_mut_ptr(&mut ipiv), AsPtr::as_mut_ptr(&mut work), &(lwork as i32), &mut info, ) }; info.as_lapack_result()?; + let ipiv = unsafe { ipiv.assume_init() }; Ok(ipiv) } fn invh(l: MatrixLayout, uplo: UPLO, a: &mut [Self], ipiv: &Pivot) -> Result<()> { let (n, _) = l.size(); let mut info = 0; - let mut work: Vec = unsafe { vec_uninit(n as usize) }; + let mut work: Vec> = unsafe { vec_uninit(n as usize) }; unsafe { $tri( uplo.as_ptr(), diff --git a/lax/src/svd.rs b/lax/src/svd.rs index 0ee56428..8c731c7a 100644 --- a/lax/src/svd.rs +++ b/lax/src/svd.rs @@ -79,7 +79,7 @@ macro_rules! impl_svd { let mut s = unsafe { vec_uninit( k as usize) }; $( - let mut $rwork_ident: Vec = unsafe { vec_uninit( 5 * k as usize) }; + let mut $rwork_ident: Vec> = unsafe { vec_uninit( 5 * k as usize) }; )* // eval work size @@ -108,7 +108,7 @@ macro_rules! impl_svd { // calc let lwork = work_size[0].to_usize().unwrap(); - let mut work: Vec = unsafe { vec_uninit( lwork) }; + let mut work: Vec> = unsafe { vec_uninit( lwork) }; unsafe { $gesvd( ju.as_ptr(), @@ -129,6 +129,11 @@ macro_rules! impl_svd { ); } info.as_lapack_result()?; + + let s = unsafe { s.assume_init() }; + let u = u.map(|v| unsafe { v.assume_init() }); + let vt = vt.map(|v| unsafe { v.assume_init() }); + match l { MatrixLayout::F { .. } => Ok(SVDOutput { s, u, vt }), MatrixLayout::C { .. } => Ok(SVDOutput { s, u: vt, vt: u }), diff --git a/lax/src/svddc.rs b/lax/src/svddc.rs index c1198286..f956d848 100644 --- a/lax/src/svddc.rs +++ b/lax/src/svddc.rs @@ -64,12 +64,12 @@ macro_rules! impl_svddc { UVTFlag::None => 7 * mn, _ => std::cmp::max(5*mn*mn + 5*mn, 2*mx*mn + 2*mn*mn + mn), }; - let mut $rwork_ident: Vec = unsafe { vec_uninit( lrwork) }; + let mut $rwork_ident: Vec> = unsafe { vec_uninit( lrwork) }; )* // eval work size let mut info = 0; - let mut iwork = unsafe { vec_uninit( 8 * k as usize) }; + let mut iwork: Vec> = unsafe { vec_uninit( 8 * k as usize) }; let mut work_size = [Self::zero()]; unsafe { $gesdd( @@ -86,7 +86,7 @@ macro_rules! impl_svddc { AsPtr::as_mut_ptr(&mut work_size), &(-1), $(AsPtr::as_mut_ptr(&mut $rwork_ident),)* - iwork.as_mut_ptr(), + AsPtr::as_mut_ptr(&mut iwork), &mut info, ); } @@ -94,7 +94,7 @@ macro_rules! impl_svddc { // do svd let lwork = work_size[0].to_usize().unwrap(); - let mut work: Vec = unsafe { vec_uninit( lwork) }; + let mut work: Vec> = unsafe { vec_uninit( lwork) }; unsafe { $gesdd( jobz.as_ptr(), @@ -110,12 +110,16 @@ macro_rules! impl_svddc { AsPtr::as_mut_ptr(&mut work), &(lwork as i32), $(AsPtr::as_mut_ptr(&mut $rwork_ident),)* - iwork.as_mut_ptr(), + AsPtr::as_mut_ptr(&mut iwork), &mut info, ); } info.as_lapack_result()?; + let s = unsafe { s.assume_init() }; + let u = u.map(|v| unsafe { v.assume_init() }); + let vt = vt.map(|v| unsafe { v.assume_init() }); + match l { MatrixLayout::F { .. } => Ok(SVDOutput { s, u, vt }), MatrixLayout::C { .. } => Ok(SVDOutput { s, u: vt, vt: u }), diff --git a/lax/src/triangular.rs b/lax/src/triangular.rs index 0288d6ba..e8825758 100644 --- a/lax/src/triangular.rs +++ b/lax/src/triangular.rs @@ -43,8 +43,9 @@ macro_rules! impl_triangular { let mut a_t = None; let a_layout = match a_layout { MatrixLayout::C { .. } => { - a_t = Some(unsafe { vec_uninit(a.len()) }); - transpose(a_layout, a, a_t.as_mut().unwrap()) + let (layout, t) = transpose(a_layout, a); + a_t = Some(t); + layout } MatrixLayout::F { .. } => a_layout, }; @@ -53,8 +54,9 @@ macro_rules! impl_triangular { let mut b_t = None; let b_layout = match b_layout { MatrixLayout::C { .. } => { - b_t = Some(unsafe { vec_uninit(b.len()) }); - transpose(b_layout, b, b_t.as_mut().unwrap()) + let (layout, t) = transpose(b_layout, b); + b_t = Some(t); + layout } MatrixLayout::F { .. } => b_layout, }; @@ -82,7 +84,7 @@ macro_rules! impl_triangular { // Re-transpose b if let Some(b_t) = b_t { - transpose(b_layout, &b_t, b); + transpose_over(b_layout, &b_t, b); } Ok(()) } diff --git a/lax/src/tridiagonal.rs b/lax/src/tridiagonal.rs index c80ad4b5..ef8dfdf6 100644 --- a/lax/src/tridiagonal.rs +++ b/lax/src/tridiagonal.rs @@ -164,11 +164,13 @@ macro_rules! impl_tridiagonal { AsPtr::as_mut_ptr(&mut a.d), AsPtr::as_mut_ptr(&mut a.du), AsPtr::as_mut_ptr(&mut du2), - ipiv.as_mut_ptr(), + AsPtr::as_mut_ptr(&mut ipiv), &mut info, ) }; info.as_lapack_result()?; + let du2 = unsafe { du2.assume_init() }; + let ipiv = unsafe { ipiv.assume_init() }; Ok(LUFactorizedTridiagonal { a, du2, @@ -180,9 +182,9 @@ macro_rules! impl_tridiagonal { fn rcond_tridiagonal(lu: &LUFactorizedTridiagonal) -> Result { let (n, _) = lu.a.l.size(); let ipiv = &lu.ipiv; - let mut work: Vec = unsafe { vec_uninit( 2 * n as usize) }; + let mut work: Vec> = unsafe { vec_uninit( 2 * n as usize) }; $( - let mut $iwork = unsafe { vec_uninit( n as usize) }; + let mut $iwork: Vec> = unsafe { vec_uninit( n as usize) }; )* let mut rcond = Self::Real::zero(); let mut info = 0; @@ -198,7 +200,7 @@ macro_rules! impl_tridiagonal { &lu.a_opnorm_one, &mut rcond, AsPtr::as_mut_ptr(&mut work), - $($iwork.as_mut_ptr(),)* + $(AsPtr::as_mut_ptr(&mut $iwork),)* &mut info, ); } @@ -218,8 +220,9 @@ macro_rules! impl_tridiagonal { let mut b_t = None; let b_layout = match b_layout { MatrixLayout::C { .. } => { - b_t = Some(unsafe { vec_uninit( b.len()) }); - transpose(b_layout, b, b_t.as_mut().unwrap()) + let (layout, t) = transpose(b_layout, b); + b_t = Some(t); + layout } MatrixLayout::F { .. } => b_layout, }; @@ -242,7 +245,7 @@ macro_rules! impl_tridiagonal { } info.as_lapack_result()?; if let Some(b_t) = b_t { - transpose(b_layout, &b_t, b); + transpose_over(b_layout, &b_t, b); } Ok(()) }