From e377e825e114abe2ba6f70f1bb5e4864af86162b Mon Sep 17 00:00:00 2001 From: bluss Date: Sat, 27 Mar 2021 14:05:13 +0100 Subject: [PATCH] FIX: Fix .windows() producer for negative stride arrays Avoid using ArrayView::from_shape_ptr (public constructor) because it does not allow negative stride arrays. Use the internal constructor ArrayView::new, which is easier. --- src/iterators/windows.rs | 2 +- tests/windows.rs | 36 ++++++++++++++++++++++++++++++++++++ 2 files changed, 37 insertions(+), 1 deletion(-) diff --git a/src/iterators/windows.rs b/src/iterators/windows.rs index 2d9095303..4538f7abb 100644 --- a/src/iterators/windows.rs +++ b/src/iterators/windows.rs @@ -41,7 +41,7 @@ impl<'a, A, D: Dimension> Windows<'a, A, D> { unsafe { Windows { - base: ArrayView::from_shape_ptr(size.strides(a.strides), a.ptr.as_ptr()), + base: ArrayView::new(a.ptr, size, a.strides), window, strides: window_strides, } diff --git a/tests/windows.rs b/tests/windows.rs index 9fb8cc8ae..87fdf6ecb 100644 --- a/tests/windows.rs +++ b/tests/windows.rs @@ -116,3 +116,39 @@ fn test_window_zip() { } } } + +#[test] +fn test_window_neg_stride() { + let array = Array::from_iter(1..10).into_shape((3, 3)).unwrap(); + + // window neg/pos stride combinations + + // Make a 2 x 2 array of the windows of the 3 x 3 array + // and compute test answers from here + let mut answer = Array::from_iter(array.windows((2, 2)).into_iter().map(|a| a.to_owned())) + .into_shape((2, 2)).unwrap(); + + answer.invert_axis(Axis(1)); + answer.map_inplace(|a| a.invert_axis(Axis(1))); + + itertools::assert_equal( + array.slice(s![.., ..;-1]).windows((2, 2)), + answer.iter().map(|a| a.view()) + ); + + answer.invert_axis(Axis(0)); + answer.map_inplace(|a| a.invert_axis(Axis(0))); + + itertools::assert_equal( + array.slice(s![..;-1, ..;-1]).windows((2, 2)), + answer.iter().map(|a| a.view()) + ); + + answer.invert_axis(Axis(1)); + answer.map_inplace(|a| a.invert_axis(Axis(1))); + + itertools::assert_equal( + array.slice(s![..;-1, ..]).windows((2, 2)), + answer.iter().map(|a| a.view()) + ); +}