Skip to content

Commit

Permalink
Merge pull request #855 from rust-ndarray/checked-shape-strides
Browse files Browse the repository at this point in the history
Error-check array shape before computing strides
  • Loading branch information
bluss committed Dec 3, 2020
2 parents 143d5b4 + c195930 commit 3a2040d
Show file tree
Hide file tree
Showing 8 changed files with 122 additions and 77 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Expand Up @@ -49,7 +49,7 @@ jobs:
with:
profile: minimal
toolchain: ${{ matrix.rust }}
taret: ${{ matrix.target }}
target: ${{ matrix.target }}
override: true
- name: Cache cargo plugins
uses: actions/cache@v1
Expand Down
23 changes: 19 additions & 4 deletions src/dimension/mod.rs
Expand Up @@ -19,6 +19,8 @@ pub use self::dynindeximpl::IxDynImpl;
pub use self::ndindex::NdIndex;
pub use self::remove_axis::RemoveAxis;

use crate::shape_builder::Strides;

use std::isize;
use std::mem;

Expand Down Expand Up @@ -114,11 +116,24 @@ pub fn size_of_shape_checked<D: Dimension>(dim: &D) -> Result<usize, ShapeError>
/// conditions 1 and 2 are sufficient to guarantee that the offset in units of
/// `A` and in units of bytes between the least address and greatest address
/// accessible by moving along all axes does not exceed `isize::MAX`.
pub fn can_index_slice_not_custom<A, D: Dimension>(data: &[A], dim: &D) -> Result<(), ShapeError> {
pub(crate) fn can_index_slice_with_strides<A, D: Dimension>(data: &[A], dim: &D,
strides: &Strides<D>)
-> Result<(), ShapeError>
{
if let Strides::Custom(strides) = strides {
can_index_slice(data, dim, strides)
} else {
can_index_slice_not_custom(data.len(), dim)
}
}

pub(crate) fn can_index_slice_not_custom<D: Dimension>(data_len: usize, dim: &D)
-> Result<(), ShapeError>
{
// Condition 1.
let len = size_of_shape_checked(dim)?;
// Condition 2.
if len > data.len() {
if len > data_len {
return Err(from_kind(ErrorKind::OutOfBounds));
}
Ok(())
Expand Down Expand Up @@ -217,7 +232,7 @@ where
/// condition 4 is sufficient to guarantee that the absolute difference in
/// units of `A` and in units of bytes between the least address and greatest
/// address accessible by moving along all axes does not exceed `isize::MAX`.
pub fn can_index_slice<A, D: Dimension>(
pub(crate) fn can_index_slice<A, D: Dimension>(
data: &[A],
dim: &D,
strides: &D,
Expand Down Expand Up @@ -771,7 +786,7 @@ mod test {
quickcheck! {
fn can_index_slice_not_custom_same_as_can_index_slice(data: Vec<u8>, dim: Vec<usize>) -> bool {
let dim = IxDyn(&dim);
let result = can_index_slice_not_custom(&data, &dim);
let result = can_index_slice_not_custom(data.len(), &dim);
if dim.size_checked().is_none() {
// Avoid overflow `dim.default_strides()` or `dim.fortran_strides()`.
result.is_err()
Expand Down
19 changes: 9 additions & 10 deletions src/impl_constructors.rs
Expand Up @@ -358,7 +358,7 @@ where
{
let shape = shape.into_shape();
let _ = size_of_shape_checked_unwrap!(&shape.dim);
if shape.is_c {
if shape.is_c() {
let v = to_vec_mapped(indices(shape.dim.clone()).into_iter(), f);
unsafe { Self::from_shape_vec_unchecked(shape, v) }
} else {
Expand Down Expand Up @@ -411,15 +411,12 @@ where

fn from_shape_vec_impl(shape: StrideShape<D>, v: Vec<A>) -> Result<Self, ShapeError> {
let dim = shape.dim;
let strides = shape.strides;
if shape.custom {
dimension::can_index_slice(&v, &dim, &strides)?;
} else {
dimension::can_index_slice_not_custom::<A, _>(&v, &dim)?;
if dim.size() != v.len() {
return Err(error::incompatible_shapes(&Ix1(v.len()), &dim));
}
let is_custom = shape.strides.is_custom();
dimension::can_index_slice_with_strides(&v, &dim, &shape.strides)?;
if !is_custom && dim.size() != v.len() {
return Err(error::incompatible_shapes(&Ix1(v.len()), &dim));
}
let strides = shape.strides.strides_for_dim(&dim);
unsafe { Ok(Self::from_vec_dim_stride_unchecked(dim, strides, v)) }
}

Expand Down Expand Up @@ -451,7 +448,9 @@ where
Sh: Into<StrideShape<D>>,
{
let shape = shape.into();
Self::from_vec_dim_stride_unchecked(shape.dim, shape.strides, v)
let dim = shape.dim;
let strides = shape.strides.strides_for_dim(&dim);
Self::from_vec_dim_stride_unchecked(dim, strides, v)
}

unsafe fn from_vec_dim_stride_unchecked(dim: D, strides: D, mut v: Vec<A>) -> Self {
Expand Down
19 changes: 14 additions & 5 deletions src/impl_raw_views.rs
Expand Up @@ -4,7 +4,8 @@ use std::ptr::NonNull;
use crate::dimension::{self, stride_offset};
use crate::extension::nonnull::nonnull_debug_checked_from_ptr;
use crate::imp_prelude::*;
use crate::{is_aligned, StrideShape};
use crate::is_aligned;
use crate::shape_builder::{Strides, StrideShape};

impl<A, D> RawArrayView<A, D>
where
Expand Down Expand Up @@ -69,11 +70,15 @@ where
{
let shape = shape.into();
let dim = shape.dim;
let strides = shape.strides;
if cfg!(debug_assertions) {
assert!(!ptr.is_null(), "The pointer must be non-null.");
dimension::max_abs_offset_check_overflow::<A, _>(&dim, &strides).unwrap();
if let Strides::Custom(strides) = &shape.strides {
dimension::max_abs_offset_check_overflow::<A, _>(&dim, &strides).unwrap();
} else {
dimension::size_of_shape_checked(&dim).unwrap();
}
}
let strides = shape.strides.strides_for_dim(&dim);
RawArrayView::new_(ptr, dim, strides)
}

Expand Down Expand Up @@ -205,11 +210,15 @@ where
{
let shape = shape.into();
let dim = shape.dim;
let strides = shape.strides;
if cfg!(debug_assertions) {
assert!(!ptr.is_null(), "The pointer must be non-null.");
dimension::max_abs_offset_check_overflow::<A, _>(&dim, &strides).unwrap();
if let Strides::Custom(strides) = &shape.strides {
dimension::max_abs_offset_check_overflow::<A, _>(&dim, &strides).unwrap();
} else {
dimension::size_of_shape_checked(&dim).unwrap();
}
}
let strides = shape.strides.strides_for_dim(&dim);
RawArrayViewMut::new_(ptr, dim, strides)
}

Expand Down
16 changes: 4 additions & 12 deletions src/impl_views/constructors.rs
Expand Up @@ -53,12 +53,8 @@ where

fn from_shape_impl(shape: StrideShape<D>, xs: &'a [A]) -> Result<Self, ShapeError> {
let dim = shape.dim;
let strides = shape.strides;
if shape.custom {
dimension::can_index_slice(xs, &dim, &strides)?;
} else {
dimension::can_index_slice_not_custom::<A, _>(xs, &dim)?;
}
dimension::can_index_slice_with_strides(xs, &dim, &shape.strides)?;
let strides = shape.strides.strides_for_dim(&dim);
unsafe { Ok(Self::new_(xs.as_ptr(), dim, strides)) }
}

Expand Down Expand Up @@ -149,12 +145,8 @@ where

fn from_shape_impl(shape: StrideShape<D>, xs: &'a mut [A]) -> Result<Self, ShapeError> {
let dim = shape.dim;
let strides = shape.strides;
if shape.custom {
dimension::can_index_slice(xs, &dim, &strides)?;
} else {
dimension::can_index_slice_not_custom::<A, _>(xs, &dim)?;
}
dimension::can_index_slice_with_strides(xs, &dim, &shape.strides)?;
let strides = shape.strides.strides_for_dim(&dim);
unsafe { Ok(Self::new_(xs.as_mut_ptr(), dim, strides)) }
}

Expand Down
20 changes: 2 additions & 18 deletions src/lib.rs
Expand Up @@ -138,7 +138,7 @@ pub use crate::linalg_traits::{LinalgScalar, NdFloat};
pub use crate::stacking::{concatenate, stack, stack_new_axis};

pub use crate::impl_views::IndexLonger;
pub use crate::shape_builder::ShapeBuilder;
pub use crate::shape_builder::{Shape, StrideShape, ShapeBuilder};

#[macro_use]
mod macro_utils;
Expand Down Expand Up @@ -1595,24 +1595,8 @@ mod impl_raw_views;
// Copy-on-write array methods
mod impl_cow;

/// A contiguous array shape of n dimensions.
///
/// Either c- or f- memory ordered (*c* a.k.a *row major* is the default).
#[derive(Copy, Clone, Debug)]
pub struct Shape<D> {
dim: D,
is_c: bool,
}

/// An array shape of n dimensions in c-order, f-order or custom strides.
#[derive(Copy, Clone, Debug)]
pub struct StrideShape<D> {
dim: D,
strides: D,
custom: bool,
}

/// Returns `true` if the pointer is aligned.
pub(crate) fn is_aligned<T>(ptr: *const T) -> bool {
(ptr as usize) % ::std::mem::align_of::<T>() == 0
}

99 changes: 72 additions & 27 deletions src/shape_builder.rs
@@ -1,6 +1,66 @@
use crate::dimension::IntoDimension;
use crate::Dimension;
use crate::{Shape, StrideShape};

/// A contiguous array shape of n dimensions.
///
/// Either c- or f- memory ordered (*c* a.k.a *row major* is the default).
#[derive(Copy, Clone, Debug)]
pub struct Shape<D> {
/// Shape (axis lengths)
pub(crate) dim: D,
/// Strides can only be C or F here
pub(crate) strides: Strides<Contiguous>,
}

#[derive(Copy, Clone, Debug)]
pub(crate) enum Contiguous { }

impl<D> Shape<D> {
pub(crate) fn is_c(&self) -> bool {
matches!(self.strides, Strides::C)
}
}


/// An array shape of n dimensions in c-order, f-order or custom strides.
#[derive(Copy, Clone, Debug)]
pub struct StrideShape<D> {
pub(crate) dim: D,
pub(crate) strides: Strides<D>,
}

/// Stride description
#[derive(Copy, Clone, Debug)]
pub(crate) enum Strides<D> {
/// Row-major ("C"-order)
C,
/// Column-major ("F"-order)
F,
/// Custom strides
Custom(D)
}

impl<D> Strides<D> {
/// Return strides for `dim` (computed from dimension if c/f, else return the custom stride)
pub(crate) fn strides_for_dim(self, dim: &D) -> D
where D: Dimension
{
match self {
Strides::C => dim.default_strides(),
Strides::F => dim.fortran_strides(),
Strides::Custom(c) => {
debug_assert_eq!(c.ndim(), dim.ndim(),
"Custom strides given with {} dimensions, expected {}",
c.ndim(), dim.ndim());
c
}
}
}

pub(crate) fn is_custom(&self) -> bool {
matches!(*self, Strides::Custom(_))
}
}

/// A trait for `Shape` and `D where D: Dimension` that allows
/// customizing the memory layout (strides) of an array shape.
Expand Down Expand Up @@ -34,36 +94,18 @@ where
{
fn from(value: T) -> Self {
let shape = value.into_shape();
let d = shape.dim;
let st = if shape.is_c {
d.default_strides()
let st = if shape.is_c() {
Strides::C
} else {
d.fortran_strides()
Strides::F
};
StrideShape {
strides: st,
dim: d,
custom: false,
dim: shape.dim,
}
}
}

/*
impl<D> From<Shape<D>> for StrideShape<D>
where D: Dimension
{
fn from(shape: Shape<D>) -> Self {
let d = shape.dim;
let st = if shape.is_c { d.default_strides() } else { d.fortran_strides() };
StrideShape {
strides: st,
dim: d,
custom: false,
}
}
}
*/

impl<T> ShapeBuilder for T
where
T: IntoDimension,
Expand All @@ -73,7 +115,7 @@ where
fn into_shape(self) -> Shape<Self::Dim> {
Shape {
dim: self.into_dimension(),
is_c: true,
strides: Strides::C,
}
}
fn f(self) -> Shape<Self::Dim> {
Expand All @@ -93,21 +135,24 @@ where
{
type Dim = D;
type Strides = D;

fn into_shape(self) -> Shape<D> {
self
}

fn f(self) -> Self {
self.set_f(true)
}

fn set_f(mut self, is_f: bool) -> Self {
self.is_c = !is_f;
self.strides = if !is_f { Strides::C } else { Strides::F };
self
}

fn strides(self, st: D) -> StrideShape<D> {
StrideShape {
dim: self.dim,
strides: st,
custom: true,
strides: Strides::Custom(st),
}
}
}
Expand Down
1 change: 1 addition & 0 deletions tests/array-construct.rs
Expand Up @@ -148,6 +148,7 @@ fn test_from_fn_f3() {
fn deny_wraparound_from_vec() {
let five = vec![0; 5];
let five_large = Array::from_shape_vec((3, 7, 29, 36760123, 823996703), five.clone());
println!("{:?}", five_large);
assert!(five_large.is_err());
let six = Array::from_shape_vec(6, five.clone());
assert!(six.is_err());
Expand Down

0 comments on commit 3a2040d

Please sign in to comment.