Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add .split_at() methods for AxisChunksIter/Mut #691

Merged
merged 4 commits into from Sep 4, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
132 changes: 83 additions & 49 deletions src/iterators/mod.rs
Expand Up @@ -825,6 +825,19 @@ impl<A, D: Dimension> AxisIterCore<A, D> {
};
(left, right)
}

/// Does the same thing as `.next()` but also returns the index of the item
/// relative to the start of the axis.
fn next_with_index(&mut self) -> Option<(usize, *mut A)> {
let index = self.index;
self.next().map(|ptr| (index, ptr))
}

/// Does the same thing as `.next_back()` but also returns the index of the
/// item relative to the start of the axis.
fn next_back_with_index(&mut self) -> Option<(usize, *mut A)> {
self.next_back().map(|ptr| (self.end, ptr))
}
}

impl<A, D> Iterator for AxisIterCore<A, D>
Expand Down Expand Up @@ -1182,9 +1195,13 @@ impl<'a, A, D: Dimension> NdProducer for AxisIterMut<'a, A, D> {
/// See [`.axis_chunks_iter()`](../struct.ArrayBase.html#method.axis_chunks_iter) for more information.
pub struct AxisChunksIter<'a, A, D> {
iter: AxisIterCore<A, D>,
n_whole_chunks: usize,
/// Dimension of the last (and possibly uneven) chunk
last_dim: D,
/// Index of the partial chunk (the chunk smaller than the specified chunk
/// size due to the axis length not being evenly divisible). If the axis
/// length is evenly divisible by the chunk size, this index is larger than
/// the maximum valid index.
partial_chunk_index: usize,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be beneficial to rephrase this as an Option, to make it clearer that we might (or might not) have a partial chunk? Something along the lines of:

pub struct AxisChunksIter<'a, A, D> {
   iter: AxisIterCore<A, D>,
   partial_chunk: Option<PartialChunk>,
   life: PhantomData<&'a A>
}

struct PartialChunk {
   partial_chunk_index: usize,
   partial_chunk_dim: D
}

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it makes sense to use both the Option variant and the value of partial_chunk_index to represent whether or not there's a partial chunk. (The biggest reason is that I prefer data structures where there's a single source of truth, rather than having to keep multiple things in sync. There might also be a small performance cost to accessing partial_chunk_index through the Option (since accessing it requires checking whether the Option is the Some variant), but we'd need to test to determine if that would really be noticeable.) IMO, putting the fields in an Option would be additional complication over the current approach without much benefit.

It would be reasonable to eliminate partial_chunk_index and just use the Option variant to represent the presence of a partial chunk, like this:

pub struct AxisChunksIter<'a, A, D> {
   iter: AxisIterCore<A, D>,
   partial_chunk: Option<D>,
   life: PhantomData<&'a A>
}

or to always store the shape of the last chunk (regardless of whether or not it's a partial chunk):

pub struct AxisChunksIter<'a, A, D> {
   iter: AxisIterCore<A, D>,
   last_chunk_dim: D,
   life: PhantomData<&'a A>
}

These approaches have two disadvantages since they rely on checking whether the iterator is at its end to handle the partial chunk instead of checking whether the current index is equal to partial_chunk_index:

  1. .split_at() needs to check whether or not the partial chunk is in the left piece and determine partial_chunk or last_chunk_dim of the left piece accordingly. (The partial chunk is in the left piece when index == self.iter.len().)

  2. .next_back() needs to set partial_chunk to None or last_chunk_dim to self.iter.inner_dim each time it's called.

So, I'd rather keep the current approach and add more comments if necessary to make it clear.

/// Dimension of the partial chunk.
partial_chunk_dim: D,
life: PhantomData<&'a A>,
}

Expand All @@ -1193,10 +1210,10 @@ clone_bounds!(
AxisChunksIter['a, A, D] {
@copy {
life,
n_whole_chunks,
partial_chunk_index,
}
iter,
last_dim,
partial_chunk_dim,
}
);

