From 23342b1827dab1bfa44e8962999540b8627eb6cf Mon Sep 17 00:00:00 2001 From: SparrowLii Date: Mon, 1 Feb 2021 12:20:05 +0800 Subject: [PATCH] rebase and use map_collect_owned in impl_ops.rs --- src/data_traits.rs | 21 ++----------- src/impl_methods.rs | 2 +- src/impl_ops.rs | 76 +++++++++++++++++---------------------------- src/lib.rs | 2 +- 4 files changed, 33 insertions(+), 68 deletions(-) diff --git a/src/data_traits.rs b/src/data_traits.rs index 8df059c64..384647efe 100644 --- a/src/data_traits.rs +++ b/src/data_traits.rs @@ -10,7 +10,8 @@ use rawpointer::PointerExt; -use std::mem::{self, size_of};use std::mem::MaybeUninit; +use std::mem::{self, size_of}; +use std::mem::MaybeUninit; use std::ptr::NonNull; use alloc::sync::Arc; use alloc::vec::Vec; @@ -620,21 +621,3 @@ impl<'a, A: 'a, B: 'a> RawDataSubst for ViewRepr<&'a mut A> { } } -/// Array representation trait. -/// -/// The MaybeUninitSubst trait maps the MaybeUninit type of element, while -/// mapping the MaybeUninit type back to origin element type. -/// -/// For example, `MaybeUninitSubst` can map the type `OwnedRepr` to `OwnedRepr>`, -/// and use `Output as RawDataSubst` to map `OwnedRepr>` back to `OwnedRepr`. -pub trait MaybeUninitSubst: DataOwned { - type Output: DataOwned> + RawDataSubst>; -} - -impl MaybeUninitSubst for OwnedRepr { - type Output = OwnedRepr>; -} - -impl MaybeUninitSubst for OwnedArcRepr { - type Output = OwnedArcRepr>; -} diff --git a/src/impl_methods.rs b/src/impl_methods.rs index f8c0ee919..2e1904a91 100644 --- a/src/impl_methods.rs +++ b/src/impl_methods.rs @@ -1961,7 +1961,7 @@ where self.unordered_foreach_mut(move |elt| *elt = x.clone()); } - fn zip_mut_with_same_shape(&mut self, rhs: &ArrayBase, mut f: F) + pub(crate) fn zip_mut_with_same_shape(&mut self, rhs: &ArrayBase, mut f: F) where S: DataMut, S2: Data, diff --git a/src/impl_ops.rs b/src/impl_ops.rs index e0a6d7f88..68e34322d 100644 --- a/src/impl_ops.rs +++ b/src/impl_ops.rs @@ -7,7 +7,6 @@ // except according to those terms. use crate::dimension::BroadcastShape; -use crate::data_traits::MaybeUninitSubst; use crate::Zip; use num_complex::Complex; @@ -68,8 +67,8 @@ impl $trt> for ArrayBase where A: Clone + $trt, B: Clone, - S: DataOwned + DataMut + MaybeUninitSubst, - >::Output: DataMut, + S: DataOwned + DataMut, + S::MaybeUninit: DataMut, S2: Data, D: Dimension + BroadcastShape, E: Dimension, @@ -96,8 +95,8 @@ impl<'a, A, B, S, S2, D, E> $trt<&'a ArrayBase> for ArrayBase where A: Clone + $trt, B: Clone, - S: DataOwned + DataMut + MaybeUninitSubst, - >::Output: DataMut, + S: DataOwned + DataMut, + S::MaybeUninit: DataMut, S2: Data, D: Dimension + BroadcastShape, E: Dimension, @@ -105,29 +104,15 @@ where type Output = ArrayBase>::Output>; fn $mth(self, rhs: &ArrayBase) -> Self::Output { - let shape = self.dim.broadcast_shape(&rhs.dim).unwrap(); - if shape.slice() == self.dim.slice() { + if self.ndim() == rhs.ndim() && self.shape() == rhs.shape() { let mut out = self.into_dimensionality::<>::Output>().unwrap(); - out.zip_mut_with(rhs, |x, y| { - *x = x.clone() $operator y.clone(); - }); + out.zip_mut_with_same_shape(rhs, clone_iopf(A::$mth)); out } else { + let shape = self.dim.broadcast_shape(&rhs.dim).unwrap(); let lhs = self.broadcast(shape.clone()).unwrap(); - let rhs = rhs.broadcast(shape.clone()).unwrap(); - // SAFETY: Overwrite all the elements in the array after - // it is created via `raw_view_mut`. - unsafe { - let mut out =ArrayBase::<>::Output, >::Output>::maybe_uninit(shape.into_pattern()); - let output_view = out.raw_view_mut().cast::(); - Zip::from(&lhs).and(&rhs) - .and(output_view) - .collect_with_partial(|x, y| { - x.clone() $operator y.clone() - }) - .release_ownership(); - out.assume_init() - } + let rhs = rhs.broadcast(shape).unwrap(); + Zip::from(&lhs).and(&rhs).map_collect_owned(clone_opf(A::$mth)) } } } @@ -148,8 +133,8 @@ where A: Clone + $trt, B: Clone, S: Data, - S2: DataOwned + DataMut + MaybeUninitSubst, - >::Output: DataMut, + S2: DataOwned + DataMut, + S2::MaybeUninit: DataMut, D: Dimension, E: Dimension + BroadcastShape, { @@ -157,29 +142,15 @@ where fn $mth(self, rhs: ArrayBase) -> Self::Output where { - let shape = rhs.dim.broadcast_shape(&self.dim).unwrap(); - if shape.slice() == rhs.dim.slice() { + if self.ndim() == rhs.ndim() && self.shape() == rhs.shape() { let mut out = rhs.into_dimensionality::<>::Output>().unwrap(); - out.zip_mut_with(self, |x, y| { - *x = y.clone() $operator x.clone(); - }); + out.zip_mut_with_same_shape(self, clone_iopf_rev(A::$mth)); out } else { + let shape = rhs.dim.broadcast_shape(&self.dim).unwrap(); let lhs = self.broadcast(shape.clone()).unwrap(); - let rhs = rhs.broadcast(shape.clone()).unwrap(); - // SAFETY: Overwrite all the elements in the array after - // it is created via `raw_view_mut`. - unsafe { - let mut out =ArrayBase::<>::Output, >::Output>::maybe_uninit(shape.into_pattern()); - let output_view = out.raw_view_mut().cast::(); - Zip::from(&lhs).and(&rhs) - .and(output_view) - .collect_with_partial(|x, y| { - x.clone() $operator y.clone() - }) - .release_ownership(); - out.assume_init() - } + let rhs = rhs.broadcast(shape).unwrap(); + Zip::from(&lhs).and(&rhs).map_collect_owned(clone_opf(A::$mth)) } } } @@ -207,8 +178,7 @@ where let shape = self.dim.broadcast_shape(&rhs.dim).unwrap(); let lhs = self.broadcast(shape.clone()).unwrap(); let rhs = rhs.broadcast(shape).unwrap(); - let out = Zip::from(&lhs).and(&rhs).map_collect(|x, y| x.clone() $operator y.clone()); - out + Zip::from(&lhs).and(&rhs).map_collect(clone_opf(A::$mth)) } } @@ -313,6 +283,18 @@ mod arithmetic_ops { use num_complex::Complex; use std::ops::*; + fn clone_opf(f: impl Fn(A, B) -> C) -> impl FnMut(&A, &B) -> C { + move |x, y| f(x.clone(), y.clone()) + } + + fn clone_iopf(f: impl Fn(A, B) -> A) -> impl FnMut(&mut A, &B) { + move |x, y| *x = f(x.clone(), y.clone()) + } + + fn clone_iopf_rev(f: impl Fn(A, B) -> B) -> impl FnMut(&mut B, &A) { + move |x, y| *x = f(y.clone(), x.clone()) + } + impl_binary_op!(Add, +, add, +=, "addition"); impl_binary_op!(Sub, -, sub, -=, "subtraction"); impl_binary_op!(Mul, *, mul, *=, "multiplication"); diff --git a/src/lib.rs b/src/lib.rs index 6b2ea1dbc..5fe893da1 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -179,7 +179,7 @@ pub use crate::aliases::*; pub use crate::data_traits::{ Data, DataMut, DataOwned, DataShared, RawData, RawDataClone, RawDataMut, - RawDataSubst, MaybeUninitSubst, + RawDataSubst, }; mod free_functions;