diff --git a/src/dimension/mod.rs b/src/dimension/mod.rs
index 28d2e9b2c..212197fed 100644
--- a/src/dimension/mod.rs
+++ b/src/dimension/mod.rs
@@ -139,6 +139,14 @@ pub fn can_index_slice_not_custom(data: &[A], dim: &D) -> Resul
/// also implies that the length of any individual axis does not exceed
/// `isize::MAX`.)
pub fn max_abs_offset_check_overflow(dim: &D, strides: &D) -> Result
+where
+ D: Dimension,
+{
+ max_abs_offset_check_overflow_impl(mem::size_of::(), dim, strides)
+}
+
+fn max_abs_offset_check_overflow_impl(elem_size: usize, dim: &D, strides: &D)
+ -> Result
where
D: Dimension,
{
@@ -168,7 +176,7 @@ where
// Determine absolute difference in units of bytes between least and
// greatest address accessible by moving along all axes
let max_offset_bytes = max_offset
- .checked_mul(mem::size_of::())
+ .checked_mul(elem_size)
.ok_or_else(|| from_kind(ErrorKind::Overflow))?;
// Condition 2b.
if max_offset_bytes > isize::MAX as usize {
@@ -216,13 +224,21 @@ pub fn can_index_slice(
) -> Result<(), ShapeError> {
// Check conditions 1 and 2 and calculate `max_offset`.
let max_offset = max_abs_offset_check_overflow::(dim, strides)?;
+ can_index_slice_impl(max_offset, data.len(), dim, strides)
+}
+fn can_index_slice_impl(
+ max_offset: usize,
+ data_len: usize,
+ dim: &D,
+ strides: &D,
+) -> Result<(), ShapeError> {
// Check condition 4.
let is_empty = dim.slice().iter().any(|&d| d == 0);
- if is_empty && max_offset > data.len() {
+ if is_empty && max_offset > data_len {
return Err(from_kind(ErrorKind::OutOfBounds));
}
- if !is_empty && max_offset >= data.len() {
+ if !is_empty && max_offset >= data_len {
return Err(from_kind(ErrorKind::OutOfBounds));
}
diff --git a/src/iterators/mod.rs b/src/iterators/mod.rs
index 091119361..621c141ff 100644
--- a/src/iterators/mod.rs
+++ b/src/iterators/mod.rs
@@ -79,15 +79,18 @@ impl Iterator for Baseiter {
let ndim = self.dim.ndim();
debug_assert_ne!(ndim, 0);
let mut accum = init;
- while let Some(mut index) = self.index.clone() {
+ while let Some(mut index) = self.index {
let stride = self.strides.last_elem() as isize;
let elem_index = index.last_elem();
let len = self.dim.last_elem();
let offset = D::stride_offset(&index, &self.strides);
unsafe {
let row_ptr = self.ptr.offset(offset);
- for i in 0..(len - elem_index) {
+ let mut i = 0;
+ let i_end = len - elem_index;
+ while i < i_end {
accum = g(accum, row_ptr.offset(i as isize * stride));
+ i += 1;
}
}
index.set_last_elem(len - 1);
diff --git a/src/zip/mod.rs b/src/zip/mod.rs
index 2b6612b54..c23791f62 100644
--- a/src/zip/mod.rs
+++ b/src/zip/mod.rs
@@ -723,7 +723,7 @@ where
}
}
- fn apply_core_contiguous(&mut self, mut acc: Acc, mut function: F) -> FoldWhile
+ fn apply_core_contiguous(&mut self, acc: Acc, mut function: F) -> FoldWhile
where
F: FnMut(Acc, P::Item) -> FoldWhile,
P: ZippableTuple,
@@ -732,15 +732,35 @@ where
let size = self.dimension.size();
let ptrs = self.parts.as_ptr();
let inner_strides = self.parts.contiguous_stride();
- for i in 0..size {
- unsafe {
- let ptr_i = ptrs.stride_offset(inner_strides, i);
- acc = fold_while![function(acc, self.parts.as_ref(ptr_i))];
- }
+ unsafe {
+ self.inner(acc, ptrs, inner_strides, size, &mut function)
+ }
+ }
+
+ /// The innermost loop of the Zip apply methods
+ ///
+ /// Run the fold while operation on a stretch of elements with constant strides
+ ///
+ /// `ptr`: base pointer for the first element in this stretch
+ /// `strides`: strides for the elements in this stretch
+ /// `len`: number of elements
+ /// `function`: closure
+ unsafe fn inner(&self, mut acc: Acc, ptr: P::Ptr, strides: P::Stride,
+ len: usize, function: &mut F) -> FoldWhile
+ where
+ F: FnMut(Acc, P::Item) -> FoldWhile,
+ P: ZippableTuple
+ {
+ let mut i = 0;
+ while i < len {
+ let p = ptr.stride_offset(strides, i);
+ acc = fold_while!(function(acc, self.parts.as_ref(p)));
+ i += 1;
}
FoldWhile::Continue(acc)
}
+
fn apply_core_strided(&mut self, acc: Acc, function: F) -> FoldWhile
where
F: FnMut(Acc, P::Item) -> FoldWhile,
@@ -773,15 +793,11 @@ where
while let Some(index) = index_ {
unsafe {
let ptr = self.parts.uget_ptr(&index);
- for i in 0..inner_len {
- let p = ptr.stride_offset(inner_strides, i);
- acc = fold_while!(function(acc, self.parts.as_ref(p)));
- }
+ acc = fold_while![self.inner(acc, ptr, inner_strides, inner_len, &mut function)];
}
index_ = self.dimension.next_for(index);
}
- self.dimension[unroll_axis] = inner_len;
FoldWhile::Continue(acc)
}
@@ -801,10 +817,7 @@ where
loop {
unsafe {
let ptr = self.parts.uget_ptr(&index);
- for i in 0..inner_len {
- let p = ptr.stride_offset(inner_strides, i);
- acc = fold_while!(function(acc, self.parts.as_ref(p)));
- }
+ acc = fold_while![self.inner(acc, ptr, inner_strides, inner_len, &mut function)];
}
if !self.dimension.next_for_f(&mut index) {
@@ -812,7 +825,6 @@ where
}
}
}
- self.dimension[unroll_axis] = inner_len;
FoldWhile::Continue(acc)
}