Skip to content

Commit

Permalink
Re-encode dictionaries in selection kernels (#3558)
Browse files Browse the repository at this point in the history
* Re-encode dictionaries in selection kernels

* More benchmarks

* Best-effort hashing

* More benchmarks

* Add fallback to concatenating dictionaries

* Fix nulls

* Format

* Cleanup

* RAT

* Clippy

* Split out heuristic

* Add support to interleave kernel

* Clippy

* More clippy

* Clippy

* Cleanup

* Optimize concat

* Review feedback

* Clippy

* Improved null handling

* Further tests

* Faster ptr_eq
  • Loading branch information
tustvold committed Sep 5, 2023
1 parent db5314c commit 65edbb1
Show file tree
Hide file tree
Showing 11 changed files with 698 additions and 91 deletions.
8 changes: 8 additions & 0 deletions arrow-buffer/src/buffer/immutable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,14 @@ impl Buffer {
length,
})
}

/// Returns true if this [`Buffer`] is equal to `other`, using pointer comparisons
/// to determine buffer equality. This is cheaper than `PartialEq::eq` but may
/// return false when the arrays are logically equal
#[inline]
pub fn ptr_eq(&self, other: &Self) -> bool {
self.ptr == other.ptr && self.length == other.length
}
}

/// Creating a `Buffer` instance by copying the memory from a `AsRef<[u8]>` into a newly
Expand Down
8 changes: 8 additions & 0 deletions arrow-buffer/src/buffer/offset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,14 @@ impl<O: ArrowNativeType> OffsetBuffer<O> {
pub fn slice(&self, offset: usize, len: usize) -> Self {
Self(self.0.slice(offset, len.saturating_add(1)))
}

/// Returns true if this [`OffsetBuffer`] is equal to `other`, using pointer comparisons
/// to determine buffer equality. This is cheaper than `PartialEq::eq` but may
/// return false when the arrays are logically equal
#[inline]
pub fn ptr_eq(&self, other: &Self) -> bool {
self.0.ptr_eq(&other.0)
}
}

