Skip to content

Commit

Permalink
Guarantee that as_slice_memory_order_mut preserves strides
Browse files Browse the repository at this point in the history
This fixes bugs in `.map_mut()` and `.zip_mut_with_same_shape()`.
Before this commit, strides obtained before calling
`.as_slice_memory_order_mut()` could not be used to correctly
interpret the data in the returned slice. Now, the strides are
preserved, so the implementations of `.map_mut()` and
`.zip_mut_with_same_shape()` work correctly. This also makes it much
easier for users of the crate to use `.as_slice_memory_order_mut()`
correctly in generic code.

Fixes #1018.
  • Loading branch information
jturner314 committed May 31, 2021
1 parent 15b0808 commit 37645bd
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 12 deletions.
18 changes: 8 additions & 10 deletions src/data_traits.rs
Expand Up @@ -50,8 +50,11 @@ pub unsafe trait RawData: Sized {
pub unsafe trait RawDataMut: RawData {
/// If possible, ensures that the array has unique access to its data.
///
/// If `Self` provides safe mutable access to array elements, then it
/// **must** panic or ensure that the data is unique.
/// The implementer must ensure that if the input is contiguous, then the
/// output has the same strides as input.
///
/// Additionally, if `Self` provides safe mutable access to array elements,
/// then this method **must** panic or ensure that the data is unique.
#[doc(hidden)]
fn try_ensure_unique<D>(_: &mut ArrayBase<Self, D>)
where
Expand Down Expand Up @@ -230,14 +233,9 @@ where
return;
}
if self_.dim.size() <= self_.data.0.len() / 2 {
// Create a new vec if the current view is less than half of
// backing data.
unsafe {
*self_ = ArrayBase::from_shape_vec_unchecked(
self_.dim.clone(),
self_.iter().cloned().collect(),
);
}
// Clone only the visible elements if the current view is less than
// half of backing data.
*self_ = self_.to_owned().into_shared();
return;
}
let rcvec = &mut self_.data.0;
Expand Down
4 changes: 4 additions & 0 deletions src/impl_methods.rs
Expand Up @@ -1536,6 +1536,10 @@ where

/// Return the array’s data as a slice if it is contiguous,
/// return `None` otherwise.
///
/// In the contiguous case, in order to return a unique reference, this
/// method unshares the data if necessary, but it preserves the existing
/// strides.
pub fn as_slice_memory_order_mut(&mut self) -> Option<&mut [A]>
where
S: DataMut,
Expand Down
40 changes: 38 additions & 2 deletions tests/array.rs
Expand Up @@ -990,8 +990,8 @@ fn map1() {
}

#[test]
fn as_slice_memory_order() {
// test that mutation breaks sharing
fn as_slice_memory_order_mut_arcarray() {
// Test that mutation breaks sharing for `ArcArray`.
let a = rcarr2(&[[1., 2.], [3., 4.0f32]]);
let mut b = a.clone();
for elt in b.as_slice_memory_order_mut().unwrap() {
Expand All @@ -1000,6 +1000,38 @@ fn as_slice_memory_order() {
assert!(a != b, "{:?} != {:?}", a, b);
}

#[test]
fn as_slice_memory_order_mut_cowarray() {
// Test that mutation breaks sharing for `CowArray`.
let a = arr2(&[[1., 2.], [3., 4.0f32]]);
let mut b = CowArray::from(a.view());
for elt in b.as_slice_memory_order_mut().unwrap() {
*elt = 0.;
}
assert!(a != b, "{:?} != {:?}", a, b);
}

#[test]
fn as_slice_memory_order_mut_contiguous_arcarray() {
// Test that unsharing preserves the strides in the contiguous case for `ArcArray`.
let a = rcarr2(&[[0, 5], [1, 6], [2, 7], [3, 8], [4, 9]]).reversed_axes();
let mut b = a.clone().slice_move(s![.., ..2]);
assert_eq!(b.strides(), &[1, 2]);
b.as_slice_memory_order_mut().unwrap();
assert_eq!(b.strides(), &[1, 2]);
}

#[test]
fn as_slice_memory_order_mut_contiguous_cowarray() {
// Test that unsharing preserves the strides in the contiguous case for `CowArray`.
let a = arr2(&[[0, 5], [1, 6], [2, 7], [3, 8], [4, 9]]).reversed_axes();
let mut b = CowArray::from(a.slice(s![.., ..2]));
assert!(b.is_view());
assert_eq!(b.strides(), &[1, 2]);
b.as_slice_memory_order_mut().unwrap();
assert_eq!(b.strides(), &[1, 2]);
}

#[test]
fn array0_into_scalar() {
// With this kind of setup, the `Array`'s pointer is not the same as the
Expand Down Expand Up @@ -1809,6 +1841,10 @@ fn map_mut_with_unsharing() {
// `.map_mut()` unshares the data. Earlier versions of `ndarray` failed
// this assertion. See #1018.
assert_eq!(b.map_mut(|&mut x| x + 10), array![[10, 11], [15, 16]]);

// The strides should be preserved.
assert_eq!(b.shape(), &[2, 2]);
assert_eq!(b.strides(), &[1, 2]);
}

#[test]
Expand Down

0 comments on commit 37645bd

Please sign in to comment.