diff --git a/arrow/src/compute/kernels/comparison.rs b/arrow/src/compute/kernels/comparison.rs index 068b9dedf59..0a6d60cea47 100644 --- a/arrow/src/compute/kernels/comparison.rs +++ b/arrow/src/compute/kernels/comparison.rs @@ -134,7 +134,7 @@ macro_rules! compare_op_primitive { } macro_rules! compare_op_scalar { - ($left:expr, $right:expr, $op:expr) => {{ + ($left:expr, $op:expr) => {{ let null_bit_buffer = $left .data() .null_buffer() @@ -143,7 +143,7 @@ macro_rules! compare_op_scalar { // Safety: // `i < $left.len()` let comparison = - (0..$left.len()).map(|i| unsafe { $op($left.value_unchecked(i), $right) }); + (0..$left.len()).map(|i| unsafe { $op($left.value_unchecked(i)) }); // same as $left.len() let buffer = unsafe { MutableBuffer::from_trusted_len_iter_bool(comparison) }; @@ -777,7 +777,7 @@ pub fn eq_utf8_scalar( left: &GenericStringArray, right: &str, ) -> Result { - compare_op_scalar!(left, right, |a, b| a == b) + compare_op_scalar!(left, |a| a == right) } #[inline] @@ -870,22 +870,22 @@ pub fn eq_bool_scalar(left: &BooleanArray, right: bool) -> Result /// Perform `left < right` operation on [`BooleanArray`] and a scalar pub fn lt_bool_scalar(left: &BooleanArray, right: bool) -> Result { - compare_op_scalar!(left, right, |a: bool, b: bool| !a & b) + compare_op_scalar!(left, |a: bool| !a & right) } /// Perform `left <= right` operation on [`BooleanArray`] and a scalar pub fn lt_eq_bool_scalar(left: &BooleanArray, right: bool) -> Result { - compare_op_scalar!(left, right, |a, b| a <= b) + compare_op_scalar!(left, |a| a <= right) } /// Perform `left > right` operation on [`BooleanArray`] and a scalar pub fn gt_bool_scalar(left: &BooleanArray, right: bool) -> Result { - compare_op_scalar!(left, right, |a: bool, b: bool| a & !b) + compare_op_scalar!(left, |a: bool| a & !right) } /// Perform `left >= right` operation on [`BooleanArray`] and a scalar pub fn gt_eq_bool_scalar(left: &BooleanArray, right: bool) -> Result { - compare_op_scalar!(left, right, |a, b| a >= b) + compare_op_scalar!(left, |a| a >= right) } /// Perform `left != right` operation on [`BooleanArray`] and a scalar @@ -906,7 +906,7 @@ pub fn eq_binary_scalar( left: &GenericBinaryArray, right: &[u8], ) -> Result { - compare_op_scalar!(left, right, |a, b| a == b) + compare_op_scalar!(left, |a| a == right) } /// Perform `left != right` operation on [`BinaryArray`] / [`LargeBinaryArray`]. @@ -922,7 +922,7 @@ pub fn neq_binary_scalar( left: &GenericBinaryArray, right: &[u8], ) -> Result { - compare_op_scalar!(left, right, |a, b| a != b) + compare_op_scalar!(left, |a| a != right) } /// Perform `left < right` operation on [`BinaryArray`] / [`LargeBinaryArray`]. @@ -938,7 +938,7 @@ pub fn lt_binary_scalar( left: &GenericBinaryArray, right: &[u8], ) -> Result { - compare_op_scalar!(left, right, |a, b| a < b) + compare_op_scalar!(left, |a| a < right) } /// Perform `left <= right` operation on [`BinaryArray`] / [`LargeBinaryArray`]. @@ -954,7 +954,7 @@ pub fn lt_eq_binary_scalar( left: &GenericBinaryArray, right: &[u8], ) -> Result { - compare_op_scalar!(left, right, |a, b| a <= b) + compare_op_scalar!(left, |a| a <= right) } /// Perform `left > right` operation on [`BinaryArray`] / [`LargeBinaryArray`]. @@ -970,7 +970,7 @@ pub fn gt_binary_scalar( left: &GenericBinaryArray, right: &[u8], ) -> Result { - compare_op_scalar!(left, right, |a, b| a > b) + compare_op_scalar!(left, |a| a > right) } /// Perform `left >= right` operation on [`BinaryArray`] / [`LargeBinaryArray`]. @@ -986,7 +986,7 @@ pub fn gt_eq_binary_scalar( left: &GenericBinaryArray, right: &[u8], ) -> Result { - compare_op_scalar!(left, right, |a, b| a >= b) + compare_op_scalar!(left, |a| a >= right) } /// Perform `left != right` operation on [`StringArray`] / [`LargeStringArray`]. @@ -1002,7 +1002,7 @@ pub fn neq_utf8_scalar( left: &GenericStringArray, right: &str, ) -> Result { - compare_op_scalar!(left, right, |a, b| a != b) + compare_op_scalar!(left, |a| a != right) } /// Perform `left < right` operation on [`StringArray`] / [`LargeStringArray`]. @@ -1018,7 +1018,7 @@ pub fn lt_utf8_scalar( left: &GenericStringArray, right: &str, ) -> Result { - compare_op_scalar!(left, right, |a, b| a < b) + compare_op_scalar!(left, |a| a < right) } /// Perform `left <= right` operation on [`StringArray`] / [`LargeStringArray`]. @@ -1034,7 +1034,7 @@ pub fn lt_eq_utf8_scalar( left: &GenericStringArray, right: &str, ) -> Result { - compare_op_scalar!(left, right, |a, b| a <= b) + compare_op_scalar!(left, |a| a <= right) } /// Perform `left > right` operation on [`StringArray`] / [`LargeStringArray`]. @@ -1050,7 +1050,7 @@ pub fn gt_utf8_scalar( left: &GenericStringArray, right: &str, ) -> Result { - compare_op_scalar!(left, right, |a, b| a > b) + compare_op_scalar!(left, |a| a > right) } /// Perform `left >= right` operation on [`StringArray`] / [`LargeStringArray`]. @@ -1066,7 +1066,7 @@ pub fn gt_eq_utf8_scalar( left: &GenericStringArray, right: &str, ) -> Result { - compare_op_scalar!(left, right, |a, b| a >= b) + compare_op_scalar!(left, |a| a >= right) } /// Calls $RIGHT.$TY() (e.g. `right.to_i128()`) with a nice error message. @@ -2554,7 +2554,16 @@ where #[cfg(feature = "simd")] return simd_compare_op_scalar(left, right, T::eq, |a, b| a == b); #[cfg(not(feature = "simd"))] - return compare_op_scalar!(left, right, |a, b| a == b); + return compare_op_scalar!(left, |a| a == right); +} + +/// Applies an unary and infallible comparison function to a primitive array. +pub fn unary_cmp(left: &PrimitiveArray, op: F) -> Result +where + T: ArrowNumericType, + F: Fn(T::Native) -> bool, +{ + return compare_op_scalar!(left, op); } /// Perform `left != right` operation on two [`PrimitiveArray`]s. @@ -2576,7 +2585,7 @@ where #[cfg(feature = "simd")] return simd_compare_op_scalar(left, right, T::ne, |a, b| a != b); #[cfg(not(feature = "simd"))] - return compare_op_scalar!(left, right, |a, b| a != b); + return compare_op_scalar!(left, |a| a != right); } /// Perform `left < right` operation on two [`PrimitiveArray`]s. Null values are less than non-null @@ -2600,7 +2609,7 @@ where #[cfg(feature = "simd")] return simd_compare_op_scalar(left, right, T::lt, |a, b| a < b); #[cfg(not(feature = "simd"))] - return compare_op_scalar!(left, right, |a, b| a < b); + return compare_op_scalar!(left, |a| a < right); } /// Perform `left <= right` operation on two [`PrimitiveArray`]s. Null values are less than non-null @@ -2627,7 +2636,7 @@ where #[cfg(feature = "simd")] return simd_compare_op_scalar(left, right, T::le, |a, b| a <= b); #[cfg(not(feature = "simd"))] - return compare_op_scalar!(left, right, |a, b| a <= b); + return compare_op_scalar!(left, |a| a <= right); } /// Perform `left > right` operation on two [`PrimitiveArray`]s. Non-null values are greater than null @@ -2651,7 +2660,7 @@ where #[cfg(feature = "simd")] return simd_compare_op_scalar(left, right, T::gt, |a, b| a > b); #[cfg(not(feature = "simd"))] - return compare_op_scalar!(left, right, |a, b| a > b); + return compare_op_scalar!(left, |a| a > right); } /// Perform `left >= right` operation on two [`PrimitiveArray`]s. Non-null values are greater than null @@ -2678,7 +2687,7 @@ where #[cfg(feature = "simd")] return simd_compare_op_scalar(left, right, T::ge, |a, b| a >= b); #[cfg(not(feature = "simd"))] - return compare_op_scalar!(left, right, |a, b| a >= b); + return compare_op_scalar!(left, |a| a >= right); } /// Checks if a [`GenericListArray`] contains a value in the [`PrimitiveArray`] @@ -5047,4 +5056,16 @@ mod tests { let result = gt_eq_dyn(&dict_array1, &dict_array2); assert_eq!(result.unwrap(), BooleanArray::from(vec![false, true, true])); } + + #[test] + fn test_unary_cmp() { + let a = Int32Array::from(vec![Some(1), None, Some(2), Some(3)]); + let values = vec![1_i32, 3]; + + let a_eq = unary_cmp(&a, |a| values.contains(&a)).unwrap(); + assert_eq!( + a_eq, + BooleanArray::from(vec![Some(true), None, Some(false), Some(true)]) + ); + } }