Skip to content

Commit

Permalink
Merge pull request #819 from rust-ndarray/reduce-generated-code
Browse files Browse the repository at this point in the history
Combine common code / reduce codegen by factoring out common parts
  • Loading branch information
bluss committed May 16, 2020
2 parents 4bac2c1 + a76bb92 commit 79d99c7
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 21 deletions.
22 changes: 19 additions & 3 deletions src/dimension/mod.rs
Expand Up @@ -139,6 +139,14 @@ pub fn can_index_slice_not_custom<A, D: Dimension>(data: &[A], dim: &D) -> Resul
/// also implies that the length of any individual axis does not exceed
/// `isize::MAX`.)
pub fn max_abs_offset_check_overflow<A, D>(dim: &D, strides: &D) -> Result<usize, ShapeError>
where
D: Dimension,
{
max_abs_offset_check_overflow_impl(mem::size_of::<A>(), dim, strides)
}

fn max_abs_offset_check_overflow_impl<D>(elem_size: usize, dim: &D, strides: &D)
-> Result<usize, ShapeError>
where
D: Dimension,
{
Expand Down Expand Up @@ -168,7 +176,7 @@ where
// Determine absolute difference in units of bytes between least and
// greatest address accessible by moving along all axes
let max_offset_bytes = max_offset
.checked_mul(mem::size_of::<A>())
.checked_mul(elem_size)
.ok_or_else(|| from_kind(ErrorKind::Overflow))?;
// Condition 2b.
if max_offset_bytes > isize::MAX as usize {
Expand Down Expand Up @@ -216,13 +224,21 @@ pub fn can_index_slice<A, D: Dimension>(
) -> Result<(), ShapeError> {
// Check conditions 1 and 2 and calculate `max_offset`.
let max_offset = max_abs_offset_check_overflow::<A, _>(dim, strides)?;
can_index_slice_impl(max_offset, data.len(), dim, strides)
}

fn can_index_slice_impl<D: Dimension>(
max_offset: usize,
data_len: usize,
dim: &D,
strides: &D,
) -> Result<(), ShapeError> {
// Check condition 4.
let is_empty = dim.slice().iter().any(|&d| d == 0);
if is_empty && max_offset > data.len() {
if is_empty && max_offset > data_len {
return Err(from_kind(ErrorKind::OutOfBounds));
}
if !is_empty && max_offset >= data.len() {
if !is_empty && max_offset >= data_len {
return Err(from_kind(ErrorKind::OutOfBounds));
}

Expand Down
7 changes: 5 additions & 2 deletions src/iterators/mod.rs
Expand Up @@ -79,15 +79,18 @@ impl<A, D: Dimension> Iterator for Baseiter<A, D> {
let ndim = self.dim.ndim();
debug_assert_ne!(ndim, 0);
let mut accum = init;
while let Some(mut index) = self.index.clone() {
while let Some(mut index) = self.index {
let stride = self.strides.last_elem() as isize;
let elem_index = index.last_elem();
let len = self.dim.last_elem();
let offset = D::stride_offset(&index, &self.strides);
unsafe {
let row_ptr = self.ptr.offset(offset);
for i in 0..(len - elem_index) {
let mut i = 0;
let i_end = len - elem_index;
while i < i_end {
accum = g(accum, row_ptr.offset(i as isize * stride));
i += 1;
}
}
index.set_last_elem(len - 1);
Expand Down
44 changes: 28 additions & 16 deletions src/zip/mod.rs
Expand Up @@ -723,7 +723,7 @@ where
}
}

fn apply_core_contiguous<F, Acc>(&mut self, mut acc: Acc, mut function: F) -> FoldWhile<Acc>
fn apply_core_contiguous<F, Acc>(&mut self, acc: Acc, mut function: F) -> FoldWhile<Acc>
where
F: FnMut(Acc, P::Item) -> FoldWhile<Acc>,
P: ZippableTuple<Dim = D>,
Expand All @@ -732,15 +732,35 @@ where
let size = self.dimension.size();
let ptrs = self.parts.as_ptr();
let inner_strides = self.parts.contiguous_stride();
for i in 0..size {
unsafe {
let ptr_i = ptrs.stride_offset(inner_strides, i);
acc = fold_while![function(acc, self.parts.as_ref(ptr_i))];
}
unsafe {
self.inner(acc, ptrs, inner_strides, size, &mut function)
}
}

/// The innermost loop of the Zip apply methods
///
/// Run the fold while operation on a stretch of elements with constant strides
///
/// `ptr`: base pointer for the first element in this stretch
/// `strides`: strides for the elements in this stretch
/// `len`: number of elements
/// `function`: closure
unsafe fn inner<F, Acc>(&self, mut acc: Acc, ptr: P::Ptr, strides: P::Stride,
len: usize, function: &mut F) -> FoldWhile<Acc>
where
F: FnMut(Acc, P::Item) -> FoldWhile<Acc>,
P: ZippableTuple
{
let mut i = 0;
while i < len {
let p = ptr.stride_offset(strides, i);
acc = fold_while!(function(acc, self.parts.as_ref(p)));
i += 1;
}
FoldWhile::Continue(acc)
}


fn apply_core_strided<F, Acc>(&mut self, acc: Acc, function: F) -> FoldWhile<Acc>
where
F: FnMut(Acc, P::Item) -> FoldWhile<Acc>,
Expand Down Expand Up @@ -773,15 +793,11 @@ where
while let Some(index) = index_ {
unsafe {
let ptr = self.parts.uget_ptr(&index);
for i in 0..inner_len {
let p = ptr.stride_offset(inner_strides, i);
acc = fold_while!(function(acc, self.parts.as_ref(p)));
}
acc = fold_while![self.inner(acc, ptr, inner_strides, inner_len, &mut function)];
}

index_ = self.dimension.next_for(index);
}
self.dimension[unroll_axis] = inner_len;
FoldWhile::Continue(acc)
}

Expand All @@ -801,18 +817,14 @@ where
loop {
unsafe {
let ptr = self.parts.uget_ptr(&index);
for i in 0..inner_len {
let p = ptr.stride_offset(inner_strides, i);
acc = fold_while!(function(acc, self.parts.as_ref(p)));
}
acc = fold_while![self.inner(acc, ptr, inner_strides, inner_len, &mut function)];
}

if !self.dimension.next_for_f(&mut index) {
break;
}
}
}
self.dimension[unroll_axis] = inner_len;
FoldWhile::Continue(acc)
}

Expand Down

0 comments on commit 79d99c7

Please sign in to comment.