Skip to content

Commit

Permalink
Merge pull request #1249 from LazaroHurtado/windows_stride_feature
Browse files Browse the repository at this point in the history
Added stride support to `Windows`
  • Loading branch information
Nil Goyette committed Jun 11, 2023
2 parents 7342ca8 + 1b00771 commit c5bb8b6
Show file tree
Hide file tree
Showing 3 changed files with 157 additions and 27 deletions.
66 changes: 44 additions & 22 deletions src/impl_methods.rs
Expand Up @@ -1418,43 +1418,65 @@ where
/// The windows are all distinct overlapping views of size `window_size`
/// that fit into the array's shape.
///
/// This produces no elements if the window size is larger than the actual array size along any
/// axis.
/// This is essentially equivalent to [`.windows_with_stride()`] with unit stride.
pub fn windows<E>(&self, window_size: E) -> Windows<'_, A, D>
where
E: IntoDimension<Dim = D>,
S: Data,
{
Windows::new(self.view(), window_size)
}

/// Return a window producer and iterable.
///
/// The windows are all distinct views of size `window_size`
/// that fit into the array's shape.
///
/// The stride is ordered by the outermost axis.<br>
/// Hence, a (x₀, x₁, ..., xₙ) stride will be applied to
/// (A₀, A₁, ..., Aₙ) where Aₓ stands for `Axis(x)`.
///
/// This produces all windows that fit within the array for the given stride,
/// assuming the window size is not larger than the array size.
///
/// The produced element is an `ArrayView<A, D>` with exactly the dimension
/// `window_size`.
///
/// Note that passing a stride of only ones is similar to
/// calling [`ArrayBase::windows()`].
///
/// **Panics** if any dimension of `window_size` is zero.<br>
/// (**Panics** if `D` is `IxDyn` and `window_size` does not match the
/// **Panics** if any dimension of `window_size` or `stride` is zero.<br>
/// (**Panics** if `D` is `IxDyn` and `window_size` or `stride` does not match the
/// number of array axes.)
///
/// This is an illustration of the 2×2 windows in a 3×4 array:
/// This is the same illustration found in [`ArrayBase::windows()`],
/// 2×2 windows in a 3×4 array, but now with a (1, 2) stride:
///
/// ```text
/// ──▶ Axis(1)
///
/// │ ┏━━━━━┳━━━━━┱─────┬─────┐ ┌─────┲━━━━━┳━━━━━┱─────┐ ┌─────┬─────┲━━━━━┳━━━━━┓
/// ▼ ┃ a₀₀ ┃ a₀₁ ┃ │ │ │ ┃ a₀₁ ┃ a₀₂ ┃ │ │ │ ┃ a₀₂ ┃ a₀₃ ┃
/// Axis(0) ┣━━━━━╋━━━━━╉─────┼─────┤ ├─────╊━━━━━╋━━━━━╉─────┤ ├─────┼─────╊━━━━━╋━━━━━┫
/// ┃ a₁₀ ┃ a₁₁ ┃ │ │ │ ┃ a₁₁ ┃ a₁₂ ┃ │ │ │ ┃ a₁₂ ┃ a₁₃ ┃
/// ┡━━━━━╇━━━━━╃─────┼─────┤ ├─────╄━━━━━╇━━━━━╃─────┤ ├─────┼─────╄━━━━━╇━━━━━┩
/// │ │ │ │ │ │ │ │ │ │ │ │ │ │ │
/// └─────┴─────┴─────┴─────┘ └─────┴─────┴─────┴─────┘ └─────┴─────┴─────┴─────┘
///
/// ┌─────┬─────┬─────┬─────┐ ┌─────┬─────┬─────┬─────┐ ┌─────┬─────┬─────┬─────┐
/// │ │ │ │ │ │ │ │ │ │ │ │ │ │ │
/// ┢━━━━━╈━━━━━╅─────┼─────┤ ├─────╆━━━━━╈━━━━━╅─────┤ ├─────┼─────╆━━━━━╈━━━━━┪
/// ┃ a₁₀ ┃ a₁₁ ┃ │ │ │ ┃ a₁₁ ┃ a₁₂ ┃ │ │ │ ┃ a₁₂ ┃ a₁₃ ┃
/// ┣━━━━━╋━━━━━╉─────┼─────┤ ├─────╊━━━━━╋━━━━━╉─────┤ ├─────┼─────╊━━━━━╋━━━━━┫
/// ┃ a₂₀ ┃ a₂₁ ┃ │ │ │ ┃ a₂₁ ┃ a₂₂ ┃ │ │ │ ┃ a₂₂ ┃ a₂₃ ┃
/// ┗━━━━━┻━━━━━┹─────┴─────┘ └─────┺━━━━━┻━━━━━┹─────┘ └─────┴─────┺━━━━━┻━━━━━┛
/// │ ┏━━━━━┳━━━━━┱─────┬─────┐ ┌─────┬─────┲━━━━━┳━━━━━┓
/// ▼ ┃ a₀₀ ┃ a₀₁ ┃ │ │ │ │ ┃ a₀₂ ┃ a₀₃ ┃
/// Axis(0) ┣━━━━━╋━━━━━╉─────┼─────┤ ├─────┼─────╊━━━━━╋━━━━━┫
/// ┃ a₁₀ ┃ a₁₁ ┃ │ │ │ │ ┃ a₁₂ ┃ a₁₃ ┃
/// ┡━━━━━╇━━━━━╃─────┼─────┤ ├─────┼─────╄━━━━━╇━━━━━┩
/// │ │ │ │ │ │ │ │ │ │
/// └─────┴─────┴─────┴─────┘ └─────┴─────┴─────┴─────┘
///
/// ┌─────┬─────┬─────┬─────┐ ┌─────┬─────┬─────┬─────┐
/// │ │ │ │ │ │ │ │ │ │
/// ┢━━━━━╈━━━━━╅─────┼─────┤ ├─────┼─────╆━━━━━╈━━━━━┪
/// ┃ a₁₀ ┃ a₁₁ ┃ │ │ │ │ ┃ a₁₂ ┃ a₁₃ ┃
/// ┣━━━━━╋━━━━━╉─────┼─────┤ ├─────┼─────╊━━━━━╋━━━━━┫
/// ┃ a₂₀ ┃ a₂₁ ┃ │ │ │ │ ┃ a₂₂ ┃ a₂₃ ┃
/// ┗━━━━━┻━━━━━┹─────┴─────┘ └─────┴─────┺━━━━━┻━━━━━┛
/// ```
pub fn windows<E>(&self, window_size: E) -> Windows<'_, A, D>
pub fn windows_with_stride<E>(&self, window_size: E, stride: E) -> Windows<'_, A, D>
where
E: IntoDimension<Dim = D>,
S: Data,
{
Windows::new(self.view(), window_size)
Windows::new_with_stride(self.view(), window_size, stride)
}

