Skip to content

Commit

Permalink
Reduce combinatorial fanout
Browse files Browse the repository at this point in the history
  • Loading branch information
viirya committed Sep 9, 2022
1 parent e66d133 commit 35cb87d
Showing 1 changed file with 49 additions and 32 deletions.
81 changes: 49 additions & 32 deletions arrow/src/compute/kernels/sort.rs
Original file line number Diff line number Diff line change
Expand Up @@ -315,14 +315,46 @@ pub fn sort_to_indices(
DataType::Dictionary(_, _) => {
downcast_dictionary_array!(
values => match values.values().data_type() {
DataType::Int8 => sort_primitive_dictionary::<_, Int8Type, _>(values, v, n, &options, limit, cmp),
DataType::Int16 => sort_primitive_dictionary::<_, Int16Type, _>(values, v, n, &options, limit, cmp),
DataType::Int32 => sort_primitive_dictionary::<_, Int32Type, _>(values, v, n, &options, limit, cmp),
DataType::Int64 => sort_primitive_dictionary::<_, Int64Type, _>(values, v, n, &options, limit, cmp),
DataType::UInt8 => sort_primitive_dictionary::<_, UInt8Type, _>(values, v, n, &options, limit, cmp),
DataType::UInt16 => sort_primitive_dictionary::<_, UInt16Type, _>(values, v, n, &options, limit, cmp),
DataType::UInt32 => sort_primitive_dictionary::<_, UInt32Type, _>(values, v, n, &options, limit, cmp),
DataType::UInt64 => sort_primitive_dictionary::<_, UInt64Type, _>(values, v, n, &options, limit, cmp),
DataType::Int8 => {
let dict_values = values.values();
let sorted_value_indices = sort_to_indices(dict_values, Some(SortOptions::default()), limit.clone())?;
sort_primitive_dictionary::<_, _>(values, &sorted_value_indices, v, n, &options, limit, cmp)
},
DataType::Int16 => {
let dict_values = values.values();
let sorted_value_indices = sort_to_indices(dict_values, Some(SortOptions::default()), limit.clone())?;
sort_primitive_dictionary::<_, _>(values, &sorted_value_indices, v, n, &options, limit, cmp)
},
DataType::Int32 => {
let dict_values = values.values();
let sorted_value_indices = sort_to_indices(dict_values, Some(SortOptions::default()), limit.clone())?;
sort_primitive_dictionary::<_, _>(values, &sorted_value_indices, v, n, &options, limit, cmp)
},
DataType::Int64 => {
let dict_values = values.values();
let sorted_value_indices = sort_to_indices(dict_values, Some(SortOptions::default()), limit.clone())?;
sort_primitive_dictionary::<_, _>(values, &sorted_value_indices,v, n, &options, limit, cmp)
},
DataType::UInt8 => {
let dict_values = values.values();
let sorted_value_indices = sort_to_indices(dict_values, Some(SortOptions::default()), limit.clone())?;
sort_primitive_dictionary::<_, _>(values, &sorted_value_indices,v, n, &options, limit, cmp)
},
DataType::UInt16 => {
let dict_values = values.values();
let sorted_value_indices = sort_to_indices(dict_values, Some(SortOptions::default()), limit.clone())?;
sort_primitive_dictionary::<_, _>(values, &sorted_value_indices,v, n, &options, limit, cmp)
},
DataType::UInt32 => {
let dict_values = values.values();
let sorted_value_indices = sort_to_indices(dict_values, Some(SortOptions::default()), limit.clone())?;
sort_primitive_dictionary::<_, _>(values, &sorted_value_indices,v, n, &options, limit, cmp)
},
DataType::UInt64 => {
let dict_values = values.values();
let sorted_value_indices = sort_to_indices(dict_values, Some(SortOptions::default()), limit.clone())?;
sort_primitive_dictionary::<_, _>(values, &sorted_value_indices, v, n, &options, limit, cmp)
},
DataType::Utf8 => sort_string_dictionary::<_>(values, v, n, &options, limit),
t => return Err(ArrowError::ComputeError(format!(
"Unsupported dictionary value type {}", t
Expand Down Expand Up @@ -511,8 +543,9 @@ where
}

/// Sort dictionary encoded primitive values
fn sort_primitive_dictionary<K, T, F>(
fn sort_primitive_dictionary<K, F>(
values: &DictionaryArray<K>,
sorted_value_indices: &UInt32Array,
value_indices: Vec<u32>,
null_indices: Vec<u32>,
options: &SortOptions,
Expand All @@ -521,36 +554,20 @@ fn sort_primitive_dictionary<K, T, F>(
) -> UInt32Array
where
K: ArrowDictionaryKeyType,
T: ArrowPrimitiveType,
T::Native: std::cmp::PartialOrd,
F: Fn(T::Native, T::Native) -> std::cmp::Ordering,
{
let valids = valids_for_sort_primitive_dictionary::<K, T>(values, value_indices);

let keys: &PrimitiveArray<K> = values.keys();
sort_primitive_inner::<_, _>(keys.len(), null_indices, cmp, options, limit, valids)
}

fn valids_for_sort_primitive_dictionary<K, T>(
values: &DictionaryArray<K>,
value_indices: Vec<u32>,
) -> Vec<(u32, T::Native)>
where
K: ArrowDictionaryKeyType,
T: ArrowPrimitiveType,
T::Native: std::cmp::PartialOrd,
F: Fn(u32, u32) -> std::cmp::Ordering,
{
let keys: &PrimitiveArray<K> = values.keys();
let dict = values.values();
let values = as_primitive_array::<T>(dict);

value_indices
// create tuples that are used for sorting
let valids = value_indices
.into_iter()
.map(|index| {
let key: K::Native = keys.value(index as usize);
(index, values.value(key.to_usize().unwrap()))
(index, sorted_value_indices.value(key.to_usize().unwrap()))
})
.collect::<Vec<(u32, T::Native)>>()
.collect::<Vec<(u32, u32)>>();

sort_primitive_inner::<_, _>(keys.len(), null_indices, cmp, options, limit, valids)
}

// sort is instantiated a lot so we only compile this inner version for each native type
Expand Down

0 comments on commit 35cb87d

Please sign in to comment.