Skip to content

Commit

Permalink
Merge pull request #817 from rust-ndarray/par-collect
Browse files Browse the repository at this point in the history
Implement parallel collect to array for non-Copy elements
  • Loading branch information
bluss committed May 28, 2020
2 parents 79d99c7 + e472612 commit f69248e
Show file tree
Hide file tree
Showing 14 changed files with 478 additions and 189 deletions.
36 changes: 36 additions & 0 deletions benches/par_rayon.rs
Expand Up @@ -136,3 +136,39 @@ fn rayon_add(bench: &mut Bencher) {
});
});
}

const COLL_STRING_N: usize = 64;
const COLL_F64_N: usize = 128;

#[bench]
fn vec_string_collect(bench: &mut test::Bencher) {
let v = vec![""; COLL_STRING_N * COLL_STRING_N];
bench.iter(|| {
v.iter().map(|s| s.to_owned()).collect::<Vec<_>>()
});
}

#[bench]
fn array_string_collect(bench: &mut test::Bencher) {
let v = Array::from_elem((COLL_STRING_N, COLL_STRING_N), "");
bench.iter(|| {
Zip::from(&v).par_apply_collect(|s| s.to_owned())
});
}

#[bench]
fn vec_f64_collect(bench: &mut test::Bencher) {
let v = vec![1.; COLL_F64_N * COLL_F64_N];
bench.iter(|| {
v.iter().map(|s| s + 1.).collect::<Vec<_>>()
});
}

#[bench]
fn array_f64_collect(bench: &mut test::Bencher) {
let v = Array::from_elem((COLL_F64_N, COLL_F64_N), 1.);
bench.iter(|| {
Zip::from(&v).par_apply_collect(|s| s + 1.)
});
}

8 changes: 8 additions & 0 deletions src/dimension/dimension_trait.rs
Expand Up @@ -540,6 +540,14 @@ impl Dimension for Dim<[Ix; 1]> {
fn try_remove_axis(&self, axis: Axis) -> Self::Smaller {
self.remove_axis(axis)
}

fn from_dimension<D2: Dimension>(d: &D2) -> Option<Self> {
if 1 == d.ndim() {
Some(Ix1(d[0]))
} else {
None
}
}
private_impl! {}
}

Expand Down
6 changes: 5 additions & 1 deletion src/impl_methods.rs
Expand Up @@ -1293,7 +1293,11 @@ where
is_standard_layout(&self.dim, &self.strides)
}

