Skip to content

Commit

Permalink
Use MaybeUninitSubst and Zip to avoid uninitialized(), rename BroadCa…
Browse files Browse the repository at this point in the history
…stOutput to Output, remove zip_mut_from_pair()
  • Loading branch information
SparrowLii committed Jan 26, 2021
1 parent 7c64bfb commit 5eec58d
Show file tree
Hide file tree
Showing 7 changed files with 109 additions and 123 deletions.
27 changes: 23 additions & 4 deletions src/data_traits.rs
Expand Up @@ -9,14 +9,12 @@
//! The data (inner representation) traits for ndarray

use rawpointer::PointerExt;
use std::mem::{self, size_of};
use std::mem::{self, size_of, MaybeUninit};
use std::ptr::NonNull;
use alloc::sync::Arc;
use alloc::vec::Vec;

use crate::{
ArrayBase, CowRepr, Dimension, OwnedArcRepr, OwnedRepr, RawViewRepr, ViewRepr,
};
use crate::{ArrayBase, CowRepr, Dimension, OwnedArcRepr, OwnedRepr, RawViewRepr, ViewRepr};

/// Array representation trait.
///
Expand Down Expand Up @@ -401,6 +399,7 @@ unsafe impl<'a, A> DataMut for ViewRepr<&'a mut A> {}
///
/// ***Internal trait, see `Data`.***
pub unsafe trait DataOwned: Data {

#[doc(hidden)]
fn new(elements: Vec<Self::Elem>) -> Self;

Expand All @@ -424,6 +423,7 @@ unsafe impl<A> DataOwned for OwnedRepr<A> {
fn new(elements: Vec<A>) -> Self {
OwnedRepr::from(elements)
}

fn into_shared(self) -> OwnedArcRepr<A> {
OwnedArcRepr(Arc::new(self))
}
Expand Down Expand Up @@ -605,3 +605,22 @@ impl<'a, A: 'a, B: 'a> RawDataSubst<B> for ViewRepr<&'a mut A> {
ViewRepr::new()
}
}

/// Array representation trait.
///
/// The MaybeUninitSubst trait maps the MaybeUninit type of element, while
/// mapping the MaybeUninit type back to origin element type.
///
/// For example, `MaybeUninitSubst` can map the type `OwnedRepr<A>` to `OwnedRepr<MaybeUninit<A>>`,
/// and use `Output as RawDataSubst` to map `OwnedRepr<MaybeUninit<A>>` back to `OwnedRepr<A>`.
pub trait MaybeUninitSubst<A>: DataOwned<Elem = A> {
type Output: DataOwned<Elem = MaybeUninit<A>> + RawDataSubst<A, Output=Self, Elem = MaybeUninit<A>>;
}

impl<A> MaybeUninitSubst<A> for OwnedRepr<A> {
type Output = OwnedRepr<MaybeUninit<A>>;
}

impl<A> MaybeUninitSubst<A> for OwnedArcRepr<A> {
type Output = OwnedArcRepr<MaybeUninit<A>>;
}
22 changes: 11 additions & 11 deletions src/dimension/broadcast.rs
Expand Up @@ -43,43 +43,43 @@ where

pub trait BroadcastShape<Other: Dimension> {
/// The resulting dimension type after broadcasting.
type BroadcastOutput: Dimension;
type Output: Dimension;

/// Determines the shape after broadcasting the dimensions together.
///
/// If the dimensions are not compatible, returns `Err`.
///
/// Uses the [NumPy broadcasting rules]
/// (https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html#general-broadcasting-rules).
fn broadcast_shape(&self, other: &Other) -> Result<Self::BroadcastOutput, ShapeError>;
fn broadcast_shape(&self, other: &Other) -> Result<Self::Output, ShapeError>;
}

/// Dimensions of the same type remain unchanged when co_broadcast.
/// So you can directly use D as the resulting type.
/// (Instead of <D as BroadcastShape<D>>::BroadcastOutput)
impl<D: Dimension> BroadcastShape<D> for D {
type BroadcastOutput = D;
type Output = D;

fn broadcast_shape(&self, other: &D) -> Result<Self::BroadcastOutput, ShapeError> {
broadcast_shape::<D, D, Self::BroadcastOutput>(self, other)
fn broadcast_shape(&self, other: &D) -> Result<Self::Output, ShapeError> {
broadcast_shape::<D, D, Self::Output>(self, other)
}
}

macro_rules! impl_broadcast_distinct_fixed {
($smaller:ty, $larger:ty) => {
impl BroadcastShape<$larger> for $smaller {
type BroadcastOutput = $larger;
type Output = $larger;

fn broadcast_shape(&self, other: &$larger) -> Result<Self::BroadcastOutput, ShapeError> {
broadcast_shape::<Self, $larger, Self::BroadcastOutput>(self, other)
fn broadcast_shape(&self, other: &$larger) -> Result<Self::Output, ShapeError> {
broadcast_shape::<Self, $larger, Self::Output>(self, other)
}
}

impl BroadcastShape<$smaller> for $larger {
type BroadcastOutput = $larger;
type Output = $larger;

fn broadcast_shape(&self, other: &$smaller) -> Result<Self::BroadcastOutput, ShapeError> {
broadcast_shape::<Self, $smaller, Self::BroadcastOutput>(self, other)
fn broadcast_shape(&self, other: &$smaller) -> Result<Self::Output, ShapeError> {
broadcast_shape::<Self, $smaller, Self::Output>(self, other)
}
}
};
Expand Down
10 changes: 5 additions & 5 deletions src/dimension/dimension_trait.rs
Expand Up @@ -46,11 +46,11 @@ pub trait Dimension:
+ MulAssign
+ for<'x> MulAssign<&'x Self>
+ MulAssign<usize>
+ BroadcastShape<Ix0, BroadcastOutput=Self>
+ BroadcastShape<Self, BroadcastOutput=Self>
+ BroadcastShape<IxDyn, BroadcastOutput=IxDyn>
+ BroadcastShape<<Self as Dimension>::Smaller, BroadcastOutput=Self>
+ BroadcastShape<<Self as Dimension>::Larger, BroadcastOutput=<Self as Dimension>::Larger>
+ BroadcastShape<Ix0, Output=Self>
+ BroadcastShape<Self, Output=Self>
+ BroadcastShape<IxDyn, Output=IxDyn>
+ BroadcastShape<<Self as Dimension>::Smaller, Output=Self>
+ BroadcastShape<<Self as Dimension>::Larger, Output=<Self as Dimension>::Larger>
{
/// For fixed-size dimension representations (e.g. `Ix2`), this should be
/// `Some(ndim)`, and for variable-size dimension representations (e.g.
Expand Down
40 changes: 0 additions & 40 deletions src/impl_methods.rs
Expand Up @@ -1968,46 +1968,6 @@ where
self.zip_mut_with_by_rows(rhs, f);
}

/// Traverse two arrays in unspecified order, in lock step,
/// calling the closure `f` on each element pair, and put
/// the result into the corresponding element of self.
pub fn zip_mut_from_pair<B, C, S1, S2, F>(&mut self, lhs: &ArrayBase<S1, D>, rhs: &ArrayBase<S2, D>, f: F, )
where
S: DataMut,
S1: Data<Elem = B>,
S2: Data<Elem = C>,
F: Fn(&B, &C) -> A,
{
debug_assert_eq!(self.shape(), lhs.shape());
debug_assert_eq!(self.shape(), rhs.shape());

if self.dim.strides_equivalent(&self.strides, &lhs.strides)
&& self.dim.strides_equivalent(&self.strides, &rhs.strides)
{
if let Some(self_s) = self.as_slice_memory_order_mut() {
if let Some(lhs_s) = lhs.as_slice_memory_order() {
if let Some(rhs_s) = rhs.as_slice_memory_order() {
for (s, (l, r)) in
self_s.iter_mut().zip(lhs_s.iter().zip(rhs_s)) {
*s = f(&l, &r);
}
return;
}
}
}
}

// Otherwise, fall back to the outer iter
let n = self.ndim();
let dim = self.raw_dim();
Zip::from(LanesMut::new(self.view_mut(), Axis(n - 1)))
.and(Lanes::new(lhs.broadcast_assume(dim.clone()), Axis(n - 1)))
.and(Lanes::new(rhs.broadcast_assume(dim), Axis(n - 1)))
.for_each(move |s_row, l_row, r_row| {
Zip::from(s_row).and(l_row).and(r_row).for_each(|s, a, b| *s = f(a, b))
});
}

// zip two arrays where they have different layout or strides
#[inline(always)]
fn zip_mut_with_by_rows<B, S2, E, F>(&mut self, rhs: &ArrayBase<S2, E>, mut f: F)
Expand Down
129 changes: 68 additions & 61 deletions src/impl_ops.rs
Expand Up @@ -7,6 +7,8 @@
// except according to those terms.

use crate::dimension::BroadcastShape;
use crate::data_traits::MaybeUninitSubst;
use crate::Zip;
use num_complex::Complex;

/// Elements that can be used as direct operands in arithmetic with arrays.
Expand Down Expand Up @@ -64,14 +66,15 @@ macro_rules! impl_binary_op(
/// **Panics** if broadcasting isn’t possible.
impl<A, B, S, S2, D, E> $trt<ArrayBase<S2, E>> for ArrayBase<S, D>
where
A: Copy + $trt<B, Output=A>,
A: Clone + $trt<B, Output=A>,
B: Clone,
S: DataOwned<Elem=A> + DataMut,
S: DataOwned<Elem=A> + DataMut + MaybeUninitSubst<A>,
<S as MaybeUninitSubst<A>>::Output: DataMut,
S2: Data<Elem=B>,
D: Dimension + BroadcastShape<E>,
E: Dimension,
{
type Output = ArrayBase<S, <D as BroadcastShape<E>>::BroadcastOutput>;
type Output = ArrayBase<S, <D as BroadcastShape<E>>::Output>;
fn $mth(self, rhs: ArrayBase<S2, E>) -> Self::Output
{
self.$mth(&rhs)
Expand All @@ -80,7 +83,7 @@ where

/// Perform elementwise
#[doc=$doc]
/// between reference `self` and `rhs`,
/// between `self` and reference `rhs`,
/// and return the result.
///
/// `rhs` must be an `Array` or `ArcArray`.
Expand All @@ -89,44 +92,49 @@ where
/// cloning the data if needed.
///
/// **Panics** if broadcasting isn’t possible.
impl<'a, A, B, S, S2, D, E> $trt<ArrayBase<S2, E>> for &'a ArrayBase<S, D>
impl<'a, A, B, S, S2, D, E> $trt<&'a ArrayBase<S2, E>> for ArrayBase<S, D>
where
A: Clone + $trt<B, Output=B>,
B: Copy,
S: Data<Elem=A>,
S2: DataOwned<Elem=B> + DataMut,
D: Dimension,
E: Dimension + BroadcastShape<D>,
A: Clone + $trt<B, Output=A>,
B: Clone,
S: DataOwned<Elem=A> + DataMut + MaybeUninitSubst<A>,
<S as MaybeUninitSubst<A>>::Output: DataMut,
S2: Data<Elem=B>,
D: Dimension + BroadcastShape<E>,
E: Dimension,
{
type Output = ArrayBase<S2, <E as BroadcastShape<D>>::BroadcastOutput>;
fn $mth(self, rhs: ArrayBase<S2, E>) -> Self::Output
type Output = ArrayBase<S, <D as BroadcastShape<E>>::Output>;
fn $mth(self, rhs: &ArrayBase<S2, E>) -> Self::Output
{
let shape = rhs.dim.broadcast_shape(&self.dim).unwrap();
if shape.slice() == rhs.dim.slice() {
let mut out = rhs.into_dimensionality::<<E as BroadcastShape<D>>::BroadcastOutput>().unwrap();
out.zip_mut_with(self, |x, y| {
*x = y.clone() $operator x.clone();
let shape = self.dim.broadcast_shape(&rhs.dim).unwrap();
if shape.slice() == self.dim.slice() {
let mut out = self.into_dimensionality::<<D as BroadcastShape<E>>::Output>().unwrap();
out.zip_mut_with(rhs, |x, y| {
*x = x.clone() $operator y.clone();
});
out
} else {
// SAFETY: Overwrite all the elements in the array after
// it is created via `zip_mut_from_pair`.
let mut out = unsafe {
Self::Output::uninitialized(shape.clone().into_pattern())
};
let lhs = self.broadcast(shape.clone()).unwrap();
let rhs = rhs.broadcast(shape).unwrap();
out.zip_mut_from_pair(&lhs, &rhs, |x, y| {
x.clone() $operator y.clone()
});
out
let rhs = rhs.broadcast(shape.clone()).unwrap();
// SAFETY: Overwrite all the elements in the array after
// it is created via `raw_view_mut`.
unsafe {
let mut out =ArrayBase::<<S as MaybeUninitSubst<A>>::Output, <D as BroadcastShape<E>>::Output>::maybe_uninit(shape.into_pattern());
let output_view = out.raw_view_mut().cast::<A>();
Zip::from(&lhs).and(&rhs)
.and(output_view)
.collect_with_partial(|x, y| {
x.clone() $operator y.clone()
})
.release_ownership();
out.assume_init()
}
}
}
}

/// Perform elementwise
#[doc=$doc]
/// between `self` and reference `rhs`,
/// between reference `self` and `rhs`,
/// and return the result.
///
/// `rhs` must be an `Array` or `ArcArray`.
Expand All @@ -135,37 +143,43 @@ where
/// cloning the data if needed.
///
/// **Panics** if broadcasting isn’t possible.
impl<'a, A, B, S, S2, D, E> $trt<&'a ArrayBase<S2, E>> for ArrayBase<S, D>
impl<'a, A, B, S, S2, D, E> $trt<ArrayBase<S2, E>> for &'a ArrayBase<S, D>
where
A: Copy + $trt<B, Output=A>,
A: Clone + $trt<B, Output=B>,
B: Clone,
S: DataOwned<Elem=A> + DataMut,
S2: Data<Elem=B>,
D: Dimension + BroadcastShape<E>,
E: Dimension,
S: Data<Elem=A>,
S2: DataOwned<Elem=B> + DataMut + MaybeUninitSubst<B>,
<S2 as MaybeUninitSubst<B>>::Output: DataMut,
D: Dimension,
E: Dimension + BroadcastShape<D>,
{
type Output = ArrayBase<S, <D as BroadcastShape<E>>::BroadcastOutput>;
fn $mth(self, rhs: &ArrayBase<S2, E>) -> Self::Output
type Output = ArrayBase<S2, <E as BroadcastShape<D>>::Output>;
fn $mth(self, rhs: ArrayBase<S2, E>) -> Self::Output
where
{
let shape = self.dim.broadcast_shape(&rhs.dim).unwrap();
if shape.slice() == self.dim.slice() {
let mut out = self.into_dimensionality::<<D as BroadcastShape<E>>::BroadcastOutput>().unwrap();
out.zip_mut_with(rhs, |x, y| {
*x = x.clone() $operator y.clone();
let shape = rhs.dim.broadcast_shape(&self.dim).unwrap();
if shape.slice() == rhs.dim.slice() {
let mut out = rhs.into_dimensionality::<<E as BroadcastShape<D>>::Output>().unwrap();
out.zip_mut_with(self, |x, y| {
*x = y.clone() $operator x.clone();
});
out
} else {
// SAFETY: Overwrite all the elements in the array after
// it is created via `zip_mut_from_pair`.
let mut out = unsafe {
Self::Output::uninitialized(shape.clone().into_pattern())
};
let lhs = self.broadcast(shape.clone()).unwrap();
let rhs = rhs.broadcast(shape).unwrap();
out.zip_mut_from_pair(&lhs, &rhs, |x, y| {
x.clone() $operator y.clone()
});
out
let rhs = rhs.broadcast(shape.clone()).unwrap();
// SAFETY: Overwrite all the elements in the array after
// it is created via `raw_view_mut`.
unsafe {
let mut out =ArrayBase::<<S2 as MaybeUninitSubst<B>>::Output, <E as BroadcastShape<D>>::Output>::maybe_uninit(shape.into_pattern());
let output_view = out.raw_view_mut().cast::<B>();
Zip::from(&lhs).and(&rhs)
.and(output_view)
.collect_with_partial(|x, y| {
x.clone() $operator y.clone()
})
.release_ownership();
out.assume_init()
}
}
}
}
Expand All @@ -188,19 +202,12 @@ where
D: Dimension + BroadcastShape<E>,
E: Dimension,
{
type Output = Array<A, <D as BroadcastShape<E>>::BroadcastOutput>;
type Output = Array<A, <D as BroadcastShape<E>>::Output>;
fn $mth(self, rhs: &'a ArrayBase<S2, E>) -> Self::Output {
let shape = self.dim.broadcast_shape(&rhs.dim).unwrap();
// SAFETY: Overwrite all the elements in the array after
// it is created via `zip_mut_from_pair`.
let mut out = unsafe {
Self::Output::uninitialized(shape.clone().into_pattern())
};
let lhs = self.broadcast(shape.clone()).unwrap();
let rhs = rhs.broadcast(shape).unwrap();
out.zip_mut_from_pair(&lhs, &rhs, |x, y| {
x.clone() $operator y.clone()
});
let out = Zip::from(&lhs).and(&rhs).map_collect(|x, y| x.clone() $operator y.clone());
out
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/lib.rs
Expand Up @@ -179,7 +179,7 @@ pub use crate::aliases::*;

pub use crate::data_traits::{
Data, DataMut, DataOwned, DataShared, RawData, RawDataClone, RawDataMut,
RawDataSubst,
RawDataSubst, MaybeUninitSubst,
};

mod free_functions;
Expand Down
2 changes: 1 addition & 1 deletion tests/array.rs
Expand Up @@ -1563,7 +1563,7 @@ fn test_broadcast_shape() {
fn test_co<D1, D2>(
d1: &D1,
d2: &D2,
r: Result<<D1 as BroadcastShape<D2>>::BroadcastOutput, ShapeError>,
r: Result<<D1 as BroadcastShape<D2>>::Output, ShapeError>,
) where
D1: Dimension + BroadcastShape<D2>,
D2: Dimension,
Expand Down

0 comments on commit 5eec58d

Please sign in to comment.