diff --git a/datafusion/physical-expr/src/expressions/binary.rs b/datafusion/physical-expr/src/expressions/binary.rs index c7aa7f014254..0bb4fa514e77 100644 --- a/datafusion/physical-expr/src/expressions/binary.rs +++ b/datafusion/physical-expr/src/expressions/binary.rs @@ -61,13 +61,11 @@ use kernels::{ }; use kernels_arrow::{ add_decimal, add_decimal_scalar, divide_decimal_scalar, divide_opt_decimal, - eq_decimal_scalar, gt_decimal_scalar, gt_eq_decimal_scalar, is_distinct_from, - is_distinct_from_bool, is_distinct_from_decimal, is_distinct_from_null, - is_distinct_from_utf8, is_not_distinct_from, is_not_distinct_from_bool, - is_not_distinct_from_decimal, is_not_distinct_from_null, is_not_distinct_from_utf8, - lt_decimal_scalar, lt_eq_decimal_scalar, modulus_decimal, modulus_decimal_scalar, - multiply_decimal, multiply_decimal_scalar, neq_decimal_scalar, subtract_decimal, - subtract_decimal_scalar, + is_distinct_from, is_distinct_from_bool, is_distinct_from_decimal, + is_distinct_from_null, is_distinct_from_utf8, is_not_distinct_from, + is_not_distinct_from_bool, is_not_distinct_from_decimal, is_not_distinct_from_null, + is_not_distinct_from_utf8, modulus_decimal, modulus_decimal_scalar, multiply_decimal, + multiply_decimal_scalar, subtract_decimal, subtract_decimal_scalar, }; use arrow::datatypes::{DataType, Schema, TimeUnit}; @@ -124,11 +122,9 @@ impl std::fmt::Display for BinaryExpr { macro_rules! compute_decimal_op_dyn_scalar { ($LEFT:expr, $RIGHT:expr, $OP:ident, $OP_TYPE:expr) => {{ let ll = as_decimal128_array($LEFT).unwrap(); - if let ScalarValue::Decimal128(Some(_), _, _) = $RIGHT { - Ok(Arc::new(paste::expr! {[<$OP _decimal_scalar>]}( - ll, - $RIGHT.try_into()?, - )?)) + if let ScalarValue::Decimal128(Some(v_i128), _, _) = $RIGHT { + // Ok(Arc::new(paste::expr! {[<$OP _decimal_scalar>]}( + Ok(Arc::new(paste::expr! {[<$OP _dyn_scalar>]}(ll, v_i128)?)) } else { // when the $RIGHT is a NULL, generate a NULL array of $OP_TYPE type Ok(Arc::new(new_null_array($OP_TYPE, $LEFT.len()))) @@ -2304,6 +2300,82 @@ mod tests { #[test] fn comparison_decimal_expr_test() -> Result<()> { + // scalar of decimal compare with decimal array + let value_i128 = 123; + let decimal_scalar = ScalarValue::Decimal128(Some(value_i128), 25, 3); + let schema = Arc::new(Schema::new(vec![Field::new( + "a", + DataType::Decimal128(25, 3), + true, + )])); + let decimal_array = Arc::new(create_decimal_array( + &[ + Some(value_i128), + None, + Some(value_i128 - 1), + Some(value_i128 + 1), + ], + 25, + 3, + )) as ArrayRef; + // array = scalar + apply_logic_op_arr_scalar( + &schema, + &decimal_array, + &decimal_scalar, + Operator::Eq, + &BooleanArray::from(vec![Some(true), None, Some(false), Some(false)]), + ) + .unwrap(); + // array != scalar + apply_logic_op_arr_scalar( + &schema, + &decimal_array, + &decimal_scalar, + Operator::NotEq, + &BooleanArray::from(vec![Some(false), None, Some(true), Some(true)]), + ) + .unwrap(); + // array < scalar + apply_logic_op_arr_scalar( + &schema, + &decimal_array, + &decimal_scalar, + Operator::Lt, + &BooleanArray::from(vec![Some(false), None, Some(true), Some(false)]), + ) + .unwrap(); + + // array <= scalar + apply_logic_op_arr_scalar( + &schema, + &decimal_array, + &decimal_scalar, + Operator::LtEq, + &BooleanArray::from(vec![Some(true), None, Some(true), Some(false)]), + ) + .unwrap(); + // array > scalar + apply_logic_op_arr_scalar( + &schema, + &decimal_array, + &decimal_scalar, + Operator::Gt, + &BooleanArray::from(vec![Some(false), None, Some(false), Some(true)]), + ) + .unwrap(); + + // array >= scalar + apply_logic_op_arr_scalar( + &schema, + &decimal_array, + &decimal_scalar, + Operator::GtEq, + &BooleanArray::from(vec![Some(true), None, Some(false), Some(true)]), + ) + .unwrap(); + + // scalar of different data type with decimal array let decimal_scalar = ScalarValue::Decimal128(Some(123_456), 10, 3); let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int64, true)])); // scalar == array diff --git a/datafusion/physical-expr/src/expressions/binary/kernels_arrow.rs b/datafusion/physical-expr/src/expressions/binary/kernels_arrow.rs index 1d86d171f37e..d1c927284f4b 100644 --- a/datafusion/physical-expr/src/expressions/binary/kernels_arrow.rs +++ b/datafusion/physical-expr/src/expressions/binary/kernels_arrow.rs @@ -118,67 +118,6 @@ pub(crate) fn is_not_distinct_from_utf8( .collect()) } -// TODO move decimal kernels to to arrow-rs -// https://github.com/apache/arrow-rs/issues/1200 - -/// Creates an BooleanArray the same size as `left`, -/// applying `op` to all non-null elements of left -pub(crate) fn compare_decimal_scalar( - left: &Decimal128Array, - right: i128, - op: F, -) -> Result -where - F: Fn(i128, i128) -> bool, -{ - Ok(left - .iter() - .map(|left| left.map(|left| op(left, right))) - .collect()) -} - -pub(crate) fn eq_decimal_scalar( - left: &Decimal128Array, - right: i128, -) -> Result { - compare_decimal_scalar(left, right, |left, right| left == right) -} - -pub(crate) fn neq_decimal_scalar( - left: &Decimal128Array, - right: i128, -) -> Result { - compare_decimal_scalar(left, right, |left, right| left != right) -} - -pub(crate) fn lt_decimal_scalar( - left: &Decimal128Array, - right: i128, -) -> Result { - compare_decimal_scalar(left, right, |left, right| left < right) -} - -pub(crate) fn lt_eq_decimal_scalar( - left: &Decimal128Array, - right: i128, -) -> Result { - compare_decimal_scalar(left, right, |left, right| left <= right) -} - -pub(crate) fn gt_decimal_scalar( - left: &Decimal128Array, - right: i128, -) -> Result { - compare_decimal_scalar(left, right, |left, right| left > right) -} - -pub(crate) fn gt_eq_decimal_scalar( - left: &Decimal128Array, - right: i128, -) -> Result { - compare_decimal_scalar(left, right, |left, right| left >= right) -} - pub(crate) fn is_distinct_from_decimal( left: &Decimal128Array, right: &Decimal128Array, @@ -403,43 +342,6 @@ mod tests { 25, 3, ); - // eq: array = i128 - let result = eq_decimal_scalar(&decimal_array, value_i128)?; - assert_eq!( - BooleanArray::from(vec![Some(true), None, Some(false), Some(false)]), - result - ); - // neq: array != i128 - let result = neq_decimal_scalar(&decimal_array, value_i128)?; - assert_eq!( - BooleanArray::from(vec![Some(false), None, Some(true), Some(true)]), - result - ); - // lt: array < i128 - let result = lt_decimal_scalar(&decimal_array, value_i128)?; - assert_eq!( - BooleanArray::from(vec![Some(false), None, Some(true), Some(false)]), - result - ); - // lt_eq: array <= i128 - let result = lt_eq_decimal_scalar(&decimal_array, value_i128)?; - assert_eq!( - BooleanArray::from(vec![Some(true), None, Some(true), Some(false)]), - result - ); - // gt: array > i128 - let result = gt_decimal_scalar(&decimal_array, value_i128)?; - assert_eq!( - BooleanArray::from(vec![Some(false), None, Some(false), Some(true)]), - result - ); - // gt_eq: array >= i128 - let result = gt_eq_decimal_scalar(&decimal_array, value_i128)?; - assert_eq!( - BooleanArray::from(vec![Some(true), None, Some(false), Some(true)]), - result - ); - let left_decimal_array = decimal_array; let right_decimal_array = create_decimal_array( &[