diff --git a/src/argument_traits.rs b/src/argument_traits.rs index a93e33a12..82d4869a9 100644 --- a/src/argument_traits.rs +++ b/src/argument_traits.rs @@ -1,6 +1,7 @@ use std::cell::Cell; use std::mem::MaybeUninit; +use crate::math_cell::MathCell; /// A producer element that can be assigned to once pub trait AssignElem { @@ -22,6 +23,13 @@ impl<'a, T> AssignElem for &'a Cell { } } +/// Assignable element, simply `self.set(input)`. +impl<'a, T> AssignElem for &'a MathCell { + fn assign_elem(self, input: T) { + self.set(input); + } +} + /// Assignable element, the item in the MaybeUninit is overwritten (prior value, if any, is not /// read or dropped). impl<'a, T> AssignElem for &'a mut MaybeUninit { diff --git a/src/impl_methods.rs b/src/impl_methods.rs index cccc195a1..df48ea1cf 100644 --- a/src/impl_methods.rs +++ b/src/impl_methods.rs @@ -20,6 +20,7 @@ use crate::dimension::{ abs_index, axes_of, do_slice, merge_axes, size_of_shape_checked, stride_offset, Axes, }; use crate::error::{self, ErrorKind, ShapeError}; +use crate::math_cell::MathCell; use crate::itertools::zip; use crate::zip::Zip; @@ -151,6 +152,20 @@ where unsafe { ArrayViewMut::new(self.ptr, self.dim.clone(), self.strides.clone()) } } + /// Return a shared view of the array with elements as if they were embedded in cells. + /// + /// The cell view requires a mutable borrow of the array. Once borrowed the + /// cell view itself can be copied and accessed without exclusivity. + /// + /// The view acts "as if" the elements are temporarily in cells, and elements + /// can be changed through shared references using the regular cell methods. + pub fn cell_view(&mut self) -> ArrayView<'_, MathCell, D> + where + S: DataMut, + { + self.view_mut().into_cell_view() + } + /// Return an uniquely owned copy of the array. /// /// If the input array is contiguous and its strides are positive, then the diff --git a/src/impl_views/conversions.rs b/src/impl_views/conversions.rs index 863d26ddb..8c4230108 100644 --- a/src/impl_views/conversions.rs +++ b/src/impl_views/conversions.rs @@ -13,6 +13,7 @@ use crate::imp_prelude::*; use crate::{Baseiter, ElementsBase, ElementsBaseMut, Iter, IterMut}; use crate::iter::{self, AxisIter, AxisIterMut}; +use crate::math_cell::MathCell; use crate::IndexLonger; /// Methods for read-only array views. @@ -117,6 +118,21 @@ where pub fn into_slice(self) -> Option<&'a mut [A]> { self.try_into_slice().ok() } + + /// Return a shared view of the array with elements as if they were embedded in cells. + /// + /// The cell view itself can be copied and accessed without exclusivity. + /// + /// The view acts "as if" the elements are temporarily in cells, and elements + /// can be changed through shared references using the regular cell methods. + pub fn into_cell_view(self) -> ArrayView<'a, MathCell, D> { + // safety: valid because + // A and MathCell have the same representation + // &'a mut T is interchangeable with &'a Cell -- see method Cell::from_mut in std + unsafe { + self.into_raw_view_mut().cast::>().deref_into_view() + } + } } /// Private array view methods diff --git a/src/lib.rs b/src/lib.rs index c0a3b2e85..89862c78c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -139,6 +139,7 @@ pub use crate::linalg_traits::LinalgScalar; pub use crate::stacking::{concatenate, stack, stack_new_axis}; +pub use crate::math_cell::MathCell; pub use crate::impl_views::IndexLonger; pub use crate::shape_builder::{Shape, ShapeBuilder, StrideShape}; @@ -180,6 +181,7 @@ mod layout; mod linalg_traits; mod linspace; mod logspace; +mod math_cell; mod numeric_util; mod partial; mod shape_builder; diff --git a/src/math_cell.rs b/src/math_cell.rs new file mode 100644 index 000000000..f0f8da40b --- /dev/null +++ b/src/math_cell.rs @@ -0,0 +1,102 @@ + +use std::cell::Cell; +use std::cmp::Ordering; +use std::fmt; + +use std::ops::{Deref, DerefMut}; + +/// A transparent wrapper of [`Cell`](std::cell::Cell) which is identical in every way, except +/// it will implement arithmetic operators as well. +/// +/// The purpose of `MathCell` is to be used from [.cell_view()](crate::ArrayBase::cell_view). +/// The `MathCell` derefs to `Cell`, so all the cell's methods are available. +#[repr(transparent)] +#[derive(Default)] +pub struct MathCell(Cell); + +impl MathCell { + /// Create a new cell with the given value + #[inline(always)] + pub const fn new(value: T) -> Self { MathCell(Cell::new(value)) } + + /// Return the inner value + pub fn into_inner(self) -> T { Cell::into_inner(self.0) } + + /// Swap value with another cell + pub fn swap(&self, other: &Self) { + Cell::swap(&self.0, &other.0) + } +} + +impl Deref for MathCell { + type Target = Cell; + #[inline(always)] + fn deref(&self) -> &Self::Target { &self.0 } +} + +impl DerefMut for MathCell { + #[inline(always)] + fn deref_mut(&mut self) -> &mut Self::Target { &mut self.0 } +} + +impl Clone for MathCell + where T: Copy +{ + fn clone(&self) -> Self { + MathCell::new(self.get()) + } +} + +impl PartialEq for MathCell + where T: Copy + PartialEq +{ + fn eq(&self, rhs: &Self) -> bool { + self.get() == rhs.get() + } +} + +impl Eq for MathCell + where T: Copy + Eq +{ } + +impl PartialOrd for MathCell + where T: Copy + PartialOrd +{ + fn partial_cmp(&self, rhs: &Self) -> Option { + self.get().partial_cmp(&rhs.get()) + } + + fn lt(&self, rhs: &Self) -> bool { self.get().lt(&rhs.get()) } + fn le(&self, rhs: &Self) -> bool { self.get().le(&rhs.get()) } + fn gt(&self, rhs: &Self) -> bool { self.get().gt(&rhs.get()) } + fn ge(&self, rhs: &Self) -> bool { self.get().ge(&rhs.get()) } +} + +impl Ord for MathCell + where T: Copy + Ord +{ + fn cmp(&self, rhs: &Self) -> Ordering { + self.get().cmp(&rhs.get()) + } +} + +impl fmt::Debug for MathCell + where T: Copy + fmt::Debug +{ + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + self.get().fmt(f) + } +} + + +#[cfg(test)] +mod tests { + use super::MathCell; + + #[test] + fn test_basic() { + let c = &MathCell::new(0); + c.set(1); + assert_eq!(c.get(), 1); + } +} diff --git a/tests/views.rs b/tests/views.rs new file mode 100644 index 000000000..216e55402 --- /dev/null +++ b/tests/views.rs @@ -0,0 +1,16 @@ +use ndarray::prelude::*; +use ndarray::Zip; + +#[test] +fn cell_view() { + let mut a = Array::from_shape_fn((10, 5), |(i, j)| (i * j) as f32); + let answer = &a + 1.; + + { + let cv1 = a.cell_view(); + let cv2 = cv1; + + Zip::from(cv1).and(cv2).apply(|a, b| a.set(b.get() + 1.)); + } + assert_eq!(a, answer); +}