Skip to content

Commit

Permalink
Return expressions instead of new arrays for operators, depends on ru…
Browse files Browse the repository at this point in the history
  • Loading branch information
fre-hu committed Dec 2, 2023
1 parent 3df527b commit 0f38f36
Show file tree
Hide file tree
Showing 8 changed files with 81 additions and 79 deletions.
28 changes: 15 additions & 13 deletions src/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use crate::alloc::Global;
use crate::buffer::{Buffer, BufferMut, SizedBuffer, SizedBufferMut};
use crate::buffer::{GridBuffer, SpanBuffer, ViewBuffer, ViewBufferMut};
use crate::dim::Dim;
use crate::expr::{Expr, ExprMut};
use crate::expr::{Expr, ExprMut, Producer};
use crate::expression::Expression;
use crate::index::SpanIndex;
use crate::iter::Iter;
Expand Down Expand Up @@ -44,34 +44,36 @@ impl<B: BufferMut + ?Sized> Array<B> {
}

impl<'a, T, B: Buffer + ?Sized> Apply<T> for &'a Array<B> {
type Output = GridArray<T, B::Dim>;
type ZippedWith<I: IntoExpression> = GridArray<T, <B::Dim as Dim>::Max<I::Dim>>;
type Output<F: FnMut(Self::Item) -> T> = Expression<impl Producer<Item = T, Dim = Self::Dim>>;
type ZippedWith<I: IntoExpression, F: FnMut(Self::Item, I::Item) -> T> =
Expression<impl Producer<Item = T, Dim = <B::Dim as Dim>::Max<I::Dim>>>;

fn apply<F: FnMut(Self::Item) -> T>(self, f: F) -> Self::Output {
self.as_span().expr().map(f).eval()
fn apply<F: FnMut(Self::Item) -> T>(self, f: F) -> Self::Output<F> {
self.as_span().expr().map(f)
}

fn zip_with<I: IntoExpression, F>(self, expr: I, mut f: F) -> Self::ZippedWith<I>
fn zip_with<I: IntoExpression, F>(self, expr: I, mut f: F) -> Self::ZippedWith<I, F>
where
F: FnMut(Self::Item, I::Item) -> T,
{
self.as_span().expr().zip(expr).map(|(x, y)| f(x, y)).eval()
self.as_span().expr().zip(expr).map(move |(x, y)| f(x, y))
}
}

impl<'a, T, B: BufferMut + ?Sized> Apply<T> for &'a mut Array<B> {
type Output = GridArray<T, B::Dim>;
type ZippedWith<I: IntoExpression> = GridArray<T, <B::Dim as Dim>::Max<I::Dim>>;
type Output<F: FnMut(Self::Item) -> T> = Expression<impl Producer<Item = T, Dim = Self::Dim>>;
type ZippedWith<I: IntoExpression, F: FnMut(Self::Item, I::Item) -> T> =
Expression<impl Producer<Item = T, Dim = <B::Dim as Dim>::Max<I::Dim>>>;

fn apply<F: FnMut(Self::Item) -> T>(self, f: F) -> Self::Output {
self.as_mut_span().expr_mut().map(f).eval()
fn apply<F: FnMut(Self::Item) -> T>(self, f: F) -> Self::Output<F> {
self.as_mut_span().expr_mut().map(f)
}

fn zip_with<I: IntoExpression, F>(self, expr: I, mut f: F) -> Self::ZippedWith<I>
fn zip_with<I: IntoExpression, F>(self, expr: I, mut f: F) -> Self::ZippedWith<I, F>
where
F: FnMut(Self::Item, I::Item) -> T,
{
self.as_mut_span().expr_mut().zip(expr).map(|(x, y)| f(x, y)).eval()
self.as_mut_span().expr_mut().zip(expr).map(move |(x, y)| f(x, y))
}
}

Expand Down
13 changes: 7 additions & 6 deletions src/expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -158,18 +158,19 @@ impl<P: Producer> Expression<P> {
}

impl<T, P: Producer> Apply<T> for Expression<P> {
type Output = GridArray<T, P::Dim>;
type ZippedWith<I: IntoExpression> = GridArray<T, <P::Dim as Dim>::Max<I::Dim>>;
type Output<F: FnMut(Self::Item) -> T> = Expression<impl Producer<Item = T, Dim = P::Dim>>;
type ZippedWith<I: IntoExpression, F: FnMut(Self::Item, I::Item) -> T> =
Expression<impl Producer<Item = T, Dim = <P::Dim as Dim>::Max<I::Dim>>>;

fn apply<F: FnMut(Self::Item) -> T>(self, f: F) -> Self::Output {
self.map(f).eval()
fn apply<F: FnMut(Self::Item) -> T>(self, f: F) -> Self::Output<F> {
self.map(f)
}

fn zip_with<I: IntoExpression, F>(self, expr: I, mut f: F) -> Self::ZippedWith<I>
fn zip_with<I: IntoExpression, F>(self, expr: I, mut f: F) -> Self::ZippedWith<I, F>
where
F: FnMut(Self::Item, I::Item) -> T,
{
self.zip(expr).map(|(x, y)| f(x, y)).eval()
self.zip(expr).map(move |(x, y)| f(x, y))
}
}

Expand Down
11 changes: 4 additions & 7 deletions src/grid.rs
Original file line number Diff line number Diff line change
Expand Up @@ -461,17 +461,14 @@ impl<T, D: Dim> GridArray<T, D> {
}

impl<T, D: Dim, A: Allocator> Apply<T> for GridArray<T, D, A> {
type Output = GridArray<T, D, A>;
type ZippedWith<I: IntoExpression> = GridArray<T, D, A>;
type Output<F: FnMut(T) -> T> = Self;
type ZippedWith<I: IntoExpression, F: FnMut(Self::Item, I::Item) -> T> = Self;

fn apply<F: FnMut(T) -> T>(self, mut f: F) -> GridArray<T, D, A> {
fn apply<F: FnMut(T) -> T>(self, mut f: F) -> Self {
self.zip_with(expr::fill(()), |x, ()| f(x))
}

fn zip_with<I: IntoExpression, F>(self, expr: I, f: F) -> GridArray<T, D, A>
where
F: FnMut(T, I::Item) -> T,
{
fn zip_with<I: IntoExpression, F: FnMut(T, I::Item) -> T>(self, expr: I, f: F) -> Self {
self.zip_with(expr, f)
}
}
Expand Down
13 changes: 6 additions & 7 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,10 +118,11 @@
//! Arithmetic, logical, negation, comparison and compound assignment operators
//! are supported for arrays and expressions.
//!
//! If at least one of the inputs is an array that is passed by value, the input
//! buffer is reused for the result. Otherwise, if all input parameters are array
//! references or expressions, a new array is created for the result. In the
//! latter case, the result may have a different element type.
//! If at least one of the inputs is an array that is passed by value, the
//! operation is evaluated directly and the input array is reused for the result.
//! Otherwise, if all input parameters are array references or expressions, an
//! expression is returned. In the latter case, the result may have a different
//! element type.
//!
//! For comparison operators, the parameters must always be arrays that are passed
//! by reference. For compound assignment operators, the first parameter is always
Expand All @@ -131,9 +132,6 @@
//! an `Expression<Fill<T>>` expression. If a type does not implement the `Copy`
//! trait, the parameter must be passed by reference.
//!
//! Note that for complex calculations, it can be more efficient to use expressions
//! and element-wise operations to reduce memory accesses and allocations.
//!
//! ## Example
//!
//! This example implements matrix multiplication and addition `C = A * B + C`.
Expand Down Expand Up @@ -170,6 +168,7 @@
#![cfg_attr(feature = "nightly", feature(hasher_prefixfree_extras))]
#![cfg_attr(feature = "nightly", feature(int_roundings))]
#![cfg_attr(feature = "nightly", feature(slice_range))]
#![feature(impl_trait_in_assoc_type)]
#![warn(missing_docs)]
#![warn(unreachable_pub)]
#![warn(unused_results)]
Expand Down
16 changes: 8 additions & 8 deletions src/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,11 +96,11 @@ where

macro_rules! impl_binary_op {
($trt:tt, $fn:tt) => {
impl<T, B: Buffer + ?Sized, I: Apply<T>> $trt<I> for &Array<B>
impl<'a, T, B: Buffer + ?Sized, I: Apply<T>> $trt<I> for &'a Array<B>
where
for<'a> &'a B::Item: $trt<I::Item, Output = T>,
for<'b> &'b B::Item: $trt<I::Item, Output = T>,
{
type Output = I::ZippedWith<Self>;
type Output = I::ZippedWith<Self, impl FnMut(I::Item, &'a B::Item) -> T>;

fn $fn(self, rhs: I) -> Self::Output {
rhs.zip_with(self, |x, y| y.$fn(x))
Expand All @@ -111,7 +111,7 @@ macro_rules! impl_binary_op {
where
P::Item: $trt<I::Item, Output = T>,
{
type Output = I::ZippedWith<Self>;
type Output = I::ZippedWith<Self, impl FnMut(I::Item, P::Item) -> T>;

fn $fn(self, rhs: I) -> Self::Output {
rhs.zip_with(self, |x, y| y.$fn(x))
Expand All @@ -133,7 +133,7 @@ macro_rules! impl_binary_op {
where
for<'b> &'b T: $trt<I::Item, Output = U>,
{
type Output = I::ZippedWith<Self>;
type Output = I::ZippedWith<Self, impl FnMut(I::Item, &'a T) -> U>;

fn $fn(self, rhs: I) -> Self::Output {
rhs.zip_with(self, |x, y| y.$fn(x))
Expand Down Expand Up @@ -183,7 +183,7 @@ macro_rules! impl_unary_op {
where
for<'a> &'a B::Item: $trt<Output = T>,
{
type Output = GridArray<T, B::Dim>;
type Output = Expression<impl Producer<Item = T, Dim = B::Dim>>;

fn $fn(self) -> Self::Output {
self.apply(|x| x.$fn())
Expand All @@ -194,7 +194,7 @@ macro_rules! impl_unary_op {
where
P::Item: $trt<Output = T>,
{
type Output = GridArray<T, P::Dim>;
type Output = Expression<impl Producer<Item = T, Dim = P::Dim>>;

fn $fn(self) -> Self::Output {
self.apply(|x| x.$fn())
Expand All @@ -216,7 +216,7 @@ macro_rules! impl_unary_op {
where
for<'b> &'b T: $trt<Output = U>,
{
type Output = GridArray<U, D>;
type Output = Expression<impl Producer<Item = U, Dim = D>>;

fn $fn(self) -> Self::Output {
self.apply(|x| x.$fn())
Expand Down
14 changes: 7 additions & 7 deletions src/traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,19 @@ use crate::dim::Dim;
use crate::expr::Producer;
use crate::expression::Expression;

/// Trait for applying a closure and returning a new or an existing array.
/// Trait for applying a closure and returning an existing array or an expression.
pub trait Apply<T>: IntoExpression {
/// The resulting type after applying a closure.
type Output: IntoExpression<Item = T, Dim = Self::Dim>;
type Output<F: FnMut(Self::Item) -> T>: IntoExpression<Item = T, Dim = Self::Dim>;

/// The resulting type after zipping elements and applying a closure.
type ZippedWith<I: IntoExpression>: IntoExpression<Item = T>;
type ZippedWith<I: IntoExpression, F: FnMut(Self::Item, I::Item) -> T>: IntoExpression<Item = T>;

/// Returns a new or an existing array with the given closure applied to each element.
fn apply<F: FnMut(Self::Item) -> T>(self, f: F) -> Self::Output;
/// Returns the array or an expression with the given closure applied to each element.
fn apply<F: FnMut(Self::Item) -> T>(self, f: F) -> Self::Output<F>;

/// Returns a new or an existing array with the given closure applied to zipped element pairs.
fn zip_with<I: IntoExpression, F>(self, expr: I, f: F) -> Self::ZippedWith<I>
/// Returns the array or an expression with the given closure applied to zipped element pairs.
fn zip_with<I: IntoExpression, F>(self, expr: I, f: F) -> Self::ZippedWith<I, F>
where
F: FnMut(Self::Item, I::Item) -> T;
}
Expand Down
30 changes: 16 additions & 14 deletions src/view.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
use std::slice;

use crate::array::{GridArray, ViewArray, ViewArrayMut};
use crate::array::{ViewArray, ViewArrayMut};
use crate::buffer::{ViewBuffer, ViewBufferMut};
use crate::dim::{Const, Dim, Shape};
use crate::expr::{Expr, ExprMut};
use crate::expr::{Expr, ExprMut, Producer};
use crate::expression::Expression;
use crate::index::{self, Axis, DimIndex, Permutation, ViewIndex};
use crate::iter::Iter;
Expand Down Expand Up @@ -263,34 +263,36 @@ impl_into_view!(5, (X, Y, Z, W, U), (x, y, z, w, u));
impl_into_view!(6, (X, Y, Z, W, U, V), (x, y, z, w, u, v));

impl<'a, T, U, D: Dim, L: Layout> Apply<U> for ViewArray<'a, T, D, L> {
type Output = GridArray<U, D>;
type ZippedWith<I: IntoExpression> = GridArray<U, D::Max<I::Dim>>;
type Output<F: FnMut(&'a T) -> U> = Expression<impl Producer<Item = U, Dim = D>>;
type ZippedWith<I: IntoExpression, F: FnMut(Self::Item, I::Item) -> U> =
Expression<impl Producer<Item = U, Dim = D::Max<I::Dim>>>;

fn apply<F: FnMut(&'a T) -> U>(self, f: F) -> Self::Output {
self.into_expr().map(f).eval()
fn apply<F: FnMut(&'a T) -> U>(self, f: F) -> Self::Output<F> {
self.into_expr().map(f)
}

fn zip_with<I: IntoExpression, F>(self, expr: I, mut f: F) -> Self::ZippedWith<I>
fn zip_with<I: IntoExpression, F>(self, expr: I, mut f: F) -> Self::ZippedWith<I, F>
where
F: FnMut(&'a T, I::Item) -> U,
{
self.into_expr().zip(expr).map(|(x, y)| f(x, y)).eval()
self.into_expr().zip(expr).map(move |(x, y)| f(x, y))
}
}

impl<'a, T, U, D: Dim, L: Layout> Apply<U> for ViewArrayMut<'a, T, D, L> {
type Output = GridArray<U, D>;
type ZippedWith<I: IntoExpression> = GridArray<U, D::Max<I::Dim>>;
type Output<F: FnMut(&'a mut T) -> U> = Expression<impl Producer<Item = U, Dim = D>>;
type ZippedWith<I: IntoExpression, F: FnMut(Self::Item, I::Item) -> U> =
Expression<impl Producer<Item = U, Dim = D::Max<I::Dim>>>;

fn apply<F: FnMut(&'a mut T) -> U>(self, f: F) -> Self::Output {
self.into_expr().map(f).eval()
fn apply<F: FnMut(&'a mut T) -> U>(self, f: F) -> Self::Output<F> {
self.into_expr().map(f)
}

fn zip_with<I: IntoExpression, F>(self, expr: I, mut f: F) -> Self::ZippedWith<I>
fn zip_with<I: IntoExpression, F>(self, expr: I, mut f: F) -> Self::ZippedWith<I, F>
where
F: FnMut(&'a mut T, I::Item) -> U,
{
self.into_expr().zip(expr).map(|(x, y)| f(x, y)).eval()
self.into_expr().zip(expr).map(move |(x, y)| f(x, y))
}
}

Expand Down
35 changes: 18 additions & 17 deletions tests/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#![cfg_attr(feature = "nightly", feature(hasher_prefixfree_extras))]
#![cfg_attr(feature = "nightly", feature(int_roundings))]
#![cfg_attr(feature = "nightly", feature(slice_range))]
#![feature(impl_trait_in_assoc_type)]
#![warn(missing_docs)]
#![warn(unreachable_pub)]
#![warn(unused_results)]
Expand Down Expand Up @@ -324,8 +325,8 @@ fn test_expr() {

assert_eq!(a.as_span().shape(), [3, 2]);

assert_eq!((&a + &view![1, 2, 3]).as_slice(), [2, 4, 6, 5, 7, 9]);
assert_eq!((&view![1, 2, 3] + &a).as_slice(), [2, 4, 6, 5, 7, 9]);
assert_eq!((&a + &view![1, 2, 3]).eval().as_slice(), [2, 4, 6, 5, 7, 9]);
assert_eq!((&view![1, 2, 3] + &a).eval().as_slice(), [2, 4, 6, 5, 7, 9]);

assert_eq!(format!("{:?}", a.axis_expr::<0>()), "AxisExpr([[1, 4], [2, 5], [3, 6]])");
assert_eq!(format!("{:?}", a.outer_expr_mut()), "AxisExprMut([[1, 2, 3], [4, 5, 6]])");
Expand Down Expand Up @@ -502,42 +503,42 @@ fn test_ops() {

a -= expr::fill(1);
a -= &b;
a -= b.as_span();
a -= b.to_view();

*a.as_mut_span() -= expr::fill(1);
*a.as_mut_span() -= &b;
*a.as_mut_span() -= b.as_span();
*a.as_mut_span() -= b.to_view();

assert_eq!(a, Grid::from([[-37, -32, -27], [-22, -17, -12]]));

a = a - expr::fill(1);
a = a - &b;
a = a - b.as_span();
a = a - b.to_view();

a = expr::fill(1) - a;
a = &b - a;
a = b.as_span() - a;
a = b.to_view() - a;

assert_eq!(a, Grid::from([[57, 50, 43], [36, 29, 22]]));

a = &a - &b;
a = &a - b.as_span();
a = a.as_span() - &b;
a = a.as_span() - b.as_span();
a = (&a - &b).eval();
a = (&a - b.to_view()).eval();
a = (a.to_view() - &b).eval();
a = (a.to_view() - b.to_view()).eval();

assert_eq!(a, Grid::from([[21, 18, 15], [12, 9, 6]]));

a = &a - expr::fill(1);
a = a.as_span() - expr::fill(1);
a = (&a - expr::fill(1)).eval();
a = (a.to_view() - expr::fill(1)).eval();

a = expr::fill(1) - &a;
a = expr::fill(1) - a.as_span();
a = (expr::fill(1) - &a).eval();
a = (expr::fill(1) - a.to_view()).eval();

assert_eq!(a, Grid::from([[19, 16, 13], [10, 7, 4]]));

a = -a;
a = -&a;
a = -a.as_span();
a = (-&a).eval();
a = (-a.to_view()).eval();

assert_eq!(a, Grid::from([[-19, -16, -13], [-10, -7, -4]]));

Expand All @@ -551,7 +552,7 @@ fn test_ops() {
let c = expr::fill_with(|| 1usize) + expr::from_elem([2, 3], 4);
let c = c + expr::from_fn([2, 3], |x| x[0] + x[1]);

assert_eq!(c, Grid::from([[5, 6], [6, 7], [7, 8]]));
assert_eq!(c.eval(), Grid::from([[5, 6], [6, 7], [7, 8]]));
}

#[cfg(feature = "serde")]
Expand Down

0 comments on commit 0f38f36

Please sign in to comment.