Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Combine common code / reduce codegen by factoring out common parts #819

Merged
merged 4 commits into from May 16, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
22 changes: 19 additions & 3 deletions src/dimension/mod.rs
Expand Up @@ -139,6 +139,14 @@ pub fn can_index_slice_not_custom<A, D: Dimension>(data: &[A], dim: &D) -> Resul
/// also implies that the length of any individual axis does not exceed
/// `isize::MAX`.)
pub fn max_abs_offset_check_overflow<A, D>(dim: &D, strides: &D) -> Result<usize, ShapeError>
where
D: Dimension,
{
max_abs_offset_check_overflow_impl(mem::size_of::<A>(), dim, strides)
}

fn max_abs_offset_check_overflow_impl<D>(elem_size: usize, dim: &D, strides: &D)
-> Result<usize, ShapeError>
where
D: Dimension,
{
Expand Down Expand Up @@ -168,7 +176,7 @@ where
// Determine absolute difference in units of bytes between least and
// greatest address accessible by moving along all axes
let max_offset_bytes = max_offset
.checked_mul(mem::size_of::<A>())
.checked_mul(elem_size)
.ok_or_else(|| from_kind(ErrorKind::Overflow))?;
// Condition 2b.
if max_offset_bytes > isize::MAX as usize {
Expand Down Expand Up @@ -216,13 +224,21 @@ pub fn can_index_slice<A, D: Dimension>(
) -> Result<(), ShapeError> {
// Check conditions 1 and 2 and calculate `max_offset`.
let max_offset = max_abs_offset_check_overflow::<A, _>(dim, strides)?;
can_index_slice_impl(max_offset, data.len(), dim, strides)
}

fn can_index_slice_impl<D: Dimension>(
max_offset: usize,
data_len: usize,
dim: &D,
strides: &D,
) -> Result<(), ShapeError> {
// Check condition 4.
let is_empty = dim.slice().iter().any(|&d| d == 0);
if is_empty && max_offset > data.len() {
if is_empty && max_offset > data_len {
return Err(from_kind(ErrorKind::OutOfBounds));
}
if !is_empty && max_offset >= data.len() {
if !is_empty && max_offset >= data_len {
return Err(from_kind(ErrorKind::OutOfBounds));
}

Expand Down
7 changes: 5 additions & 2 deletions src/iterators/mod.rs
Expand Up @@ -79,15 +79,18 @@ impl<A, D: Dimension> Iterator for Baseiter<A, D> {
let ndim = self.dim.ndim();
debug_assert_ne!(ndim, 0);
let mut accum = init;
while let Some(mut index) = self.index.clone() {
while let Some(mut index) = self.index {
let stride = self.strides.last_elem() as isize;
let elem_index = index.last_elem();
let len = self.dim.last_elem();
let offset = D::stride_offset(&index, &self.strides);
unsafe {
let row_ptr = self.ptr.offset(offset);
for i in 0..(len - elem_index) {
let mut i = 0;
let i_end = len - elem_index;
while i < i_end {
accum = g(accum, row_ptr.offset(i as isize * stride));
i += 1;
}
}
index.set_last_elem(len - 1);
Expand Down
44 changes: 28 additions & 16 deletions src/zip/mod.rs
Expand Up @@ -723,7 +723,7 @@ where
}
}

fn apply_core_contiguous<F, Acc>(&mut self, mut acc: Acc, mut function: F) -> FoldWhile<Acc>
fn apply_core_contiguous<F, Acc>(&mut self, acc: Acc, mut function: F) -> FoldWhile<Acc>
where
F: FnMut(Acc, P::Item) -> FoldWhile<Acc>,
P: ZippableTuple<Dim = D>,
Expand All @@ -732,15 +732,35 @@ where
let size = self.dimension.size();
let ptrs = self.parts.as_ptr();
let inner_strides = self.parts.contiguous_stride();
for i in 0..size {
unsafe {
let ptr_i = ptrs.stride_offset(inner_strides, i);
acc = fold_while![function(acc, self.parts.as_ref(ptr_i))];
}
unsafe {
self.inner(acc, ptrs, inner_strides, size, &mut function)
}
}

/// The innermost loop of the Zip apply methods
///
/// Run the fold while operation on a stretch of elements with constant strides
///
/// `ptr`: base pointer for the first element in this stretch
/// `strides`: strides for the elements in this stretch
/// `len`: number of elements
/// `function`: closure
unsafe fn inner<F, Acc>(&self, mut acc: Acc, ptr: P::Ptr, strides: P::Stride,
len: usize, function: &mut F) -> FoldWhile<Acc>
where
F: FnMut(Acc, P::Item) -> FoldWhile<Acc>,
P: ZippableTuple
{
let mut i = 0;
while i < len {
let p = ptr.stride_offset(strides, i);
acc = fold_while!(function(acc, self.parts.as_ref(p)));
i += 1;
}
FoldWhile::Continue(acc)
}


fn apply_core_strided<F, Acc>(&mut self, acc: Acc, function: F) -> FoldWhile<Acc>
where
F: FnMut(Acc, P::Item) -> FoldWhile<Acc>,
Expand Down Expand Up @@ -773,15 +793,11 @@ where
while let Some(index) = index_ {
unsafe {
let ptr = self.parts.uget_ptr(&index);
for i in 0..inner_len {
let p = ptr.stride_offset(inner_strides, i);
acc = fold_while!(function(acc, self.parts.as_ref(p)));
}
acc = fold_while![self.inner(acc, ptr, inner_strides, inner_len, &mut function)];
}

index_ = self.dimension.next_for(index);
}
self.dimension[unroll_axis] = inner_len;
FoldWhile::Continue(acc)
}

Expand All @@ -801,18 +817,14 @@ where
loop {
unsafe {
let ptr = self.parts.uget_ptr(&index);
for i in 0..inner_len {
let p = ptr.stride_offset(inner_strides, i);
acc = fold_while!(function(acc, self.parts.as_ref(p)));
}
acc = fold_while![self.inner(acc, ptr, inner_strides, inner_len, &mut function)];
}

if !self.dimension.next_for_f(&mut index) {
break;
}
}
}
self.dimension[unroll_axis] = inner_len;
FoldWhile::Continue(acc)
}

Expand Down