Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add unary_cmp #1991

Merged
merged 5 commits into from Jul 4, 2022
Merged
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
69 changes: 45 additions & 24 deletions arrow/src/compute/kernels/comparison.rs
Expand Up @@ -134,7 +134,7 @@ macro_rules! compare_op_primitive {
}

macro_rules! compare_op_scalar {
($left:expr, $right:expr, $op:expr) => {{
($left:expr, $op:expr) => {{
let null_bit_buffer = $left
.data()
.null_buffer()
Expand All @@ -143,7 +143,7 @@ macro_rules! compare_op_scalar {
// Safety:
// `i < $left.len()`
let comparison =
(0..$left.len()).map(|i| unsafe { $op($left.value_unchecked(i), $right) });
(0..$left.len()).map(|i| unsafe { $op($left.value_unchecked(i)) });
// same as $left.len()
let buffer = unsafe { MutableBuffer::from_trusted_len_iter_bool(comparison) };

Expand Down Expand Up @@ -777,7 +777,7 @@ pub fn eq_utf8_scalar<OffsetSize: OffsetSizeTrait>(
left: &GenericStringArray<OffsetSize>,
right: &str,
) -> Result<BooleanArray> {
compare_op_scalar!(left, right, |a, b| a == b)
compare_op_scalar!(left, |a| a == right)
}

#[inline]
Expand Down Expand Up @@ -870,22 +870,22 @@ pub fn eq_bool_scalar(left: &BooleanArray, right: bool) -> Result<BooleanArray>

/// Perform `left < right` operation on [`BooleanArray`] and a scalar
pub fn lt_bool_scalar(left: &BooleanArray, right: bool) -> Result<BooleanArray> {
compare_op_scalar!(left, right, |a: bool, b: bool| !a & b)
compare_op_scalar!(left, |a: bool| !a & right)
}

/// Perform `left <= right` operation on [`BooleanArray`] and a scalar
pub fn lt_eq_bool_scalar(left: &BooleanArray, right: bool) -> Result<BooleanArray> {
compare_op_scalar!(left, right, |a, b| a <= b)
compare_op_scalar!(left, |a| a <= right)
}

/// Perform `left > right` operation on [`BooleanArray`] and a scalar
pub fn gt_bool_scalar(left: &BooleanArray, right: bool) -> Result<BooleanArray> {
compare_op_scalar!(left, right, |a: bool, b: bool| a & !b)
compare_op_scalar!(left, |a: bool| a & !right)
}

/// Perform `left >= right` operation on [`BooleanArray`] and a scalar
pub fn gt_eq_bool_scalar(left: &BooleanArray, right: bool) -> Result<BooleanArray> {
compare_op_scalar!(left, right, |a, b| a >= b)
compare_op_scalar!(left, |a| a >= right)
}

/// Perform `left != right` operation on [`BooleanArray`] and a scalar
Expand All @@ -906,7 +906,7 @@ pub fn eq_binary_scalar<OffsetSize: OffsetSizeTrait>(
left: &GenericBinaryArray<OffsetSize>,
right: &[u8],
) -> Result<BooleanArray> {
compare_op_scalar!(left, right, |a, b| a == b)
compare_op_scalar!(left, |a| a == right)
}

/// Perform `left != right` operation on [`BinaryArray`] / [`LargeBinaryArray`].
Expand All @@ -922,7 +922,7 @@ pub fn neq_binary_scalar<OffsetSize: OffsetSizeTrait>(
left: &GenericBinaryArray<OffsetSize>,
right: &[u8],
) -> Result<BooleanArray> {
compare_op_scalar!(left, right, |a, b| a != b)
compare_op_scalar!(left, |a| a != right)
}

/// Perform `left < right` operation on [`BinaryArray`] / [`LargeBinaryArray`].
Expand All @@ -938,7 +938,7 @@ pub fn lt_binary_scalar<OffsetSize: OffsetSizeTrait>(
left: &GenericBinaryArray<OffsetSize>,
right: &[u8],
) -> Result<BooleanArray> {
compare_op_scalar!(left, right, |a, b| a < b)
compare_op_scalar!(left, |a| a < right)
}

/// Perform `left <= right` operation on [`BinaryArray`] / [`LargeBinaryArray`].
Expand All @@ -954,7 +954,7 @@ pub fn lt_eq_binary_scalar<OffsetSize: OffsetSizeTrait>(
left: &GenericBinaryArray<OffsetSize>,
right: &[u8],
) -> Result<BooleanArray> {
compare_op_scalar!(left, right, |a, b| a <= b)
compare_op_scalar!(left, |a| a <= right)
}

/// Perform `left > right` operation on [`BinaryArray`] / [`LargeBinaryArray`].
Expand All @@ -970,7 +970,7 @@ pub fn gt_binary_scalar<OffsetSize: OffsetSizeTrait>(
left: &GenericBinaryArray<OffsetSize>,
right: &[u8],
) -> Result<BooleanArray> {
compare_op_scalar!(left, right, |a, b| a > b)
compare_op_scalar!(left, |a| a > right)
}

/// Perform `left >= right` operation on [`BinaryArray`] / [`LargeBinaryArray`].
Expand All @@ -986,7 +986,7 @@ pub fn gt_eq_binary_scalar<OffsetSize: OffsetSizeTrait>(
left: &GenericBinaryArray<OffsetSize>,
right: &[u8],
) -> Result<BooleanArray> {
compare_op_scalar!(left, right, |a, b| a >= b)
compare_op_scalar!(left, |a| a >= right)
}

/// Perform `left != right` operation on [`StringArray`] / [`LargeStringArray`].
Expand All @@ -1002,7 +1002,7 @@ pub fn neq_utf8_scalar<OffsetSize: OffsetSizeTrait>(
left: &GenericStringArray<OffsetSize>,
right: &str,
) -> Result<BooleanArray> {
compare_op_scalar!(left, right, |a, b| a != b)
compare_op_scalar!(left, |a| a != right)
}

/// Perform `left < right` operation on [`StringArray`] / [`LargeStringArray`].
Expand All @@ -1018,7 +1018,7 @@ pub fn lt_utf8_scalar<OffsetSize: OffsetSizeTrait>(
left: &GenericStringArray<OffsetSize>,
right: &str,
) -> Result<BooleanArray> {
compare_op_scalar!(left, right, |a, b| a < b)
compare_op_scalar!(left, |a| a < right)
}

/// Perform `left <= right` operation on [`StringArray`] / [`LargeStringArray`].
Expand All @@ -1034,7 +1034,7 @@ pub fn lt_eq_utf8_scalar<OffsetSize: OffsetSizeTrait>(
left: &GenericStringArray<OffsetSize>,
right: &str,
) -> Result<BooleanArray> {
compare_op_scalar!(left, right, |a, b| a <= b)
compare_op_scalar!(left, |a| a <= right)
}

/// Perform `left > right` operation on [`StringArray`] / [`LargeStringArray`].
Expand All @@ -1050,7 +1050,7 @@ pub fn gt_utf8_scalar<OffsetSize: OffsetSizeTrait>(
left: &GenericStringArray<OffsetSize>,
right: &str,
) -> Result<BooleanArray> {
compare_op_scalar!(left, right, |a, b| a > b)
compare_op_scalar!(left, |a| a > right)
}

/// Perform `left >= right` operation on [`StringArray`] / [`LargeStringArray`].
Expand All @@ -1066,7 +1066,7 @@ pub fn gt_eq_utf8_scalar<OffsetSize: OffsetSizeTrait>(
left: &GenericStringArray<OffsetSize>,
right: &str,
) -> Result<BooleanArray> {
compare_op_scalar!(left, right, |a, b| a >= b)
compare_op_scalar!(left, |a| a >= right)
}

/// Calls $RIGHT.$TY() (e.g. `right.to_i128()`) with a nice error message.
Expand Down Expand Up @@ -2554,7 +2554,16 @@ where
#[cfg(feature = "simd")]
return simd_compare_op_scalar(left, right, T::eq, |a, b| a == b);
#[cfg(not(feature = "simd"))]
return compare_op_scalar!(left, right, |a, b| a == b);
return compare_op_scalar!(left, |a| a == right);
}

/// Applies an unary and infallible comparison function to a primitive array.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

❤️ thank you @viirya

pub fn unary_cmp<T, F>(left: &PrimitiveArray<T>, op: F) -> Result<BooleanArray>
where
T: ArrowNumericType,
F: Fn(T::Native) -> bool,
{
return compare_op_scalar!(left, op);
}

/// Perform `left != right` operation on two [`PrimitiveArray`]s.
Expand All @@ -2576,7 +2585,7 @@ where
#[cfg(feature = "simd")]
return simd_compare_op_scalar(left, right, T::ne, |a, b| a != b);
#[cfg(not(feature = "simd"))]
return compare_op_scalar!(left, right, |a, b| a != b);
return compare_op_scalar!(left, |a| a != right);
}

/// Perform `left < right` operation on two [`PrimitiveArray`]s. Null values are less than non-null
Expand All @@ -2600,7 +2609,7 @@ where
#[cfg(feature = "simd")]
return simd_compare_op_scalar(left, right, T::lt, |a, b| a < b);
#[cfg(not(feature = "simd"))]
return compare_op_scalar!(left, right, |a, b| a < b);
return compare_op_scalar!(left, |a| a < right);
}

/// Perform `left <= right` operation on two [`PrimitiveArray`]s. Null values are less than non-null
Expand All @@ -2627,7 +2636,7 @@ where
#[cfg(feature = "simd")]
return simd_compare_op_scalar(left, right, T::le, |a, b| a <= b);
#[cfg(not(feature = "simd"))]
return compare_op_scalar!(left, right, |a, b| a <= b);
return compare_op_scalar!(left, |a| a <= right);
}

/// Perform `left > right` operation on two [`PrimitiveArray`]s. Non-null values are greater than null
Expand All @@ -2651,7 +2660,7 @@ where
#[cfg(feature = "simd")]
return simd_compare_op_scalar(left, right, T::gt, |a, b| a > b);
#[cfg(not(feature = "simd"))]
return compare_op_scalar!(left, right, |a, b| a > b);
return compare_op_scalar!(left, |a| a > right);
}

/// Perform `left >= right` operation on two [`PrimitiveArray`]s. Non-null values are greater than null
Expand All @@ -2678,7 +2687,7 @@ where
#[cfg(feature = "simd")]
return simd_compare_op_scalar(left, right, T::ge, |a, b| a >= b);
#[cfg(not(feature = "simd"))]
return compare_op_scalar!(left, right, |a, b| a >= b);
return compare_op_scalar!(left, |a| a >= right);
}

/// Checks if a [`GenericListArray`] contains a value in the [`PrimitiveArray`]
Expand Down Expand Up @@ -5047,4 +5056,16 @@ mod tests {
let result = gt_eq_dyn(&dict_array1, &dict_array2);
assert_eq!(result.unwrap(), BooleanArray::from(vec![false, true, true]));
}

#[test]
fn test_unary_cmp() {
let a = Int32Array::from(vec![Some(1), None, Some(2), Some(3)]);
let values = vec![1_i32, 3];

let a_eq = unary_cmp(&a, |a| values.contains(&a)).unwrap();
assert_eq!(
a_eq,
BooleanArray::from(vec![Some(true), None, Some(false), Some(true)])
);
}
}