Skip to content

Commit

Permalink
repalce the compare kernel for decimal dyn op
Browse files Browse the repository at this point in the history
  • Loading branch information
liukun4515 committed Dec 1, 2022
1 parent d9e58db commit 6b93b8d
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 110 deletions.
96 changes: 84 additions & 12 deletions datafusion/physical-expr/src/expressions/binary.rs
Expand Up @@ -61,13 +61,11 @@ use kernels::{
};
use kernels_arrow::{
add_decimal, add_decimal_scalar, divide_decimal_scalar, divide_opt_decimal,
eq_decimal_scalar, gt_decimal_scalar, gt_eq_decimal_scalar, is_distinct_from,
is_distinct_from_bool, is_distinct_from_decimal, is_distinct_from_null,
is_distinct_from_utf8, is_not_distinct_from, is_not_distinct_from_bool,
is_not_distinct_from_decimal, is_not_distinct_from_null, is_not_distinct_from_utf8,
lt_decimal_scalar, lt_eq_decimal_scalar, modulus_decimal, modulus_decimal_scalar,
multiply_decimal, multiply_decimal_scalar, neq_decimal_scalar, subtract_decimal,
subtract_decimal_scalar,
is_distinct_from, is_distinct_from_bool, is_distinct_from_decimal,
is_distinct_from_null, is_distinct_from_utf8, is_not_distinct_from,
is_not_distinct_from_bool, is_not_distinct_from_decimal, is_not_distinct_from_null,
is_not_distinct_from_utf8, modulus_decimal, modulus_decimal_scalar, multiply_decimal,
multiply_decimal_scalar, subtract_decimal, subtract_decimal_scalar,
};

