Skip to content

Commit

Permalink
Compare dictionary decimal arrays (#2982)
Browse files Browse the repository at this point in the history
* Compare dictionary decimal arrays

* Use wildcard import
  • Loading branch information
viirya committed Oct 31, 2022
1 parent 3c1f323 commit 99e205f
Showing 1 changed file with 85 additions and 10 deletions.
95 changes: 85 additions & 10 deletions arrow/src/compute/kernels/comparison.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,16 +26,7 @@
use crate::array::*;
use crate::buffer::{buffer_unary_not, Buffer, MutableBuffer};
use crate::compute::util::combine_option_bitmap;
#[allow(unused_imports)]
use crate::datatypes::{
ArrowNativeTypeOp, ArrowNumericType, DataType, Date32Type, Date64Type,
Decimal128Type, Decimal256Type, Float32Type, Float64Type, Int16Type, Int32Type,
Int64Type, Int8Type, IntervalDayTimeType, IntervalMonthDayNanoType, IntervalUnit,
IntervalYearMonthType, Time32MillisecondType, Time32SecondType,
Time64MicrosecondType, Time64NanosecondType, TimeUnit, TimestampMicrosecondType,
TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, UInt16Type,
UInt32Type, UInt64Type, UInt8Type,
};
use crate::datatypes::*;
#[allow(unused_imports)]
use crate::downcast_dictionary_array;
use crate::error::{ArrowError, Result};
Expand Down Expand Up @@ -2388,6 +2379,12 @@ macro_rules! typed_dict_cmp {
(DataType::Float64, DataType::Float64) => {
cmp_dict::<$KT, Float64Type, _>($LEFT, $RIGHT, $OP_FLOAT)
}
(DataType::Decimal128(_, s1), DataType::Decimal128(_, s2)) if s1 == s2 => {
cmp_dict::<$KT, Decimal128Type, _>($LEFT, $RIGHT, $OP)
}
(DataType::Decimal256(_, s1), DataType::Decimal256(_, s2)) if s1 == s2 => {
cmp_dict::<$KT, Decimal256Type, _>($LEFT, $RIGHT, $OP)
}
(DataType::Utf8, DataType::Utf8) => {
cmp_dict_utf8::<$KT, i32, _>($LEFT, $RIGHT, $OP)
}
Expand Down Expand Up @@ -6660,6 +6657,43 @@ mod tests {
);
}

#[test]
#[cfg(feature = "dyn_cmp_dict")]
fn test_cmp_dict_decimal128() {
let values = Decimal128Array::from_iter_values([0, 1, 2, 3, 4, 5]);
let keys = Int8Array::from_iter_values([1_i8, 2, 5, 4, 3, 0]);
let array1 = DictionaryArray::try_new(&keys, &values).unwrap();

let values = Decimal128Array::from_iter_values([7, -3, 4, 3, 5]);
let keys = Int8Array::from_iter_values([0_i8, 0, 1, 2, 3, 4]);
let array2 = DictionaryArray::try_new(&keys, &values).unwrap();

let expected = BooleanArray::from(
vec![Some(false), Some(false), Some(false), Some(true), Some(true), Some(false)],
);
assert_eq!(eq_dyn(&array1, &array2).unwrap(), expected);

let expected = BooleanArray::from(
vec![Some(true), Some(true), Some(false), Some(false), Some(false), Some(true)],
);
assert_eq!(lt_dyn(&array1, &array2).unwrap(), expected);

let expected = BooleanArray::from(
vec![Some(true), Some(true), Some(false), Some(true), Some(true), Some(true)],
);
assert_eq!(lt_eq_dyn(&array1, &array2).unwrap(), expected);

let expected = BooleanArray::from(
vec![Some(false), Some(false), Some(true), Some(false), Some(false), Some(false)],
);
assert_eq!(gt_dyn(&array1, &array2).unwrap(), expected);

let expected = BooleanArray::from(
vec![Some(false), Some(false), Some(true), Some(true), Some(true), Some(false)],
);
assert_eq!(gt_eq_dyn(&array1, &array2).unwrap(), expected);
}

#[test]
#[cfg(feature = "dyn_cmp_dict")]
fn test_cmp_dict_non_dict_decimal128() {
Expand Down Expand Up @@ -6696,6 +6730,47 @@ mod tests {
assert_eq!(gt_eq_dyn(&array1, &array2).unwrap(), expected);
}

#[test]
#[cfg(feature = "dyn_cmp_dict")]
fn test_cmp_dict_decimal256() {
let values = Decimal256Array::from_iter_values(
[0, 1, 2, 3, 4, 5].into_iter().map(i256::from_i128),
);
let keys = Int8Array::from_iter_values([1_i8, 2, 5, 4, 3, 0]);
let array1 = DictionaryArray::try_new(&keys, &values).unwrap();

let values = Decimal256Array::from_iter_values(
[7, -3, 4, 3, 5].into_iter().map(i256::from_i128),
);
let keys = Int8Array::from_iter_values([0_i8, 0, 1, 2, 3, 4]);
let array2 = DictionaryArray::try_new(&keys, &values).unwrap();

let expected = BooleanArray::from(
vec![Some(false), Some(false), Some(false), Some(true), Some(true), Some(false)],
);
assert_eq!(eq_dyn(&array1, &array2).unwrap(), expected);

let expected = BooleanArray::from(
vec![Some(true), Some(true), Some(false), Some(false), Some(false), Some(true)],
);
assert_eq!(lt_dyn(&array1, &array2).unwrap(), expected);

let expected = BooleanArray::from(
vec![Some(true), Some(true), Some(false), Some(true), Some(true), Some(true)],
);
assert_eq!(lt_eq_dyn(&array1, &array2).unwrap(), expected);

let expected = BooleanArray::from(
vec![Some(false), Some(false), Some(true), Some(false), Some(false), Some(false)],
);
assert_eq!(gt_dyn(&array1, &array2).unwrap(), expected);

let expected = BooleanArray::from(
vec![Some(false), Some(false), Some(true), Some(true), Some(true), Some(false)],
);
assert_eq!(gt_eq_dyn(&array1, &array2).unwrap(), expected);
}

#[test]
#[cfg(feature = "dyn_cmp_dict")]
fn test_cmp_dict_non_dict_decimal256() {
Expand Down

0 comments on commit 99e205f

Please sign in to comment.