Skip to content

Commit

Permalink
Add unary_cmp (#1991)
Browse files Browse the repository at this point in the history
* Add unary_cmp

* Fix clippy

* Trigger Build

* Trigger Build

* Trigger Build
  • Loading branch information
viirya committed Jul 4, 2022
1 parent e436041 commit 932ffc5
Showing 1 changed file with 45 additions and 24 deletions.
69 changes: 45 additions & 24 deletions arrow/src/compute/kernels/comparison.rs
Expand Up @@ -134,7 +134,7 @@ macro_rules! compare_op_primitive {
}

macro_rules! compare_op_scalar {
($left:expr, $right:expr, $op:expr) => {{
($left:expr, $op:expr) => {{
let null_bit_buffer = $left
.data()
.null_buffer()
Expand All @@ -143,7 +143,7 @@ macro_rules! compare_op_scalar {
// Safety:
// `i < $left.len()`
let comparison =
(0..$left.len()).map(|i| unsafe { $op($left.value_unchecked(i), $right) });
(0..$left.len()).map(|i| unsafe { $op($left.value_unchecked(i)) });
// same as $left.len()
let buffer = unsafe { MutableBuffer::from_trusted_len_iter_bool(comparison) };

Expand Down Expand Up @@ -777,7 +777,7 @@ pub fn eq_utf8_scalar<OffsetSize: OffsetSizeTrait>(
left: &GenericStringArray<OffsetSize>,
right: &str,
) -> Result<BooleanArray> {
compare_op_scalar!(left, right, |a, b| a == b)
compare_op_scalar!(left, |a| a == right)
}

#[inline]
Expand Down Expand Up @@ -870,22 +870,22 @@ pub fn eq_bool_scalar(left: &BooleanArray, right: bool) -> Result<BooleanArray>

/// Perform `left < right` operation on [`BooleanArray`] and a scalar
pub fn lt_bool_scalar(left: &BooleanArray, right: bool) -> Result<BooleanArray> {
compare_op_scalar!(left, right, |a: bool, b: bool| !a & b)
compare_op_scalar!(left, |a: bool| !a & right)
}

/// Perform `left <= right` operation on [`BooleanArray`] and a scalar
pub fn lt_eq_bool_scalar(left: &BooleanArray, right: bool) -> Result<BooleanArray> {
compare_op_scalar!(left, right, |a, b| a <= b)
compare_op_scalar!(left, |a| a <= right)
}

/// Perform `left > right` operation on [`BooleanArray`] and a scalar
pub fn gt_bool_scalar(left: &BooleanArray, right: bool) -> Result<BooleanArray> {
compare_op_scalar!(left, right, |a: bool, b: bool| a & !b)
compare_op_scalar!(left, |a: bool| a & !right)
}

/// Perform `left >= right` operation on [`BooleanArray`] and a scalar
pub fn gt_eq_bool_scalar(left: &BooleanArray, right: bool) -> Result<BooleanArray> {
compare_op_scalar!(left, right, |a, b| a >= b)
compare_op_scalar!(left, |a| a >= right)
}

/// Perform `left != right` operation on [`BooleanArray`] and a scalar
Expand All @@ -906,7 +906,7 @@ pub fn eq_binary_scalar<OffsetSize: OffsetSizeTrait>(
left: &GenericBinaryArray<OffsetSize>,
right: &[u8],
) -> Result<BooleanArray> {
compare_op_scalar!(left, right, |a, b| a == b)
compare_op_scalar!(left, |a| a == right)
}

/// Perform `left != right` operation on [`BinaryArray`] / [`LargeBinaryArray`].
Expand All @@ -922,7 +922,7 @@ pub fn neq_binary_scalar<OffsetSize: OffsetSizeTrait>(
left: &GenericBinaryArray<OffsetSize>,
right: &[u8],
) -> Result<BooleanArray> {
compare_op_scalar!(left, right, |a, b| a != b)
compare_op_scalar!(left, |a| a != right)
}

/// Perform `left < right` operation on [`BinaryArray`] / [`LargeBinaryArray`].
Expand All @@ -938,7 +938,7 @@ pub fn lt_binary_scalar<OffsetSize: OffsetSizeTrait>(
left: &GenericBinaryArray<OffsetSize>,
right: &[u8],
) -> Result<BooleanArray> {
compare_op_scalar!(left, right, |a, b| a < b)
compare_op_scalar!(left, |a| a < right)
}

/// Perform `left <= right` operation on [`BinaryArray`] / [`LargeBinaryArray`].
Expand All @@ -954,7 +954,7 @@ pub fn lt_eq_binary_scalar<OffsetSize: OffsetSizeTrait>(
left: &GenericBinaryArray<OffsetSize>,
right: &[u8],
) -> Result<BooleanArray> {
compare_op_scalar!(left, right, |a, b| a <= b)
compare_op_scalar!(left, |a| a <= right)
}

/// Perform `left > right` operation on [`BinaryArray`] / [`LargeBinaryArray`].
Expand All @@ -970,7 +970,7 @@ pub fn gt_binary_scalar<OffsetSize: OffsetSizeTrait>(
left: &GenericBinaryArray<OffsetSize>,
right: &[u8],
) -> Result<BooleanArray> {
compare_op_scalar!(left, right, |a, b| a > b)
compare_op_scalar!(left, |a| a > right)
}

/// Perform `left >= right` operation on [`BinaryArray`] / [`LargeBinaryArray`].
Expand All @@ -986,7 +986,7 @@ pub fn gt_eq_binary_scalar<OffsetSize: OffsetSizeTrait>(
left: &GenericBinaryArray<OffsetSize>,
right: &[u8],
) -> Result<BooleanArray> {
compare_op_scalar!(left, right, |a, b| a >= b)
compare_op_scalar!(left, |a| a >= right)
}

/// Perform `left != right` operation on [`StringArray`] / [`LargeStringArray`].
Expand All @@ -1002,7 +1002,7 @@ pub fn neq_utf8_scalar<OffsetSize: OffsetSizeTrait>(
left: &GenericStringArray<OffsetSize>,
right: &str,
) -> Result<BooleanArray> {
compare_op_scalar!(left, right, |a, b| a != b)
compare_op_scalar!(left, |a| a != right)
}

/// Perform `left < right` operation on [`StringArray`] / [`LargeStringArray`].
Expand All @@ -1018,7 +1018,7 @@ pub fn lt_utf8_scalar<OffsetSize: OffsetSizeTrait>(
left: &GenericStringArray<OffsetSize>,
right: &str,
) -> Result<BooleanArray> {
compare_op_scalar!(left, right, |a, b| a < b)
compare_op_scalar!(left, |a| a < right)
}

/// Perform `left <= right` operation on [`StringArray`] / [`LargeStringArray`].
Expand All @@ -1034,7 +1034,7 @@ pub fn lt_eq_utf8_scalar<OffsetSize: OffsetSizeTrait>(
left: &GenericStringArray<OffsetSize>,
right: &str,
) -> Result<BooleanArray> {
compare_op_scalar!(left, right, |a, b| a <= b)
compare_op_scalar!(left, |a| a <= right)
}

/// Perform `left > right` operation on [`StringArray`] / [`LargeStringArray`].
Expand All @@ -1050,7 +1050,7 @@ pub fn gt_utf8_scalar<OffsetSize: OffsetSizeTrait>(
left: &GenericStringArray<OffsetSize>,
right: &str,
) -> Result<BooleanArray> {
compare_op_scalar!(left, right, |a, b| a > b)
compare_op_scalar!(left, |a| a > right)
}

/// Perform `left >= right` operation on [`StringArray`] / [`LargeStringArray`].
Expand All @@ -1066,7 +1066,7 @@ pub fn gt_eq_utf8_scalar<OffsetSize: OffsetSizeTrait>(
left: &GenericStringArray<OffsetSize>,
right: &str,
) -> Result<BooleanArray> {
compare_op_scalar!(left, right, |a, b| a >= b)
compare_op_scalar!(left, |a| a >= right)
}

/// Calls $RIGHT.$TY() (e.g. `right.to_i128()`) with a nice error message.
Expand Down Expand Up @@ -2554,7 +2554,16 @@ where
#[cfg(feature = "simd")]
return simd_compare_op_scalar(left, right, T::eq, |a, b| a == b);
#[cfg(not(feature = "simd"))]
return compare_op_scalar!(left, right, |a, b| a == b);
return compare_op_scalar!(left, |a| a == right);
}

/// Applies an unary and infallible comparison function to a primitive array.
pub fn unary_cmp<T, F>(left: &PrimitiveArray<T>, op: F) -> Result<BooleanArray>
where
T: ArrowNumericType,
F: Fn(T::Native) -> bool,
{
return compare_op_scalar!(left, op);
}

/// Perform `left != right` operation on two [`PrimitiveArray`]s.
Expand All @@ -2576,7 +2585,7 @@ where
#[cfg(feature = "simd")]
return simd_compare_op_scalar(left, right, T::ne, |a, b| a != b);
#[cfg(not(feature = "simd"))]
return compare_op_scalar!(left, right, |a, b| a != b);
return compare_op_scalar!(left, |a| a != right);
}

/// Perform `left < right` operation on two [`PrimitiveArray`]s. Null values are less than non-null
Expand All @@ -2600,7 +2609,7 @@ where
#[cfg(feature = "simd")]
return simd_compare_op_scalar(left, right, T::lt, |a, b| a < b);
#[cfg(not(feature = "simd"))]
return compare_op_scalar!(left, right, |a, b| a < b);
return compare_op_scalar!(left, |a| a < right);
}

/// Perform `left <= right` operation on two [`PrimitiveArray`]s. Null values are less than non-null
Expand All @@ -2627,7 +2636,7 @@ where
#[cfg(feature = "simd")]
return simd_compare_op_scalar(left, right, T::le, |a, b| a <= b);
#[cfg(not(feature = "simd"))]
return compare_op_scalar!(left, right, |a, b| a <= b);
return compare_op_scalar!(left, |a| a <= right);
}

/// Perform `left > right` operation on two [`PrimitiveArray`]s. Non-null values are greater than null
Expand All @@ -2651,7 +2660,7 @@ where
#[cfg(feature = "simd")]
return simd_compare_op_scalar(left, right, T::gt, |a, b| a > b);
#[cfg(not(feature = "simd"))]
return compare_op_scalar!(left, right, |a, b| a > b);
return compare_op_scalar!(left, |a| a > right);
}

/// Perform `left >= right` operation on two [`PrimitiveArray`]s. Non-null values are greater than null
Expand All @@ -2678,7 +2687,7 @@ where
#[cfg(feature = "simd")]
return simd_compare_op_scalar(left, right, T::ge, |a, b| a >= b);
#[cfg(not(feature = "simd"))]
return compare_op_scalar!(left, right, |a, b| a >= b);
return compare_op_scalar!(left, |a| a >= right);
}

/// Checks if a [`GenericListArray`] contains a value in the [`PrimitiveArray`]
Expand Down Expand Up @@ -5047,4 +5056,16 @@ mod tests {
let result = gt_eq_dyn(&dict_array1, &dict_array2);
assert_eq!(result.unwrap(), BooleanArray::from(vec![false, true, true]));
}

#[test]
fn test_unary_cmp() {
let a = Int32Array::from(vec![Some(1), None, Some(2), Some(3)]);
let values = vec![1_i32, 3];

let a_eq = unary_cmp(&a, |a| values.contains(&a)).unwrap();
assert_eq!(
a_eq,
BooleanArray::from(vec![Some(true), None, Some(false), Some(true)])
);
}
}

0 comments on commit 932ffc5

Please sign in to comment.