From 6c13c8924ab6f35feb72443a611d9b27122396e6 Mon Sep 17 00:00:00 2001 From: Raphael Taylor-Davies Date: Thu, 27 Oct 2022 20:03:57 +1300 Subject: [PATCH] Specialize interleave dictionary (#2944) --- arrow-select/src/interleave.rs | 148 +++++++++++++++++++++++++++- arrow/benches/interleave_kernels.rs | 6 ++ 2 files changed, 150 insertions(+), 4 deletions(-) diff --git a/arrow-select/src/interleave.rs b/arrow-select/src/interleave.rs index 18f834fa184..2738b430974 100644 --- a/arrow-select/src/interleave.rs +++ b/arrow-select/src/interleave.rs @@ -16,14 +16,18 @@ // under the License. use arrow_array::builder::{BooleanBufferBuilder, BufferBuilder}; +use arrow_array::types::ArrowDictionaryKeyType; use arrow_array::{ - downcast_primitive, make_array, new_empty_array, Array, ArrayRef, ArrowPrimitiveType, - GenericStringArray, OffsetSizeTrait, PrimitiveArray, + downcast_integer, downcast_primitive, make_array, new_empty_array, Array, ArrayRef, + ArrowPrimitiveType, DictionaryArray, GenericStringArray, OffsetSizeTrait, + PrimitiveArray, }; -use arrow_buffer::{Buffer, MutableBuffer}; +use arrow_buffer::bit_util::get_bit; +use arrow_buffer::{ArrowNativeType, Buffer, MutableBuffer}; use arrow_data::transform::MutableArrayData; use arrow_data::ArrayDataBuilder; use arrow_schema::{ArrowError, DataType}; +use std::collections::HashMap; use std::sync::Arc; macro_rules! primitive_helper { @@ -32,6 +36,12 @@ macro_rules! primitive_helper { }; } +macro_rules! dictionary_helper { + ($t:ty, $values:ident, $indices:ident, $data_type:ident) => { + interleave_dictionary::<$t>($values, $indices, $data_type) + }; +} + /// /// Takes elements by index from a list of [`Array`], creating a new [`Array`] from those values. /// @@ -87,6 +97,10 @@ pub fn interleave( data_type => (primitive_helper, values, indices, data_type), DataType::Utf8 => interleave_string::(values, indices, data_type), DataType::LargeUtf8 => interleave_string::(values, indices, data_type), + DataType::Dictionary(k, _) => downcast_integer! { + k.as_ref() => (dictionary_helper, values, indices, data_type), + _ => unreachable!(), + } _ => interleave_fallback(values, indices) } } @@ -184,6 +198,74 @@ fn interleave_string( Ok(Arc::new(GenericStringArray::::from(data))) } +/// Interleaves dictionary arrays +/// +/// +/// This will only copy dictionary values used by the output selection. However, the +/// resulting dictionary may contain duplicates if the source dictionaries contain duplicates +/// or the same value appears in multiple source arrays +fn interleave_dictionary( + values: &[&dyn Array], + indices: &[(usize, usize)], + data_type: &DataType, +) -> Result +where + K: ArrowDictionaryKeyType, + K::Native: std::hash::Hash + Eq, +{ + let interleaved = Interleave::<'_, DictionaryArray>::new(values, indices); + let mut value_indices = Vec::with_capacity(indices.len()); + let mut keys = MutableBuffer::new(indices.len()); + + // Map from (array,key) to output key + let mut mapping: HashMap<(usize, K::Native), K::Native> = + HashMap::with_capacity(indices.len()); + + // Given an array index and key, updates mapping and value_indices, returning the new key + let mut intern_key = |a: usize, k: K::Native| -> K::Native { + *mapping.entry((a, k)).or_insert_with(|| { + let new_key = K::Native::from_usize(value_indices.len()).expect("overflow"); + value_indices.push((a, k.as_usize())); + new_key + }) + }; + + // Iterate through identifying selected dictionary keys + match &interleaved.nulls { + Some(nulls) => { + for (idx, (a, b)) in indices.iter().enumerate() { + keys.push(match get_bit(nulls.as_ref(), idx) { + true => intern_key(*a, interleaved.arrays[*a].keys().value(*b)), + false => K::Native::default(), + }) + } + } + None => { + for (a, b) in indices { + keys.push(intern_key(*a, interleaved.arrays[*a].keys().value(*b))) + } + } + } + + // Copy across only values that were selected + let values: Vec<_> = interleaved + .arrays + .iter() + .map(|x| x.values().as_ref()) + .collect(); + let child_data = interleave(&values, &value_indices)?.data().clone(); + + let builder = ArrayDataBuilder::new(data_type.clone()) + .len(indices.len()) + .add_buffer(keys.into()) + .add_child_data(child_data) + .null_bit_buffer(interleaved.nulls) + .null_count(interleaved.null_count); + + let data = unsafe { builder.build_unchecked() }; + Ok(Arc::new(DictionaryArray::::from(data))) +} + /// Fallback implementation of interleave using [`MutableArrayData`] fn interleave_fallback( values: &[&dyn Array], @@ -221,7 +303,7 @@ fn interleave_fallback( mod tests { use super::*; use arrow_array::builder::{Int32Builder, ListBuilder}; - use arrow_array::cast::{as_primitive_array, as_string_array}; + use arrow_array::cast::{as_dictionary_array, as_primitive_array, as_string_array}; use arrow_array::types::Int32Type; use arrow_array::{Int32Array, ListArray, StringArray}; use arrow_schema::DataType; @@ -277,6 +359,64 @@ mod tests { ) } + #[test] + fn test_string_dictionaries() { + let a = DictionaryArray::::from_iter([ + Some("a"), + Some("b"), + None, + Some("b"), + Some("b"), + Some("a"), + ]); + + let b = DictionaryArray::::from_iter([ + Some("a"), + Some("c"), + None, + Some("c"), + Some("c"), + Some("d"), + ]); + + let interleaved = + interleave(&[&a, &b], &[(0, 2), (0, 2), (1, 0), (1, 1), (0, 1), (0, 2)]) + .unwrap(); + + let result = as_dictionary_array::(interleaved.as_ref()) + .downcast_dict::() + .unwrap(); + + let r: Vec<_> = result.into_iter().collect(); + assert_eq!(r, vec![None, None, Some("a"), Some("c"), Some("b"), None]); + } + + #[test] + fn test_dictionary_nulls() { + let child = Int32Array::from(vec![0]).into_data(); + let dictionary = ArrayDataBuilder::new(DataType::Dictionary( + Box::new(DataType::Int32), + Box::new(DataType::Int32), + )) + .len(2) + .add_buffer(Buffer::from_iter([-1_i32, 0_i32])) + .null_bit_buffer(Some(Buffer::from_slice_ref(&[0b00000010]))) + .null_count(1) + .add_child_data(child) + .build() + .unwrap(); + + let dictionary = DictionaryArray::::from(dictionary); + let interleaved = interleave(&[&dictionary], &[(0, 0), (0, 1)]).unwrap(); + + let result = as_dictionary_array::(interleaved.as_ref()) + .downcast_dict::() + .unwrap(); + + let r: Vec<_> = result.into_iter().collect(); + assert_eq!(r, vec![None, Some(0)]); + } + #[test] fn test_lists() { // [[1, 2], null, [3]] diff --git a/arrow/benches/interleave_kernels.rs b/arrow/benches/interleave_kernels.rs index 0c3eec60c0c..5fa0ccefb02 100644 --- a/arrow/benches/interleave_kernels.rs +++ b/arrow/benches/interleave_kernels.rs @@ -65,11 +65,17 @@ fn add_benchmark(c: &mut Criterion) { let string = create_string_array_with_len::(1024, 0., 20); let string_opt = create_string_array_with_len::(1024, 0.5, 20); + let values = create_string_array_with_len::(10, 0., 20); + let dict = create_dict_from_values::(1024, 0., &values); + let dict_opt = create_dict_from_values::(1024, 0.5, &values); + let cases: &[(&str, &dyn Array)] = &[ ("i32(0.0)", &i32), ("i32(0.5)", &i32_opt), ("str(20, 0.0)", &string), ("str(20, 0.5)", &string_opt), + ("dict(10, 0.0, str(20, 0.0))", &dict), + ("dict(10, 0.5, str(20, 0.0)", &dict_opt), ]; for (prefix, base) in cases {