Expand Down Expand Up @@ -1233,12 +1250,9 @@ fn chunk_iter_parts<A, D: Dimension>(
let mut inner_dim = v.dim.clone();
inner_dim[axis] = size;

let mut last_dim = v.dim;
last_dim[axis] = if chunk_remainder == 0 {
size
} else {
chunk_remainder
};
let mut partial_chunk_dim = v.dim;
partial_chunk_dim[axis] = chunk_remainder;
let partial_chunk_index = n_whole_chunks;

let iter = AxisIterCore {
index: 0,
Expand All @@ -1249,16 +1263,16 @@ fn chunk_iter_parts<A, D: Dimension>(
ptr: v.ptr,
};

(iter, n_whole_chunks, last_dim)
(iter, partial_chunk_index, partial_chunk_dim)
}

impl<'a, A, D: Dimension> AxisChunksIter<'a, A, D> {
pub(crate) fn new(v: ArrayView<'a, A, D>, axis: Axis, size: usize) -> Self {
let (iter, n_whole_chunks, last_dim) = chunk_iter_parts(v, axis, size);
let (iter, partial_chunk_index, partial_chunk_dim) = chunk_iter_parts(v, axis, size);
AxisChunksIter {
iter,
n_whole_chunks,
last_dim,
partial_chunk_index,
partial_chunk_dim,
life: PhantomData,
}
}
Expand All @@ -1270,30 +1284,49 @@ macro_rules! chunk_iter_impl {
where
D: Dimension,
{
fn get_subview(
&self,
iter_item: Option<*mut A>,
is_uneven: bool,
) -> Option<$array<'a, A, D>> {
iter_item.map(|ptr| {
if !is_uneven {
unsafe {
$array::new_(
ptr,
self.iter.inner_dim.clone(),
self.iter.inner_strides.clone(),
)
}
} else {
unsafe {
$array::new_(
ptr,
self.last_dim.clone(),
self.iter.inner_strides.clone(),
)
}
fn get_subview(&self, index: usize, ptr: *mut A) -> $array<'a, A, D> {
if index != self.partial_chunk_index {
unsafe {
$array::new_(
ptr,
self.iter.inner_dim.clone(),
self.iter.inner_strides.clone(),
)
}
} else {
unsafe {
$array::new_(
ptr,
self.partial_chunk_dim.clone(),
self.iter.inner_strides.clone(),
)
}
})
}
}

/// Splits the iterator at index, yielding two disjoint iterators.
///
/// `index` is relative to the current state of the iterator (which is not
/// necessarily the start of the axis).
///
/// **Panics** if `index` is strictly greater than the iterator's remaining
/// length.
pub fn split_at(self, index: usize) -> (Self, Self) {
let (left, right) = self.iter.split_at(index);
(
Self {
iter: left,
partial_chunk_index: self.partial_chunk_index,
partial_chunk_dim: self.partial_chunk_dim.clone(),
life: self.life,
},
Self {
iter: right,
partial_chunk_index: self.partial_chunk_index,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't read the whole code unfortunately (what's not visible in the diff) - why doesn't this partial_chunk_index require adjusting - the right part of the iter now starts at index, so I'd expect this to be offset by - index?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here's an example:

use ndarray::prelude::*;

fn main() {
    let a: Array1<i32> = (0..13).collect();
    let mut iter = a.axis_chunks_iter(Axis(0), 3);
    iter.next();  // skip the first element so that we consider a partially-consumed iterator
    println!("before_split = {:#?}", iter);
    let (left, right) = iter.split_at(2);
    println!("left = {:#?}", left);
    println!("right = {:#?}", right);
}

which gives the output

before_split = AxisChunksIter {
    iter: AxisIterCore {
        index: 1,
        end: 5,
        stride: 3,
        inner_dim: [3],
        inner_strides: [1],
        ptr: 0x00005634af728b40,
    },
    partial_chunk_index: 4,
    partial_chunk_dim: [1],
    life: PhantomData,
}
left = AxisChunksIter {
    iter: AxisIterCore {
        index: 1,
        end: 3,
        stride: 3,
        inner_dim: [3],
        inner_strides: [1],
        ptr: 0x00005634af728b40,
    },
    partial_chunk_index: 4,
    partial_chunk_dim: [1],
    life: PhantomData,
}
right = AxisChunksIter {
    iter: AxisIterCore {
        index: 3,
        end: 5,
        stride: 3,
        inner_dim: [3],
        inner_strides: [1],
        ptr: 0x00005634af728b40,
    },
    partial_chunk_index: 4,
    partial_chunk_dim: [1],
    life: PhantomData,
}

We can visualize the situation like this:

               0 1 2 3 4
before split:    ^      |
after split:     ^  |^  |

The ^s represent the indexes and the |s represent the ends of the iterators. (The |s appear just before the corresponding end indices.) There are 4 full chunks (indices 0..=3) and 1 partial chunk (index 4). Note that all indices are relative to the start of the axis, so any given index value represents the same location before and after the split. This is why partial_chunk_index is the same before and after splitting. Before splitting, the index of the partial chunk is 4, and it stays 4 in the split pieces. (The left piece will never actually reach index 4 since its end is 3; that's the desired behavior.)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. If the only use of index is counting up to the partial_chunk_index, it makes total sense.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

index is also used in AxisIterCore (which AxisChunksIter wraps) to compute the pointer of each element/chunk and to check for the end of the iterator; see AxisIterCore's implementation of .next(). (.split_at() on AxisIterCore doesn't change the ptr value; ptr always corresponds to the start of the axis. This was part of #669.)

partial_chunk_dim: self.partial_chunk_dim,
life: self.life,
},
)
}
}

Expand All @@ -1304,9 +1337,9 @@ macro_rules! chunk_iter_impl {
type Item = $array<'a, A, D>;

fn next(&mut self) -> Option<Self::Item> {
let res = self.iter.next();
let is_uneven = self.iter.index > self.n_whole_chunks;
self.get_subview(res, is_uneven)
self.iter
.next_with_index()
.map(|(index, ptr)| self.get_subview(index, ptr))
}

fn size_hint(&self) -> (usize, Option<usize>) {
Expand All @@ -1319,9 +1352,9 @@ macro_rules! chunk_iter_impl {
D: Dimension,
{
fn next_back(&mut self) -> Option<Self::Item> {
let is_uneven = self.iter.end > self.n_whole_chunks;
let res = self.iter.next_back();
self.get_subview(res, is_uneven)
self.iter
.next_back_with_index()
.map(|(index, ptr)| self.get_subview(index, ptr))
}
}

Expand All @@ -1342,18 +1375,19 @@ macro_rules! chunk_iter_impl {
/// for more information.
pub struct AxisChunksIterMut<'a, A, D> {
iter: AxisIterCore<A, D>,
n_whole_chunks: usize,
last_dim: D,
partial_chunk_index: usize,
partial_chunk_dim: D,
life: PhantomData<&'a mut A>,
}

impl<'a, A, D: Dimension> AxisChunksIterMut<'a, A, D> {
pub(crate) fn new(v: ArrayViewMut<'a, A, D>, axis: Axis, size: usize) -> Self {
let (iter, len, last_dim) = chunk_iter_parts(v.into_view(), axis, size);
let (iter, partial_chunk_index, partial_chunk_dim) =
chunk_iter_parts(v.into_view(), axis, size);
AxisChunksIterMut {
iter,
n_whole_chunks: len,
last_dim,
partial_chunk_index,
partial_chunk_dim,
life: PhantomData,
}
}
Expand Down
41 changes: 41 additions & 0 deletions tests/iterators.rs
Expand Up @@ -13,6 +13,20 @@ use itertools::assert_equal;
use itertools::{enumerate, rev};
use std::iter::FromIterator;

macro_rules! assert_panics {
($body:expr) => {
if let Ok(v) = ::std::panic::catch_unwind(|| $body) {
panic!("assertion failed: should_panic; \
non-panicking result: {:?}", v);
}
};
($body:expr, $($arg:tt)*) => {
if let Ok(_) = ::std::panic::catch_unwind(|| $body) {
panic!($($arg)*);
}
};
}

#[test]
fn double_ended() {
let a = ArcArray::linspace(0., 7., 8);
Expand Down Expand Up @@ -585,6 +599,33 @@ fn axis_chunks_iter_zero_axis_len() {
assert!(a.axis_chunks_iter(Axis(0), 5).next().is_none());
}

#[test]
fn axis_chunks_iter_split_at() {
let mut a = Array2::<usize>::zeros((11, 3));
a.iter_mut().enumerate().for_each(|(i, elt)| *elt = i);
for source in &[
a.slice(s![..0, ..]),
a.slice(s![..1, ..]),
a.slice(s![..5, ..]),
a.slice(s![..10, ..]),
a.slice(s![..11, ..]),
a.slice(s![.., ..0]),
] {
let chunks_iter = source.axis_chunks_iter(Axis(0), 5);
let all_chunks: Vec<_> = chunks_iter.clone().collect();
let n_chunks = chunks_iter.len();
assert_eq!(n_chunks, all_chunks.len());
for index in 0..=n_chunks {
let (left, right) = chunks_iter.clone().split_at(index);
assert_eq!(&all_chunks[..index], &left.collect::<Vec<_>>()[..]);
assert_eq!(&all_chunks[index..], &right.collect::<Vec<_>>()[..]);
}
assert_panics!({
chunks_iter.split_at(n_chunks + 1);
});
}
}

#[test]
fn axis_chunks_iter_mut() {
let a = ArcArray::from_iter(0..24);
Expand Down