Skip to content

Commit

Permalink
Merge pull request #639 from nitsky/par-iter-axis-chunks
Browse files Browse the repository at this point in the history
Parallel Iterator for AxisChunksIter
  • Loading branch information
bluss committed Sep 15, 2019
2 parents f489851 + f9ac9d4 commit f607ff6
Show file tree
Hide file tree
Showing 6 changed files with 119 additions and 0 deletions.
23 changes: 23 additions & 0 deletions parallel/src/lib.rs
Expand Up @@ -64,6 +64,29 @@
//! }
//! ```
//!
//! ## Axis chunks iterators
//!
//! Use the parallel `.axis_chunks_iter()` to process your data in chunks.
//!
//! ```
//! extern crate ndarray;
//!
//! use ndarray::Array;
//! use ndarray::Axis;
//! use ndarray_parallel::prelude::*;
//!
//! fn main() {
//! let a = Array::linspace(0., 63., 64).into_shape((4, 16)).unwrap();
//! let mut shapes = Vec::new();
//! a.axis_chunks_iter(Axis(0), 3)
//! .into_par_iter()
//! .map(|chunk| chunk.shape().to_owned())
//! .collect_into_vec(&mut shapes);
//!
//! assert_eq!(shapes, [vec![3, 16], vec![1, 16]]);
//! }
//! ```
//!
//! ## Zip
//!
//! Use zip for lock step function application across several arrays
Expand Down
4 changes: 4 additions & 0 deletions parallel/src/par.rs
Expand Up @@ -8,6 +8,8 @@ use rayon::iter::plumbing::{Consumer, UnindexedConsumer};
use rayon::iter::IndexedParallelIterator;
use rayon::iter::ParallelIterator;

