Skip to content

Commit

Permalink
As standard layout method (#616)
Browse files Browse the repository at this point in the history
This adds an `.as_standard_layout()` method which returns a standard-layout array containing the data, cloning if necessary.
  • Loading branch information
Andrew authored and jturner314 committed Jun 19, 2019
1 parent 5b5d898 commit 5c2da21
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 0 deletions.
42 changes: 42 additions & 0 deletions src/impl_methods.rs
Expand Up @@ -1218,6 +1218,48 @@ where
D::is_contiguous(&self.dim, &self.strides)
}

/// Return a standard-layout array containing the data, cloning if
/// necessary.
///
/// If `self` is in standard layout, a COW view of the data is returned
/// without cloning. Otherwise, the data is cloned, and the returned array
/// owns the cloned data.
///
/// ```
/// use ndarray::Array2;
///
/// let standard = Array2::<f64>::zeros((3, 4));
/// assert!(standard.is_standard_layout());
/// let cow_view = standard.as_standard_layout();
/// assert!(cow_view.is_view());
/// assert!(cow_view.is_standard_layout());
///
/// let fortran = standard.reversed_axes();
/// assert!(!fortran.is_standard_layout());
/// let cow_owned = fortran.as_standard_layout();
/// assert!(cow_owned.is_owned());
/// assert!(cow_owned.is_standard_layout());
/// ```
pub fn as_standard_layout(&self) -> CowArray<'_, A, D>
where
S: Data<Elem = A>,
A: Clone,
{
if self.is_standard_layout() {
CowArray::from(self.view())
} else {
let v: Vec<A> = self.iter().cloned().collect();
let dim = self.dim.clone();
assert_eq!(v.len(), dim.size());
let owned_array: Array<A, D> = unsafe {
// Safe because the shape and element type are from the existing array
// and the strides are the default strides.
Array::from_shape_vec_unchecked(dim, v)
};
CowArray::from(owned_array)
}
}

/// Return a pointer to the first element in the array.
///
/// Raw access to array elements needs to follow the strided indexing
Expand Down
68 changes: 68 additions & 0 deletions tests/array.rs
Expand Up @@ -1970,6 +1970,74 @@ fn array_macros() {
assert_eq!(empty2, array![[]]);
}

#[cfg(test)]
mod as_standard_layout_tests {
use super::*;
use ndarray::Data;
use std::fmt::Debug;

fn test_as_standard_layout_for<S, D>(orig: ArrayBase<S, D>)
where
S: Data,
S::Elem: Clone + Debug + PartialEq,
D: Dimension,
{
let orig_is_standard = orig.is_standard_layout();
let out = orig.as_standard_layout();
assert!(out.is_standard_layout());
assert_eq!(out, orig);
assert_eq!(orig_is_standard, out.is_view());
}

#[test]
fn test_f_layout() {
let shape = (2, 2).f();
let arr = Array::<i32, Ix2>::from_shape_vec(shape, vec![1, 2, 3, 4]).unwrap();
assert!(!arr.is_standard_layout());
test_as_standard_layout_for(arr);
}

#[test]
fn test_c_layout() {
let arr = Array::<i32, Ix2>::from_shape_vec((2, 2), vec![1, 2, 3, 4]).unwrap();
assert!(arr.is_standard_layout());
test_as_standard_layout_for(arr);
}

#[test]
fn test_f_layout_view() {
let shape = (2, 2).f();
let arr = Array::<i32, Ix2>::from_shape_vec(shape, vec![1, 2, 3, 4]).unwrap();
let arr_view = arr.view();
assert!(!arr_view.is_standard_layout());
test_as_standard_layout_for(arr);
}

#[test]
fn test_c_layout_view() {
let arr = Array::<i32, Ix2>::from_shape_vec((2, 2), vec![1, 2, 3, 4]).unwrap();
let arr_view = arr.view();
assert!(arr_view.is_standard_layout());
test_as_standard_layout_for(arr_view);
}

#[test]
fn test_zero_dimensional_array() {
let arr_view = ArrayView1::<i32>::from(&[]);
assert!(arr_view.is_standard_layout());
test_as_standard_layout_for(arr_view);
}

#[test]
fn test_custom_layout() {
let shape = (1, 2, 3, 2).strides((12, 1, 2, 6));
let arr_data: Vec<i32> = (0..12).collect();
let arr = Array::<i32, Ix4>::from_shape_vec(shape, arr_data).unwrap();
assert!(!arr.is_standard_layout());
test_as_standard_layout_for(arr);
}
}

#[cfg(test)]
mod array_cow_tests {
use super::*;
Expand Down

0 comments on commit 5c2da21

Please sign in to comment.