Skip to content

Commit

Permalink
Compare dictionary with primitive array (#2533)
Browse files Browse the repository at this point in the history
  • Loading branch information
viirya committed Aug 20, 2022
1 parent 4949a3d commit 1d5656e
Showing 1 changed file with 174 additions and 40 deletions.
214 changes: 174 additions & 40 deletions arrow/src/compute/kernels/comparison.rs
Expand Up @@ -41,9 +41,13 @@ use std::collections::HashMap;

/// Helper function to perform boolean lambda function on values from two array accessors, this
/// version does not attempt to use SIMD.
fn compare_op<T: ArrayAccessor, F>(left: T, right: T, op: F) -> Result<BooleanArray>
fn compare_op<T: ArrayAccessor, S: ArrayAccessor, F>(
left: T,
right: S,
op: F,
) -> Result<BooleanArray>
where
F: Fn(T::Item, T::Item) -> bool,
F: Fn(T::Item, S::Item) -> bool,
{
if left.len() != right.len() {
return Err(ArrowError::ComputeError(
Expand Down Expand Up @@ -1861,6 +1865,99 @@ where
compare_op(left_array, right_array, op)
}

macro_rules! typed_dict_non_dict_cmp {
($LEFT: expr, $RIGHT: expr, $LEFT_KEY_TYPE: expr, $RIGHT_TYPE: tt, $OP_BOOL: expr, $OP: expr) => {{
match $LEFT_KEY_TYPE {
DataType::Int8 => {
let left = as_dictionary_array::<Int8Type>($LEFT);
cmp_dict_primitive::<_, $RIGHT_TYPE, _>(left, $RIGHT, $OP)
}
DataType::Int16 => {
let left = as_dictionary_array::<Int16Type>($LEFT);
cmp_dict_primitive::<_, $RIGHT_TYPE, _>(left, $RIGHT, $OP)
}
DataType::Int32 => {
let left = as_dictionary_array::<Int32Type>($LEFT);
cmp_dict_primitive::<_, $RIGHT_TYPE, _>(left, $RIGHT, $OP)
}
DataType::Int64 => {
let left = as_dictionary_array::<Int64Type>($LEFT);
cmp_dict_primitive::<_, $RIGHT_TYPE, _>(left, $RIGHT, $OP)
}
DataType::UInt8 => {
let left = as_dictionary_array::<UInt8Type>($LEFT);
cmp_dict_primitive::<_, $RIGHT_TYPE, _>(left, $RIGHT, $OP)
}
DataType::UInt16 => {
let left = as_dictionary_array::<UInt16Type>($LEFT);
cmp_dict_primitive::<_, $RIGHT_TYPE, _>(left, $RIGHT, $OP)
}
DataType::UInt32 => {
let left = as_dictionary_array::<UInt32Type>($LEFT);
cmp_dict_primitive::<_, $RIGHT_TYPE, _>(left, $RIGHT, $OP)
}
DataType::UInt64 => {
let left = as_dictionary_array::<UInt64Type>($LEFT);
cmp_dict_primitive::<_, $RIGHT_TYPE, _>(left, $RIGHT, $OP)
}
t => Err(ArrowError::NotYetImplemented(format!(
"Cannot compare dictionary array of key type {}",
t
))),
}
}};
}

macro_rules! typed_cmp_dict_non_dict {
($LEFT: expr, $RIGHT: expr, $OP_BOOL: expr, $OP: expr) => {{
match ($LEFT.data_type(), $RIGHT.data_type()) {
(DataType::Dictionary(left_key_type, left_value_type), right_type) => {
match (left_value_type.as_ref(), right_type) {
(DataType::Int8, DataType::Int8) => {
typed_dict_non_dict_cmp!($LEFT, $RIGHT, left_key_type.as_ref(), Int8Type, $OP_BOOL, $OP)
}
(DataType::Int16, DataType::Int16) => {
typed_dict_non_dict_cmp!($LEFT, $RIGHT, left_key_type.as_ref(), Int16Type, $OP_BOOL, $OP)
}
(DataType::Int32, DataType::Int32) => {
typed_dict_non_dict_cmp!($LEFT, $RIGHT, left_key_type.as_ref(), Int32Type, $OP_BOOL, $OP)
}
(DataType::Int64, DataType::Int64) => {
typed_dict_non_dict_cmp!($LEFT, $RIGHT, left_key_type.as_ref(), Int64Type, $OP_BOOL, $OP)
}
(DataType::UInt8, DataType::UInt8) => {
typed_dict_non_dict_cmp!($LEFT, $RIGHT, left_key_type.as_ref(), UInt8Type, $OP_BOOL, $OP)
}
(DataType::UInt16, DataType::UInt16) => {
typed_dict_non_dict_cmp!($LEFT, $RIGHT, left_key_type.as_ref(), UInt16Type, $OP_BOOL, $OP)
}
(DataType::UInt32, DataType::UInt32) => {
typed_dict_non_dict_cmp!($LEFT, $RIGHT, left_key_type.as_ref(), UInt32Type, $OP_BOOL, $OP)
}
(DataType::UInt64, DataType::UInt64) => {
typed_dict_non_dict_cmp!($LEFT, $RIGHT, left_key_type.as_ref(), UInt64Type, $OP_BOOL, $OP)
}
(DataType::Float32, DataType::Float32) => {
typed_dict_non_dict_cmp!($LEFT, $RIGHT, left_key_type.as_ref(), Float32Type, $OP_BOOL, $OP)
}
(DataType::Float64, DataType::Float64) => {
typed_dict_non_dict_cmp!($LEFT, $RIGHT, left_key_type.as_ref(), Float64Type, $OP_BOOL, $OP)
}
(t1, t2) if t1 == t2 => Err(ArrowError::NotYetImplemented(format!(
"Comparing dictionary array of type {} with array of type {} is not yet implemented",
t1, t2
))),
(t1, t2) => Err(ArrowError::CastError(format!(
"Cannot compare dictionary array with array of different value types ({} and {})",
t1, t2
))),
}
}
_ => unreachable!("Should not reach this branch"),
}
}};
}

macro_rules! typed_compares {
($LEFT: expr, $RIGHT: expr, $OP_BOOL: expr, $OP: expr) => {{
match ($LEFT.data_type(), $RIGHT.data_type()) {
Expand Down Expand Up @@ -2173,45 +2270,28 @@ macro_rules! typed_dict_compares {
}};
}

/// Helper function to perform boolean lambda function on values from two dictionary arrays, this
/// version does not attempt to use SIMD explicitly (though the compiler may auto vectorize)
fn compare_dict_op<'a, K, V, F>(
left: TypedDictionaryArray<'a, K, V>,
right: TypedDictionaryArray<'a, K, V>,
/// Perform given operation on `DictionaryArray` and `PrimitiveArray`. The value
/// type of `DictionaryArray` is same as `PrimitiveArray`'s type.
fn cmp_dict_primitive<K, T, F>(
left: &DictionaryArray<K>,
right: &dyn Array,
op: F,
) -> Result<BooleanArray>
where
K: ArrowNumericType,
V: Sync + Send,
&'a V: ArrayAccessor,
F: Fn(<&V as ArrayAccessor>::Item, <&V as ArrayAccessor>::Item) -> bool,
T: ArrowNumericType + Sync + Send,
F: Fn(T::Native, T::Native) -> bool,
{
if left.len() != right.len() {
return Err(ArrowError::ComputeError(
"Cannot perform comparison operation on arrays of different length"
.to_string(),
));
}

let left_iter = left.into_iter();
let right_iter = right.into_iter();

let result = left_iter
.zip(right_iter)
.map(|(left_value, right_value)| {
if let (Some(left), Some(right)) = (left_value, right_value) {
Some(op(left, right))
} else {
None
}
})
.collect();

Ok(result)
compare_op(
left.downcast_dict::<PrimitiveArray<T>>().unwrap(),
as_primitive_array::<T>(right),
op,
)
}

/// Perform given operation on two `DictionaryArray`s.
/// Returns an error if the two arrays have different value type
/// Perform given operation on two `DictionaryArray`s which value type is
/// primitive type. Returns an error if the two arrays have different value
/// type
pub fn cmp_dict<K, T, F>(
left: &DictionaryArray<K>,
right: &DictionaryArray<K>,
Expand All @@ -2222,7 +2302,7 @@ where
T: ArrowNumericType + Sync + Send,
F: Fn(T::Native, T::Native) -> bool,
{
compare_dict_op(
compare_op(
left.downcast_dict::<PrimitiveArray<T>>().unwrap(),
right.downcast_dict::<PrimitiveArray<T>>().unwrap(),
op,
Expand All @@ -2240,7 +2320,7 @@ where
K: ArrowNumericType,
F: Fn(bool, bool) -> bool,
{
compare_dict_op(
compare_op(
left.downcast_dict::<BooleanArray>().unwrap(),
right.downcast_dict::<BooleanArray>().unwrap(),
op,
Expand All @@ -2258,7 +2338,7 @@ where
K: ArrowNumericType,
F: Fn(&str, &str) -> bool,
{
compare_dict_op(
compare_op(
left.downcast_dict::<GenericStringArray<OffsetSize>>()
.unwrap(),
right
Expand All @@ -2279,7 +2359,7 @@ where
K: ArrowNumericType,
F: Fn(&[u8], &[u8]) -> bool,
{
compare_dict_op(
compare_op(
left.downcast_dict::<GenericBinaryArray<OffsetSize>>()
.unwrap(),
right
Expand All @@ -2305,9 +2385,19 @@ where
/// ```
pub fn eq_dyn(left: &dyn Array, right: &dyn Array) -> Result<BooleanArray> {
match left.data_type() {
DataType::Dictionary(_, _) => {
DataType::Dictionary(_, _)
if matches!(right.data_type(), DataType::Dictionary(_, _)) =>
{
typed_dict_compares!(left, right, |a, b| a == b, |a, b| a == b)
}
DataType::Dictionary(_, _)
if !matches!(right.data_type(), DataType::Dictionary(_, _)) =>
{
typed_cmp_dict_non_dict!(left, right, |a, b| a == b, |a, b| a == b)
}
_ if matches!(right.data_type(), DataType::Dictionary(_, _)) => {
typed_cmp_dict_non_dict!(right, left, |a, b| a == b, |a, b| a == b)
}
_ => typed_compares!(left, right, |a, b| !(a ^ b), |a, b| a == b),
}
}
Expand All @@ -2330,9 +2420,19 @@ pub fn eq_dyn(left: &dyn Array, right: &dyn Array) -> Result<BooleanArray> {
/// ```
pub fn neq_dyn(left: &dyn Array, right: &dyn Array) -> Result<BooleanArray> {
match left.data_type() {
DataType::Dictionary(_, _) => {
DataType::Dictionary(_, _)
if matches!(right.data_type(), DataType::Dictionary(_, _)) =>
{
typed_dict_compares!(left, right, |a, b| a != b, |a, b| a != b)
}
DataType::Dictionary(_, _)
if !matches!(right.data_type(), DataType::Dictionary(_, _)) =>
{
typed_cmp_dict_non_dict!(left, right, |a, b| a != b, |a, b| a != b)
}
_ if matches!(right.data_type(), DataType::Dictionary(_, _)) => {
typed_cmp_dict_non_dict!(right, left, |a, b| a != b, |a, b| a != b)
}
_ => typed_compares!(left, right, |a, b| (a ^ b), |a, b| a != b),
}
}
Expand Down Expand Up @@ -5046,4 +5146,38 @@ mod tests {
BooleanArray::from(vec![Some(true), None, Some(false), Some(true)])
);
}

#[test]
fn test_eq_dyn_neq_dyn_dictionary_i8_i8_array() {
let values = Int8Array::from_iter_values([10_i8, 11, 12, 13, 14, 15, 16, 17]);
let keys = Int8Array::from_iter_values([2_i8, 3, 4]);

let dict_array = DictionaryArray::try_new(&keys, &values).unwrap();

let array = Int8Array::from_iter([Some(12_i8), None, Some(14)]);

let result = eq_dyn(&dict_array, &array);
assert_eq!(
result.unwrap(),
BooleanArray::from(vec![Some(true), None, Some(true)])
);

let result = eq_dyn(&array, &dict_array);
assert_eq!(
result.unwrap(),
BooleanArray::from(vec![Some(true), None, Some(true)])
);

let result = neq_dyn(&dict_array, &array);
assert_eq!(
result.unwrap(),
BooleanArray::from(vec![Some(false), None, Some(false)])
);

let result = neq_dyn(&array, &dict_array);
assert_eq!(
result.unwrap(),
BooleanArray::from(vec![Some(false), None, Some(false)])
);
}
}

0 comments on commit 1d5656e

Please sign in to comment.