use ndarray::iter::AxisChunksIter;
use ndarray::iter::AxisChunksIterMut;
use ndarray::iter::AxisIter;
use ndarray::iter::AxisIterMut;
use ndarray::Dimension;
Expand Down Expand Up @@ -112,6 +114,8 @@ macro_rules! par_iter_wrapper {

par_iter_wrapper!(AxisIter, [Sync]);
par_iter_wrapper!(AxisIterMut, [Send + Sync]);
par_iter_wrapper!(AxisChunksIter, [Sync]);
par_iter_wrapper!(AxisChunksIterMut, [Send + Sync]);

macro_rules! par_iter_view_wrapper {
// thread_bounds are either Sync or Send + Sync
Expand Down
31 changes: 31 additions & 0 deletions parallel/tests/rayon.rs
Expand Up @@ -7,6 +7,8 @@ use ndarray_parallel::prelude::*;

const M: usize = 1024 * 10;
const N: usize = 100;
const CHUNK_SIZE: usize = 100;
const N_CHUNKS: usize = (M + CHUNK_SIZE - 1) / CHUNK_SIZE;

#[test]
fn test_axis_iter() {
Expand Down Expand Up @@ -53,3 +55,32 @@ fn test_regular_iter_collect() {
let v = a.view().into_par_iter().map(|&x| x).collect::<Vec<_>>();
assert_eq!(v.len(), a.len());
}

#[test]
fn test_axis_chunks_iter() {
let mut a = Array2::<f64>::zeros((M, N));
for (i, mut v) in a.axis_chunks_iter_mut(Axis(0), CHUNK_SIZE).enumerate() {
v.fill(i as _);
}
assert_eq!(a.axis_chunks_iter(Axis(0), CHUNK_SIZE).len(), N_CHUNKS);
let s: f64 = a
.axis_chunks_iter(Axis(0), CHUNK_SIZE)
.into_par_iter()
.map(|x| x.sum())
.sum();
println!("{:?}", a.slice(s![..10, ..5]));
assert_eq!(s, a.sum());
}

#[test]
fn test_axis_chunks_iter_mut() {
let mut a = Array::linspace(0., 1.0f64, M * N)
.into_shape((M, N))
.unwrap();
let b = a.mapv(|x| x.exp());
a.axis_chunks_iter_mut(Axis(0), CHUNK_SIZE)
.into_par_iter()
.for_each(|mut v| v.mapv_inplace(|x| x.exp()));
println!("{:?}", a.slice(s![..10, ..5]));
assert!(a.all_close(&b, 0.001));
}
24 changes: 24 additions & 0 deletions src/parallel/mod.rs
Expand Up @@ -14,6 +14,7 @@
//! - [`ArrayView`](ArrayView): `.into_par_iter()`
//! - [`ArrayViewMut`](ArrayViewMut): `.into_par_iter()`
//! - [`AxisIter`](iter::AxisIter), [`AxisIterMut`](iter::AxisIterMut): `.into_par_iter()`
//! - [`AxisChunksIter`](iter::AxisChunksIter), [`AxisChunksIterMut`](iter::AxisChunksIterMut): `.into_par_iter()`
//! - [`Zip`] `.into_par_iter()`
//!
//! The following other parallelized methods exist:
Expand Down Expand Up @@ -76,6 +77,29 @@
//! }
//! ```
//!
//! ## Axis chunks iterators
//!
//! Use the parallel `.axis_chunks_iter()` to process your data in chunks.
//!
//! ```
//! extern crate ndarray;
//!
//! use ndarray::Array;
//! use ndarray::Axis;
//! use ndarray::parallel::prelude::*;
//!
//! fn main() {
//! let a = Array::linspace(0., 63., 64).into_shape((4, 16)).unwrap();
//! let mut shapes = Vec::new();
//! a.axis_chunks_iter(Axis(0), 3)
//! .into_par_iter()
//! .map(|chunk| chunk.shape().to_owned())
//! .collect_into_vec(&mut shapes);
//!
//! assert_eq!(shapes, [vec![3, 16], vec![1, 16]]);
//! }
//! ```
//!
//! ## Zip
//!
//! Use zip for lock step function application across several arrays
Expand Down
4 changes: 4 additions & 0 deletions src/parallel/par.rs
Expand Up @@ -9,6 +9,8 @@ use rayon::iter::IndexedParallelIterator;
use rayon::iter::ParallelIterator;
use rayon::prelude::IntoParallelIterator;

use crate::iter::AxisChunksIter;
use crate::iter::AxisChunksIterMut;
use crate::iter::AxisIter;
use crate::iter::AxisIterMut;
use crate::Dimension;
Expand Down Expand Up @@ -112,6 +114,8 @@ macro_rules! par_iter_wrapper {

par_iter_wrapper!(AxisIter, [Sync]);
par_iter_wrapper!(AxisIterMut, [Send + Sync]);
par_iter_wrapper!(AxisChunksIter, [Sync]);
par_iter_wrapper!(AxisChunksIterMut, [Send + Sync]);

macro_rules! par_iter_view_wrapper {
// thread_bounds are either Sync or Send + Sync
Expand Down
33 changes: 33 additions & 0 deletions tests/par_rayon.rs
Expand Up @@ -5,6 +5,8 @@ use ndarray::prelude::*;

const M: usize = 1024 * 10;
const N: usize = 100;
const CHUNK_SIZE: usize = 100;
const N_CHUNKS: usize = (M + CHUNK_SIZE - 1) / CHUNK_SIZE;

#[test]
fn test_axis_iter() {
Expand Down Expand Up @@ -53,3 +55,34 @@ fn test_regular_iter_collect() {
let v = a.view().into_par_iter().map(|&x| x).collect::<Vec<_>>();
assert_eq!(v.len(), a.len());
}

#[test]
fn test_axis_chunks_iter() {
let mut a = Array2::<f64>::zeros((M, N));
for (i, mut v) in a.axis_chunks_iter_mut(Axis(0), CHUNK_SIZE).enumerate() {
v.fill(i as _);
}
assert_eq!(a.axis_chunks_iter(Axis(0), CHUNK_SIZE).len(), N_CHUNKS);
let s: f64 = a
.axis_chunks_iter(Axis(0), CHUNK_SIZE)
.into_par_iter()
.map(|x| x.sum())
.sum();
println!("{:?}", a.slice(s![..10, ..5]));
assert_eq!(s, a.sum());
}

#[test]
#[cfg(feature = "approx")]
fn test_axis_chunks_iter_mut() {
use approx::assert_abs_diff_eq;
let mut a = Array::linspace(0., 1.0f64, M * N)
.into_shape((M, N))
.unwrap();
let b = a.mapv(|x| x.exp());
a.axis_chunks_iter_mut(Axis(0), CHUNK_SIZE)
.into_par_iter()
.for_each(|mut v| v.mapv_inplace(|x| x.exp()));
println!("{:?}", a.slice(s![..10, ..5]));
assert_abs_diff_eq!(a, b, epsilon = 0.001);
}

0 comments on commit f607ff6

Please sign in to comment.