From 4c1c527e5c56545442d47c2570643ddfa45c044e Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 19 Aug 2022 01:16:13 -0700 Subject: [PATCH] Replace macro with TypedDictionaryArray --- arrow/src/compute/kernels/comparison.rs | 102 +++++++++++++----------- 1 file changed, 57 insertions(+), 45 deletions(-) diff --git a/arrow/src/compute/kernels/comparison.rs b/arrow/src/compute/kernels/comparison.rs index 02e0b65a3e0..9640554e6c1 100644 --- a/arrow/src/compute/kernels/comparison.rs +++ b/arrow/src/compute/kernels/comparison.rs @@ -2154,49 +2154,39 @@ macro_rules! typed_dict_compares { /// Helper function to perform boolean lambda function on values from two dictionary arrays, this /// version does not attempt to use SIMD explicitly (though the compiler may auto vectorize) -macro_rules! compare_dict_op { - ($left: expr, $right:expr, $op:expr, $value_ty:ty) => {{ - if $left.len() != $right.len() { - return Err(ArrowError::ComputeError( - "Cannot perform comparison operation on arrays of different length" - .to_string(), - )); - } - - // Safety justification: Since the inputs are valid Arrow arrays, all values are - // valid indexes into the dictionary (which is verified during construction) - - let left_iter = unsafe { - $left - .values() - .as_any() - .downcast_ref::<$value_ty>() - .unwrap() - .take_iter_unchecked($left.keys_iter()) - }; +fn compare_dict_op<'a, K, V, F>( + left: TypedDictionaryArray<'a, K, V>, + right: TypedDictionaryArray<'a, K, V>, + op: F, +) -> Result +where + K: ArrowNumericType, + V: Sync + Send, + &'a V: ArrayAccessor, + F: Fn(<&V as ArrayAccessor>::Item, <&V as ArrayAccessor>::Item) -> bool, +{ + if left.len() != right.len() { + return Err(ArrowError::ComputeError( + "Cannot perform comparison operation on arrays of different length" + .to_string(), + )); + } - let right_iter = unsafe { - $right - .values() - .as_any() - .downcast_ref::<$value_ty>() - .unwrap() - .take_iter_unchecked($right.keys_iter()) - }; + let left_iter = left.into_iter(); + let right_iter = right.into_iter(); - let result = left_iter - .zip(right_iter) - .map(|(left_value, right_value)| { - if let (Some(left), Some(right)) = (left_value, right_value) { - Some($op(left, right)) - } else { - None - } - }) - .collect(); + let result = left_iter + .zip(right_iter) + .map(|(left_value, right_value)| { + if let (Some(left), Some(right)) = (left_value, right_value) { + Some(op(left, right)) + } else { + None + } + }) + .collect(); - Ok(result) - }}; + Ok(result) } /// Perform given operation on two `DictionaryArray`s. @@ -2208,10 +2198,14 @@ pub fn cmp_dict( ) -> Result where K: ArrowNumericType, - T: ArrowNumericType, + T: ArrowNumericType + Sync + Send, F: Fn(T::Native, T::Native) -> bool, { - compare_dict_op!(left, right, op, PrimitiveArray) + compare_dict_op( + left.downcast_dict::>().unwrap(), + right.downcast_dict::>().unwrap(), + op, + ) } /// Perform the given operation on two `DictionaryArray`s which value type is @@ -2225,7 +2219,11 @@ where K: ArrowNumericType, F: Fn(bool, bool) -> bool, { - compare_dict_op!(left, right, op, BooleanArray) + compare_dict_op( + left.downcast_dict::().unwrap(), + right.downcast_dict::().unwrap(), + op, + ) } /// Perform the given operation on two `DictionaryArray`s which value type is @@ -2239,7 +2237,14 @@ where K: ArrowNumericType, F: Fn(&str, &str) -> bool, { - compare_dict_op!(left, right, op, GenericStringArray) + compare_dict_op( + left.downcast_dict::>() + .unwrap(), + right + .downcast_dict::>() + .unwrap(), + op, + ) } /// Perform the given operation on two `DictionaryArray`s which value type is @@ -2253,7 +2258,14 @@ where K: ArrowNumericType, F: Fn(&[u8], &[u8]) -> bool, { - compare_dict_op!(left, right, op, GenericBinaryArray) + compare_dict_op( + left.downcast_dict::>() + .unwrap(), + right + .downcast_dict::>() + .unwrap(), + op, + ) } /// Perform `left == right` operation on two (dynamic) [`Array`]s.