Skip to content

Commit

Permalink
Implement approx traits for ArrayBase (#581)
Browse files Browse the repository at this point in the history
* Implement approx traits for ArrayBase

* Feature-gate approx trait implementations

* Use Zip::all where possible

* Mark `all_close` as deprecated

* Fix implementation.

* Fix issues with conditional execution based on activated features

* Remove all_close from all doc tests

* Update guide for NumPy users

* Replace all_close with abs_diff_eq in tests (currently failing)

* Fix typo, pin 0.3.2 to get latest changes

* Allow comparison between arrays with different ownership properties

* Fix assertions

* Fix test

* Move tests from all_close to approx

* Move tests from all_close to approx

* Impl approx traits for differing element types

* Fix unused import warning

* Remove duplicate type parameter

* Fix link in docs

* Fix tests

* Fix formatting

* Remove unnecessary &
  • Loading branch information
jturner314 authored and LukeMathWalker committed May 5, 2019
1 parent 7733b7c commit 97def32
Show file tree
Hide file tree
Showing 15 changed files with 274 additions and 68 deletions.
6 changes: 4 additions & 2 deletions Cargo.toml
Expand Up @@ -35,6 +35,8 @@ itertools = { version = "0.8.0", default-features = false }

rayon = { version = "1.0.3", optional = true }

approx = { version = "0.3.2", optional = true }

# Use via the `blas` crate feature!
cblas-sys = { version = "0.1.4", optional = true, default-features = false }
blas-src = { version = "0.2.0", optional = true, default-features = false }
Expand All @@ -47,8 +49,8 @@ serde = { version = "1.0", optional = true }
defmac = "0.2"
quickcheck = { version = "0.8", default-features = false }
rawpointer = "0.1"
approx = "0.3.2"
itertools = { version = "0.8.0", default-features = false, features = ["use_std"] }
approx = "0.3"

[features]
# Enable blas usage
Expand All @@ -63,7 +65,7 @@ test-blas-openblas-sys = ["blas"]
test = ["test-blas-openblas-sys"]

# This feature is used for docs
docs = ["serde-1", "rayon"]
docs = ["approx", "serde-1", "rayon"]

[profile.release]
[profile.bench]
Expand Down
139 changes: 139 additions & 0 deletions src/array_approx.rs
@@ -0,0 +1,139 @@
use crate::imp_prelude::*;
use crate::Zip;
use approx::{AbsDiffEq, RelativeEq, UlpsEq};

/// **Requires crate feature `"approx"`**
impl<A, B, S, S2, D> AbsDiffEq<ArrayBase<S2, D>> for ArrayBase<S, D>
where
A: AbsDiffEq<B>,
A::Epsilon: Clone,
S: Data<Elem = A>,
S2: Data<Elem = B>,
D: Dimension,
{
type Epsilon = A::Epsilon;

fn default_epsilon() -> A::Epsilon {
A::default_epsilon()
}

fn abs_diff_eq(&self, other: &ArrayBase<S2, D>, epsilon: A::Epsilon) -> bool {
if self.shape() != other.shape() {
return false;
}
Zip::from(self)
.and(other)
.all(|a, b| A::abs_diff_eq(a, b, epsilon.clone()))
}
}

/// **Requires crate feature `"approx"`**
impl<A, B, S, S2, D> RelativeEq<ArrayBase<S2, D>> for ArrayBase<S, D>
where
A: RelativeEq<B>,
A::Epsilon: Clone,
S: Data<Elem = A>,
S2: Data<Elem = B>,
D: Dimension,
{
fn default_max_relative() -> A::Epsilon {
A::default_max_relative()
}

fn relative_eq(
&self,
other: &ArrayBase<S2, D>,
epsilon: A::Epsilon,
max_relative: A::Epsilon,
) -> bool {
if self.shape() != other.shape() {
return false;
}
Zip::from(self)
.and(other)
.all(|a, b| A::relative_eq(a, b, epsilon.clone(), max_relative.clone()))
}
}

/// **Requires crate feature `"approx"`**
impl<A, B, S, S2, D> UlpsEq<ArrayBase<S2, D>> for ArrayBase<S, D>
where
A: UlpsEq<B>,
A::Epsilon: Clone,
S: Data<Elem = A>,
S2: Data<Elem = B>,
D: Dimension,
{
fn default_max_ulps() -> u32 {
A::default_max_ulps()
}

fn ulps_eq(&self, other: &ArrayBase<S2, D>, epsilon: A::Epsilon, max_ulps: u32) -> bool {
if self.shape() != other.shape() {
return false;
}
Zip::from(self)
.and(other)
.all(|a, b| A::ulps_eq(a, b, epsilon.clone(), max_ulps))
}
}

#[cfg(test)]
mod tests {
use crate::prelude::*;
use approx::{
assert_abs_diff_eq, assert_abs_diff_ne, assert_relative_eq, assert_relative_ne,
assert_ulps_eq, assert_ulps_ne,
};

#[test]
fn abs_diff_eq() {
let a: Array2<f32> = array![[0., 2.], [-0.000010001, 100000000.]];
let mut b: Array2<f32> = array![[0., 1.], [-0.000010002, 100000001.]];
assert_abs_diff_ne!(a, b);
b[(0, 1)] = 2.;
assert_abs_diff_eq!(a, b);

// Check epsilon.
assert_abs_diff_eq!(array![0.0f32], array![1e-40f32], epsilon = 1e-40f32);
assert_abs_diff_ne!(array![0.0f32], array![1e-40f32], epsilon = 1e-41f32);

// Make sure we can compare different shapes without failure.
let c = array![[1., 2.]];
assert_abs_diff_ne!(a, c);
}

#[test]
fn relative_eq() {
let a: Array2<f32> = array![[1., 2.], [-0.000010001, 100000000.]];
let mut b: Array2<f32> = array![[1., 1.], [-0.000010002, 100000001.]];
assert_relative_ne!(a, b);
b[(0, 1)] = 2.;
assert_relative_eq!(a, b);

// Check epsilon.
assert_relative_eq!(array![0.0f32], array![1e-40f32], epsilon = 1e-40f32);
assert_relative_ne!(array![0.0f32], array![1e-40f32], epsilon = 1e-41f32);

// Make sure we can compare different shapes without failure.
let c = array![[1., 2.]];
assert_relative_ne!(a, c);
}

#[test]
fn ulps_eq() {
let a: Array2<f32> = array![[1., 2.], [-0.000010001, 100000000.]];
let mut b: Array2<f32> = array![[1., 1.], [-0.000010002, 100000001.]];
assert_ulps_ne!(a, b);
b[(0, 1)] = 2.;
assert_ulps_eq!(a, b);

// Check epsilon.
assert_ulps_eq!(array![0.0f32], array![1e-40f32], epsilon = 1e-40f32);
assert_ulps_ne!(array![0.0f32], array![1e-40f32], epsilon = 1e-41f32);

// Make sure we can compare different shapes without failure.
let c = array![[1., 2.]];
assert_ulps_ne!(a, c);
}
}
4 changes: 2 additions & 2 deletions src/doc/ndarray_for_numpy_users/mod.rs
Expand Up @@ -473,7 +473,7 @@
//!
//! </td><td>
//!
//! [`a.all_close(&b, 1e-8)`][.all_close()]
//! [`a.abs_diff_eq(&b, 1e-8)`][.abs_diff_eq()]
//!
//! </td><td>
//!
Expand Down Expand Up @@ -557,7 +557,7 @@
//! `a[:,4]` | [`a.column(4)`][.column()] or [`a.column_mut(4)`][.column_mut()] | view (or mutable view) of column 4 in a 2-D array
//! `a.shape[0] == a.shape[1]` | [`a.is_square()`][.is_square()] | check if the array is square
//!
//! [.all_close()]: ../../struct.ArrayBase.html#method.all_close
//! [.abs_diff_eq()]: ../../struct.ArrayBase.html#impl-AbsDiffEq<ArrayBase<S2%2C%20D>>
//! [ArcArray]: ../../type.ArcArray.html
//! [arr2()]: ../../fn.arr2.html
//! [array!]: ../../macro.array.html
Expand Down
13 changes: 8 additions & 5 deletions src/geomspace.rs
Expand Up @@ -106,21 +106,24 @@ where
#[cfg(test)]
mod tests {
use super::geomspace;
use crate::{arr1, Array1};

#[test]
#[cfg(feature = "approx")]
fn valid() {
use approx::assert_abs_diff_eq;
use crate::{arr1, Array1};

let array: Array1<_> = geomspace(1e0, 1e3, 4).collect();
assert!(array.all_close(&arr1(&[1e0, 1e1, 1e2, 1e3]), 1e-5));
assert_abs_diff_eq!(array, arr1(&[1e0, 1e1, 1e2, 1e3]), epsilon = 1e-12);

let array: Array1<_> = geomspace(1e3, 1e0, 4).collect();
assert!(array.all_close(&arr1(&[1e3, 1e2, 1e1, 1e0]), 1e-5));
assert_abs_diff_eq!(array, arr1(&[1e3, 1e2, 1e1, 1e0]), epsilon = 1e-12);

let array: Array1<_> = geomspace(-1e3, -1e0, 4).collect();
assert!(array.all_close(&arr1(&[-1e3, -1e2, -1e1, -1e0]), 1e-5));
assert_abs_diff_eq!(array, arr1(&[-1e3, -1e2, -1e1, -1e0]), epsilon = 1e-12);

let array: Array1<_> = geomspace(-1e0, -1e3, 4).collect();
assert!(array.all_close(&arr1(&[-1e0, -1e1, -1e2, -1e3]), 1e-5));
assert_abs_diff_eq!(array, arr1(&[-1e0, -1e1, -1e2, -1e3]), epsilon = 1e-12);
}

#[test]
Expand Down
14 changes: 10 additions & 4 deletions src/impl_constructors.rs
Expand Up @@ -111,13 +111,16 @@ impl<S, A> ArrayBase<S, Ix1>
/// **Panics** if the length is greater than `isize::MAX`.
///
/// ```rust
/// use approx::assert_abs_diff_eq;
/// use ndarray::{Array, arr1};
///
/// # #[cfg(feature = "approx")] {
/// let array = Array::logspace(10.0, 0.0, 3.0, 4);
/// assert!(array.all_close(&arr1(&[1e0, 1e1, 1e2, 1e3]), 1e-5));
/// assert_abs_diff_eq!(array, arr1(&[1e0, 1e1, 1e2, 1e3]));
///
/// let array = Array::logspace(-10.0, 3.0, 0.0, 4);
/// assert!(array.all_close(&arr1(&[-1e3, -1e2, -1e1, -1e0]), 1e-5));
/// assert_abs_diff_eq!(array, arr1(&[-1e3, -1e2, -1e1, -1e0]));
/// # }
/// ```
pub fn logspace(base: A, start: A, end: A, n: usize) -> Self
where
Expand All @@ -136,13 +139,16 @@ impl<S, A> ArrayBase<S, Ix1>
/// **Panics** if `n` is greater than `isize::MAX`.
///
/// ```rust
/// use approx::assert_abs_diff_eq;
/// use ndarray::{Array, arr1};
///
/// # #[cfg(feature = "approx")] {
/// let array = Array::geomspace(1e0, 1e3, 4);
/// assert!(array.all_close(&arr1(&[1e0, 1e1, 1e2, 1e3]), 1e-5));
/// assert_abs_diff_eq!(array, arr1(&[1e0, 1e1, 1e2, 1e3]), epsilon = 1e-12);
///
/// let array = Array::geomspace(-1e3, -1e0, 4);
/// assert!(array.all_close(&arr1(&[-1e3, -1e2, -1e1, -1e0]), 1e-5));
/// assert_abs_diff_eq!(array, arr1(&[-1e3, -1e2, -1e1, -1e0]), epsilon = 1e-12);
/// # }
/// ```
pub fn geomspace(start: A, end: A, n: usize) -> Self
where
Expand Down
11 changes: 8 additions & 3 deletions src/impl_methods.rs
Expand Up @@ -2039,15 +2039,20 @@ where
/// Elements are visited in arbitrary order.
///
/// ```
/// use approx::assert_abs_diff_eq;
/// use ndarray::arr2;
///
/// # #[cfg(feature = "approx")] {
/// let mut a = arr2(&[[ 0., 1.],
/// [-1., 2.]]);
/// a.mapv_inplace(f32::exp);
/// assert!(
/// a.all_close(&arr2(&[[1.00000, 2.71828],
/// [0.36788, 7.38906]]), 1e-5)
/// assert_abs_diff_eq!(
/// a,
/// arr2(&[[1.00000, 2.71828],
/// [0.36788, 7.38906]]),
/// epsilon = 1e-5,
/// );
/// # }
/// ```
pub fn mapv_inplace<F>(&mut self, mut f: F)
where S: DataMut,
Expand Down
8 changes: 8 additions & 0 deletions src/lib.rs
Expand Up @@ -68,6 +68,9 @@
//! - `rayon`
//! - Optional, compatible with Rust stable
//! - Enables parallel iterators, parallelized methods and [`par_azip!`].
//! - `approx`
//! - Optional, compatible with Rust stable
//! - Enables implementations of traits from the [`approx`] crate.
//! - `blas`
//! - Optional and experimental, compatible with Rust stable
//! - Enable transparent BLAS support for matrix multiplication.
Expand All @@ -90,6 +93,9 @@ extern crate serde;
#[cfg(feature="rayon")]
extern crate rayon;

#[cfg(feature="approx")]
extern crate approx;

#[cfg(feature="blas")]
extern crate cblas_sys;
#[cfg(feature="blas")]
Expand Down Expand Up @@ -146,6 +152,8 @@ mod aliases;
mod arraytraits;
#[cfg(feature = "serde-1")]
mod array_serde;
#[cfg(feature = "approx")]
mod array_approx;
mod arrayformat;
mod data_traits;

Expand Down
13 changes: 8 additions & 5 deletions src/logspace.rs
Expand Up @@ -96,21 +96,24 @@ where
#[cfg(test)]
mod tests {
use super::logspace;
use crate::{arr1, Array1};

#[test]
#[cfg(feature = "approx")]
fn valid() {
use approx::assert_abs_diff_eq;
use crate::{arr1, Array1};

let array: Array1<_> = logspace(10.0, 0.0, 3.0, 4).collect();
assert!(array.all_close(&arr1(&[1e0, 1e1, 1e2, 1e3]), 1e-5));
assert_abs_diff_eq!(array, arr1(&[1e0, 1e1, 1e2, 1e3]));

let array: Array1<_> = logspace(10.0, 3.0, 0.0, 4).collect();
assert!(array.all_close(&arr1(&[1e3, 1e2, 1e1, 1e0]), 1e-5));
assert_abs_diff_eq!(array, arr1(&[1e3, 1e2, 1e1, 1e0]));

let array: Array1<_> = logspace(-10.0, 3.0, 0.0, 4).collect();
assert!(array.all_close(&arr1(&[-1e3, -1e2, -1e1, -1e0]), 1e-5));
assert_abs_diff_eq!(array, arr1(&[-1e3, -1e2, -1e1, -1e0]));

let array: Array1<_> = logspace(-10.0, 0.0, 3.0, 4).collect();
assert!(array.all_close(&arr1(&[-1e0, -1e1, -1e2, -1e3]), 1e-5));
assert_abs_diff_eq!(array, arr1(&[-1e0, -1e1, -1e2, -1e3]));
}

#[test]
Expand Down
1 change: 1 addition & 0 deletions src/numeric/impl_numeric.rs
Expand Up @@ -306,6 +306,7 @@ impl<A, S, D> ArrayBase<S, D>
/// If their shapes disagree, `rhs` is broadcast to the shape of `self`.
///
/// **Panics** if broadcasting to the same shape isn’t possible.
#[deprecated(note="Use `abs_diff_eq` - it requires the `approx` crate feature", since="0.13")]
pub fn all_close<S2, E>(&self, rhs: &ArrayBase<S2, E>, tol: A) -> bool
where A: Float,
S2: Data<Elem=A>,
Expand Down
5 changes: 4 additions & 1 deletion tests/array-construct.rs
Expand Up @@ -21,14 +21,17 @@ fn test_dimension_zero() {
}

#[test]
#[cfg(feature = "approx")]
fn test_arc_into_owned() {
use approx::assert_abs_diff_ne;

let a = Array2::from_elem((5, 5), 1.).into_shared();
let mut b = a.clone();
b.fill(0.);
let mut c = b.into_owned();
c.fill(2.);
// test that they are unshared
assert!(!a.all_close(&c, 0.01));
assert_abs_diff_ne!(a, c, epsilon = 0.01);
}

#[test]
Expand Down

0 comments on commit 97def32

Please sign in to comment.