Skip to content

Commit

Permalink
Change to generic function
Browse files Browse the repository at this point in the history
  • Loading branch information
viirya committed Sep 7, 2022
1 parent 4a9bdbd commit 79229f3
Showing 1 changed file with 32 additions and 40 deletions.
72 changes: 32 additions & 40 deletions arrow/src/array/ord.rs
Original file line number Diff line number Diff line change
Expand Up @@ -126,25 +126,31 @@ where
})
}

macro_rules! cmp_dict_primitive {
($KEY_TYPE:expr, $VALUE_TYPE:ident, $LEFT:ident, $RIGHT:ident) => {
match $KEY_TYPE {
UInt8 => compare_dict_primitive::<UInt8Type, $VALUE_TYPE>($LEFT, $RIGHT),
UInt16 => compare_dict_primitive::<UInt16Type, $VALUE_TYPE>($LEFT, $RIGHT),
UInt32 => compare_dict_primitive::<UInt32Type, $VALUE_TYPE>($LEFT, $RIGHT),
UInt64 => compare_dict_primitive::<UInt64Type, $VALUE_TYPE>($LEFT, $RIGHT),
Int8 => compare_dict_primitive::<Int8Type, $VALUE_TYPE>($LEFT, $RIGHT),
Int16 => compare_dict_primitive::<Int16Type, $VALUE_TYPE>($LEFT, $RIGHT),
Int32 => compare_dict_primitive::<Int32Type, $VALUE_TYPE>($LEFT, $RIGHT),
Int64 => compare_dict_primitive::<Int64Type, $VALUE_TYPE>($LEFT, $RIGHT),
t => {
return Err(ArrowError::InvalidArgumentError(format!(
"Dictionaries do not support keys of type {:?}",
t
)));
}
fn cmp_dict_primitive<VT>(
key_type: &DataType,
left: &dyn Array,
right: &dyn Array,
) -> Result<DynComparator>
where
VT: ArrowPrimitiveType,
VT::Native: Ord,
{
Ok(match key_type {
UInt8 => compare_dict_primitive::<UInt8Type, VT>(left, right),
UInt16 => compare_dict_primitive::<UInt16Type, VT>(left, right),
UInt32 => compare_dict_primitive::<UInt32Type, VT>(left, right),
UInt64 => compare_dict_primitive::<UInt64Type, VT>(left, right),
Int8 => compare_dict_primitive::<Int8Type, VT>(left, right),
Int16 => compare_dict_primitive::<Int16Type, VT>(left, right),
Int32 => compare_dict_primitive::<Int32Type, VT>(left, right),
Int64 => compare_dict_primitive::<Int64Type, VT>(left, right),
t => {
return Err(ArrowError::InvalidArgumentError(format!(
"Dictionaries do not support keys of type {:?}",
t
)));
}
};
})
}

/// returns a comparison function that compares two values at two different positions
Expand Down Expand Up @@ -250,28 +256,14 @@ pub fn build_compare(left: &dyn Array, right: &dyn Array) -> Result<DynComparato
let key_type_lhs = key_type_lhs.as_ref();

match value_type_lhs.as_ref() {
Int8 => cmp_dict_primitive!(key_type_lhs, Int8Type, left, right),
Int16 => {
cmp_dict_primitive!(key_type_lhs, Int16Type, left, right)
}
Int32 => {
cmp_dict_primitive!(key_type_lhs, Int32Type, left, right)
}
Int64 => {
cmp_dict_primitive!(key_type_lhs, Int64Type, left, right)
}
UInt8 => {
cmp_dict_primitive!(key_type_lhs, UInt8Type, left, right)
}
UInt16 => {
cmp_dict_primitive!(key_type_lhs, UInt16Type, left, right)
}
UInt32 => {
cmp_dict_primitive!(key_type_lhs, UInt32Type, left, right)
}
UInt64 => {
cmp_dict_primitive!(key_type_lhs, UInt64Type, left, right)
}
Int8 => cmp_dict_primitive::<Int8Type>(key_type_lhs, left, right)?,
Int16 => cmp_dict_primitive::<Int16Type>(key_type_lhs, left, right)?,
Int32 => cmp_dict_primitive::<Int32Type>(key_type_lhs, left, right)?,
Int64 => cmp_dict_primitive::<Int64Type>(key_type_lhs, left, right)?,
UInt8 => cmp_dict_primitive::<UInt8Type>(key_type_lhs, left, right)?,
UInt16 => cmp_dict_primitive::<UInt16Type>(key_type_lhs, left, right)?,
UInt32 => cmp_dict_primitive::<UInt32Type>(key_type_lhs, left, right)?,
UInt64 => cmp_dict_primitive::<UInt64Type>(key_type_lhs, left, right)?,
Utf8 => match key_type_lhs {
UInt8 => compare_dict_string::<UInt8Type>(left, right),
UInt16 => compare_dict_string::<UInt16Type>(left, right),
Expand Down

0 comments on commit 79229f3

Please sign in to comment.