From fdb780693a8737f2ac3652cd4c50b8c0b6d774cb Mon Sep 17 00:00:00 2001 From: Raphael Taylor-Davies Date: Thu, 6 Oct 2022 18:18:15 +0100 Subject: [PATCH] Add interleave kernel (#1523) --- arrow/src/compute/kernels/interleave.rs | 201 ++++++++++++++++++++++++ arrow/src/compute/kernels/mod.rs | 1 + arrow/src/compute/kernels/take.rs | 9 +- arrow/src/compute/mod.rs | 1 + 4 files changed, 210 insertions(+), 2 deletions(-) create mode 100644 arrow/src/compute/kernels/interleave.rs diff --git a/arrow/src/compute/kernels/interleave.rs b/arrow/src/compute/kernels/interleave.rs new file mode 100644 index 00000000000..21dedc8e8ba --- /dev/null +++ b/arrow/src/compute/kernels/interleave.rs @@ -0,0 +1,201 @@ +use arrow_array::{make_array, new_empty_array, Array, ArrayRef}; +use arrow_data::transform::MutableArrayData; +use arrow_schema::ArrowError; + +/// +/// Takes elements by index from an array of [`Array`], creating a new [`Array`] from those values. +/// +/// Each element in `indices` is a pair of `usize` with the first identifying the index +/// of the [`Array`] in `values`, and the second the index of the value within that [`Array`] +/// +/// ```text +/// ┌─────────────────┐ ┌─────────┐ ┌─────────────────┐ +/// │ A │ │ (0, 0) │ interleave( │ A │ +/// ├─────────────────┤ ├─────────┤ [values0, values1], ├─────────────────┤ +/// │ D │ │ (1, 0) │ indices │ B │ +/// └─────────────────┘ ├─────────┤ ) ├─────────────────┤ +/// values array 0 │ (1, 1) │ ─────────────────────────▶ │ C │ +/// ├─────────┤ ├─────────────────┤ +/// │ (0, 1) │ │ D │ +/// └─────────┘ └─────────────────┘ +/// ┌─────────────────┐ indices +/// │ B │ array +/// ├─────────────────┤ result +/// │ C │ +/// ├─────────────────┤ +/// │ E │ +/// └─────────────────┘ +/// values array 1 +/// ``` +/// +/// For selecting values by index from a single array see [compute::take](crate::compute::take) +/// +/// # Panics +/// +/// Panics if the arrays do not have the same data type or `values` is empty +pub fn interleave( + values: &[&dyn Array], + indices: &[(usize, usize)], +) -> Result { + if values.is_empty() { + return Err(ArrowError::InvalidArgumentError( + "interleave requires input of at least one array".to_string(), + )); + } + let data_type = values[0].data_type(); + + if values + .iter() + .skip(1) + .any(|array| array.data_type() != data_type) + { + return Err(ArrowError::InvalidArgumentError( + "It is not possible to interleave arrays of different data types." + .to_string(), + )); + } + + if indices.is_empty() { + return Ok(new_empty_array(data_type)); + } + + // TODO: Add specialized implementations (#1523) + + interleave_fallback(values, indices) +} + +/// Fallback implementation of interleave using [`MutableArrayData`] +fn interleave_fallback( + values: &[&dyn Array], + indices: &[(usize, usize)], +) -> Result { + let arrays: Vec<_> = values.iter().map(|x| x.data()).collect(); + let mut array_data = MutableArrayData::new(arrays, false, indices.len()); + + let mut cur_array = indices[0].0; + let mut start_row_idx = indices[0].1; + let mut end_row_idx = start_row_idx + 1; + + for (array, row) in indices.iter().skip(1).copied() { + if array == cur_array && row == end_row_idx { + // subsequent row in same batch + end_row_idx += 1; + continue; + } + + // emit current batch of rows for current buffer + array_data.extend(cur_array, start_row_idx, end_row_idx); + + // start new batch of rows + cur_array = array; + start_row_idx = row; + end_row_idx = start_row_idx + 1; + } + + // emit final batch of rows + array_data.extend(cur_array, start_row_idx, end_row_idx); + Ok(make_array(array_data.freeze())) +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow_array::builder::{Int32Builder, ListBuilder}; + use arrow_array::cast::{as_primitive_array, as_string_array}; + use arrow_array::types::Int32Type; + use arrow_array::{Int32Array, ListArray, StringArray}; + + #[test] + fn test_primitive() { + let a = Int32Array::from_iter_values([1, 2, 3, 4]); + let b = Int32Array::from_iter_values([5, 6, 7]); + let c = Int32Array::from_iter_values([8, 9, 10]); + let values = + interleave(&[&a, &b, &c], &[(0, 3), (0, 3), (2, 2), (2, 0), (1, 1)]).unwrap(); + let v = as_primitive_array::(&values); + assert_eq!(v.values(), &[4, 4, 10, 8, 6]); + } + + #[test] + fn test_primitive_nulls() { + let a = Int32Array::from_iter_values([1, 2, 3, 4]); + let b = Int32Array::from_iter([Some(1), Some(4), None]); + let values = + interleave(&[&a, &b], &[(0, 1), (1, 2), (1, 2), (0, 3), (0, 2)]).unwrap(); + let v: Vec<_> = as_primitive_array::(&values) + .into_iter() + .collect(); + assert_eq!(&v, &[Some(2), None, None, Some(4), Some(3)]) + } + + #[test] + fn test_primitive_empty() { + let a = Int32Array::from_iter_values([1, 2, 3, 4]); + let v = interleave(&[&a], &[]).unwrap(); + assert!(v.is_empty()); + } + + #[test] + fn test_strings() { + let a = StringArray::from_iter_values(["a", "b", "c"]); + let b = StringArray::from_iter_values(["hello", "world", "foo"]); + let values = + interleave(&[&a, &b], &[(0, 2), (0, 2), (1, 0), (1, 1), (0, 1)]).unwrap(); + let v = as_string_array(&values); + let values: Vec<_> = v.into_iter().collect(); + assert_eq!( + &values, + &[ + Some("c"), + Some("c"), + Some("hello"), + Some("world"), + Some("b") + ] + ) + } + + #[test] + fn test_lists() { + // [[1, 2], null, [3]] + let mut a = ListBuilder::new(Int32Builder::new()); + a.values().append_value(1); + a.values().append_value(2); + a.append(true); + a.append(false); + a.values().append_value(3); + a.append(true); + let a = a.finish(); + + // [[4], null, [5, 6, null]] + let mut b = ListBuilder::new(Int32Builder::new()); + b.values().append_value(4); + b.append(true); + b.append(false); + b.values().append_value(5); + b.values().append_value(6); + b.values().append_null(); + b.append(true); + let b = b.finish(); + + let values = + interleave(&[&a, &b], &[(0, 2), (0, 1), (1, 0), (1, 2), (1, 1)]).unwrap(); + let v = values.as_any().downcast_ref::().unwrap(); + + // [[3], null, [4], [5, 6, null], null] + let mut expected = ListBuilder::new(Int32Builder::new()); + expected.values().append_value(3); + expected.append(true); + expected.append(false); + expected.values().append_value(4); + expected.append(true); + expected.values().append_value(5); + expected.values().append_value(6); + expected.values().append_null(); + expected.append(true); + expected.append(false); + let expected = expected.finish(); + + assert_eq!(v, &expected); + } +} diff --git a/arrow/src/compute/kernels/mod.rs b/arrow/src/compute/kernels/mod.rs index 99cdcf460ce..8301f69bbf8 100644 --- a/arrow/src/compute/kernels/mod.rs +++ b/arrow/src/compute/kernels/mod.rs @@ -28,6 +28,7 @@ pub mod comparison; pub mod concat; pub mod concat_elements; pub mod filter; +pub mod interleave; pub mod length; pub mod limit; pub mod partition; diff --git a/arrow/src/compute/kernels/take.rs b/arrow/src/compute/kernels/take.rs index 1aa4473c044..5be0b58e252 100644 --- a/arrow/src/compute/kernels/take.rs +++ b/arrow/src/compute/kernels/take.rs @@ -46,15 +46,20 @@ use num::{ToPrimitive, Zero}; /// ├─────────────────┤ └─────────┘ └─────────────────┘ /// │ E │ /// └─────────────────┘ -/// values array indicies array result +/// values array indices array result /// ``` /// +/// For selecting values by index from multiple arrays see [compute::interleave](crate::compute::interleave) +/// /// # Errors /// This function errors whenever: /// * An index cannot be casted to `usize` (typically 32 bit architectures) /// * An index is out of bounds and `options` is set to check bounds. +/// /// # Safety -/// When `options` is not set to check bounds (default), taking indexes after `len` is undefined behavior. +/// +/// When `options` is not set to check bounds, taking indexes after `len` will panic. +/// /// # Examples /// ``` /// use arrow::array::{StringArray, UInt32Array}; diff --git a/arrow/src/compute/mod.rs b/arrow/src/compute/mod.rs index 2b3b9a76873..28e5e6b520b 100644 --- a/arrow/src/compute/mod.rs +++ b/arrow/src/compute/mod.rs @@ -29,6 +29,7 @@ pub use self::kernels::cast::*; pub use self::kernels::comparison::*; pub use self::kernels::concat::*; pub use self::kernels::filter::*; +pub use self::kernels::interleave::*; pub use self::kernels::limit::*; pub use self::kernels::partition::*; pub use self::kernels::regexp::*;