Skip to content

Commit

Permalink
Merge pull request #877 from rust-ndarray/cell-view
Browse files Browse the repository at this point in the history
Add methods .cell_view() and .into_cell_view()
  • Loading branch information
bluss committed Jan 8, 2021
2 parents 4d9641d + 894d981 commit 1f53dce
Show file tree
Hide file tree
Showing 6 changed files with 159 additions and 0 deletions.
8 changes: 8 additions & 0 deletions src/argument_traits.rs
@@ -1,6 +1,7 @@
use std::cell::Cell;
use std::mem::MaybeUninit;

use crate::math_cell::MathCell;

/// A producer element that can be assigned to once
pub trait AssignElem<T> {
Expand All @@ -22,6 +23,13 @@ impl<'a, T> AssignElem<T> for &'a Cell<T> {
}
}

/// Assignable element, simply `self.set(input)`.
impl<'a, T> AssignElem<T> for &'a MathCell<T> {
fn assign_elem(self, input: T) {
self.set(input);
}
}

/// Assignable element, the item in the MaybeUninit is overwritten (prior value, if any, is not
/// read or dropped).
impl<'a, T> AssignElem<T> for &'a mut MaybeUninit<T> {
Expand Down
15 changes: 15 additions & 0 deletions src/impl_methods.rs
Expand Up @@ -20,6 +20,7 @@ use crate::dimension::{
abs_index, axes_of, do_slice, merge_axes, size_of_shape_checked, stride_offset, Axes,
};
use crate::error::{self, ErrorKind, ShapeError};
use crate::math_cell::MathCell;
use crate::itertools::zip;
use crate::zip::Zip;

Expand Down Expand Up @@ -151,6 +152,20 @@ where
unsafe { ArrayViewMut::new(self.ptr, self.dim.clone(), self.strides.clone()) }
}

/// Return a shared view of the array with elements as if they were embedded in cells.
///
/// The cell view requires a mutable borrow of the array. Once borrowed the
/// cell view itself can be copied and accessed without exclusivity.
///
/// The view acts "as if" the elements are temporarily in cells, and elements
/// can be changed through shared references using the regular cell methods.
pub fn cell_view(&mut self) -> ArrayView<'_, MathCell<A>, D>
where
S: DataMut,
{
self.view_mut().into_cell_view()
}

/// Return an uniquely owned copy of the array.
///
/// If the input array is contiguous and its strides are positive, then the
Expand Down
16 changes: 16 additions & 0 deletions src/impl_views/conversions.rs
Expand Up @@ -13,6 +13,7 @@ use crate::imp_prelude::*;
use crate::{Baseiter, ElementsBase, ElementsBaseMut, Iter, IterMut};

use crate::iter::{self, AxisIter, AxisIterMut};
use crate::math_cell::MathCell;
use crate::IndexLonger;

/// Methods for read-only array views.
Expand Down Expand Up @@ -117,6 +118,21 @@ where
pub fn into_slice(self) -> Option<&'a mut [A]> {
self.try_into_slice().ok()
}

/// Return a shared view of the array with elements as if they were embedded in cells.
///
/// The cell view itself can be copied and accessed without exclusivity.
///
/// The view acts "as if" the elements are temporarily in cells, and elements
/// can be changed through shared references using the regular cell methods.
pub fn into_cell_view(self) -> ArrayView<'a, MathCell<A>, D> {
// safety: valid because
// A and MathCell<A> have the same representation
// &'a mut T is interchangeable with &'a Cell<T> -- see method Cell::from_mut in std
unsafe {
self.into_raw_view_mut().cast::<MathCell<A>>().deref_into_view()
}
}
}

/// Private array view methods
Expand Down
2 changes: 2 additions & 0 deletions src/lib.rs
Expand Up @@ -139,6 +139,7 @@ pub use crate::linalg_traits::LinalgScalar;

pub use crate::stacking::{concatenate, stack, stack_new_axis};