use arrow::datatypes::{DataType, Schema, TimeUnit};
Expand Down Expand Up @@ -124,11 +122,9 @@ impl std::fmt::Display for BinaryExpr {
macro_rules! compute_decimal_op_dyn_scalar {
($LEFT:expr, $RIGHT:expr, $OP:ident, $OP_TYPE:expr) => {{
let ll = as_decimal128_array($LEFT).unwrap();
if let ScalarValue::Decimal128(Some(_), _, _) = $RIGHT {
Ok(Arc::new(paste::expr! {[<$OP _decimal_scalar>]}(
ll,
$RIGHT.try_into()?,
)?))
if let ScalarValue::Decimal128(Some(v_i128), _, _) = $RIGHT {
// Ok(Arc::new(paste::expr! {[<$OP _decimal_scalar>]}(
Ok(Arc::new(paste::expr! {[<$OP _dyn_scalar>]}(ll, v_i128)?))
} else {
// when the $RIGHT is a NULL, generate a NULL array of $OP_TYPE type
Ok(Arc::new(new_null_array($OP_TYPE, $LEFT.len())))
Expand Down Expand Up @@ -2304,6 +2300,82 @@ mod tests {

#[test]
fn comparison_decimal_expr_test() -> Result<()> {
// scalar of decimal compare with decimal array
let value_i128 = 123;
let decimal_scalar = ScalarValue::Decimal128(Some(value_i128), 25, 3);
let schema = Arc::new(Schema::new(vec![Field::new(
"a",
DataType::Decimal128(25, 3),
true,
)]));
let decimal_array = Arc::new(create_decimal_array(
&[
Some(value_i128),
None,
Some(value_i128 - 1),
Some(value_i128 + 1),
],
25,
3,
)) as ArrayRef;
// array = scalar
apply_logic_op_arr_scalar(
&schema,
&decimal_array,
&decimal_scalar,
Operator::Eq,
&BooleanArray::from(vec![Some(true), None, Some(false), Some(false)]),
)
.unwrap();
// array != scalar
apply_logic_op_arr_scalar(
&schema,
&decimal_array,
&decimal_scalar,
Operator::NotEq,
&BooleanArray::from(vec![Some(false), None, Some(true), Some(true)]),
)
.unwrap();
// array < scalar
apply_logic_op_arr_scalar(
&schema,
&decimal_array,
&decimal_scalar,
Operator::Lt,
&BooleanArray::from(vec![Some(false), None, Some(true), Some(false)]),
)
.unwrap();

// array <= scalar
apply_logic_op_arr_scalar(
&schema,
&decimal_array,
&decimal_scalar,
Operator::LtEq,
&BooleanArray::from(vec![Some(true), None, Some(true), Some(false)]),
)
.unwrap();
// array > scalar
apply_logic_op_arr_scalar(
&schema,
&decimal_array,
&decimal_scalar,
Operator::Gt,
&BooleanArray::from(vec![Some(false), None, Some(false), Some(true)]),
)
.unwrap();

// array >= scalar
apply_logic_op_arr_scalar(
&schema,
&decimal_array,
&decimal_scalar,
Operator::GtEq,
&BooleanArray::from(vec![Some(true), None, Some(false), Some(true)]),
)
.unwrap();

// scalar of different data type with decimal array
let decimal_scalar = ScalarValue::Decimal128(Some(123_456), 10, 3);
let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int64, true)]));
// scalar == array
Expand Down
98 changes: 0 additions & 98 deletions datafusion/physical-expr/src/expressions/binary/kernels_arrow.rs
Expand Up @@ -118,67 +118,6 @@ pub(crate) fn is_not_distinct_from_utf8<OffsetSize: OffsetSizeTrait>(
.collect())
}

// TODO move decimal kernels to to arrow-rs
// https://github.com/apache/arrow-rs/issues/1200

/// Creates an BooleanArray the same size as `left`,
/// applying `op` to all non-null elements of left
pub(crate) fn compare_decimal_scalar<F>(
left: &Decimal128Array,
right: i128,
op: F,
) -> Result<BooleanArray>
where
F: Fn(i128, i128) -> bool,
{
Ok(left
.iter()
.map(|left| left.map(|left| op(left, right)))
.collect())
}

pub(crate) fn eq_decimal_scalar(
left: &Decimal128Array,
right: i128,
) -> Result<BooleanArray> {
compare_decimal_scalar(left, right, |left, right| left == right)
}

pub(crate) fn neq_decimal_scalar(
left: &Decimal128Array,
right: i128,
) -> Result<BooleanArray> {
compare_decimal_scalar(left, right, |left, right| left != right)
}

pub(crate) fn lt_decimal_scalar(
left: &Decimal128Array,
right: i128,
) -> Result<BooleanArray> {
compare_decimal_scalar(left, right, |left, right| left < right)
}

pub(crate) fn lt_eq_decimal_scalar(
left: &Decimal128Array,
right: i128,
) -> Result<BooleanArray> {
compare_decimal_scalar(left, right, |left, right| left <= right)
}

pub(crate) fn gt_decimal_scalar(
left: &Decimal128Array,
right: i128,
) -> Result<BooleanArray> {
compare_decimal_scalar(left, right, |left, right| left > right)
}

pub(crate) fn gt_eq_decimal_scalar(
left: &Decimal128Array,
right: i128,
) -> Result<BooleanArray> {
compare_decimal_scalar(left, right, |left, right| left >= right)
}

pub(crate) fn is_distinct_from_decimal(
left: &Decimal128Array,
right: &Decimal128Array,
Expand Down Expand Up @@ -403,43 +342,6 @@ mod tests {
25,
3,
);
// eq: array = i128
let result = eq_decimal_scalar(&decimal_array, value_i128)?;
assert_eq!(
BooleanArray::from(vec![Some(true), None, Some(false), Some(false)]),
result
);
// neq: array != i128
let result = neq_decimal_scalar(&decimal_array, value_i128)?;
assert_eq!(
BooleanArray::from(vec![Some(false), None, Some(true), Some(true)]),
result
);
// lt: array < i128
let result = lt_decimal_scalar(&decimal_array, value_i128)?;
assert_eq!(
BooleanArray::from(vec![Some(false), None, Some(true), Some(false)]),
result
);
// lt_eq: array <= i128
let result = lt_eq_decimal_scalar(&decimal_array, value_i128)?;
assert_eq!(
BooleanArray::from(vec![Some(true), None, Some(true), Some(false)]),
result
);
// gt: array > i128
let result = gt_decimal_scalar(&decimal_array, value_i128)?;
assert_eq!(
BooleanArray::from(vec![Some(false), None, Some(false), Some(true)]),
result
);
// gt_eq: array >= i128
let result = gt_eq_decimal_scalar(&decimal_array, value_i128)?;
assert_eq!(
BooleanArray::from(vec![Some(true), None, Some(false), Some(true)]),
result
);

let left_decimal_array = decimal_array;
let right_decimal_array = create_decimal_array(
&[
Expand Down

0 comments on commit 6b93b8d

Please sign in to comment.