Skip to content

Commit

Permalink
rebase and use map_collect_owned in impl_ops.rs
Browse files Browse the repository at this point in the history
  • Loading branch information
SparrowLii committed Feb 1, 2021
1 parent 1055ebf commit 23342b1
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 68 deletions.
21 changes: 2 additions & 19 deletions src/data_traits.rs
Expand Up @@ -10,7 +10,8 @@

use rawpointer::PointerExt;

use std::mem::{self, size_of};use std::mem::MaybeUninit;
use std::mem::{self, size_of};
use std::mem::MaybeUninit;
use std::ptr::NonNull;
use alloc::sync::Arc;
use alloc::vec::Vec;
Expand Down Expand Up @@ -620,21 +621,3 @@ impl<'a, A: 'a, B: 'a> RawDataSubst<B> for ViewRepr<&'a mut A> {
}
}

/// Array representation trait.
///
/// The MaybeUninitSubst trait maps the MaybeUninit type of element, while
/// mapping the MaybeUninit type back to origin element type.
///
/// For example, `MaybeUninitSubst` can map the type `OwnedRepr<A>` to `OwnedRepr<MaybeUninit<A>>`,
/// and use `Output as RawDataSubst` to map `OwnedRepr<MaybeUninit<A>>` back to `OwnedRepr<A>`.
pub trait MaybeUninitSubst<A>: DataOwned<Elem = A> {
type Output: DataOwned<Elem = MaybeUninit<A>> + RawDataSubst<A, Output=Self, Elem = MaybeUninit<A>>;
}

impl<A> MaybeUninitSubst<A> for OwnedRepr<A> {
type Output = OwnedRepr<MaybeUninit<A>>;
}

impl<A> MaybeUninitSubst<A> for OwnedArcRepr<A> {
type Output = OwnedArcRepr<MaybeUninit<A>>;
}
2 changes: 1 addition & 1 deletion src/impl_methods.rs
Expand Up @@ -1961,7 +1961,7 @@ where
self.unordered_foreach_mut(move |elt| *elt = x.clone());
}

fn zip_mut_with_same_shape<B, S2, E, F>(&mut self, rhs: &ArrayBase<S2, E>, mut f: F)
pub(crate) fn zip_mut_with_same_shape<B, S2, E, F>(&mut self, rhs: &ArrayBase<S2, E>, mut f: F)
where
S: DataMut,
S2: Data<Elem = B>,
Expand Down
76 changes: 29 additions & 47 deletions src/impl_ops.rs
Expand Up @@ -7,7 +7,6 @@
// except according to those terms.

use crate::dimension::BroadcastShape;
use crate::data_traits::MaybeUninitSubst;
use crate::Zip;
use num_complex::Complex;

