diff --git a/src/dimension/mod.rs b/src/dimension/mod.rs index 212197fed..9ee1f44d6 100644 --- a/src/dimension/mod.rs +++ b/src/dimension/mod.rs @@ -19,6 +19,8 @@ pub use self::dynindeximpl::IxDynImpl; pub use self::ndindex::NdIndex; pub use self::remove_axis::RemoveAxis; +use crate::shape_builder::Strides; + use std::isize; use std::mem; @@ -114,11 +116,24 @@ pub fn size_of_shape_checked(dim: &D) -> Result /// conditions 1 and 2 are sufficient to guarantee that the offset in units of /// `A` and in units of bytes between the least address and greatest address /// accessible by moving along all axes does not exceed `isize::MAX`. -pub fn can_index_slice_not_custom(data: &[A], dim: &D) -> Result<(), ShapeError> { +pub(crate) fn can_index_slice_with_strides(data: &[A], dim: &D, + strides: &Strides) + -> Result<(), ShapeError> +{ + if let Strides::Custom(strides) = strides { + can_index_slice(data, dim, strides) + } else { + can_index_slice_not_custom(data.len(), dim) + } +} + +pub(crate) fn can_index_slice_not_custom(data_len: usize, dim: &D) + -> Result<(), ShapeError> +{ // Condition 1. let len = size_of_shape_checked(dim)?; // Condition 2. - if len > data.len() { + if len > data_len { return Err(from_kind(ErrorKind::OutOfBounds)); } Ok(()) @@ -217,7 +232,7 @@ where /// condition 4 is sufficient to guarantee that the absolute difference in /// units of `A` and in units of bytes between the least address and greatest /// address accessible by moving along all axes does not exceed `isize::MAX`. -pub fn can_index_slice( +pub(crate) fn can_index_slice( data: &[A], dim: &D, strides: &D, @@ -771,7 +786,7 @@ mod test { quickcheck! { fn can_index_slice_not_custom_same_as_can_index_slice(data: Vec, dim: Vec) -> bool { let dim = IxDyn(&dim); - let result = can_index_slice_not_custom(&data, &dim); + let result = can_index_slice_not_custom(data.len(), &dim); if dim.size_checked().is_none() { // Avoid overflow `dim.default_strides()` or `dim.fortran_strides()`. result.is_err() diff --git a/src/impl_constructors.rs b/src/impl_constructors.rs index 8d1471e78..4353a827e 100644 --- a/src/impl_constructors.rs +++ b/src/impl_constructors.rs @@ -358,7 +358,7 @@ where { let shape = shape.into_shape(); let _ = size_of_shape_checked_unwrap!(&shape.dim); - if shape.is_c { + if shape.is_c() { let v = to_vec_mapped(indices(shape.dim.clone()).into_iter(), f); unsafe { Self::from_shape_vec_unchecked(shape, v) } } else { @@ -411,15 +411,12 @@ where fn from_shape_vec_impl(shape: StrideShape, v: Vec) -> Result { let dim = shape.dim; - let strides = shape.strides; - if shape.custom { - dimension::can_index_slice(&v, &dim, &strides)?; - } else { - dimension::can_index_slice_not_custom::(&v, &dim)?; - if dim.size() != v.len() { - return Err(error::incompatible_shapes(&Ix1(v.len()), &dim)); - } + let is_custom = shape.strides.is_custom(); + dimension::can_index_slice_with_strides(&v, &dim, &shape.strides)?; + if !is_custom && dim.size() != v.len() { + return Err(error::incompatible_shapes(&Ix1(v.len()), &dim)); } + let strides = shape.strides.strides_for_dim(&dim); unsafe { Ok(Self::from_vec_dim_stride_unchecked(dim, strides, v)) } } @@ -451,7 +448,9 @@ where Sh: Into>, { let shape = shape.into(); - Self::from_vec_dim_stride_unchecked(shape.dim, shape.strides, v) + let dim = shape.dim; + let strides = shape.strides.strides_for_dim(&dim); + Self::from_vec_dim_stride_unchecked(dim, strides, v) } unsafe fn from_vec_dim_stride_unchecked(dim: D, strides: D, mut v: Vec) -> Self { diff --git a/src/impl_raw_views.rs b/src/impl_raw_views.rs index 22e76277a..8377dc93c 100644 --- a/src/impl_raw_views.rs +++ b/src/impl_raw_views.rs @@ -4,7 +4,8 @@ use std::ptr::NonNull; use crate::dimension::{self, stride_offset}; use crate::extension::nonnull::nonnull_debug_checked_from_ptr; use crate::imp_prelude::*; -use crate::{is_aligned, StrideShape}; +use crate::is_aligned; +use crate::shape_builder::{Strides, StrideShape}; impl RawArrayView where @@ -69,11 +70,15 @@ where { let shape = shape.into(); let dim = shape.dim; - let strides = shape.strides; if cfg!(debug_assertions) { assert!(!ptr.is_null(), "The pointer must be non-null."); - dimension::max_abs_offset_check_overflow::(&dim, &strides).unwrap(); + if let Strides::Custom(strides) = &shape.strides { + dimension::max_abs_offset_check_overflow::(&dim, &strides).unwrap(); + } else { + dimension::size_of_shape_checked(&dim).unwrap(); + } } + let strides = shape.strides.strides_for_dim(&dim); RawArrayView::new_(ptr, dim, strides) } @@ -205,11 +210,15 @@ where { let shape = shape.into(); let dim = shape.dim; - let strides = shape.strides; if cfg!(debug_assertions) { assert!(!ptr.is_null(), "The pointer must be non-null."); - dimension::max_abs_offset_check_overflow::(&dim, &strides).unwrap(); + if let Strides::Custom(strides) = &shape.strides { + dimension::max_abs_offset_check_overflow::(&dim, &strides).unwrap(); + } else { + dimension::size_of_shape_checked(&dim).unwrap(); + } } + let strides = shape.strides.strides_for_dim(&dim); RawArrayViewMut::new_(ptr, dim, strides) } diff --git a/src/impl_views/constructors.rs b/src/impl_views/constructors.rs index e2244dd09..bc1602af4 100644 --- a/src/impl_views/constructors.rs +++ b/src/impl_views/constructors.rs @@ -53,12 +53,8 @@ where fn from_shape_impl(shape: StrideShape, xs: &'a [A]) -> Result { let dim = shape.dim; - let strides = shape.strides; - if shape.custom { - dimension::can_index_slice(xs, &dim, &strides)?; - } else { - dimension::can_index_slice_not_custom::(xs, &dim)?; - } + dimension::can_index_slice_with_strides(xs, &dim, &shape.strides)?; + let strides = shape.strides.strides_for_dim(&dim); unsafe { Ok(Self::new_(xs.as_ptr(), dim, strides)) } } @@ -149,12 +145,8 @@ where fn from_shape_impl(shape: StrideShape, xs: &'a mut [A]) -> Result { let dim = shape.dim; - let strides = shape.strides; - if shape.custom { - dimension::can_index_slice(xs, &dim, &strides)?; - } else { - dimension::can_index_slice_not_custom::(xs, &dim)?; - } + dimension::can_index_slice_with_strides(xs, &dim, &shape.strides)?; + let strides = shape.strides.strides_for_dim(&dim); unsafe { Ok(Self::new_(xs.as_mut_ptr(), dim, strides)) } } diff --git a/src/lib.rs b/src/lib.rs index 7e79b4cbf..f394413db 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -138,7 +138,7 @@ pub use crate::linalg_traits::{LinalgScalar, NdFloat}; pub use crate::stacking::{concatenate, stack, stack_new_axis}; pub use crate::impl_views::IndexLonger; -pub use crate::shape_builder::ShapeBuilder; +pub use crate::shape_builder::{Shape, StrideShape, ShapeBuilder}; #[macro_use] mod macro_utils; @@ -1595,24 +1595,8 @@ mod impl_raw_views; // Copy-on-write array methods mod impl_cow; -/// A contiguous array shape of n dimensions. -/// -/// Either c- or f- memory ordered (*c* a.k.a *row major* is the default). -#[derive(Copy, Clone, Debug)] -pub struct Shape { - dim: D, - is_c: bool, -} - -/// An array shape of n dimensions in c-order, f-order or custom strides. -#[derive(Copy, Clone, Debug)] -pub struct StrideShape { - dim: D, - strides: D, - custom: bool, -} - /// Returns `true` if the pointer is aligned. pub(crate) fn is_aligned(ptr: *const T) -> bool { (ptr as usize) % ::std::mem::align_of::() == 0 } + diff --git a/src/shape_builder.rs b/src/shape_builder.rs index bb5a949ab..6fc99d0b2 100644 --- a/src/shape_builder.rs +++ b/src/shape_builder.rs @@ -1,6 +1,66 @@ use crate::dimension::IntoDimension; use crate::Dimension; -use crate::{Shape, StrideShape}; + +/// A contiguous array shape of n dimensions. +/// +/// Either c- or f- memory ordered (*c* a.k.a *row major* is the default). +#[derive(Copy, Clone, Debug)] +pub struct Shape { + /// Shape (axis lengths) + pub(crate) dim: D, + /// Strides can only be C or F here + pub(crate) strides: Strides, +} + +#[derive(Copy, Clone, Debug)] +pub(crate) enum Contiguous { } + +impl Shape { + pub(crate) fn is_c(&self) -> bool { + matches!(self.strides, Strides::C) + } +} + + +/// An array shape of n dimensions in c-order, f-order or custom strides. +#[derive(Copy, Clone, Debug)] +pub struct StrideShape { + pub(crate) dim: D, + pub(crate) strides: Strides, +} + +/// Stride description +#[derive(Copy, Clone, Debug)] +pub(crate) enum Strides { + /// Row-major ("C"-order) + C, + /// Column-major ("F"-order) + F, + /// Custom strides + Custom(D) +} + +impl Strides { + /// Return strides for `dim` (computed from dimension if c/f, else return the custom stride) + pub(crate) fn strides_for_dim(self, dim: &D) -> D + where D: Dimension + { + match self { + Strides::C => dim.default_strides(), + Strides::F => dim.fortran_strides(), + Strides::Custom(c) => { + debug_assert_eq!(c.ndim(), dim.ndim(), + "Custom strides given with {} dimensions, expected {}", + c.ndim(), dim.ndim()); + c + } + } + } + + pub(crate) fn is_custom(&self) -> bool { + matches!(*self, Strides::Custom(_)) + } +} /// A trait for `Shape` and `D where D: Dimension` that allows /// customizing the memory layout (strides) of an array shape. @@ -34,36 +94,18 @@ where { fn from(value: T) -> Self { let shape = value.into_shape(); - let d = shape.dim; - let st = if shape.is_c { - d.default_strides() + let st = if shape.is_c() { + Strides::C } else { - d.fortran_strides() + Strides::F }; StrideShape { strides: st, - dim: d, - custom: false, + dim: shape.dim, } } } -/* -impl From> for StrideShape - where D: Dimension -{ - fn from(shape: Shape) -> Self { - let d = shape.dim; - let st = if shape.is_c { d.default_strides() } else { d.fortran_strides() }; - StrideShape { - strides: st, - dim: d, - custom: false, - } - } -} -*/ - impl ShapeBuilder for T where T: IntoDimension, @@ -73,7 +115,7 @@ where fn into_shape(self) -> Shape { Shape { dim: self.into_dimension(), - is_c: true, + strides: Strides::C, } } fn f(self) -> Shape { @@ -93,21 +135,24 @@ where { type Dim = D; type Strides = D; + fn into_shape(self) -> Shape { self } + fn f(self) -> Self { self.set_f(true) } + fn set_f(mut self, is_f: bool) -> Self { - self.is_c = !is_f; + self.strides = if !is_f { Strides::C } else { Strides::F }; self } + fn strides(self, st: D) -> StrideShape { StrideShape { dim: self.dim, - strides: st, - custom: true, + strides: Strides::Custom(st), } } } diff --git a/tests/array-construct.rs b/tests/array-construct.rs index 97e4ef491..738e3b1fc 100644 --- a/tests/array-construct.rs +++ b/tests/array-construct.rs @@ -148,6 +148,7 @@ fn test_from_fn_f3() { fn deny_wraparound_from_vec() { let five = vec![0; 5]; let five_large = Array::from_shape_vec((3, 7, 29, 36760123, 823996703), five.clone()); + println!("{:?}", five_large); assert!(five_large.is_err()); let six = Array::from_shape_vec(6, five.clone()); assert!(six.is_err());