Skip to content

Commit

Permalink
Use ArrowNativeTypeOp instead of total_cmp directly (#3087)
Browse files Browse the repository at this point in the history
  • Loading branch information
viirya committed Nov 13, 2022
1 parent c7210ce commit 3084ee2
Showing 1 changed file with 32 additions and 80 deletions.
112 changes: 32 additions & 80 deletions arrow/src/compute/kernels/comparison.rs
Expand Up @@ -2748,30 +2748,22 @@ pub fn eq_dyn(left: &dyn Array, right: &dyn Array) -> Result<BooleanArray> {
DataType::Dictionary(_, _)
if matches!(right.data_type(), DataType::Dictionary(_, _)) =>
{
typed_dict_compares!(
left,
right,
|a, b| a == b,
|a, b| a.total_cmp(&b).is_eq(),
|a, b| a == b
)
typed_dict_compares!(left, right, |a, b| a == b, |a, b| a.is_eq(b), |a, b| a
== b)
}
DataType::Dictionary(_, _)
if !matches!(right.data_type(), DataType::Dictionary(_, _)) =>
{
typed_cmp_dict_non_dict!(left, right, |a, b| a == b, |a, b| a == b, |a, b| a
.total_cmp(&b)
.is_eq())
.is_eq(b))
}
_ if matches!(right.data_type(), DataType::Dictionary(_, _)) => {
typed_cmp_dict_non_dict!(right, left, |a, b| a == b, |a, b| a == b, |a, b| a
.total_cmp(&b)
.is_eq())
typed_cmp_dict_non_dict!(right, left, |a, b| a == b, |a, b| a == b, |a, b| b
.is_eq(a))
}
_ => {
typed_compares!(left, right, |a, b| !(a ^ b), |a, b| a == b, |a, b| a
.total_cmp(&b)
.is_eq())
.is_eq(b))
}
}
}
Expand Down Expand Up @@ -2801,30 +2793,22 @@ pub fn neq_dyn(left: &dyn Array, right: &dyn Array) -> Result<BooleanArray> {
DataType::Dictionary(_, _)
if matches!(right.data_type(), DataType::Dictionary(_, _)) =>
{
typed_dict_compares!(
left,
right,
|a, b| a != b,
|a, b| a.total_cmp(&b).is_ne(),
|a, b| a != b
)
typed_dict_compares!(left, right, |a, b| a != b, |a, b| a.is_ne(b), |a, b| a
!= b)
}
DataType::Dictionary(_, _)
if !matches!(right.data_type(), DataType::Dictionary(_, _)) =>
{
typed_cmp_dict_non_dict!(left, right, |a, b| a != b, |a, b| a != b, |a, b| a
.total_cmp(&b)
.is_ne())
.is_ne(b))
}
_ if matches!(right.data_type(), DataType::Dictionary(_, _)) => {
typed_cmp_dict_non_dict!(right, left, |a, b| a != b, |a, b| a != b, |a, b| a
.total_cmp(&b)
.is_ne())
typed_cmp_dict_non_dict!(right, left, |a, b| a != b, |a, b| a != b, |a, b| b
.is_ne(a))
}
_ => {
typed_compares!(left, right, |a, b| (a ^ b), |a, b| a != b, |a, b| a
.total_cmp(&b)
.is_ne())
.is_ne(b))
}
}
}
Expand Down Expand Up @@ -2854,30 +2838,22 @@ pub fn lt_dyn(left: &dyn Array, right: &dyn Array) -> Result<BooleanArray> {
DataType::Dictionary(_, _)
if matches!(right.data_type(), DataType::Dictionary(_, _)) =>
{
typed_dict_compares!(
left,
right,
|a, b| a < b,
|a, b| a.total_cmp(&b).is_lt(),
|a, b| a < b
)
typed_dict_compares!(left, right, |a, b| a < b, |a, b| a.is_lt(b), |a, b| a
< b)
}
DataType::Dictionary(_, _)
if !matches!(right.data_type(), DataType::Dictionary(_, _)) =>
{
typed_cmp_dict_non_dict!(left, right, |a, b| a < b, |a, b| a < b, |a, b| a
.total_cmp(&b)
.is_lt())
.is_lt(b))
}
_ if matches!(right.data_type(), DataType::Dictionary(_, _)) => {
typed_cmp_dict_non_dict!(right, left, |a, b| a > b, |a, b| a > b, |a, b| b
.total_cmp(&a)
.is_lt())
.is_lt(a))
}
_ => {
typed_compares!(left, right, |a, b| ((!a) & b), |a, b| a < b, |a, b| a
.total_cmp(&b)
.is_lt())
.is_lt(b))
}
}
}
Expand Down Expand Up @@ -2906,30 +2882,22 @@ pub fn lt_eq_dyn(left: &dyn Array, right: &dyn Array) -> Result<BooleanArray> {
DataType::Dictionary(_, _)
if matches!(right.data_type(), DataType::Dictionary(_, _)) =>
{
typed_dict_compares!(
left,
right,
|a, b| a <= b,
|a, b| a.total_cmp(&b).is_le(),
|a, b| a <= b
)
typed_dict_compares!(left, right, |a, b| a <= b, |a, b| a.is_le(b), |a, b| a
<= b)
}
DataType::Dictionary(_, _)
if !matches!(right.data_type(), DataType::Dictionary(_, _)) =>
{
typed_cmp_dict_non_dict!(left, right, |a, b| a <= b, |a, b| a <= b, |a, b| a
.total_cmp(&b)
.is_le())
.is_le(b))
}
_ if matches!(right.data_type(), DataType::Dictionary(_, _)) => {
typed_cmp_dict_non_dict!(right, left, |a, b| a >= b, |a, b| a >= b, |a, b| b
.total_cmp(&a)
.is_le())
.is_le(a))
}
_ => {
typed_compares!(left, right, |a, b| !(a & (!b)), |a, b| a <= b, |a, b| a
.total_cmp(&b)
.is_le())
.is_le(b))
}
}
}
Expand Down Expand Up @@ -2958,30 +2926,22 @@ pub fn gt_dyn(left: &dyn Array, right: &dyn Array) -> Result<BooleanArray> {
DataType::Dictionary(_, _)
if matches!(right.data_type(), DataType::Dictionary(_, _)) =>
{
typed_dict_compares!(
left,
right,
|a, b| a > b,
|a, b| a.total_cmp(&b).is_gt(),
|a, b| a > b
)
typed_dict_compares!(left, right, |a, b| a > b, |a, b| a.is_gt(b), |a, b| a
> b)
}
DataType::Dictionary(_, _)
if !matches!(right.data_type(), DataType::Dictionary(_, _)) =>
{
typed_cmp_dict_non_dict!(left, right, |a, b| a > b, |a, b| a > b, |a, b| a
.total_cmp(&b)
.is_gt())
.is_gt(b))
}
_ if matches!(right.data_type(), DataType::Dictionary(_, _)) => {
typed_cmp_dict_non_dict!(right, left, |a, b| a < b, |a, b| a < b, |a, b| b
.total_cmp(&a)
.is_gt())
.is_gt(a))
}
_ => {
typed_compares!(left, right, |a, b| (a & (!b)), |a, b| a > b, |a, b| a
.total_cmp(&b)
.is_gt())
.is_gt(b))
}
}
}
Expand Down Expand Up @@ -3009,30 +2969,22 @@ pub fn gt_eq_dyn(left: &dyn Array, right: &dyn Array) -> Result<BooleanArray> {
DataType::Dictionary(_, _)
if matches!(right.data_type(), DataType::Dictionary(_, _)) =>
{
typed_dict_compares!(
left,
right,
|a, b| a >= b,
|a, b| a.total_cmp(&b).is_ge(),
|a, b| a >= b
)
typed_dict_compares!(left, right, |a, b| a >= b, |a, b| a.is_ge(b), |a, b| a
>= b)
}
DataType::Dictionary(_, _)
if !matches!(right.data_type(), DataType::Dictionary(_, _)) =>
{
typed_cmp_dict_non_dict!(left, right, |a, b| a >= b, |a, b| a >= b, |a, b| a
.total_cmp(&b)
.is_ge())
.is_ge(b))
}
_ if matches!(right.data_type(), DataType::Dictionary(_, _)) => {
typed_cmp_dict_non_dict!(right, left, |a, b| a <= b, |a, b| a <= b, |a, b| b
.total_cmp(&a)
.is_ge())
.is_ge(a))
}
_ => {
typed_compares!(left, right, |a, b| !((!a) & b), |a, b| a >= b, |a, b| a
.total_cmp(&b)
.is_ge())
.is_ge(b))
}
}
}
Expand Down

0 comments on commit 3084ee2

Please sign in to comment.