Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Matrix::pow and make it work with integer matrices #1055

Merged
merged 2 commits into from Dec 31, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
90 changes: 39 additions & 51 deletions src/linalg/pow.rs
@@ -1,83 +1,71 @@
//! This module provides the matrix exponential (pow) function to square matrices.

use std::ops::DivAssign;

use crate::{
allocator::Allocator,
storage::{Storage, StorageMut},
DefaultAllocator, DimMin, Matrix, OMatrix,
DefaultAllocator, DimMin, Matrix, OMatrix, Scalar,
};
use num::PrimInt;
use simba::scalar::ComplexField;
use num::{One, Zero};
use simba::scalar::{ClosedAdd, ClosedMul};

impl<T: ComplexField, D, S> Matrix<T, D, D, S>
impl<T, D, S> Matrix<T, D, D, S>
where
T: Scalar + Zero + One + ClosedAdd + ClosedMul,
D: DimMin<D, Output = D>,
S: StorageMut<T, D, D>,
DefaultAllocator: Allocator<T, D, D> + Allocator<T, D>,
{
/// Attempts to raise this matrix to an integral power `e` in-place. If this
/// matrix is non-invertible and `e` is negative, it leaves this matrix
/// untouched and returns `false`. Otherwise, it returns `true` and
/// overwrites this matrix with the result.
pub fn pow_mut<I: PrimInt + DivAssign>(&mut self, mut e: I) -> bool {
let zero = I::zero();

/// Raises this matrix to an integral power `exp` in-place.
pub fn pow_mut(&mut self, mut exp: u32) {
// A matrix raised to the zeroth power is just the identity.
if e == zero {
if exp == 0 {
self.fill_with_identity();
return true;
}

// If e is negative, we compute the inverse matrix, then raise it to the
// power of -e.
if e < zero && !self.try_inverse_mut() {
return false;
}
} else if exp > 1 {
// We use the buffer to hold the result of multiplier^2, thus avoiding
// extra allocations.
let mut x = self.clone_owned();
let mut workspace = self.clone_owned();

let one = I::one();
let two = I::from(2u8).unwrap();
if exp % 2 == 0 {
self.fill_with_identity();
} else {
// Avoid an useless multiplication by the identity
// if the exponent is odd.
exp -= 1;
}

// We use the buffer to hold the result of multiplier ^ 2, thus avoiding
// extra allocations.
let mut multiplier = self.clone_owned();
let mut buf = self.clone_owned();
// Exponentiation by squares.
loop {
if exp % 2 == 1 {
self.mul_to(&x, &mut workspace);
self.copy_from(&workspace);
}

// Exponentiation by squares.
loop {
if e % two == one {
self.mul_to(&multiplier, &mut buf);
self.copy_from(&buf);
}
exp /= 2;

e /= two;
multiplier.mul_to(&multiplier, &mut buf);
multiplier.copy_from(&buf);
if exp == 0 {
break;
}

if e == zero {
return true;
x.mul_to(&x, &mut workspace);
x.copy_from(&workspace);
}
}
}
}

impl<T: ComplexField, D, S: Storage<T, D, D>> Matrix<T, D, D, S>
impl<T, D, S: Storage<T, D, D>> Matrix<T, D, D, S>
where
T: Scalar + Zero + One + ClosedAdd + ClosedMul,
D: DimMin<D, Output = D>,
S: StorageMut<T, D, D>,
DefaultAllocator: Allocator<T, D, D> + Allocator<T, D>,
{
/// Attempts to raise this matrix to an integral power `e`. If this matrix
/// is non-invertible and `e` is negative, it returns `None`. Otherwise, it
/// returns the result as a new matrix. Uses exponentiation by squares.
/// Raise this matrix to an integral power `exp`.
#[must_use]
pub fn pow<I: PrimInt + DivAssign>(&self, e: I) -> Option<OMatrix<T, D, D>> {
let mut clone = self.clone_owned();

if clone.pow_mut(e) {
Some(clone)
} else {
None
}
pub fn pow(&self, exp: u32) -> OMatrix<T, D, D> {
let mut result = self.clone_owned();
result.pow_mut(exp);
result
}
}
1 change: 1 addition & 0 deletions tests/linalg/mod.rs
Expand Up @@ -9,6 +9,7 @@ mod full_piv_lu;
mod hessenberg;
mod inverse;
mod lu;
mod pow;
mod qr;
mod schur;
mod solve;
Expand Down
49 changes: 49 additions & 0 deletions tests/linalg/pow.rs
@@ -0,0 +1,49 @@
#[cfg(feature = "proptest-support")]
mod proptest_tests {
macro_rules! gen_tests(
($module: ident, $scalar: expr, $scalar_type: ty) => {
mod $module {
use na::DMatrix;
#[allow(unused_imports)]
use crate::core::helper::{RandScalar, RandComplex};
use std::cmp;

use crate::proptest::*;
use proptest::{prop_assert, proptest};

proptest! {
#[test]
fn pow(n in PROPTEST_MATRIX_DIM, p in 0u32..=4) {
let n = cmp::max(1, cmp::min(n, 10));
let m = DMatrix::<$scalar_type>::new_random(n, n).map(|e| e.0);
let m_pow = m.pow(p);
let mut expected = m.clone();
expected.fill_with_identity();

for _ in 0..p {
expected = &m * &expected;
}

prop_assert!(relative_eq!(m_pow, expected, epsilon = 1.0e-5))
}

#[test]
fn pow_static_square_4x4(m in matrix4_($scalar), p in 0u32..=4) {
let mut expected = m.clone();
let m_pow = m.pow(p);
expected.fill_with_identity();

for _ in 0..p {
expected = &m * &expected;
}

prop_assert!(relative_eq!(m_pow, expected, epsilon = 1.0e-5))
}
}
}
}
);

gen_tests!(complex, complex_f64(), RandComplex<f64>);
gen_tests!(f64, PROPTEST_F64, RandScalar<f64>);
}