From 3084ee258122910cc491d85a8bf9729b7bed95dc Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sat, 12 Nov 2022 23:39:24 -0800 Subject: [PATCH] Use ArrowNativeTypeOp instead of total_cmp directly (#3087) --- arrow/src/compute/kernels/comparison.rs | 112 +++++++----------------- 1 file changed, 32 insertions(+), 80 deletions(-) diff --git a/arrow/src/compute/kernels/comparison.rs b/arrow/src/compute/kernels/comparison.rs index 9d89287eebf..a286eedd190 100644 --- a/arrow/src/compute/kernels/comparison.rs +++ b/arrow/src/compute/kernels/comparison.rs @@ -2748,30 +2748,22 @@ pub fn eq_dyn(left: &dyn Array, right: &dyn Array) -> Result { DataType::Dictionary(_, _) if matches!(right.data_type(), DataType::Dictionary(_, _)) => { - typed_dict_compares!( - left, - right, - |a, b| a == b, - |a, b| a.total_cmp(&b).is_eq(), - |a, b| a == b - ) + typed_dict_compares!(left, right, |a, b| a == b, |a, b| a.is_eq(b), |a, b| a + == b) } DataType::Dictionary(_, _) if !matches!(right.data_type(), DataType::Dictionary(_, _)) => { typed_cmp_dict_non_dict!(left, right, |a, b| a == b, |a, b| a == b, |a, b| a - .total_cmp(&b) - .is_eq()) + .is_eq(b)) } _ if matches!(right.data_type(), DataType::Dictionary(_, _)) => { - typed_cmp_dict_non_dict!(right, left, |a, b| a == b, |a, b| a == b, |a, b| a - .total_cmp(&b) - .is_eq()) + typed_cmp_dict_non_dict!(right, left, |a, b| a == b, |a, b| a == b, |a, b| b + .is_eq(a)) } _ => { typed_compares!(left, right, |a, b| !(a ^ b), |a, b| a == b, |a, b| a - .total_cmp(&b) - .is_eq()) + .is_eq(b)) } } } @@ -2801,30 +2793,22 @@ pub fn neq_dyn(left: &dyn Array, right: &dyn Array) -> Result { DataType::Dictionary(_, _) if matches!(right.data_type(), DataType::Dictionary(_, _)) => { - typed_dict_compares!( - left, - right, - |a, b| a != b, - |a, b| a.total_cmp(&b).is_ne(), - |a, b| a != b - ) + typed_dict_compares!(left, right, |a, b| a != b, |a, b| a.is_ne(b), |a, b| a + != b) } DataType::Dictionary(_, _) if !matches!(right.data_type(), DataType::Dictionary(_, _)) => { typed_cmp_dict_non_dict!(left, right, |a, b| a != b, |a, b| a != b, |a, b| a - .total_cmp(&b) - .is_ne()) + .is_ne(b)) } _ if matches!(right.data_type(), DataType::Dictionary(_, _)) => { - typed_cmp_dict_non_dict!(right, left, |a, b| a != b, |a, b| a != b, |a, b| a - .total_cmp(&b) - .is_ne()) + typed_cmp_dict_non_dict!(right, left, |a, b| a != b, |a, b| a != b, |a, b| b + .is_ne(a)) } _ => { typed_compares!(left, right, |a, b| (a ^ b), |a, b| a != b, |a, b| a - .total_cmp(&b) - .is_ne()) + .is_ne(b)) } } } @@ -2854,30 +2838,22 @@ pub fn lt_dyn(left: &dyn Array, right: &dyn Array) -> Result { DataType::Dictionary(_, _) if matches!(right.data_type(), DataType::Dictionary(_, _)) => { - typed_dict_compares!( - left, - right, - |a, b| a < b, - |a, b| a.total_cmp(&b).is_lt(), - |a, b| a < b - ) + typed_dict_compares!(left, right, |a, b| a < b, |a, b| a.is_lt(b), |a, b| a + < b) } DataType::Dictionary(_, _) if !matches!(right.data_type(), DataType::Dictionary(_, _)) => { typed_cmp_dict_non_dict!(left, right, |a, b| a < b, |a, b| a < b, |a, b| a - .total_cmp(&b) - .is_lt()) + .is_lt(b)) } _ if matches!(right.data_type(), DataType::Dictionary(_, _)) => { typed_cmp_dict_non_dict!(right, left, |a, b| a > b, |a, b| a > b, |a, b| b - .total_cmp(&a) - .is_lt()) + .is_lt(a)) } _ => { typed_compares!(left, right, |a, b| ((!a) & b), |a, b| a < b, |a, b| a - .total_cmp(&b) - .is_lt()) + .is_lt(b)) } } } @@ -2906,30 +2882,22 @@ pub fn lt_eq_dyn(left: &dyn Array, right: &dyn Array) -> Result { DataType::Dictionary(_, _) if matches!(right.data_type(), DataType::Dictionary(_, _)) => { - typed_dict_compares!( - left, - right, - |a, b| a <= b, - |a, b| a.total_cmp(&b).is_le(), - |a, b| a <= b - ) + typed_dict_compares!(left, right, |a, b| a <= b, |a, b| a.is_le(b), |a, b| a + <= b) } DataType::Dictionary(_, _) if !matches!(right.data_type(), DataType::Dictionary(_, _)) => { typed_cmp_dict_non_dict!(left, right, |a, b| a <= b, |a, b| a <= b, |a, b| a - .total_cmp(&b) - .is_le()) + .is_le(b)) } _ if matches!(right.data_type(), DataType::Dictionary(_, _)) => { typed_cmp_dict_non_dict!(right, left, |a, b| a >= b, |a, b| a >= b, |a, b| b - .total_cmp(&a) - .is_le()) + .is_le(a)) } _ => { typed_compares!(left, right, |a, b| !(a & (!b)), |a, b| a <= b, |a, b| a - .total_cmp(&b) - .is_le()) + .is_le(b)) } } } @@ -2958,30 +2926,22 @@ pub fn gt_dyn(left: &dyn Array, right: &dyn Array) -> Result { DataType::Dictionary(_, _) if matches!(right.data_type(), DataType::Dictionary(_, _)) => { - typed_dict_compares!( - left, - right, - |a, b| a > b, - |a, b| a.total_cmp(&b).is_gt(), - |a, b| a > b - ) + typed_dict_compares!(left, right, |a, b| a > b, |a, b| a.is_gt(b), |a, b| a + > b) } DataType::Dictionary(_, _) if !matches!(right.data_type(), DataType::Dictionary(_, _)) => { typed_cmp_dict_non_dict!(left, right, |a, b| a > b, |a, b| a > b, |a, b| a - .total_cmp(&b) - .is_gt()) + .is_gt(b)) } _ if matches!(right.data_type(), DataType::Dictionary(_, _)) => { typed_cmp_dict_non_dict!(right, left, |a, b| a < b, |a, b| a < b, |a, b| b - .total_cmp(&a) - .is_gt()) + .is_gt(a)) } _ => { typed_compares!(left, right, |a, b| (a & (!b)), |a, b| a > b, |a, b| a - .total_cmp(&b) - .is_gt()) + .is_gt(b)) } } } @@ -3009,30 +2969,22 @@ pub fn gt_eq_dyn(left: &dyn Array, right: &dyn Array) -> Result { DataType::Dictionary(_, _) if matches!(right.data_type(), DataType::Dictionary(_, _)) => { - typed_dict_compares!( - left, - right, - |a, b| a >= b, - |a, b| a.total_cmp(&b).is_ge(), - |a, b| a >= b - ) + typed_dict_compares!(left, right, |a, b| a >= b, |a, b| a.is_ge(b), |a, b| a + >= b) } DataType::Dictionary(_, _) if !matches!(right.data_type(), DataType::Dictionary(_, _)) => { typed_cmp_dict_non_dict!(left, right, |a, b| a >= b, |a, b| a >= b, |a, b| a - .total_cmp(&b) - .is_ge()) + .is_ge(b)) } _ if matches!(right.data_type(), DataType::Dictionary(_, _)) => { typed_cmp_dict_non_dict!(right, left, |a, b| a <= b, |a, b| a <= b, |a, b| b - .total_cmp(&a) - .is_ge()) + .is_ge(a)) } _ => { typed_compares!(left, right, |a, b| !((!a) & b), |a, b| a >= b, |a, b| a - .total_cmp(&b) - .is_ge()) + .is_ge(b)) } } }