Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add methods .cell_view() and .into_cell_view() #877

Merged
merged 2 commits into from Jan 8, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/argument_traits.rs
@@ -1,6 +1,7 @@
use std::cell::Cell;
use std::mem::MaybeUninit;

use crate::math_cell::MathCell;

/// A producer element that can be assigned to once
pub trait AssignElem<T> {
Expand All @@ -22,6 +23,13 @@ impl<'a, T> AssignElem<T> for &'a Cell<T> {
}
}

/// Assignable element, simply `self.set(input)`.
impl<'a, T> AssignElem<T> for &'a MathCell<T> {
fn assign_elem(self, input: T) {
self.set(input);
}
}

/// Assignable element, the item in the MaybeUninit is overwritten (prior value, if any, is not
/// read or dropped).
impl<'a, T> AssignElem<T> for &'a mut MaybeUninit<T> {
Expand Down
15 changes: 15 additions & 0 deletions src/impl_methods.rs
Expand Up @@ -20,6 +20,7 @@ use crate::dimension::{
abs_index, axes_of, do_slice, merge_axes, size_of_shape_checked, stride_offset, Axes,
};
use crate::error::{self, ErrorKind, ShapeError};
use crate::math_cell::MathCell;
use crate::itertools::zip;
use crate::zip::Zip;

Expand Down Expand Up @@ -151,6 +152,20 @@ where
unsafe { ArrayViewMut::new(self.ptr, self.dim.clone(), self.strides.clone()) }
}

/// Return a shared view of the array with elements as if they were embedded in cells.
///
/// The cell view requires a mutable borrow of the array. Once borrowed the
/// cell view itself can be copied and accessed without exclusivity.
///
/// The view acts "as if" the elements are temporarily in cells, and elements
/// can be changed through shared references using the regular cell methods.
pub fn cell_view(&mut self) -> ArrayView<'_, MathCell<A>, D>
where
S: DataMut,
{
self.view_mut().into_cell_view()
}

/// Return an uniquely owned copy of the array.
///
/// If the input array is contiguous and its strides are positive, then the
Expand Down
16 changes: 16 additions & 0 deletions src/impl_views/conversions.rs
Expand Up @@ -13,6 +13,7 @@ use crate::imp_prelude::*;
use crate::{Baseiter, ElementsBase, ElementsBaseMut, Iter, IterMut};

use crate::iter::{self, AxisIter, AxisIterMut};
use crate::math_cell::MathCell;
use crate::IndexLonger;

/// Methods for read-only array views.
Expand Down Expand Up @@ -117,6 +118,21 @@ where
pub fn into_slice(self) -> Option<&'a mut [A]> {
self.try_into_slice().ok()
}

/// Return a shared view of the array with elements as if they were embedded in cells.
///
/// The cell view itself can be copied and accessed without exclusivity.
///
/// The view acts "as if" the elements are temporarily in cells, and elements
/// can be changed through shared references using the regular cell methods.
pub fn into_cell_view(self) -> ArrayView<'a, MathCell<A>, D> {
// safety: valid because
// A and MathCell<A> have the same representation
// &'a mut T is interchangeable with &'a Cell<T> -- see method Cell::from_mut in std
unsafe {
self.into_raw_view_mut().cast::<MathCell<A>>().deref_into_view()
}
}
}

/// Private array view methods
Expand Down
2 changes: 2 additions & 0 deletions src/lib.rs
Expand Up @@ -139,6 +139,7 @@ pub use crate::linalg_traits::LinalgScalar;

pub use crate::stacking::{concatenate, stack, stack_new_axis};

pub use crate::math_cell::MathCell;
pub use crate::impl_views::IndexLonger;
pub use crate::shape_builder::{Shape, ShapeBuilder, StrideShape};

Expand Down Expand Up @@ -180,6 +181,7 @@ mod layout;
mod linalg_traits;
mod linspace;
mod logspace;
mod math_cell;
mod numeric_util;
mod partial;
mod shape_builder;
Expand Down
102 changes: 102 additions & 0 deletions src/math_cell.rs
@@ -0,0 +1,102 @@

use std::cell::Cell;
use std::cmp::Ordering;
use std::fmt;

use std::ops::{Deref, DerefMut};

/// A transparent wrapper of [`Cell<T>`](std::cell::Cell) which is identical in every way, except
/// it will implement arithmetic operators as well.
///
/// The purpose of `MathCell` is to be used from [.cell_view()](crate::ArrayBase::cell_view).
/// The `MathCell` derefs to `Cell`, so all the cell's methods are available.
#[repr(transparent)]
#[derive(Default)]
pub struct MathCell<T>(Cell<T>);

impl<T> MathCell<T> {
/// Create a new cell with the given value
#[inline(always)]
pub const fn new(value: T) -> Self { MathCell(Cell::new(value)) }

/// Return the inner value
pub fn into_inner(self) -> T { Cell::into_inner(self.0) }

/// Swap value with another cell
pub fn swap(&self, other: &Self) {
Cell::swap(&self.0, &other.0)
}
}

impl<T> Deref for MathCell<T> {
type Target = Cell<T>;
#[inline(always)]
fn deref(&self) -> &Self::Target { &self.0 }
}

impl<T> DerefMut for MathCell<T> {
#[inline(always)]
fn deref_mut(&mut self) -> &mut Self::Target { &mut self.0 }
}

impl<T> Clone for MathCell<T>
where T: Copy
{
fn clone(&self) -> Self {
MathCell::new(self.get())
}
}

impl<T> PartialEq for MathCell<T>
where T: Copy + PartialEq
{
fn eq(&self, rhs: &Self) -> bool {
self.get() == rhs.get()
}
}

impl<T> Eq for MathCell<T>
where T: Copy + Eq
{ }

impl<T> PartialOrd for MathCell<T>
where T: Copy + PartialOrd
{
fn partial_cmp(&self, rhs: &Self) -> Option<Ordering> {
self.get().partial_cmp(&rhs.get())
}

fn lt(&self, rhs: &Self) -> bool { self.get().lt(&rhs.get()) }
fn le(&self, rhs: &Self) -> bool { self.get().le(&rhs.get()) }
fn gt(&self, rhs: &Self) -> bool { self.get().gt(&rhs.get()) }
fn ge(&self, rhs: &Self) -> bool { self.get().ge(&rhs.get()) }
}

impl<T> Ord for MathCell<T>
where T: Copy + Ord
{
fn cmp(&self, rhs: &Self) -> Ordering {
self.get().cmp(&rhs.get())
}
}

impl<T> fmt::Debug for MathCell<T>
where T: Copy + fmt::Debug
{
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
self.get().fmt(f)
}
}


#[cfg(test)]
mod tests {
use super::MathCell;

#[test]
fn test_basic() {
let c = &MathCell::new(0);
c.set(1);
assert_eq!(c.get(), 1);
}
}
16 changes: 16 additions & 0 deletions tests/views.rs
@@ -0,0 +1,16 @@
use ndarray::prelude::*;
use ndarray::Zip;

#[test]
fn cell_view() {
let mut a = Array::from_shape_fn((10, 5), |(i, j)| (i * j) as f32);
let answer = &a + 1.;

{
let cv1 = a.cell_view();
let cv2 = cv1;

Zip::from(cv1).and(cv2).apply(|a, b| a.set(b.get() + 1.));
}
assert_eq!(a, answer);
}