Skip to content

Commit

Permalink
Merge pull request #574 from jturner314/optimize-fold
Browse files Browse the repository at this point in the history
Improve performance of .fold()
  • Loading branch information
bluss committed Dec 15, 2018
2 parents 55aca3b + 1157763 commit 03552e2
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 7 deletions.
9 changes: 9 additions & 0 deletions src/dimension/axes.rs
Expand Up @@ -75,6 +75,15 @@ impl<'a, D> Iterator for Axes<'a, D>
}
}

fn fold<B, F>(self, init: B, f: F) -> B
where
F: FnMut(B, AxisDescription) -> B,
{
(self.start..self.end)
.map(move |i| AxisDescription(Axis(i), self.dim[i], self.strides[i] as isize))
.fold(init, f)
}

fn size_hint(&self) -> (usize, Option<usize>) {
let len = self.end - self.start;
(len, Some(len))
Expand Down
26 changes: 19 additions & 7 deletions src/impl_methods.rs
Expand Up @@ -1854,13 +1854,25 @@ where
} else {
let mut v = self.view();
// put the narrowest axis at the last position
if v.ndim() > 1 {
let last = v.ndim() - 1;
let narrow_axis = v.axes()
.filter(|ax| ax.len() > 1)
.min_by_key(|ax| ax.stride().abs())
.map_or(last, |ax| ax.axis().index());
v.swap_axes(last, narrow_axis);
match v.ndim() {
0 | 1 => {}
2 => {
if self.len_of(Axis(1)) <= 1
|| self.len_of(Axis(0)) > 1
&& self.stride_of(Axis(0)).abs() < self.stride_of(Axis(1)).abs()
{
v.swap_axes(0, 1);
}
}
n => {
let last = n - 1;
let narrow_axis = v
.axes()
.filter(|ax| ax.len() > 1)
.min_by_key(|ax| ax.stride().abs())
.map_or(last, |ax| ax.axis().index());
v.swap_axes(last, narrow_axis);
}
}
v.into_elements_base().fold(init, f)
}
Expand Down

0 comments on commit 03552e2

Please sign in to comment.