Expand Down Expand Up @@ -68,8 +67,8 @@ impl<A, B, S, S2, D, E> $trt<ArrayBase<S2, E>> for ArrayBase<S, D>
where
A: Clone + $trt<B, Output=A>,
B: Clone,
S: DataOwned<Elem=A> + DataMut + MaybeUninitSubst<A>,
<S as MaybeUninitSubst<A>>::Output: DataMut,
S: DataOwned<Elem=A> + DataMut,
S::MaybeUninit: DataMut,
S2: Data<Elem=B>,
D: Dimension + BroadcastShape<E>,
E: Dimension,
Expand All @@ -96,38 +95,24 @@ impl<'a, A, B, S, S2, D, E> $trt<&'a ArrayBase<S2, E>> for ArrayBase<S, D>
where
A: Clone + $trt<B, Output=A>,
B: Clone,
S: DataOwned<Elem=A> + DataMut + MaybeUninitSubst<A>,
<S as MaybeUninitSubst<A>>::Output: DataMut,
S: DataOwned<Elem=A> + DataMut,
S::MaybeUninit: DataMut,
S2: Data<Elem=B>,
D: Dimension + BroadcastShape<E>,
E: Dimension,
{
type Output = ArrayBase<S, <D as BroadcastShape<E>>::Output>;
fn $mth(self, rhs: &ArrayBase<S2, E>) -> Self::Output
{
let shape = self.dim.broadcast_shape(&rhs.dim).unwrap();
if shape.slice() == self.dim.slice() {
if self.ndim() == rhs.ndim() && self.shape() == rhs.shape() {
let mut out = self.into_dimensionality::<<D as BroadcastShape<E>>::Output>().unwrap();
out.zip_mut_with(rhs, |x, y| {
*x = x.clone() $operator y.clone();
});
out.zip_mut_with_same_shape(rhs, clone_iopf(A::$mth));
out
} else {
let shape = self.dim.broadcast_shape(&rhs.dim).unwrap();
let lhs = self.broadcast(shape.clone()).unwrap();
let rhs = rhs.broadcast(shape.clone()).unwrap();
// SAFETY: Overwrite all the elements in the array after
// it is created via `raw_view_mut`.
unsafe {
let mut out =ArrayBase::<<S as MaybeUninitSubst<A>>::Output, <D as BroadcastShape<E>>::Output>::maybe_uninit(shape.into_pattern());
let output_view = out.raw_view_mut().cast::<A>();
Zip::from(&lhs).and(&rhs)
.and(output_view)
.collect_with_partial(|x, y| {
x.clone() $operator y.clone()
})
.release_ownership();
out.assume_init()
}
let rhs = rhs.broadcast(shape).unwrap();
Zip::from(&lhs).and(&rhs).map_collect_owned(clone_opf(A::$mth))
}
}
}
Expand All @@ -148,38 +133,24 @@ where
A: Clone + $trt<B, Output=B>,
B: Clone,
S: Data<Elem=A>,
S2: DataOwned<Elem=B> + DataMut + MaybeUninitSubst<B>,
<S2 as MaybeUninitSubst<B>>::Output: DataMut,
S2: DataOwned<Elem=B> + DataMut,
S2::MaybeUninit: DataMut,
D: Dimension,
E: Dimension + BroadcastShape<D>,
{
type Output = ArrayBase<S2, <E as BroadcastShape<D>>::Output>;
fn $mth(self, rhs: ArrayBase<S2, E>) -> Self::Output
where
{
let shape = rhs.dim.broadcast_shape(&self.dim).unwrap();
if shape.slice() == rhs.dim.slice() {
if self.ndim() == rhs.ndim() && self.shape() == rhs.shape() {
let mut out = rhs.into_dimensionality::<<E as BroadcastShape<D>>::Output>().unwrap();
out.zip_mut_with(self, |x, y| {
*x = y.clone() $operator x.clone();
});
out.zip_mut_with_same_shape(self, clone_iopf_rev(A::$mth));
out
} else {
let shape = rhs.dim.broadcast_shape(&self.dim).unwrap();
let lhs = self.broadcast(shape.clone()).unwrap();
let rhs = rhs.broadcast(shape.clone()).unwrap();
// SAFETY: Overwrite all the elements in the array after
// it is created via `raw_view_mut`.
unsafe {
let mut out =ArrayBase::<<S2 as MaybeUninitSubst<B>>::Output, <E as BroadcastShape<D>>::Output>::maybe_uninit(shape.into_pattern());
let output_view = out.raw_view_mut().cast::<B>();
Zip::from(&lhs).and(&rhs)
.and(output_view)
.collect_with_partial(|x, y| {
x.clone() $operator y.clone()
})
.release_ownership();
out.assume_init()
}
let rhs = rhs.broadcast(shape).unwrap();
Zip::from(&lhs).and(&rhs).map_collect_owned(clone_opf(A::$mth))
}
}
}
Expand Down Expand Up @@ -207,8 +178,7 @@ where
let shape = self.dim.broadcast_shape(&rhs.dim).unwrap();
let lhs = self.broadcast(shape.clone()).unwrap();
let rhs = rhs.broadcast(shape).unwrap();
let out = Zip::from(&lhs).and(&rhs).map_collect(|x, y| x.clone() $operator y.clone());
out
Zip::from(&lhs).and(&rhs).map_collect(clone_opf(A::$mth))
}
}

Expand Down Expand Up @@ -313,6 +283,18 @@ mod arithmetic_ops {
use num_complex::Complex;
use std::ops::*;

fn clone_opf<A: Clone, B: Clone, C>(f: impl Fn(A, B) -> C) -> impl FnMut(&A, &B) -> C {
move |x, y| f(x.clone(), y.clone())
}

fn clone_iopf<A: Clone, B: Clone>(f: impl Fn(A, B) -> A) -> impl FnMut(&mut A, &B) {
move |x, y| *x = f(x.clone(), y.clone())
}

fn clone_iopf_rev<A: Clone, B: Clone>(f: impl Fn(A, B) -> B) -> impl FnMut(&mut B, &A) {
move |x, y| *x = f(y.clone(), x.clone())
}

impl_binary_op!(Add, +, add, +=, "addition");
impl_binary_op!(Sub, -, sub, -=, "subtraction");
impl_binary_op!(Mul, *, mul, *=, "multiplication");
Expand Down
2 changes: 1 addition & 1 deletion src/lib.rs
Expand Up @@ -179,7 +179,7 @@ pub use crate::aliases::*;

pub use crate::data_traits::{
Data, DataMut, DataOwned, DataShared, RawData, RawDataClone, RawDataMut,
RawDataSubst, MaybeUninitSubst,
RawDataSubst,
};

mod free_functions;
Expand Down

0 comments on commit 23342b1

Please sign in to comment.