-
Notifications
You must be signed in to change notification settings - Fork 657
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support sorting dictionary encoded primitive integer arrays #2680
Changes from 1 commit
e66d133
3697abc
286d9a9
6a778cf
93708c0
2ed13a4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -21,6 +21,7 @@ use crate::array::*; | |
use crate::buffer::MutableBuffer; | ||
use crate::compute::take; | ||
use crate::datatypes::*; | ||
use crate::downcast_dictionary_array; | ||
use crate::error::{ArrowError, Result}; | ||
use std::cmp::Ordering; | ||
use TimeUnit::*; | ||
|
@@ -311,41 +312,26 @@ pub fn sort_to_indices( | |
))); | ||
} | ||
}, | ||
DataType::Dictionary(key_type, value_type) | ||
if *value_type.as_ref() == DataType::Utf8 => | ||
{ | ||
match key_type.as_ref() { | ||
DataType::Int8 => { | ||
sort_string_dictionary::<Int8Type>(values, v, n, &options, limit) | ||
} | ||
DataType::Int16 => { | ||
sort_string_dictionary::<Int16Type>(values, v, n, &options, limit) | ||
} | ||
DataType::Int32 => { | ||
sort_string_dictionary::<Int32Type>(values, v, n, &options, limit) | ||
} | ||
DataType::Int64 => { | ||
sort_string_dictionary::<Int64Type>(values, v, n, &options, limit) | ||
} | ||
DataType::UInt8 => { | ||
sort_string_dictionary::<UInt8Type>(values, v, n, &options, limit) | ||
} | ||
DataType::UInt16 => { | ||
sort_string_dictionary::<UInt16Type>(values, v, n, &options, limit) | ||
} | ||
DataType::UInt32 => { | ||
sort_string_dictionary::<UInt32Type>(values, v, n, &options, limit) | ||
} | ||
DataType::UInt64 => { | ||
sort_string_dictionary::<UInt64Type>(values, v, n, &options, limit) | ||
} | ||
t => { | ||
return Err(ArrowError::ComputeError(format!( | ||
"Sort not supported for dictionary key type {:?}", | ||
t | ||
))); | ||
} | ||
} | ||
DataType::Dictionary(_, _) => { | ||
downcast_dictionary_array!( | ||
values => match values.values().data_type() { | ||
DataType::Int8 => sort_primitive_dictionary::<_, Int8Type, _>(values, v, n, &options, limit, cmp), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For other primitive types, we just need to add other patterns and call I will add them in separate PRs. |
||
DataType::Int16 => sort_primitive_dictionary::<_, Int16Type, _>(values, v, n, &options, limit, cmp), | ||
DataType::Int32 => sort_primitive_dictionary::<_, Int32Type, _>(values, v, n, &options, limit, cmp), | ||
DataType::Int64 => sort_primitive_dictionary::<_, Int64Type, _>(values, v, n, &options, limit, cmp), | ||
DataType::UInt8 => sort_primitive_dictionary::<_, UInt8Type, _>(values, v, n, &options, limit, cmp), | ||
DataType::UInt16 => sort_primitive_dictionary::<_, UInt16Type, _>(values, v, n, &options, limit, cmp), | ||
DataType::UInt32 => sort_primitive_dictionary::<_, UInt32Type, _>(values, v, n, &options, limit, cmp), | ||
DataType::UInt64 => sort_primitive_dictionary::<_, UInt64Type, _>(values, v, n, &options, limit, cmp), | ||
DataType::Utf8 => sort_string_dictionary::<_>(values, v, n, &options, limit), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think |
||
t => return Err(ArrowError::ComputeError(format!( | ||
"Unsupported dictionary value type {}", t | ||
))), | ||
}, | ||
t => return Err(ArrowError::ComputeError(format!( | ||
"Unsupported datatype {}", t | ||
))), | ||
) | ||
} | ||
DataType::Binary | DataType::FixedSizeBinary(_) => { | ||
sort_binary::<i32>(values, v, n, &options, limit) | ||
|
@@ -489,7 +475,14 @@ where | |
.into_iter() | ||
.map(|index| (index, decimal_array.value(index as usize).as_i128())) | ||
.collect::<Vec<(u32, i128)>>(); | ||
sort_primitive_inner(decimal_values, null_indices, cmp, options, limit, valids) | ||
sort_primitive_inner( | ||
decimal_values.len(), | ||
null_indices, | ||
cmp, | ||
options, | ||
limit, | ||
valids, | ||
) | ||
} | ||
|
||
/// Sort primitive values | ||
|
@@ -514,12 +507,55 @@ where | |
.map(|index| (index, values.value(index as usize))) | ||
.collect::<Vec<(u32, T::Native)>>() | ||
}; | ||
sort_primitive_inner(values, null_indices, cmp, options, limit, valids) | ||
sort_primitive_inner(values.len(), null_indices, cmp, options, limit, valids) | ||
} | ||
|
||
/// Sort dictionary encoded primitive values | ||
fn sort_primitive_dictionary<K, T, F>( | ||
values: &DictionaryArray<K>, | ||
value_indices: Vec<u32>, | ||
null_indices: Vec<u32>, | ||
options: &SortOptions, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. FWIW |
||
limit: Option<usize>, | ||
cmp: F, | ||
) -> UInt32Array | ||
where | ||
K: ArrowDictionaryKeyType, | ||
T: ArrowPrimitiveType, | ||
T::Native: std::cmp::PartialOrd, | ||
F: Fn(T::Native, T::Native) -> std::cmp::Ordering, | ||
{ | ||
let valids = valids_for_sort_primitive_dictionary::<K, T>(values, value_indices); | ||
|
||
let keys: &PrimitiveArray<K> = values.keys(); | ||
sort_primitive_inner::<_, _>(keys.len(), null_indices, cmp, options, limit, valids) | ||
} | ||
|
||
fn valids_for_sort_primitive_dictionary<K, T>( | ||
values: &DictionaryArray<K>, | ||
value_indices: Vec<u32>, | ||
) -> Vec<(u32, T::Native)> | ||
where | ||
K: ArrowDictionaryKeyType, | ||
T: ArrowPrimitiveType, | ||
T::Native: std::cmp::PartialOrd, | ||
{ | ||
let keys: &PrimitiveArray<K> = values.keys(); | ||
let dict = values.values(); | ||
let values = as_primitive_array::<T>(dict); | ||
|
||
value_indices | ||
.into_iter() | ||
.map(|index| { | ||
let key: K::Native = keys.value(index as usize); | ||
(index, values.value(key.to_usize().unwrap())) | ||
}) | ||
.collect::<Vec<(u32, T::Native)>>() | ||
} | ||
|
||
// sort is instantiated a lot so we only compile this inner version for each native type | ||
fn sort_primitive_inner<T, F>( | ||
values: &ArrayRef, | ||
value_len: usize, | ||
null_indices: Vec<u32>, | ||
cmp: F, | ||
options: &SortOptions, | ||
|
@@ -535,7 +571,7 @@ where | |
|
||
let valids_len = valids.len(); | ||
let nulls_len = nulls.len(); | ||
let mut len = values.len(); | ||
let mut len = value_len; | ||
|
||
if let Some(limit) = limit { | ||
len = limit.min(len); | ||
|
@@ -620,14 +656,12 @@ fn sort_string<Offset: OffsetSizeTrait>( | |
|
||
/// Sort dictionary encoded strings | ||
fn sort_string_dictionary<T: ArrowDictionaryKeyType>( | ||
values: &ArrayRef, | ||
values: &DictionaryArray<T>, | ||
value_indices: Vec<u32>, | ||
null_indices: Vec<u32>, | ||
options: &SortOptions, | ||
limit: Option<usize>, | ||
) -> UInt32Array { | ||
let values: &DictionaryArray<T> = as_dictionary_array::<T>(values); | ||
|
||
let keys: &PrimitiveArray<T> = values.keys(); | ||
|
||
let dict = values.values(); | ||
|
@@ -1239,6 +1273,58 @@ mod tests { | |
assert_eq!(sorted_strings, expected) | ||
} | ||
|
||
fn test_sort_primitive_dict_arrays<K: ArrowDictionaryKeyType, T: ArrowPrimitiveType>( | ||
keys: PrimitiveArray<K>, | ||
values: PrimitiveArray<T>, | ||
options: Option<SortOptions>, | ||
limit: Option<usize>, | ||
expected_data: Vec<Option<T::Native>>, | ||
) where | ||
PrimitiveArray<T>: From<Vec<Option<T::Native>>>, | ||
{ | ||
let array = DictionaryArray::<K>::try_new(&keys, &values).unwrap(); | ||
let array_values = array.values().clone(); | ||
let dict = array_values | ||
.as_any() | ||
.downcast_ref::<PrimitiveArray<T>>() | ||
.expect("Unable to get dictionary values"); | ||
|
||
let sorted = match limit { | ||
Some(_) => { | ||
sort_limit(&(Arc::new(array) as ArrayRef), options, limit).unwrap() | ||
} | ||
_ => sort(&(Arc::new(array) as ArrayRef), options).unwrap(), | ||
}; | ||
let sorted = sorted | ||
.as_any() | ||
.downcast_ref::<DictionaryArray<K>>() | ||
.unwrap(); | ||
let sorted_values = sorted.values(); | ||
let sorted_dict = sorted_values | ||
.as_any() | ||
.downcast_ref::<PrimitiveArray<T>>() | ||
.expect("Unable to get dictionary values"); | ||
let sorted_keys = sorted.keys(); | ||
|
||
assert_eq!(sorted_dict, dict); | ||
|
||
let sorted_values: PrimitiveArray<T> = From::<Vec<Option<T::Native>>>::from( | ||
(0..sorted.len()) | ||
.map(|i| { | ||
if sorted.is_valid(i) { | ||
Some(sorted_dict.value(sorted_keys.value(i).to_usize().unwrap())) | ||
} else { | ||
None | ||
} | ||
}) | ||
.collect::<Vec<Option<T::Native>>>(), | ||
); | ||
let expected: PrimitiveArray<T> = | ||
From::<Vec<Option<T::Native>>>::from(expected_data); | ||
|
||
assert_eq!(sorted_values, expected) | ||
} | ||
|
||
fn test_sort_list_arrays<T>( | ||
data: Vec<Option<Vec<Option<T::Native>>>>, | ||
options: Option<SortOptions>, | ||
|
@@ -3222,4 +3308,60 @@ mod tests { | |
partial_sort(&mut before, last, |a, b| a.cmp(b)); | ||
assert_eq!(&d[0..last], &before[0..last]); | ||
} | ||
|
||
#[test] | ||
fn test_sort_int8_dicts() { | ||
let keys = | ||
Int8Array::from(vec![Some(1_i8), None, Some(2), None, Some(2), Some(0)]); | ||
let values = Int8Array::from(vec![1, 3, 5]); | ||
test_sort_primitive_dict_arrays::<Int8Type, Int8Type>( | ||
keys, | ||
values, | ||
None, | ||
None, | ||
vec![None, None, Some(1), Some(3), Some(5), Some(5)], | ||
); | ||
|
||
let keys = | ||
Int8Array::from(vec![Some(1_i8), None, Some(2), None, Some(2), Some(0)]); | ||
let values = Int8Array::from(vec![1, 3, 5]); | ||
test_sort_primitive_dict_arrays::<Int8Type, Int8Type>( | ||
keys, | ||
values, | ||
Some(SortOptions { | ||
descending: true, | ||
nulls_first: false, | ||
}), | ||
None, | ||
vec![Some(5), Some(5), Some(3), Some(1), None, None], | ||
); | ||
|
||
let keys = | ||
Int8Array::from(vec![Some(1_i8), None, Some(2), None, Some(2), Some(0)]); | ||
let values = Int8Array::from(vec![1, 3, 5]); | ||
test_sort_primitive_dict_arrays::<Int8Type, Int8Type>( | ||
keys, | ||
values, | ||
Some(SortOptions { | ||
descending: false, | ||
nulls_first: false, | ||
}), | ||
None, | ||
vec![Some(1), Some(3), Some(5), Some(5), None, None], | ||
); | ||
|
||
let keys = | ||
Int8Array::from(vec![Some(1_i8), None, Some(2), None, Some(2), Some(0)]); | ||
let values = Int8Array::from(vec![1, 3, 5]); | ||
test_sort_primitive_dict_arrays::<Int8Type, Int8Type>( | ||
keys, | ||
values, | ||
Some(SortOptions { | ||
descending: false, | ||
nulls_first: true, | ||
}), | ||
Some(3), | ||
vec![None, None, Some(1)], | ||
); | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NGL I worry a bit that this has the combinatorial fanout that absolutely tanks compile times... Perhaps we could compute the sort order of the dictionary values and then use this to compare the keys? This might even be faster
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, I agree that sorting the dictionary values and using it to compare the keys might be faster. But does it help on combinatorial fanout from dictionary? For example, in order to sort on the dictionary values, we still need get the value array from dictionary array so a
downcast_dictionary_array!
is also necessary.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think so, you could call sort_to_indices on the values array, which is only typed on the values type, and then to compare dictionary values you compare the indices you just computed, which is only typed on the dictionary key type. You therefore avoid the fanout?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So
downcast_dictionary_array
is not the one you concern butdowncast_dictionary_array
+sort_primitive_dictionary
which is typed on both key and value types?Then it makes sense to me. I can refactor this and split sorting dictionary to sorting dictionary values and sorting keys based on the computed indices.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes it is the combinatorial fanout that is especially painful, and given it is avoidable I think it makes sense to do so. It's a case of every little helps 😅