Skip to content

Commit

Permalink
Remove duplication in from_shape_ptr and split_at
Browse files Browse the repository at this point in the history
  • Loading branch information
jturner314 committed Nov 19, 2018
1 parent c711bf4 commit 9960d06
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 78 deletions.
49 changes: 48 additions & 1 deletion src/impl_raw_views.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use dimension;
use dimension::{self, stride_offset};
use imp_prelude::*;
use {is_aligned, StrideShape};

Expand Down Expand Up @@ -87,6 +87,33 @@ where
pub unsafe fn deref_into_view<'a>(self) -> ArrayView<'a, A, D> {
ArrayView::new_(self.ptr, self.dim, self.strides)
}

/// Split the array view along `axis` and return one array pointer strictly
/// before the split and one array pointer after the split.
///
/// **Panics** if `axis` or `index` is out of bounds.
pub fn split_at(self, axis: Axis, index: Ix) -> (Self, Self) {
assert!(index <= self.len_of(axis));
let left_ptr = self.ptr;
let right_ptr = if index == self.len_of(axis) {
self.ptr
} else {
let offset = stride_offset(index, self.strides.axis(axis));
// The `.offset()` is safe due to the guarantees of `DataRaw`.
unsafe { self.ptr.offset(offset) }
};

let mut dim_left = self.dim.clone();
dim_left.set_axis(axis, index);
let left = unsafe { Self::new_(left_ptr, dim_left, self.strides.clone()) };

let mut dim_right = self.dim;
let right_len = dim_right.axis(axis) - index;
dim_right.set_axis(axis, right_len);
let right = unsafe { Self::new_(right_ptr, dim_right, self.strides) };

(left, right)
}
}

impl<A, D> RawArrayViewMut<A, D>
Expand Down Expand Up @@ -155,6 +182,12 @@ where
RawArrayViewMut::new_(ptr, dim, strides)
}

/// Converts to a non-mutable `RawArrayView`.
#[inline]
pub(crate) fn into_raw_view(self) -> RawArrayView<A, D> {
unsafe { RawArrayView::new_(self.ptr, self.dim, self.strides) }
}

/// Return a read-only view of the array
///
/// **Warning** from a safety standpoint, this is equivalent to
Expand Down Expand Up @@ -194,4 +227,18 @@ where
pub unsafe fn deref_into_view_mut<'a>(self) -> ArrayViewMut<'a, A, D> {
ArrayViewMut::new_(self.ptr, self.dim, self.strides)
}

/// Split the array view along `axis` and return one array pointer strictly
/// before the split and one array pointer after the split.
///
/// **Panics** if `axis` or `index` is out of bounds.
pub fn split_at(self, axis: Axis, index: Ix) -> (Self, Self) {
let (left, right) = self.into_raw_view().split_at(axis, index);
unsafe {
(
Self::new_(left.ptr, left.dim, left.strides),
Self::new_(right.ptr, right.dim, right.strides),
)
}
}
}
100 changes: 23 additions & 77 deletions src/impl_views.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
use std::slice;

use imp_prelude::*;
use dimension::{self, stride_offset};
use dimension;
use error::ShapeError;
use arraytraits::array_out_of_bounds;
use {is_aligned, NdIndex, StrideShape};
Expand Down Expand Up @@ -111,15 +111,7 @@ impl<'a, A, D> ArrayView<'a, A, D>
pub unsafe fn from_shape_ptr<Sh>(shape: Sh, ptr: *const A) -> Self
where Sh: Into<StrideShape<D>>
{
let shape = shape.into();
let dim = shape.dim;
let strides = shape.strides;
if cfg!(debug_assertions) {
assert!(!ptr.is_null(), "The pointer must be non-null.");
assert!(is_aligned(ptr), "The pointer must be aligned.");
dimension::max_abs_offset_check_overflow::<A, _>(&dim, &strides).unwrap();
}
ArrayView::new_(ptr, dim, strides)
RawArrayView::from_shape_ptr(shape, ptr).deref_into_view()
}

