diff --git a/arrow/src/compute/kernels/comparison.rs b/arrow/src/compute/kernels/comparison.rs index f7b025a0d9b..580e347a91e 100644 --- a/arrow/src/compute/kernels/comparison.rs +++ b/arrow/src/compute/kernels/comparison.rs @@ -41,9 +41,13 @@ use std::collections::HashMap; /// Helper function to perform boolean lambda function on values from two array accessors, this /// version does not attempt to use SIMD. -fn compare_op(left: T, right: T, op: F) -> Result +fn compare_op( + left: T, + right: S, + op: F, +) -> Result where - F: Fn(T::Item, T::Item) -> bool, + F: Fn(T::Item, S::Item) -> bool, { if left.len() != right.len() { return Err(ArrowError::ComputeError( @@ -1861,6 +1865,99 @@ where compare_op(left_array, right_array, op) } +macro_rules! typed_dict_non_dict_cmp { + ($LEFT: expr, $RIGHT: expr, $LEFT_KEY_TYPE: expr, $RIGHT_TYPE: tt, $OP_BOOL: expr, $OP: expr) => {{ + match $LEFT_KEY_TYPE { + DataType::Int8 => { + let left = as_dictionary_array::($LEFT); + cmp_dict_primitive::<_, $RIGHT_TYPE, _>(left, $RIGHT, $OP) + } + DataType::Int16 => { + let left = as_dictionary_array::($LEFT); + cmp_dict_primitive::<_, $RIGHT_TYPE, _>(left, $RIGHT, $OP) + } + DataType::Int32 => { + let left = as_dictionary_array::($LEFT); + cmp_dict_primitive::<_, $RIGHT_TYPE, _>(left, $RIGHT, $OP) + } + DataType::Int64 => { + let left = as_dictionary_array::($LEFT); + cmp_dict_primitive::<_, $RIGHT_TYPE, _>(left, $RIGHT, $OP) + } + DataType::UInt8 => { + let left = as_dictionary_array::($LEFT); + cmp_dict_primitive::<_, $RIGHT_TYPE, _>(left, $RIGHT, $OP) + } + DataType::UInt16 => { + let left = as_dictionary_array::($LEFT); + cmp_dict_primitive::<_, $RIGHT_TYPE, _>(left, $RIGHT, $OP) + } + DataType::UInt32 => { + let left = as_dictionary_array::($LEFT); + cmp_dict_primitive::<_, $RIGHT_TYPE, _>(left, $RIGHT, $OP) + } + DataType::UInt64 => { + let left = as_dictionary_array::($LEFT); + cmp_dict_primitive::<_, $RIGHT_TYPE, _>(left, $RIGHT, $OP) + } + t => Err(ArrowError::NotYetImplemented(format!( + "Cannot compare dictionary array of key type {}", + t + ))), + } + }}; +} + +macro_rules! typed_cmp_dict_non_dict { + ($LEFT: expr, $RIGHT: expr, $OP_BOOL: expr, $OP: expr) => {{ + match ($LEFT.data_type(), $RIGHT.data_type()) { + (DataType::Dictionary(left_key_type, left_value_type), right_type) => { + match (left_value_type.as_ref(), right_type) { + (DataType::Int8, DataType::Int8) => { + typed_dict_non_dict_cmp!($LEFT, $RIGHT, left_key_type.as_ref(), Int8Type, $OP_BOOL, $OP) + } + (DataType::Int16, DataType::Int16) => { + typed_dict_non_dict_cmp!($LEFT, $RIGHT, left_key_type.as_ref(), Int16Type, $OP_BOOL, $OP) + } + (DataType::Int32, DataType::Int32) => { + typed_dict_non_dict_cmp!($LEFT, $RIGHT, left_key_type.as_ref(), Int32Type, $OP_BOOL, $OP) + } + (DataType::Int64, DataType::Int64) => { + typed_dict_non_dict_cmp!($LEFT, $RIGHT, left_key_type.as_ref(), Int64Type, $OP_BOOL, $OP) + } + (DataType::UInt8, DataType::UInt8) => { + typed_dict_non_dict_cmp!($LEFT, $RIGHT, left_key_type.as_ref(), UInt8Type, $OP_BOOL, $OP) + } + (DataType::UInt16, DataType::UInt16) => { + typed_dict_non_dict_cmp!($LEFT, $RIGHT, left_key_type.as_ref(), UInt16Type, $OP_BOOL, $OP) + } + (DataType::UInt32, DataType::UInt32) => { + typed_dict_non_dict_cmp!($LEFT, $RIGHT, left_key_type.as_ref(), UInt32Type, $OP_BOOL, $OP) + } + (DataType::UInt64, DataType::UInt64) => { + typed_dict_non_dict_cmp!($LEFT, $RIGHT, left_key_type.as_ref(), UInt64Type, $OP_BOOL, $OP) + } + (DataType::Float32, DataType::Float32) => { + typed_dict_non_dict_cmp!($LEFT, $RIGHT, left_key_type.as_ref(), Float32Type, $OP_BOOL, $OP) + } + (DataType::Float64, DataType::Float64) => { + typed_dict_non_dict_cmp!($LEFT, $RIGHT, left_key_type.as_ref(), Float64Type, $OP_BOOL, $OP) + } + (t1, t2) if t1 == t2 => Err(ArrowError::NotYetImplemented(format!( + "Comparing dictionary array of type {} with array of type {} is not yet implemented", + t1, t2 + ))), + (t1, t2) => Err(ArrowError::CastError(format!( + "Cannot compare dictionary array with array of different value types ({} and {})", + t1, t2 + ))), + } + } + _ => unreachable!("Should not reach this branch"), + } + }}; +} + macro_rules! typed_compares { ($LEFT: expr, $RIGHT: expr, $OP_BOOL: expr, $OP: expr) => {{ match ($LEFT.data_type(), $RIGHT.data_type()) { @@ -2173,45 +2270,28 @@ macro_rules! typed_dict_compares { }}; } -/// Helper function to perform boolean lambda function on values from two dictionary arrays, this -/// version does not attempt to use SIMD explicitly (though the compiler may auto vectorize) -fn compare_dict_op<'a, K, V, F>( - left: TypedDictionaryArray<'a, K, V>, - right: TypedDictionaryArray<'a, K, V>, +/// Perform given operation on `DictionaryArray` and `PrimitiveArray`. The value +/// type of `DictionaryArray` is same as `PrimitiveArray`'s type. +fn cmp_dict_primitive( + left: &DictionaryArray, + right: &dyn Array, op: F, ) -> Result where K: ArrowNumericType, - V: Sync + Send, - &'a V: ArrayAccessor, - F: Fn(<&V as ArrayAccessor>::Item, <&V as ArrayAccessor>::Item) -> bool, + T: ArrowNumericType + Sync + Send, + F: Fn(T::Native, T::Native) -> bool, { - if left.len() != right.len() { - return Err(ArrowError::ComputeError( - "Cannot perform comparison operation on arrays of different length" - .to_string(), - )); - } - - let left_iter = left.into_iter(); - let right_iter = right.into_iter(); - - let result = left_iter - .zip(right_iter) - .map(|(left_value, right_value)| { - if let (Some(left), Some(right)) = (left_value, right_value) { - Some(op(left, right)) - } else { - None - } - }) - .collect(); - - Ok(result) + compare_op( + left.downcast_dict::>().unwrap(), + as_primitive_array::(right), + op, + ) } -/// Perform given operation on two `DictionaryArray`s. -/// Returns an error if the two arrays have different value type +/// Perform given operation on two `DictionaryArray`s which value type is +/// primitive type. Returns an error if the two arrays have different value +/// type pub fn cmp_dict( left: &DictionaryArray, right: &DictionaryArray, @@ -2222,7 +2302,7 @@ where T: ArrowNumericType + Sync + Send, F: Fn(T::Native, T::Native) -> bool, { - compare_dict_op( + compare_op( left.downcast_dict::>().unwrap(), right.downcast_dict::>().unwrap(), op, @@ -2240,7 +2320,7 @@ where K: ArrowNumericType, F: Fn(bool, bool) -> bool, { - compare_dict_op( + compare_op( left.downcast_dict::().unwrap(), right.downcast_dict::().unwrap(), op, @@ -2258,7 +2338,7 @@ where K: ArrowNumericType, F: Fn(&str, &str) -> bool, { - compare_dict_op( + compare_op( left.downcast_dict::>() .unwrap(), right @@ -2279,7 +2359,7 @@ where K: ArrowNumericType, F: Fn(&[u8], &[u8]) -> bool, { - compare_dict_op( + compare_op( left.downcast_dict::>() .unwrap(), right @@ -2305,9 +2385,19 @@ where /// ``` pub fn eq_dyn(left: &dyn Array, right: &dyn Array) -> Result { match left.data_type() { - DataType::Dictionary(_, _) => { + DataType::Dictionary(_, _) + if matches!(right.data_type(), DataType::Dictionary(_, _)) => + { typed_dict_compares!(left, right, |a, b| a == b, |a, b| a == b) } + DataType::Dictionary(_, _) + if !matches!(right.data_type(), DataType::Dictionary(_, _)) => + { + typed_cmp_dict_non_dict!(left, right, |a, b| a == b, |a, b| a == b) + } + _ if matches!(right.data_type(), DataType::Dictionary(_, _)) => { + typed_cmp_dict_non_dict!(right, left, |a, b| a == b, |a, b| a == b) + } _ => typed_compares!(left, right, |a, b| !(a ^ b), |a, b| a == b), } } @@ -2330,9 +2420,19 @@ pub fn eq_dyn(left: &dyn Array, right: &dyn Array) -> Result { /// ``` pub fn neq_dyn(left: &dyn Array, right: &dyn Array) -> Result { match left.data_type() { - DataType::Dictionary(_, _) => { + DataType::Dictionary(_, _) + if matches!(right.data_type(), DataType::Dictionary(_, _)) => + { typed_dict_compares!(left, right, |a, b| a != b, |a, b| a != b) } + DataType::Dictionary(_, _) + if !matches!(right.data_type(), DataType::Dictionary(_, _)) => + { + typed_cmp_dict_non_dict!(left, right, |a, b| a != b, |a, b| a != b) + } + _ if matches!(right.data_type(), DataType::Dictionary(_, _)) => { + typed_cmp_dict_non_dict!(right, left, |a, b| a != b, |a, b| a != b) + } _ => typed_compares!(left, right, |a, b| (a ^ b), |a, b| a != b), } } @@ -5046,4 +5146,38 @@ mod tests { BooleanArray::from(vec![Some(true), None, Some(false), Some(true)]) ); } + + #[test] + fn test_eq_dyn_neq_dyn_dictionary_i8_i8_array() { + let values = Int8Array::from_iter_values([10_i8, 11, 12, 13, 14, 15, 16, 17]); + let keys = Int8Array::from_iter_values([2_i8, 3, 4]); + + let dict_array = DictionaryArray::try_new(&keys, &values).unwrap(); + + let array = Int8Array::from_iter([Some(12_i8), None, Some(14)]); + + let result = eq_dyn(&dict_array, &array); + assert_eq!( + result.unwrap(), + BooleanArray::from(vec![Some(true), None, Some(true)]) + ); + + let result = eq_dyn(&array, &dict_array); + assert_eq!( + result.unwrap(), + BooleanArray::from(vec![Some(true), None, Some(true)]) + ); + + let result = neq_dyn(&dict_array, &array); + assert_eq!( + result.unwrap(), + BooleanArray::from(vec![Some(false), None, Some(false)]) + ); + + let result = neq_dyn(&array, &dict_array); + assert_eq!( + result.unwrap(), + BooleanArray::from(vec![Some(false), None, Some(false)]) + ); + } }