diff --git a/src/dimension/mod.rs b/src/dimension/mod.rs index 28d2e9b2c..212197fed 100644 --- a/src/dimension/mod.rs +++ b/src/dimension/mod.rs @@ -139,6 +139,14 @@ pub fn can_index_slice_not_custom(data: &[A], dim: &D) -> Resul /// also implies that the length of any individual axis does not exceed /// `isize::MAX`.) pub fn max_abs_offset_check_overflow(dim: &D, strides: &D) -> Result +where + D: Dimension, +{ + max_abs_offset_check_overflow_impl(mem::size_of::(), dim, strides) +} + +fn max_abs_offset_check_overflow_impl(elem_size: usize, dim: &D, strides: &D) + -> Result where D: Dimension, { @@ -168,7 +176,7 @@ where // Determine absolute difference in units of bytes between least and // greatest address accessible by moving along all axes let max_offset_bytes = max_offset - .checked_mul(mem::size_of::()) + .checked_mul(elem_size) .ok_or_else(|| from_kind(ErrorKind::Overflow))?; // Condition 2b. if max_offset_bytes > isize::MAX as usize { @@ -216,13 +224,21 @@ pub fn can_index_slice( ) -> Result<(), ShapeError> { // Check conditions 1 and 2 and calculate `max_offset`. let max_offset = max_abs_offset_check_overflow::(dim, strides)?; + can_index_slice_impl(max_offset, data.len(), dim, strides) +} +fn can_index_slice_impl( + max_offset: usize, + data_len: usize, + dim: &D, + strides: &D, +) -> Result<(), ShapeError> { // Check condition 4. let is_empty = dim.slice().iter().any(|&d| d == 0); - if is_empty && max_offset > data.len() { + if is_empty && max_offset > data_len { return Err(from_kind(ErrorKind::OutOfBounds)); } - if !is_empty && max_offset >= data.len() { + if !is_empty && max_offset >= data_len { return Err(from_kind(ErrorKind::OutOfBounds)); } diff --git a/src/iterators/mod.rs b/src/iterators/mod.rs index 091119361..621c141ff 100644 --- a/src/iterators/mod.rs +++ b/src/iterators/mod.rs @@ -79,15 +79,18 @@ impl Iterator for Baseiter { let ndim = self.dim.ndim(); debug_assert_ne!(ndim, 0); let mut accum = init; - while let Some(mut index) = self.index.clone() { + while let Some(mut index) = self.index { let stride = self.strides.last_elem() as isize; let elem_index = index.last_elem(); let len = self.dim.last_elem(); let offset = D::stride_offset(&index, &self.strides); unsafe { let row_ptr = self.ptr.offset(offset); - for i in 0..(len - elem_index) { + let mut i = 0; + let i_end = len - elem_index; + while i < i_end { accum = g(accum, row_ptr.offset(i as isize * stride)); + i += 1; } } index.set_last_elem(len - 1); diff --git a/src/zip/mod.rs b/src/zip/mod.rs index 2b6612b54..c23791f62 100644 --- a/src/zip/mod.rs +++ b/src/zip/mod.rs @@ -723,7 +723,7 @@ where } } - fn apply_core_contiguous(&mut self, mut acc: Acc, mut function: F) -> FoldWhile + fn apply_core_contiguous(&mut self, acc: Acc, mut function: F) -> FoldWhile where F: FnMut(Acc, P::Item) -> FoldWhile, P: ZippableTuple, @@ -732,15 +732,35 @@ where let size = self.dimension.size(); let ptrs = self.parts.as_ptr(); let inner_strides = self.parts.contiguous_stride(); - for i in 0..size { - unsafe { - let ptr_i = ptrs.stride_offset(inner_strides, i); - acc = fold_while![function(acc, self.parts.as_ref(ptr_i))]; - } + unsafe { + self.inner(acc, ptrs, inner_strides, size, &mut function) + } + } + + /// The innermost loop of the Zip apply methods + /// + /// Run the fold while operation on a stretch of elements with constant strides + /// + /// `ptr`: base pointer for the first element in this stretch + /// `strides`: strides for the elements in this stretch + /// `len`: number of elements + /// `function`: closure + unsafe fn inner(&self, mut acc: Acc, ptr: P::Ptr, strides: P::Stride, + len: usize, function: &mut F) -> FoldWhile + where + F: FnMut(Acc, P::Item) -> FoldWhile, + P: ZippableTuple + { + let mut i = 0; + while i < len { + let p = ptr.stride_offset(strides, i); + acc = fold_while!(function(acc, self.parts.as_ref(p))); + i += 1; } FoldWhile::Continue(acc) } + fn apply_core_strided(&mut self, acc: Acc, function: F) -> FoldWhile where F: FnMut(Acc, P::Item) -> FoldWhile, @@ -773,15 +793,11 @@ where while let Some(index) = index_ { unsafe { let ptr = self.parts.uget_ptr(&index); - for i in 0..inner_len { - let p = ptr.stride_offset(inner_strides, i); - acc = fold_while!(function(acc, self.parts.as_ref(p))); - } + acc = fold_while![self.inner(acc, ptr, inner_strides, inner_len, &mut function)]; } index_ = self.dimension.next_for(index); } - self.dimension[unroll_axis] = inner_len; FoldWhile::Continue(acc) } @@ -801,10 +817,7 @@ where loop { unsafe { let ptr = self.parts.uget_ptr(&index); - for i in 0..inner_len { - let p = ptr.stride_offset(inner_strides, i); - acc = fold_while!(function(acc, self.parts.as_ref(p))); - } + acc = fold_while![self.inner(acc, ptr, inner_strides, inner_len, &mut function)]; } if !self.dimension.next_for_f(&mut index) { @@ -812,7 +825,6 @@ where } } } - self.dimension[unroll_axis] = inner_len; FoldWhile::Continue(acc) }