/// Convert the view into an `ArrayView<'b, A, D>` where `'b` is a lifetime
Expand All @@ -141,35 +133,11 @@ impl<'a, A, D> ArrayView<'a, A, D>
/// an array with shape 3 × 5 × 5.
///
/// <img src="https://rust-ndarray.github.io/ndarray/images/split_at.svg" width="300px" height="271px">
pub fn split_at(self, axis: Axis, index: Ix)
-> (Self, Self)
{
// NOTE: Keep this in sync with the ArrayViewMut version
assert!(index <= self.len_of(axis));
let left_ptr = self.ptr;
let right_ptr = if index == self.len_of(axis) {
self.ptr
} else {
let offset = stride_offset(index, self.strides.axis(axis));
unsafe {
self.ptr.offset(offset)
}
};

let mut dim_left = self.dim.clone();
dim_left.set_axis(axis, index);
let left = unsafe {
Self::new_(left_ptr, dim_left, self.strides.clone())
};

let mut dim_right = self.dim;
let right_len = dim_right.axis(axis) - index;
dim_right.set_axis(axis, right_len);
let right = unsafe {
Self::new_(right_ptr, dim_right, self.strides)
};

(left, right)
pub fn split_at(self, axis: Axis, index: Ix) -> (Self, Self) {
unsafe {
let (left, right) = self.into_raw_view().split_at(axis, index);
(left.deref_into_view(), right.deref_into_view())
}
}

/// Return the array’s data as a slice, if it is contiguous and in standard order.
Expand All @@ -183,6 +151,11 @@ impl<'a, A, D> ArrayView<'a, A, D>
None
}
}

/// Converts to a raw array view.
pub(crate) fn into_raw_view(self) -> RawArrayView<A, D> {
unsafe { RawArrayView::new_(self.ptr, self.dim, self.strides) }
}
}


Expand Down Expand Up @@ -408,15 +381,7 @@ impl<'a, A, D> ArrayViewMut<'a, A, D>
pub unsafe fn from_shape_ptr<Sh>(shape: Sh, ptr: *mut A) -> Self
where Sh: Into<StrideShape<D>>
{
let shape = shape.into();
let dim = shape.dim;
let strides = shape.strides;
if cfg!(debug_assertions) {
assert!(!ptr.is_null(), "The pointer must be non-null.");
assert!(is_aligned(ptr), "The pointer must be aligned.");
dimension::max_abs_offset_check_overflow::<A, _>(&dim, &strides).unwrap();
}
ArrayViewMut::new_(ptr, dim, strides)
RawArrayViewMut::from_shape_ptr(shape, ptr).deref_into_view_mut()
}

/// Convert the view into an `ArrayViewMut<'b, A, D>` where `'b` is a lifetime
Expand All @@ -433,35 +398,11 @@ impl<'a, A, D> ArrayViewMut<'a, A, D>
/// before the split and one mutable view after the split.
///
/// **Panics** if `axis` or `index` is out of bounds.
pub fn split_at(self, axis: Axis, index: Ix)
-> (Self, Self)
{
// NOTE: Keep this in sync with the ArrayView version
assert!(index <= self.len_of(axis));
let left_ptr = self.ptr;
let right_ptr = if index == self.len_of(axis) {
self.ptr
} else {
let offset = stride_offset(index, self.strides.axis(axis));
unsafe {
self.ptr.offset(offset)
}
};

let mut dim_left = self.dim.clone();
dim_left.set_axis(axis, index);
let left = unsafe {
Self::new_(left_ptr, dim_left, self.strides.clone())
};

let mut dim_right = self.dim;
let right_len = dim_right.axis(axis) - index;
dim_right.set_axis(axis, right_len);
let right = unsafe {
Self::new_(right_ptr, dim_right, self.strides)
};

(left, right)
pub fn split_at(self, axis: Axis, index: Ix) -> (Self, Self) {
unsafe {
let (left, right) = self.into_raw_view_mut().split_at(axis, index);
(left.deref_into_view_mut(), right.deref_into_view_mut())
}
}

/// Return the array’s data as a slice, if it is contiguous and in standard order.
Expand Down Expand Up @@ -605,6 +546,11 @@ impl<'a, A, D> ArrayViewMut<'a, A, D>
}
}

/// Converts to a mutable raw array view.
pub(crate) fn into_raw_view_mut(self) -> RawArrayViewMut<A, D> {
unsafe { RawArrayViewMut::new_(self.ptr, self.dim, self.strides) }
}

#[inline]
pub(crate) fn into_base_iter(self) -> Baseiter<A, D> {
unsafe {
Expand Down

0 comments on commit 9960d06

Please sign in to comment.