fn is_contiguous(&self) -> bool {
/// Return true if the array is known to be contiguous.
///
/// Will detect c- and f-contig arrays correctly, but otherwise
/// There are some false negatives.
pub(crate) fn is_contiguous(&self) -> bool {
D::is_contiguous(&self.dim, &self.strides)
}

Expand Down
3 changes: 2 additions & 1 deletion src/indexes.rs
Expand Up @@ -7,7 +7,8 @@
// except according to those terms.
use super::Dimension;
use crate::dimension::IntoDimension;
use crate::zip::{Offset, Splittable};
use crate::zip::Offset;
use crate::split_at::SplitAt;
use crate::Axis;
use crate::Layout;
use crate::NdProducer;
Expand Down
2 changes: 2 additions & 0 deletions src/lib.rs
Expand Up @@ -176,9 +176,11 @@ mod linalg_traits;
mod linspace;
mod logspace;
mod numeric_util;
mod partial;
mod shape_builder;
#[macro_use]
mod slice;
mod split_at;
mod stacking;
#[macro_use]
mod zip;
Expand Down
50 changes: 44 additions & 6 deletions src/parallel/impl_par_methods.rs
Expand Up @@ -2,6 +2,10 @@ use crate::{Array, ArrayBase, DataMut, Dimension, IntoNdProducer, NdProducer, Zi
use crate::AssignElem;

use crate::parallel::prelude::*;
use crate::parallel::par::ParallelSplits;
use super::send_producer::SendProducer;

use crate::partial::Partial;

/// # Parallel methods
///
Expand Down Expand Up @@ -43,6 +47,8 @@ where

// Zip

const COLLECT_MAX_SPLITS: usize = 10;

macro_rules! zip_impl {
($([$notlast:ident $($p:ident)*],)+) => {
$(
Expand Down Expand Up @@ -71,14 +77,46 @@ macro_rules! zip_impl {
/// inputs.
///
/// If all inputs are c- or f-order respectively, that is preserved in the output.
///
/// Restricted to functions that produce copyable results for technical reasons; other
/// cases are not yet implemented.
pub fn par_apply_collect<R>(self, f: impl Fn($($p::Item,)* ) -> R + Sync + Send) -> Array<R, D>
where R: Copy + Send
pub fn par_apply_collect<R>(self, f: impl Fn($($p::Item,)* ) -> R + Sync + Send)
-> Array<R, D>
where R: Send
{
let mut output = self.uninitalized_for_current_layout::<R>();
self.par_apply_assign_into(&mut output, f);
let total_len = output.len();

// Create a parallel iterator that produces chunks of the zip with the output
// array. It's crucial that both parts split in the same way, and in a way
// so that the chunks of the output are still contig.
//
// Use a raw view so that we can alias the output data here and in the partial
// result.
let splits = unsafe {
ParallelSplits {
iter: self.and(SendProducer::new(output.raw_view_mut().cast::<R>())),
// Keep it from splitting the Zip down too small
max_splits: COLLECT_MAX_SPLITS,
}
};

let collect_result = splits.map(move |zip| {
// Apply the mapping function on this chunk of the zip
// Create a partial result for the contiguous slice of data being written to
unsafe {
zip.collect_with_partial(&f)
}
})
.reduce(Partial::stub, Partial::try_merge);

if std::mem::needs_drop::<R>() {
debug_assert_eq!(total_len, collect_result.len,
"collect len is not correct, expected {}", total_len);
assert!(collect_result.len == total_len,
"Collect: Expected number of writes not completed");
}

// Here the collect result is complete, and we release its ownership and transfer
// it to the output array.
collect_result.release_ownership();
unsafe {
output.assume_init()
}
Expand Down
1 change: 1 addition & 0 deletions src/parallel/mod.rs
Expand Up @@ -155,4 +155,5 @@ pub use crate::par_azip;
mod impl_par_methods;
mod into_impls;
mod par;
mod send_producer;
mod zipmacro;
62 changes: 60 additions & 2 deletions src/parallel/par.rs
Expand Up @@ -15,6 +15,7 @@ use crate::iter::AxisIter;
use crate::iter::AxisIterMut;
use crate::Dimension;
use crate::{ArrayView, ArrayViewMut};
use crate::split_at::SplitPreference;

/// Parallel iterator wrapper.
#[derive(Copy, Clone, Debug)]
Expand Down Expand Up @@ -170,7 +171,14 @@ macro_rules! par_iter_view_wrapper {
fn fold_with<F>(self, folder: F) -> F
where F: Folder<Self::Item>,
{
self.into_iter().fold(folder, move |f, elt| f.consume(elt))
Zip::from(self.0).fold_while(folder, |mut folder, elt| {
folder = folder.consume(elt);
if folder.full() {
FoldWhile::Done(folder)
} else {
FoldWhile::Continue(folder)
}
}).into_inner()
}
}

Expand Down Expand Up @@ -243,7 +251,7 @@ macro_rules! zip_impl {
type Item = ($($p::Item ,)*);

fn split(self) -> (Self, Option<Self>) {
if self.0.size() <= 1 {
if !self.0.can_split() {
return (self, None)
}
let (a, b) = self.0.split();
Expand Down Expand Up @@ -275,3 +283,53 @@ zip_impl! {
[P1 P2 P3 P4 P5],
[P1 P2 P3 P4 P5 P6],
}

/// A parallel iterator (unindexed) that produces the splits of the array
/// or producer `P`.
pub(crate) struct ParallelSplits<P> {
pub(crate) iter: P,
pub(crate) max_splits: usize,
}

impl<P> ParallelIterator for ParallelSplits<P>
where P: SplitPreference + Send,
{
type Item = P;

fn drive_unindexed<C>(self, consumer: C) -> C::Result
where C: UnindexedConsumer<Self::Item>
{
bridge_unindexed(self, consumer)
}

fn opt_len(&self) -> Option<usize> {
None
}
}

impl<P> UnindexedProducer for ParallelSplits<P>
where P: SplitPreference + Send,
{
type Item = P;

fn split(self) -> (Self, Option<Self>) {
if self.max_splits == 0 || !self.iter.can_split() {
return (self, None)
}
let (a, b) = self.iter.split();
(ParallelSplits {
iter: a,
max_splits: self.max_splits - 1,
},
Some(ParallelSplits {
iter: b,
max_splits: self.max_splits - 1,
}))
}

fn fold_with<Fold>(self, folder: Fold) -> Fold
where Fold: Folder<Self::Item>,
{
folder.consume(self.iter)
}
}
83 changes: 83 additions & 0 deletions src/parallel/send_producer.rs
@@ -0,0 +1,83 @@

use crate::imp_prelude::*;
use crate::{Layout, NdProducer};
use std::ops::{Deref, DerefMut};

/// An NdProducer that is unconditionally `Send`.
#[repr(transparent)]
pub(crate) struct SendProducer<T> {
inner: T
}

impl<T> SendProducer<T> {
/// Create an unconditionally `Send` ndproducer from the producer
pub(crate) unsafe fn new(producer: T) -> Self { Self { inner: producer } }
}

unsafe impl<P> Send for SendProducer<P> { }

impl<P> Deref for SendProducer<P> {
type Target = P;
fn deref(&self) -> &P { &self.inner }
}

impl<P> DerefMut for SendProducer<P> {
fn deref_mut(&mut self) -> &mut P { &mut self.inner }
}

impl<P: NdProducer> NdProducer for SendProducer<P>
where P: NdProducer,
{
type Item = P::Item;
type Dim = P::Dim;
type Ptr = P::Ptr;
type Stride = P::Stride;

private_impl! {}

#[inline(always)]
fn raw_dim(&self) -> Self::Dim {
self.inner.raw_dim()
}

#[inline(always)]
fn equal_dim(&self, dim: &Self::Dim) -> bool {
self.inner.equal_dim(dim)
}

#[inline(always)]
fn as_ptr(&self) -> Self::Ptr {
self.inner.as_ptr()
}

#[inline(always)]
fn layout(&self) -> Layout {
self.inner.layout()
}

#[inline(always)]
unsafe fn as_ref(&self, ptr: Self::Ptr) -> Self::Item {
self.inner.as_ref(ptr)
}

#[inline(always)]
unsafe fn uget_ptr(&self, i: &Self::Dim) -> Self::Ptr {
self.inner.uget_ptr(i)
}

#[inline(always)]
fn stride_of(&self, axis: Axis) -> Self::Stride {
self.inner.stride_of(axis)
}

#[inline(always)]
fn contiguous_stride(&self) -> Self::Stride {
self.inner.contiguous_stride()
}

fn split_at(self, axis: Axis, index: usize) -> (Self, Self) {
let (a, b) = self.inner.split_at(axis, index);
(Self { inner: a }, Self { inner: b })
}
}

0 comments on commit f69248e

Please sign in to comment.