diff --git a/src/impl_methods.rs b/src/impl_methods.rs index 59f27d55c..df48ea1cf 100644 --- a/src/impl_methods.rs +++ b/src/impl_methods.rs @@ -6,7 +6,6 @@ // option. This file may not be copied, modified, or distributed // except according to those terms. -use std::cell::Cell; use std::ptr as std_ptr; use std::slice; @@ -21,6 +20,7 @@ use crate::dimension::{ abs_index, axes_of, do_slice, merge_axes, size_of_shape_checked, stride_offset, Axes, }; use crate::error::{self, ErrorKind, ShapeError}; +use crate::math_cell::MathCell; use crate::itertools::zip; use crate::zip::Zip; @@ -159,7 +159,7 @@ where /// /// The view acts "as if" the elements are temporarily in cells, and elements /// can be changed through shared references using the regular cell methods. - pub fn cell_view(&mut self) -> ArrayView<'_, Cell, D> + pub fn cell_view(&mut self) -> ArrayView<'_, MathCell, D> where S: DataMut, { diff --git a/src/impl_views/conversions.rs b/src/impl_views/conversions.rs index 4265f7616..8c4230108 100644 --- a/src/impl_views/conversions.rs +++ b/src/impl_views/conversions.rs @@ -6,7 +6,6 @@ // option. This file may not be copied, modified, or distributed // except according to those terms. -use std::cell::Cell; use std::slice; use crate::imp_prelude::*; @@ -14,6 +13,7 @@ use crate::imp_prelude::*; use crate::{Baseiter, ElementsBase, ElementsBaseMut, Iter, IterMut}; use crate::iter::{self, AxisIter, AxisIterMut}; +use crate::math_cell::MathCell; use crate::IndexLonger; /// Methods for read-only array views. @@ -125,12 +125,12 @@ where /// /// The view acts "as if" the elements are temporarily in cells, and elements /// can be changed through shared references using the regular cell methods. - pub fn into_cell_view(self) -> ArrayView<'a, Cell, D> { + pub fn into_cell_view(self) -> ArrayView<'a, MathCell, D> { // safety: valid because - // A and Cell have the same representation - // &'a mut T is interchangeable with &'a Cell -- see method Cell::from_mut + // A and MathCell have the same representation + // &'a mut T is interchangeable with &'a Cell -- see method Cell::from_mut in std unsafe { - self.into_raw_view_mut().cast::>().deref_into_view() + self.into_raw_view_mut().cast::>().deref_into_view() } } } diff --git a/src/lib.rs b/src/lib.rs index c0a3b2e85..89862c78c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -139,6 +139,7 @@ pub use crate::linalg_traits::LinalgScalar; pub use crate::stacking::{concatenate, stack, stack_new_axis}; +pub use crate::math_cell::MathCell; pub use crate::impl_views::IndexLonger; pub use crate::shape_builder::{Shape, ShapeBuilder, StrideShape}; @@ -180,6 +181,7 @@ mod layout; mod linalg_traits; mod linspace; mod logspace; +mod math_cell; mod numeric_util; mod partial; mod shape_builder; diff --git a/src/math_cell.rs b/src/math_cell.rs new file mode 100644 index 000000000..f0f8da40b --- /dev/null +++ b/src/math_cell.rs @@ -0,0 +1,102 @@ + +use std::cell::Cell; +use std::cmp::Ordering; +use std::fmt; + +use std::ops::{Deref, DerefMut}; + +/// A transparent wrapper of [`Cell`](std::cell::Cell) which is identical in every way, except +/// it will implement arithmetic operators as well. +/// +/// The purpose of `MathCell` is to be used from [.cell_view()](crate::ArrayBase::cell_view). +/// The `MathCell` derefs to `Cell`, so all the cell's methods are available. +#[repr(transparent)] +#[derive(Default)] +pub struct MathCell(Cell); + +impl MathCell { + /// Create a new cell with the given value + #[inline(always)] + pub const fn new(value: T) -> Self { MathCell(Cell::new(value)) } + + /// Return the inner value + pub fn into_inner(self) -> T { Cell::into_inner(self.0) } + + /// Swap value with another cell + pub fn swap(&self, other: &Self) { + Cell::swap(&self.0, &other.0) + } +} + +impl Deref for MathCell { + type Target = Cell; + #[inline(always)] + fn deref(&self) -> &Self::Target { &self.0 } +} + +impl DerefMut for MathCell { + #[inline(always)] + fn deref_mut(&mut self) -> &mut Self::Target { &mut self.0 } +} + +impl Clone for MathCell + where T: Copy +{ + fn clone(&self) -> Self { + MathCell::new(self.get()) + } +} + +impl PartialEq for MathCell + where T: Copy + PartialEq +{ + fn eq(&self, rhs: &Self) -> bool { + self.get() == rhs.get() + } +} + +impl Eq for MathCell + where T: Copy + Eq +{ } + +impl PartialOrd for MathCell + where T: Copy + PartialOrd +{ + fn partial_cmp(&self, rhs: &Self) -> Option { + self.get().partial_cmp(&rhs.get()) + } + + fn lt(&self, rhs: &Self) -> bool { self.get().lt(&rhs.get()) } + fn le(&self, rhs: &Self) -> bool { self.get().le(&rhs.get()) } + fn gt(&self, rhs: &Self) -> bool { self.get().gt(&rhs.get()) } + fn ge(&self, rhs: &Self) -> bool { self.get().ge(&rhs.get()) } +} + +impl Ord for MathCell + where T: Copy + Ord +{ + fn cmp(&self, rhs: &Self) -> Ordering { + self.get().cmp(&rhs.get()) + } +} + +impl fmt::Debug for MathCell + where T: Copy + fmt::Debug +{ + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + self.get().fmt(f) + } +} + + +#[cfg(test)] +mod tests { + use super::MathCell; + + #[test] + fn test_basic() { + let c = &MathCell::new(0); + c.set(1); + assert_eq!(c.get(), 1); + } +}