pub use crate::math_cell::MathCell;
pub use crate::impl_views::IndexLonger;
pub use crate::shape_builder::{Shape, ShapeBuilder, StrideShape};

Expand Down Expand Up @@ -180,6 +181,7 @@ mod layout;
mod linalg_traits;
mod linspace;
mod logspace;
mod math_cell;
mod numeric_util;
mod partial;
mod shape_builder;
Expand Down
102 changes: 102 additions & 0 deletions src/math_cell.rs
@@ -0,0 +1,102 @@

use std::cell::Cell;
use std::cmp::Ordering;
use std::fmt;

use std::ops::{Deref, DerefMut};

/// A transparent wrapper of [`Cell<T>`](std::cell::Cell) which is identical in every way, except
/// it will implement arithmetic operators as well.
///
/// The purpose of `MathCell` is to be used from [.cell_view()](crate::ArrayBase::cell_view).
/// The `MathCell` derefs to `Cell`, so all the cell's methods are available.
#[repr(transparent)]
#[derive(Default)]
pub struct MathCell<T>(Cell<T>);

impl<T> MathCell<T> {
/// Create a new cell with the given value
#[inline(always)]
pub const fn new(value: T) -> Self { MathCell(Cell::new(value)) }

/// Return the inner value
pub fn into_inner(self) -> T { Cell::into_inner(self.0) }

/// Swap value with another cell
pub fn swap(&self, other: &Self) {
Cell::swap(&self.0, &other.0)
}
}

impl<T> Deref for MathCell<T> {
type Target = Cell<T>;
#[inline(always)]
fn deref(&self) -> &Self::Target { &self.0 }
}

impl<T> DerefMut for MathCell<T> {
#[inline(always)]
fn deref_mut(&mut self) -> &mut Self::Target { &mut self.0 }
}

impl<T> Clone for MathCell<T>
where T: Copy
{
fn clone(&self) -> Self {
MathCell::new(self.get())
}
}

impl<T> PartialEq for MathCell<T>
where T: Copy + PartialEq
{
fn eq(&self, rhs: &Self) -> bool {
self.get() == rhs.get()
}
}

impl<T> Eq for MathCell<T>
where T: Copy + Eq
{ }

impl<T> PartialOrd for MathCell<T>
where T: Copy + PartialOrd
{
fn partial_cmp(&self, rhs: &Self) -> Option<Ordering> {
self.get().partial_cmp(&rhs.get())
}

fn lt(&self, rhs: &Self) -> bool { self.get().lt(&rhs.get()) }
fn le(&self, rhs: &Self) -> bool { self.get().le(&rhs.get()) }
fn gt(&self, rhs: &Self) -> bool { self.get().gt(&rhs.get()) }
fn ge(&self, rhs: &Self) -> bool { self.get().ge(&rhs.get()) }
}

impl<T> Ord for MathCell<T>
where T: Copy + Ord
{
fn cmp(&self, rhs: &Self) -> Ordering {
self.get().cmp(&rhs.get())
}
}

impl<T> fmt::Debug for MathCell<T>
where T: Copy + fmt::Debug
{
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
self.get().fmt(f)
}
}


#[cfg(test)]
mod tests {
use super::MathCell;

#[test]
fn test_basic() {
let c = &MathCell::new(0);
c.set(1);
assert_eq!(c.get(), 1);
}
}
16 changes: 16 additions & 0 deletions tests/views.rs
@@ -0,0 +1,16 @@
use ndarray::prelude::*;
use ndarray::Zip;

#[test]
fn cell_view() {
let mut a = Array::from_shape_fn((10, 5), |(i, j)| (i * j) as f32);
let answer = &a + 1.;

{
let cv1 = a.cell_view();
let cv2 = cv1;

Zip::from(cv1).and(cv2).apply(|a, b| a.set(b.get() + 1.));
}
assert_eq!(a, answer);
}

0 comments on commit 1f53dce

Please sign in to comment.