diff --git a/src/dimension/axes.rs b/src/dimension/axes.rs index f40ffa48c..f087eb5c5 100644 --- a/src/dimension/axes.rs +++ b/src/dimension/axes.rs @@ -75,6 +75,15 @@ impl<'a, D> Iterator for Axes<'a, D> } } + fn fold(self, init: B, f: F) -> B + where + F: FnMut(B, AxisDescription) -> B, + { + (self.start..self.end) + .map(move |i| AxisDescription(Axis(i), self.dim[i], self.strides[i] as isize)) + .fold(init, f) + } + fn size_hint(&self) -> (usize, Option) { let len = self.end - self.start; (len, Some(len)) diff --git a/src/impl_methods.rs b/src/impl_methods.rs index d7b2547a0..de74af795 100644 --- a/src/impl_methods.rs +++ b/src/impl_methods.rs @@ -1854,13 +1854,25 @@ where } else { let mut v = self.view(); // put the narrowest axis at the last position - if v.ndim() > 1 { - let last = v.ndim() - 1; - let narrow_axis = v.axes() - .filter(|ax| ax.len() > 1) - .min_by_key(|ax| ax.stride().abs()) - .map_or(last, |ax| ax.axis().index()); - v.swap_axes(last, narrow_axis); + match v.ndim() { + 0 | 1 => {} + 2 => { + if self.len_of(Axis(1)) <= 1 + || self.len_of(Axis(0)) > 1 + && self.stride_of(Axis(0)).abs() < self.stride_of(Axis(1)).abs() + { + v.swap_axes(0, 1); + } + } + n => { + let last = n - 1; + let narrow_axis = v + .axes() + .filter(|ax| ax.len() > 1) + .min_by_key(|ax| ax.stride().abs()) + .map_or(last, |ax| ax.axis().index()); + v.swap_axes(last, narrow_axis); + } } v.into_elements_base().fold(init, f) }