/// Returns a producer which traverses over all windows of a given length along an axis.
Expand Down
46 changes: 42 additions & 4 deletions src/iterators/windows.rs
Expand Up @@ -20,6 +20,20 @@ impl<'a, A, D: Dimension> Windows<'a, A, D> {
E: IntoDimension<Dim = D>,
{
let window = window_size.into_dimension();
let ndim = window.ndim();

let mut unit_stride = D::zeros(ndim);
unit_stride.slice_mut().fill(1);

Windows::new_with_stride(a, window, unit_stride)
}

pub(crate) fn new_with_stride<E>(a: ArrayView<'a, A, D>, window_size: E, strides: E) -> Self
where
E: IntoDimension<Dim = D>,
{
let window = window_size.into_dimension();
let strides_d = strides.into_dimension();
ndassert!(
a.ndim() == window.ndim(),
concat!(
Expand All @@ -30,18 +44,42 @@ impl<'a, A, D: Dimension> Windows<'a, A, D> {
a.ndim(),
a.shape()
);
ndassert!(
a.ndim() == strides_d.ndim(),
concat!(
"Stride dimension {} does not match array dimension {} ",
"(with array of shape {:?})"
),
strides_d.ndim(),
a.ndim(),
a.shape()
);
let mut size = a.dim;
for (sz, &ws) in size.slice_mut().iter_mut().zip(window.slice()) {
for ((sz, &ws), &stride) in size
.slice_mut()
.iter_mut()
.zip(window.slice())
.zip(strides_d.slice())
{
assert_ne!(ws, 0, "window-size must not be zero!");
assert_ne!(stride, 0, "stride cannot have a dimension as zero!");
// cannot use std::cmp::max(0, ..) since arithmetic underflow panics
*sz = if *sz < ws { 0 } else { *sz - ws + 1 };
*sz = if *sz < ws {
0
} else {
((*sz - (ws - 1) - 1) / stride) + 1
};
}

let window_strides = a.strides.clone();

let mut array_strides = a.strides.clone();
for (arr_stride, ix_stride) in array_strides.slice_mut().iter_mut().zip(strides_d.slice()) {
*arr_stride *= ix_stride;
}

unsafe {
Windows {
base: ArrayView::new(a.ptr, size, a.strides),
base: ArrayView::new(a.ptr, size, array_strides),
window,
strides: window_strides,
}
Expand Down
72 changes: 71 additions & 1 deletion tests/windows.rs
Expand Up @@ -30,7 +30,7 @@ fn windows_iterator_zero_size() {
a.windows(Dim((0, 0, 0)));
}

/// Test that verifites that no windows are yielded on oversized window sizes.
/// Test that verifies that no windows are yielded on oversized window sizes.
#[test]
fn windows_iterator_oversized() {
let a = Array::from_iter(10..37).into_shape((3, 3, 3)).unwrap();
Expand Down Expand Up @@ -95,6 +95,76 @@ fn windows_iterator_3d() {
);
}

/// Test that verifies the `Windows` iterator panics when stride has an axis equal to zero.
#[test]
#[should_panic]
fn windows_iterator_stride_axis_zero() {
let a = Array::from_iter(10..37).into_shape((3, 3, 3)).unwrap();
a.windows_with_stride((2, 2, 2), (0, 2, 2));
}

/// Test that verifies that only first window is yielded when stride is oversized on every axis.
#[test]
fn windows_iterator_only_one_valid_window_for_oversized_stride() {
let a = Array::from_iter(10..135).into_shape((5, 5, 5)).unwrap();
let mut iter = a.windows_with_stride((2, 2, 2), (8, 8, 8)).into_iter(); // (4,3,2) doesn't fit into (3,3,3) => oversized!
itertools::assert_equal(
iter.next(),
Some(arr3(&[[[10, 11], [15, 16]], [[35, 36], [40, 41]]])),
);
}

/// Simple test for iterating 1d-arrays via `Windows` with stride.
#[test]
fn windows_iterator_1d_with_stride() {
let a = Array::from_iter(10..20).into_shape(10).unwrap();
itertools::assert_equal(
a.windows_with_stride(4, 2),
vec![
arr1(&[10, 11, 12, 13]),
arr1(&[12, 13, 14, 15]),
arr1(&[14, 15, 16, 17]),
arr1(&[16, 17, 18, 19]),
],
);
}

/// Simple test for iterating 2d-arrays via `Windows` with stride.
#[test]
fn windows_iterator_2d_with_stride() {
let a = Array::from_iter(10..30).into_shape((5, 4)).unwrap();
itertools::assert_equal(
a.windows_with_stride((3, 2), (2, 1)),
vec![
arr2(&[[10, 11], [14, 15], [18, 19]]),
arr2(&[[11, 12], [15, 16], [19, 20]]),
arr2(&[[12, 13], [16, 17], [20, 21]]),
arr2(&[[18, 19], [22, 23], [26, 27]]),
arr2(&[[19, 20], [23, 24], [27, 28]]),
arr2(&[[20, 21], [24, 25], [28, 29]]),
],
);
}

/// Simple test for iterating 3d-arrays via `Windows` with stride.
#[test]
fn windows_iterator_3d_with_stride() {
let a = Array::from_iter(10..74).into_shape((4, 4, 4)).unwrap();
itertools::assert_equal(
a.windows_with_stride((2, 2, 2), (2, 2, 2)),
vec![
arr3(&[[[10, 11], [14, 15]], [[26, 27], [30, 31]]]),
arr3(&[[[12, 13], [16, 17]], [[28, 29], [32, 33]]]),
arr3(&[[[18, 19], [22, 23]], [[34, 35], [38, 39]]]),
arr3(&[[[20, 21], [24, 25]], [[36, 37], [40, 41]]]),
arr3(&[[[42, 43], [46, 47]], [[58, 59], [62, 63]]]),
arr3(&[[[44, 45], [48, 49]], [[60, 61], [64, 65]]]),
arr3(&[[[50, 51], [54, 55]], [[66, 67], [70, 71]]]),
arr3(&[[[52, 53], [56, 57]], [[68, 69], [72, 73]]]),
],
);
}

#[test]
fn test_window_zip() {
let a = Array::from_iter(0..64).into_shape((4, 4, 4)).unwrap();
Expand Down

0 comments on commit c5bb8b6

Please sign in to comment.