diff --git a/parallel/src/lib.rs b/parallel/src/lib.rs index 86b5da4c7..f16cec19b 100644 --- a/parallel/src/lib.rs +++ b/parallel/src/lib.rs @@ -64,6 +64,29 @@ //! } //! ``` //! +//! ## Axis chunks iterators +//! +//! Use the parallel `.axis_chunks_iter()` to process your data in chunks. +//! +//! ``` +//! extern crate ndarray; +//! +//! use ndarray::Array; +//! use ndarray::Axis; +//! use ndarray_parallel::prelude::*; +//! +//! fn main() { +//! let a = Array::linspace(0., 63., 64).into_shape((4, 16)).unwrap(); +//! let mut shapes = Vec::new(); +//! a.axis_chunks_iter(Axis(0), 3) +//! .into_par_iter() +//! .map(|chunk| chunk.shape().to_owned()) +//! .collect_into_vec(&mut shapes); +//! +//! assert_eq!(shapes, [vec![3, 16], vec![1, 16]]); +//! } +//! ``` +//! //! ## Zip //! //! Use zip for lock step function application across several arrays diff --git a/parallel/src/par.rs b/parallel/src/par.rs index 04bfdd07e..f0bc50de7 100644 --- a/parallel/src/par.rs +++ b/parallel/src/par.rs @@ -8,6 +8,8 @@ use rayon::iter::plumbing::{Consumer, UnindexedConsumer}; use rayon::iter::IndexedParallelIterator; use rayon::iter::ParallelIterator; +use ndarray::iter::AxisChunksIter; +use ndarray::iter::AxisChunksIterMut; use ndarray::iter::AxisIter; use ndarray::iter::AxisIterMut; use ndarray::Dimension; @@ -112,6 +114,8 @@ macro_rules! par_iter_wrapper { par_iter_wrapper!(AxisIter, [Sync]); par_iter_wrapper!(AxisIterMut, [Send + Sync]); +par_iter_wrapper!(AxisChunksIter, [Sync]); +par_iter_wrapper!(AxisChunksIterMut, [Send + Sync]); macro_rules! par_iter_view_wrapper { // thread_bounds are either Sync or Send + Sync diff --git a/parallel/tests/rayon.rs b/parallel/tests/rayon.rs index 9432f9edf..ca0140193 100644 --- a/parallel/tests/rayon.rs +++ b/parallel/tests/rayon.rs @@ -7,6 +7,8 @@ use ndarray_parallel::prelude::*; const M: usize = 1024 * 10; const N: usize = 100; +const CHUNK_SIZE: usize = 100; +const N_CHUNKS: usize = (M + CHUNK_SIZE - 1) / CHUNK_SIZE; #[test] fn test_axis_iter() { @@ -53,3 +55,32 @@ fn test_regular_iter_collect() { let v = a.view().into_par_iter().map(|&x| x).collect::>(); assert_eq!(v.len(), a.len()); } + +#[test] +fn test_axis_chunks_iter() { + let mut a = Array2::::zeros((M, N)); + for (i, mut v) in a.axis_chunks_iter_mut(Axis(0), CHUNK_SIZE).enumerate() { + v.fill(i as _); + } + assert_eq!(a.axis_chunks_iter(Axis(0), CHUNK_SIZE).len(), N_CHUNKS); + let s: f64 = a + .axis_chunks_iter(Axis(0), CHUNK_SIZE) + .into_par_iter() + .map(|x| x.sum()) + .sum(); + println!("{:?}", a.slice(s![..10, ..5])); + assert_eq!(s, a.sum()); +} + +#[test] +fn test_axis_chunks_iter_mut() { + let mut a = Array::linspace(0., 1.0f64, M * N) + .into_shape((M, N)) + .unwrap(); + let b = a.mapv(|x| x.exp()); + a.axis_chunks_iter_mut(Axis(0), CHUNK_SIZE) + .into_par_iter() + .for_each(|mut v| v.mapv_inplace(|x| x.exp())); + println!("{:?}", a.slice(s![..10, ..5])); + assert!(a.all_close(&b, 0.001)); +} diff --git a/src/parallel/mod.rs b/src/parallel/mod.rs index fb9738817..60dbe4662 100644 --- a/src/parallel/mod.rs +++ b/src/parallel/mod.rs @@ -14,6 +14,7 @@ //! - [`ArrayView`](ArrayView): `.into_par_iter()` //! - [`ArrayViewMut`](ArrayViewMut): `.into_par_iter()` //! - [`AxisIter`](iter::AxisIter), [`AxisIterMut`](iter::AxisIterMut): `.into_par_iter()` +//! - [`AxisChunksIter`](iter::AxisChunksIter), [`AxisChunksIterMut`](iter::AxisChunksIterMut): `.into_par_iter()` //! - [`Zip`] `.into_par_iter()` //! //! The following other parallelized methods exist: @@ -76,6 +77,29 @@ //! } //! ``` //! +//! ## Axis chunks iterators +//! +//! Use the parallel `.axis_chunks_iter()` to process your data in chunks. +//! +//! ``` +//! extern crate ndarray; +//! +//! use ndarray::Array; +//! use ndarray::Axis; +//! use ndarray::parallel::prelude::*; +//! +//! fn main() { +//! let a = Array::linspace(0., 63., 64).into_shape((4, 16)).unwrap(); +//! let mut shapes = Vec::new(); +//! a.axis_chunks_iter(Axis(0), 3) +//! .into_par_iter() +//! .map(|chunk| chunk.shape().to_owned()) +//! .collect_into_vec(&mut shapes); +//! +//! assert_eq!(shapes, [vec![3, 16], vec![1, 16]]); +//! } +//! ``` +//! //! ## Zip //! //! Use zip for lock step function application across several arrays diff --git a/src/parallel/par.rs b/src/parallel/par.rs index bfa7522ad..efd761acf 100644 --- a/src/parallel/par.rs +++ b/src/parallel/par.rs @@ -9,6 +9,8 @@ use rayon::iter::IndexedParallelIterator; use rayon::iter::ParallelIterator; use rayon::prelude::IntoParallelIterator; +use crate::iter::AxisChunksIter; +use crate::iter::AxisChunksIterMut; use crate::iter::AxisIter; use crate::iter::AxisIterMut; use crate::Dimension; @@ -112,6 +114,8 @@ macro_rules! par_iter_wrapper { par_iter_wrapper!(AxisIter, [Sync]); par_iter_wrapper!(AxisIterMut, [Send + Sync]); +par_iter_wrapper!(AxisChunksIter, [Sync]); +par_iter_wrapper!(AxisChunksIterMut, [Send + Sync]); macro_rules! par_iter_view_wrapper { // thread_bounds are either Sync or Send + Sync diff --git a/tests/par_rayon.rs b/tests/par_rayon.rs index 24a636275..4d5a8f1a9 100644 --- a/tests/par_rayon.rs +++ b/tests/par_rayon.rs @@ -5,6 +5,8 @@ use ndarray::prelude::*; const M: usize = 1024 * 10; const N: usize = 100; +const CHUNK_SIZE: usize = 100; +const N_CHUNKS: usize = (M + CHUNK_SIZE - 1) / CHUNK_SIZE; #[test] fn test_axis_iter() { @@ -53,3 +55,34 @@ fn test_regular_iter_collect() { let v = a.view().into_par_iter().map(|&x| x).collect::>(); assert_eq!(v.len(), a.len()); } + +#[test] +fn test_axis_chunks_iter() { + let mut a = Array2::::zeros((M, N)); + for (i, mut v) in a.axis_chunks_iter_mut(Axis(0), CHUNK_SIZE).enumerate() { + v.fill(i as _); + } + assert_eq!(a.axis_chunks_iter(Axis(0), CHUNK_SIZE).len(), N_CHUNKS); + let s: f64 = a + .axis_chunks_iter(Axis(0), CHUNK_SIZE) + .into_par_iter() + .map(|x| x.sum()) + .sum(); + println!("{:?}", a.slice(s![..10, ..5])); + assert_eq!(s, a.sum()); +} + +#[test] +#[cfg(feature = "approx")] +fn test_axis_chunks_iter_mut() { + use approx::assert_abs_diff_eq; + let mut a = Array::linspace(0., 1.0f64, M * N) + .into_shape((M, N)) + .unwrap(); + let b = a.mapv(|x| x.exp()); + a.axis_chunks_iter_mut(Axis(0), CHUNK_SIZE) + .into_par_iter() + .for_each(|mut v| v.mapv_inplace(|x| x.exp())); + println!("{:?}", a.slice(s![..10, ..5])); + assert_abs_diff_eq!(a, b, epsilon = 0.001); +}