impl<T: ArrowNativeType> Deref for OffsetBuffer<T> {
Expand Down
8 changes: 8 additions & 0 deletions arrow-buffer/src/buffer/scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,14 @@ impl<T: ArrowNativeType> ScalarBuffer<T> {
pub fn into_inner(self) -> Buffer {
self.buffer
}

/// Returns true if this [`ScalarBuffer`] is equal to `other`, using pointer comparisons
/// to determine buffer equality. This is cheaper than `PartialEq::eq` but may
/// return false when the arrays are logically equal
#[inline]
pub fn ptr_eq(&self, other: &Self) -> bool {
self.buffer.ptr_eq(&other.buffer)
}
}

impl<T: ArrowNativeType> Deref for ScalarBuffer<T> {
Expand Down
1 change: 1 addition & 0 deletions arrow-select/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ arrow-data = { workspace = true }
arrow-schema = { workspace = true }
arrow-array = { workspace = true }
num = { version = "0.4", default-features = false, features = ["std"] }
ahash = { version = "0.8", default-features = false}

[features]
default = []
Expand Down
186 changes: 140 additions & 46 deletions arrow-select/src/concat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,20 +30,20 @@
//! assert_eq!(arr.len(), 3);
//! ```

use crate::dictionary::{merge_dictionary_values, should_merge_dictionary_values};
use arrow_array::cast::AsArray;
use arrow_array::types::*;
use arrow_array::*;
use arrow_buffer::ArrowNativeType;
use arrow_buffer::{ArrowNativeType, BooleanBufferBuilder, NullBuffer};
use arrow_data::transform::{Capacities, MutableArrayData};
use arrow_schema::{ArrowError, DataType, SchemaRef};
use std::sync::Arc;

fn binary_capacity<T: ByteArrayType>(arrays: &[&dyn Array]) -> Capacities {
let mut item_capacity = 0;
let mut bytes_capacity = 0;
for array in arrays {
let a = array
.as_any()
.downcast_ref::<GenericByteArray<T>>()
.unwrap();
let a = array.as_bytes::<T>();

// Guaranteed to always have at least one element
let offsets = a.value_offsets();
Expand All @@ -54,6 +54,59 @@ fn binary_capacity<T: ByteArrayType>(arrays: &[&dyn Array]) -> Capacities {
Capacities::Binary(item_capacity, Some(bytes_capacity))
}

fn concat_dictionaries<K: ArrowDictionaryKeyType>(
arrays: &[&dyn Array],
) -> Result<ArrayRef, ArrowError> {
let mut output_len = 0;
let dictionaries: Vec<_> = arrays
.iter()
.map(|x| x.as_dictionary::<K>())
.inspect(|d| output_len += d.len())
.collect();

if !should_merge_dictionary_values::<K>(&dictionaries, output_len) {
return concat_fallback(arrays, Capacities::Array(output_len));
}

let merged = merge_dictionary_values(&dictionaries, None)?;

// Recompute keys
let mut key_values = Vec::with_capacity(output_len);

let mut has_nulls = false;
for (d, mapping) in dictionaries.iter().zip(merged.key_mappings) {
has_nulls |= d.null_count() != 0;
for key in d.keys().values() {
// Use get to safely handle nulls
key_values.push(mapping.get(key.as_usize()).copied().unwrap_or_default())
}
}

let nulls = has_nulls.then(|| {
let mut nulls = BooleanBufferBuilder::new(output_len);
for d in &dictionaries {
match d.nulls() {
Some(n) => nulls.append_buffer(n.inner()),
None => nulls.append_n(d.len(), true),
}
}
NullBuffer::new(nulls.finish())
});

let keys = PrimitiveArray::<K>::new(key_values.into(), nulls);
// Sanity check
assert_eq!(keys.len(), output_len);

let array = unsafe { DictionaryArray::new_unchecked(keys, merged.values) };
Ok(Arc::new(array))
}

macro_rules! dict_helper {
($t:ty, $arrays:expr) => {
return Ok(Arc::new(concat_dictionaries::<$t>($arrays)?) as _)
};
}

/// Concatenate multiple [Array] of the same type into a single [ArrayRef].
pub fn concat(arrays: &[&dyn Array]) -> Result<ArrayRef, ArrowError> {
if arrays.is_empty() {
Expand All @@ -78,9 +131,23 @@ pub fn concat(arrays: &[&dyn Array]) -> Result<ArrayRef, ArrowError> {
DataType::LargeUtf8 => binary_capacity::<LargeUtf8Type>(arrays),
DataType::Binary => binary_capacity::<BinaryType>(arrays),
DataType::LargeBinary => binary_capacity::<LargeBinaryType>(arrays),
DataType::Dictionary(k, _) => downcast_integer! {
k.as_ref() => (dict_helper, arrays),
_ => unreachable!("illegal dictionary key type {k}")
},
_ => Capacities::Array(arrays.iter().map(|a| a.len()).sum()),
};

concat_fallback(arrays, capacity)
}

/// Concatenates arrays using MutableArrayData
///
/// This will naively concatenate dictionaries
fn concat_fallback(
arrays: &[&dyn Array],
capacity: Capacities,
) -> Result<ArrayRef, ArrowError> {
let array_data: Vec<_> = arrays.iter().map(|a| a.to_data()).collect::<Vec<_>>();
let array_data = array_data.iter().collect();
let mut mutable = MutableArrayData::with_capacities(array_data, false, capacity);
Expand Down Expand Up @@ -140,6 +207,7 @@ pub fn concat_batches<'a>(
#[cfg(test)]
mod tests {
use super::*;
use arrow_array::builder::StringDictionaryBuilder;
use arrow_array::cast::AsArray;
use arrow_schema::{Field, Schema};
use std::sync::Arc;
Expand Down Expand Up @@ -468,29 +536,10 @@ mod tests {
}

fn collect_string_dictionary(
dictionary: &DictionaryArray<Int32Type>,
) -> Vec<Option<String>> {
let values = dictionary.values();
let values = values.as_any().downcast_ref::<StringArray>().unwrap();

dictionary
.keys()
.iter()
.map(|key| key.map(|key| values.value(key as _).to_string()))
.collect()
}

fn concat_dictionary(
input_1: DictionaryArray<Int32Type>,
input_2: DictionaryArray<Int32Type>,
) -> Vec<Option<String>> {
let concat = concat(&[&input_1 as _, &input_2 as _]).unwrap();
let concat = concat
.as_any()
.downcast_ref::<DictionaryArray<Int32Type>>()
.unwrap();

collect_string_dictionary(concat)
array: &DictionaryArray<Int32Type>,
) -> Vec<Option<&str>> {
let concrete = array.downcast_dict::<StringArray>().unwrap();
concrete.into_iter().collect()
}

#[test]
Expand All @@ -509,11 +558,19 @@ mod tests {
"E",
]
.into_iter()
.map(|x| Some(x.to_string()))
.map(Some)
.collect();

let concat = concat_dictionary(input_1, input_2);
assert_eq!(concat, expected);
let concat = concat(&[&input_1 as _, &input_2 as _]).unwrap();
let dictionary = concat.as_dictionary::<Int32Type>();
let actual = collect_string_dictionary(dictionary);
assert_eq!(actual, expected);

// Should have concatenated inputs together
assert_eq!(
dictionary.values().len(),
input_1.values().len() + input_2.values().len(),
)
}

#[test]
Expand All @@ -523,16 +580,45 @@ mod tests {
.into_iter()
.collect();
let input_2: DictionaryArray<Int32Type> = vec![None].into_iter().collect();
let expected = vec![
Some("foo".to_string()),
Some("bar".to_string()),
None,
Some("fiz".to_string()),
None,
];
let expected = vec![Some("foo"), Some("bar"), None, Some("fiz"), None];

let concat = concat_dictionary(input_1, input_2);
assert_eq!(concat, expected);
let concat = concat(&[&input_1 as _, &input_2 as _]).unwrap();
let dictionary = concat.as_dictionary::<Int32Type>();
let actual = collect_string_dictionary(dictionary);
assert_eq!(actual, expected);

// Should have concatenated inputs together
assert_eq!(
dictionary.values().len(),
input_1.values().len() + input_2.values().len(),
)
}

#[test]
fn test_string_dictionary_merge() {
let mut builder = StringDictionaryBuilder::<Int32Type>::new();
for i in 0..20 {
builder.append(&i.to_string()).unwrap();
}
let input_1 = builder.finish();

let mut builder = StringDictionaryBuilder::<Int32Type>::new();
for i in 0..30 {
builder.append(&i.to_string()).unwrap();
}
let input_2 = builder.finish();

let expected: Vec<_> = (0..20).chain(0..30).map(|x| x.to_string()).collect();
let expected: Vec<_> = expected.iter().map(|x| Some(x.as_str())).collect();

let concat = concat(&[&input_1 as _, &input_2 as _]).unwrap();
let dictionary = concat.as_dictionary::<Int32Type>();
let actual = collect_string_dictionary(dictionary);
assert_eq!(actual, expected);

// Should have merged inputs together
// Not 30 as this is done on a best-effort basis
assert_eq!(dictionary.values().len(), 33)
}

#[test]
Expand All @@ -556,7 +642,7 @@ mod tests {
fn test_dictionary_concat_reuse() {
let array: DictionaryArray<Int8Type> =
vec!["a", "a", "b", "c"].into_iter().collect();
let copy: DictionaryArray<Int8Type> = array.to_data().into();
let copy: DictionaryArray<Int8Type> = array.clone();

// dictionary is "a", "b", "c"
assert_eq!(
Expand All @@ -567,11 +653,7 @@ mod tests {

// concatenate it with itself
let combined = concat(&[&copy as _, &array as _]).unwrap();

let combined = combined
.as_any()
.downcast_ref::<DictionaryArray<Int8Type>>()
.unwrap();
let combined = combined.as_dictionary::<Int8Type>();

assert_eq!(
combined.values(),
Expand Down Expand Up @@ -738,4 +820,16 @@ mod tests {
assert_eq!(data.buffers()[1].len(), 200);
assert_eq!(data.buffers()[1].capacity(), 256); // Nearest multiple of 64
}

#[test]
fn concat_sparse_nulls() {
let values = StringArray::from_iter_values((0..100).map(|x| x.to_string()));
let keys = Int32Array::from(vec![1; 10]);
let dict_a = DictionaryArray::new(keys, Arc::new(values));
let values = StringArray::new_null(0);
let keys = Int32Array::new_null(10);
let dict_b = DictionaryArray::new(keys, Arc::new(values));
let array = concat(&[&dict_a, &dict_b]).unwrap();
assert_eq!(array.null_count(), 10);
}
}

0 comments on commit 65edbb1

Please sign in to comment.