Skip to content

Commit

Permalink
Merge pull request dimforge#1043 from dimforge/svd3-fix
Browse files Browse the repository at this point in the history
Fix the special-case for 3x3 Real SVD
  • Loading branch information
sebcrozet committed Dec 9, 2021
2 parents a9890e2 + 88dd544 commit 507ead2
Show file tree
Hide file tree
Showing 6 changed files with 112 additions and 18 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/nalgebra-ci-build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -120,10 +120,10 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: Jimver/cuda-toolkit@v0.2.4
- name: Install nightly-2021-10-17
- name: Install nightly-2021-12-04
uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2021-10-17
toolchain: nightly-2021-12-04
override: true
- uses: actions/checkout@v2
- run: rustup target add nvptx64-nvidia-cuda
Expand Down
34 changes: 24 additions & 10 deletions src/linalg/svd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,16 @@ where
+ Allocator<T::RealField, DimMinimum<R, C>>
+ Allocator<T::RealField, DimDiff<DimMinimum<R, C>, U1>>,
{
fn use_special_always_ordered_svd2() -> bool {
TypeId::of::<OMatrix<T, R, C>>() == TypeId::of::<Matrix2<T::RealField>>()
&& TypeId::of::<Self>() == TypeId::of::<SVD<T::RealField, U2, U2>>()
}

fn use_special_always_ordered_svd3() -> bool {
TypeId::of::<OMatrix<T, R, C>>() == TypeId::of::<Matrix3<T::RealField>>()
&& TypeId::of::<Self>() == TypeId::of::<SVD<T::RealField, U3, U3>>()
}

/// Computes the Singular Value Decomposition of `matrix` using implicit shift.
/// The singular values are not guaranteed to be sorted in any particular order.
/// If a descending order is required, consider using `new` instead.
Expand Down Expand Up @@ -120,20 +130,16 @@ where
let (nrows, ncols) = matrix.shape_generic();
let min_nrows_ncols = nrows.min(ncols);

if TypeId::of::<OMatrix<T, R, C>>() == TypeId::of::<Matrix2<T::RealField>>()
&& TypeId::of::<Self>() == TypeId::of::<SVD<T::RealField, U2, U2>>()
{
if Self::use_special_always_ordered_svd2() {
// SAFETY: the reference transmutes are OK since we checked that the types match exactly.
let matrix: &Matrix2<T::RealField> = unsafe { std::mem::transmute(&matrix) };
let result = super::svd2::svd2(matrix, compute_u, compute_v);
let result = super::svd2::svd_ordered2(matrix, compute_u, compute_v);
let typed_result: &Self = unsafe { std::mem::transmute(&result) };
return Some(typed_result.clone());
} else if TypeId::of::<OMatrix<T, R, C>>() == TypeId::of::<Matrix3<T::RealField>>()
&& TypeId::of::<Self>() == TypeId::of::<SVD<T::RealField, U3, U3>>()
{
} else if Self::use_special_always_ordered_svd3() {
// SAFETY: the reference transmutes are OK since we checked that the types match exactly.
let matrix: &Matrix3<T::RealField> = unsafe { std::mem::transmute(&matrix) };
let result = super::svd3::svd3(matrix, compute_u, compute_v, eps, max_niter);
let result = super::svd3::svd_ordered3(matrix, compute_u, compute_v, eps, max_niter);
let typed_result: &Self = unsafe { std::mem::transmute(&result) };
return Some(typed_result.clone());
}
Expand Down Expand Up @@ -657,7 +663,11 @@ where
/// If this order is not required consider using `new_unordered`.
pub fn new(matrix: OMatrix<T, R, C>, compute_u: bool, compute_v: bool) -> Self {
let mut svd = Self::new_unordered(matrix, compute_u, compute_v);
svd.sort_by_singular_values();

if !Self::use_special_always_ordered_svd3() && !Self::use_special_always_ordered_svd2() {
svd.sort_by_singular_values();
}

svd
}

Expand All @@ -681,7 +691,11 @@ where
max_niter: usize,
) -> Option<Self> {
Self::try_new_unordered(matrix, compute_u, compute_v, eps, max_niter).map(|mut svd| {
svd.sort_by_singular_values();
if !Self::use_special_always_ordered_svd3() && !Self::use_special_always_ordered_svd2()
{
svd.sort_by_singular_values();
}

svd
})
}
Expand Down
9 changes: 8 additions & 1 deletion src/linalg/svd2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@ use crate::{Matrix2, RealField, Vector2, SVD, U2};

// Implementation of the 2D SVD from https://ieeexplore.ieee.org/document/486688
// See also https://scicomp.stackexchange.com/questions/8899/robust-algorithm-for-2-times-2-svd
pub fn svd2<T: RealField>(m: &Matrix2<T>, compute_u: bool, compute_v: bool) -> SVD<T, U2, U2> {
pub fn svd_ordered2<T: RealField>(
m: &Matrix2<T>,
compute_u: bool,
compute_v: bool,
) -> SVD<T, U2, U2> {
let half: T = crate::convert(0.5);
let one: T = crate::convert(1.0);

Expand All @@ -12,6 +16,9 @@ pub fn svd2<T: RealField>(m: &Matrix2<T>, compute_u: bool, compute_v: bool) -> S
let h = (m.m21.clone() - m.m12.clone()) * half.clone();
let q = (e.clone() * e.clone() + h.clone() * h.clone()).sqrt();
let r = (f.clone() * f.clone() + g.clone() * g.clone()).sqrt();

// Note that the singular values are always sorted because sx >= sy
// because q >= 0 and r >= 0.
let sx = q.clone() + r.clone();
let sy = q - r;
let sy_sign = if sy < T::zero() { -one.clone() } else { one };
Expand Down
40 changes: 35 additions & 5 deletions src/linalg/svd3.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,53 @@ use simba::scalar::RealField;

// For the 3x3 case, on the GPU, it is much more efficient to compute the SVD
// using an eigendecomposition followed by a QR decomposition.
pub fn svd3<T: RealField>(
//
// This is based on the paper "Computing the Singular Value Decomposition of 3 x 3 matrices with
// minimal branching and elementary floating point operations" from McAdams, et al.
pub fn svd_ordered3<T: RealField>(
m: &Matrix3<T>,
compute_u: bool,
compute_v: bool,
eps: T,
niter: usize,
) -> Option<SVD<T, U3, U3>> {
let s = m.tr_mul(&m);
let v = s.try_symmetric_eigen(eps, niter)?.eigenvectors;
let b = m * &v;
let mut v = s.try_symmetric_eigen(eps, niter)?.eigenvectors;
let mut b = m * &v;

// Sort singular values. This is a necessary step to ensure that
// the QR decompositions R matrix ends up diagonal.
let mut rho0 = b.column(0).norm_squared();
let mut rho1 = b.column(1).norm_squared();
let mut rho2 = b.column(2).norm_squared();

if rho0 < rho1 {
b.swap_columns(0, 1);
b.column_mut(1).neg_mut();
v.swap_columns(0, 1);
v.column_mut(1).neg_mut();
std::mem::swap(&mut rho0, &mut rho1);
}
if rho0 < rho2 {
b.swap_columns(0, 2);
b.column_mut(2).neg_mut();
v.swap_columns(0, 2);
v.column_mut(2).neg_mut();
std::mem::swap(&mut rho0, &mut rho2);
}
if rho1 < rho2 {
b.swap_columns(1, 2);
b.column_mut(2).neg_mut();
v.swap_columns(1, 2);
v.column_mut(2).neg_mut();
std::mem::swap(&mut rho0, &mut rho2);
}

let qr = b.qr();
let singular_values = qr.diag_internal().map(|e| e.abs());

Some(SVD {
u: if compute_u { Some(qr.q()) } else { None },
singular_values,
singular_values: qr.diag_internal().map(|e| e.abs()),
v_t: if compute_v { Some(v.transpose()) } else { None },
})
}
7 changes: 7 additions & 0 deletions tests/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,10 @@ mod proptest;
//#[cfg(all(feature = "debug", feature = "compare", feature = "rand"))]
//#[cfg(feature = "sparse")]
//mod sparse;

mod utils {
/// Checks if a slice is sorted in descending order.
pub fn is_sorted_descending<T: PartialOrd>(slice: &[T]) -> bool {
slice.windows(2).all(|elts| elts[0] >= elts[1])
}
}
36 changes: 36 additions & 0 deletions tests/linalg/svd.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use crate::utils::is_sorted_descending;
use na::{DMatrix, Matrix6};

#[cfg(feature = "proptest-support")]
Expand All @@ -14,6 +15,7 @@ mod proptest_tests {
use crate::core::helper::{RandScalar, RandComplex};
use crate::proptest::*;
use proptest::{prop_assert, proptest};
use crate::utils::is_sorted_descending;

proptest! {
#[test]
Expand All @@ -26,6 +28,7 @@ mod proptest_tests {
prop_assert!(s.iter().all(|e| *e >= 0.0));
prop_assert!(relative_eq!(&u * ds * &v_t, recomp_m, epsilon = 1.0e-5));
prop_assert!(relative_eq!(m, recomp_m, epsilon = 1.0e-5));
prop_assert!(is_sorted_descending(s.as_slice()));
}

#[test]
Expand All @@ -38,6 +41,7 @@ mod proptest_tests {
prop_assert!(relative_eq!(m, &u * ds * &v_t, epsilon = 1.0e-5));
prop_assert!(u.is_orthogonal(1.0e-5));
prop_assert!(v_t.is_orthogonal(1.0e-5));
prop_assert!(is_sorted_descending(s.as_slice()));
}

#[test]
Expand All @@ -50,6 +54,7 @@ mod proptest_tests {
prop_assert!(relative_eq!(m, &u * ds * &v_t, epsilon = 1.0e-5));
prop_assert!(u.is_orthogonal(1.0e-5));
prop_assert!(v_t.is_orthogonal(1.0e-5));
prop_assert!(is_sorted_descending(s.as_slice()));
}

#[test]
Expand All @@ -61,6 +66,7 @@ mod proptest_tests {

prop_assert!(s.iter().all(|e| *e >= 0.0));
prop_assert!(relative_eq!(m, u * ds * v_t, epsilon = 1.0e-5));
prop_assert!(is_sorted_descending(s.as_slice()));
}

#[test]
Expand All @@ -71,6 +77,7 @@ mod proptest_tests {

prop_assert!(s.iter().all(|e| *e >= 0.0));
prop_assert!(relative_eq!(m, u * ds * v_t, epsilon = 1.0e-5));
prop_assert!(is_sorted_descending(s.as_slice()));
}

#[test]
Expand All @@ -83,6 +90,7 @@ mod proptest_tests {
prop_assert!(relative_eq!(m, u * ds * v_t, epsilon = 1.0e-5));
prop_assert!(u.is_orthogonal(1.0e-5));
prop_assert!(v_t.is_orthogonal(1.0e-5));
prop_assert!(is_sorted_descending(s.as_slice()));
}

#[test]
Expand All @@ -95,6 +103,7 @@ mod proptest_tests {
prop_assert!(relative_eq!(m, u * ds * v_t, epsilon = 1.0e-5));
prop_assert!(u.is_orthogonal(1.0e-5));
prop_assert!(v_t.is_orthogonal(1.0e-5));
prop_assert!(is_sorted_descending(s.as_slice()));
}

#[test]
Expand All @@ -107,6 +116,7 @@ mod proptest_tests {
prop_assert!(relative_eq!(m, u * ds * v_t, epsilon = 1.0e-5));
prop_assert!(u.is_orthogonal(1.0e-5));
prop_assert!(v_t.is_orthogonal(1.0e-5));
prop_assert!(is_sorted_descending(s.as_slice()));
}

#[test]
Expand Down Expand Up @@ -187,6 +197,7 @@ fn svd_singular() {
let ds = DMatrix::from_diagonal(&s);

assert!(s.iter().all(|e| *e >= 0.0));
assert!(is_sorted_descending(s.as_slice()));
assert!(u.is_orthogonal(1.0e-5));
assert!(v_t.is_orthogonal(1.0e-5));
assert_relative_eq!(m, &u * ds * &v_t, epsilon = 1.0e-5);
Expand Down Expand Up @@ -229,6 +240,7 @@ fn svd_singular_vertical() {
let ds = DMatrix::from_diagonal(&s);

assert!(s.iter().all(|e| *e >= 0.0));
assert!(is_sorted_descending(s.as_slice()));
assert_relative_eq!(m, &u * ds * &v_t, epsilon = 1.0e-5);
}

Expand Down Expand Up @@ -267,6 +279,7 @@ fn svd_singular_horizontal() {
let ds = DMatrix::from_diagonal(&s);

assert!(s.iter().all(|e| *e >= 0.0));
assert!(is_sorted_descending(s.as_slice()));
assert_relative_eq!(m, &u * ds * &v_t, epsilon = 1.0e-5);
}

Expand Down Expand Up @@ -350,6 +363,29 @@ fn svd_fail() {
assert_relative_eq!(m, recomp, epsilon = 1.0e-5);
}

#[test]
#[rustfmt::skip]
fn svd3_fail() {
// NOTE: this matrix fails the special case done for 3x3 SVDs.
// It was found on an actual application using SVD as part of the minimization of a
// quadratic error function.
let m = nalgebra::matrix![
0.0, 1.0, 0.0;
0.0, 1.7320508075688772, 0.0;
0.0, 0.0, 0.0
];

// Check unordered ...
let svd = m.svd_unordered(true, true);
let recomp = svd.recompose().unwrap();
assert_relative_eq!(m, recomp, epsilon = 1.0e-5);

// ... and ordered SVD.
let svd = m.svd(true, true);
let recomp = svd.recompose().unwrap();
assert_relative_eq!(m, recomp, epsilon = 1.0e-5);
}

#[test]
fn svd_err() {
let m = DMatrix::from_element(10, 10, 0.0);
Expand Down

0 comments on commit 507ead2

Please sign in to comment.