Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(ord): Support equality of StructArray #5217

Draft
wants to merge 6 commits into
base: master
Choose a base branch
from
Draft
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
285 changes: 238 additions & 47 deletions arrow-ord/src/cmp.rs
Expand Up @@ -27,7 +27,7 @@ use arrow_array::cast::AsArray;
use arrow_array::types::ByteArrayType;
use arrow_array::{
downcast_primitive_array, AnyDictionaryArray, Array, ArrowNativeTypeOp, BooleanArray, Datum,
FixedSizeBinaryArray, GenericByteArray,
FixedSizeBinaryArray, GenericByteArray, StructArray,
};
use arrow_buffer::bit_util::ceil;
use arrow_buffer::{BooleanBuffer, MutableBuffer, NullBuffer};
Expand Down Expand Up @@ -169,12 +169,14 @@ pub fn not_distinct(lhs: &dyn Datum, rhs: &dyn Datum) -> Result<BooleanArray, Ar
/// Perform `op` on the provided `Datum`
#[inline(never)]
fn compare_op(op: Op, lhs: &dyn Datum, rhs: &dyn Datum) -> Result<BooleanArray, ArrowError> {
use arrow_schema::DataType::*;
let (l, l_s) = lhs.get();
let (r, r_s) = rhs.get();
let (l_array, l_s) = lhs.get();
my-vegetable-has-exploded marked this conversation as resolved.
Show resolved Hide resolved
let (r_array, r_s) = rhs.get();

let l_nulls = l_array.logical_nulls();
let r_nulls = r_array.logical_nulls();

let l_len = l.len();
let r_len = r.len();
let l_len = l_array.len();
let r_len = r_array.len();

if l_len != r_len && !l_s && !r_s {
return Err(ArrowError::InvalidArgumentError(format!(
Expand All @@ -187,47 +189,14 @@ fn compare_op(op: Op, lhs: &dyn Datum, rhs: &dyn Datum) -> Result<BooleanArray,
false => l_len,
};

let l_nulls = l.logical_nulls();
let r_nulls = r.logical_nulls();

let l_v = l.as_any_dictionary_opt();
let l = l_v.map(|x| x.values().as_ref()).unwrap_or(l);
let l_t = l.data_type();

let r_v = r.as_any_dictionary_opt();
let r = r_v.map(|x| x.values().as_ref()).unwrap_or(r);
let r_t = r.data_type();

if l_t != r_t || l_t.is_nested() {
return Err(ArrowError::InvalidArgumentError(format!(
"Invalid comparison operation: {l_t} {op} {r_t}"
)));
}

// Defer computation as may not be necessary
let values = || -> BooleanBuffer {
let d = downcast_primitive_array! {
(l, r) => apply(op, l.values().as_ref(), l_s, l_v, r.values().as_ref(), r_s, r_v),
(Boolean, Boolean) => apply(op, l.as_boolean(), l_s, l_v, r.as_boolean(), r_s, r_v),
(Utf8, Utf8) => apply(op, l.as_string::<i32>(), l_s, l_v, r.as_string::<i32>(), r_s, r_v),
(LargeUtf8, LargeUtf8) => apply(op, l.as_string::<i64>(), l_s, l_v, r.as_string::<i64>(), r_s, r_v),
(Binary, Binary) => apply(op, l.as_binary::<i32>(), l_s, l_v, r.as_binary::<i32>(), r_s, r_v),
(LargeBinary, LargeBinary) => apply(op, l.as_binary::<i64>(), l_s, l_v, r.as_binary::<i64>(), r_s, r_v),
(FixedSizeBinary(_), FixedSizeBinary(_)) => apply(op, l.as_fixed_size_binary(), l_s, l_v, r.as_fixed_size_binary(), r_s, r_v),
(Null, Null) => None,
_ => unreachable!(),
};
d.unwrap_or_else(|| BooleanBuffer::new_unset(len))
};

let l_nulls = l_nulls.filter(|n| n.null_count() > 0);
let r_nulls = r_nulls.filter(|n| n.null_count() > 0);
Ok(match (l_nulls, l_s, r_nulls, r_s) {
(Some(l), true, Some(r), true) | (Some(l), false, Some(r), false) => {
// Either both sides are scalar or neither side is scalar
match op {
Op::Distinct => {
let values = values();
let values = compare_op_values(op, l_array, l_s, r_array, r_s, len)?;
let l = l.inner().bit_chunks().iter_padded();
let r = r.inner().bit_chunks().iter_padded();
let ne = values.bit_chunks().iter_padded();
Expand All @@ -237,7 +206,7 @@ fn compare_op(op: Op, lhs: &dyn Datum, rhs: &dyn Datum) -> Result<BooleanArray,
BooleanBuffer::new(buffer, 0, len).into()
}
Op::NotDistinct => {
let values = values();
let values = compare_op_values(op, l_array, l_s, r_array, r_s, len)?;
let l = l.inner().bit_chunks().iter_padded();
let r = r.inner().bit_chunks().iter_padded();
let e = values.bit_chunks().iter_padded();
Expand All @@ -246,7 +215,10 @@ fn compare_op(op: Op, lhs: &dyn Datum, rhs: &dyn Datum) -> Result<BooleanArray,
let buffer = l.zip(r).zip(e).map(c).collect();
BooleanBuffer::new(buffer, 0, len).into()
}
_ => BooleanArray::new(values(), NullBuffer::union(Some(&l), Some(&r))),
_ => BooleanArray::new(
compare_op_values(op, l_array, l_s, r_array, r_s, len)?,
NullBuffer::union(Some(&l), Some(&r)),
),
}
}
(Some(_), true, Some(a), false) | (Some(a), false, Some(_), true) => {
Expand All @@ -268,23 +240,122 @@ fn compare_op(op: Op, lhs: &dyn Datum, rhs: &dyn Datum) -> Result<BooleanArray,
},
false => match op {
Op::Distinct => {
let values = values();
let values = compare_op_values(op, l_array, l_s, r_array, r_s, len)?;
let l = nulls.inner().bit_chunks().iter_padded();
let ne = values.bit_chunks().iter_padded();
let c = |(l, n)| u64::not(l) | n;
let buffer = l.zip(ne).map(c).collect();
BooleanBuffer::new(buffer, 0, len).into()
}
Op::NotDistinct => (nulls.inner() & &values()).into(),
_ => BooleanArray::new(values(), Some(nulls)),
Op::NotDistinct => (nulls.inner()
& &compare_op_values(op, l_array, l_s, r_array, r_s, len)?)
.into(),
_ => BooleanArray::new(
compare_op_values(op, l_array, l_s, r_array, r_s, len)?,
Some(nulls),
),
},
}
}
// Neither side is nullable
(None, _, None, _) => BooleanArray::new(values(), None),
(None, _, None, _) => BooleanArray::new(
compare_op_values(op, l_array, l_s, r_array, r_s, len)?,
None,
),
})
}

/// Defer computation as may not be necessary
/// get the BooleanBuffer result of the comparison
fn compare_op_values(
op: Op,
l: &dyn Array,
l_s: bool,
r: &dyn Array,
r_s: bool,
len: usize,
) -> Result<BooleanBuffer, ArrowError> {
use arrow_schema::DataType::*;
let l_v = l.as_any_dictionary_opt();
let l = l_v.map(|x| x.values().as_ref()).unwrap_or(l);
let l_t = l.data_type();

let r_v = r.as_any_dictionary_opt();
let r = r_v.map(|x| x.values().as_ref()).unwrap_or(r);
let r_t = r.data_type();

if l_t.is_nested() {
if !l_t.equals_datatype(r_t) {
return Err(ArrowError::InvalidArgumentError(format!(
"Invalid comparison operation: {l_t} {op} {r_t}"
)));
}
match (l_t, op) {
(Struct(_), Op::Equal | Op::NotEqual) => {}
_ => {
return Err(ArrowError::InvalidArgumentError(format!(
"Invalid comparison operation: {l_t} {op} {r_t}"
)));
}
}
} else if r_t != l_t {
return Err(ArrowError::InvalidArgumentError(format!(
"Invalid comparison operation: {l_t} {op} {r_t}"
)));
}
let d = downcast_primitive_array! {
(l, r) => apply(op, l.values().as_ref(), l_s, l_v, r.values().as_ref(), r_s, r_v),
(Boolean, Boolean) => apply(op, l.as_boolean(), l_s, l_v, r.as_boolean(), r_s, r_v),
(Utf8, Utf8) => apply(op, l.as_string::<i32>(), l_s, l_v, r.as_string::<i32>(), r_s, r_v),
(LargeUtf8, LargeUtf8) => apply(op, l.as_string::<i64>(), l_s, l_v, r.as_string::<i64>(), r_s, r_v),
(Binary, Binary) => apply(op, l.as_binary::<i32>(), l_s, l_v, r.as_binary::<i32>(), r_s, r_v),
(LargeBinary, LargeBinary) => apply(op, l.as_binary::<i64>(), l_s, l_v, r.as_binary::<i64>(), r_s, r_v),
(FixedSizeBinary(_), FixedSizeBinary(_)) => apply(op, l.as_fixed_size_binary(), l_s, l_v, r.as_fixed_size_binary(), r_s, r_v),
(Null, Null) => None,
(Struct(_), Struct(_)) => Some(compare_op_struct_values(op, l, l_s, r, r_s, len)?),
_ => unreachable!(),
};
Ok(d.unwrap_or_else(|| BooleanBuffer::new_unset(len)))
}

/// recursively compare fields of struct arrays
fn compare_op_struct_values(
op: Op,
l: &dyn Array,
l_s: bool,
r: &dyn Array,
r_s: bool,
len: usize,
) -> Result<BooleanBuffer, ArrowError> {
// when one of field is equal, the result is false for not equal
// so we use neg to reverse the result of equal when handle not equal
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not just pass the operator into compare_op_values?

let neg = match op {
my-vegetable-has-exploded marked this conversation as resolved.
Show resolved Hide resolved
Op::Equal => false,
Op::NotEqual => true,
_ => unreachable!(),
};

let l = l.as_any().downcast_ref::<StructArray>().unwrap();
my-vegetable-has-exploded marked this conversation as resolved.
Show resolved Hide resolved
let r = r.as_any().downcast_ref::<StructArray>().unwrap();

let mut child_res: Vec<BooleanBuffer> = Vec::with_capacity(len);
// compare each field of struct
for item in l
.columns()
.to_vec()
.iter()
.zip(r.columns().to_vec().iter())
.map(|(col_l, col_r)| compare_op_values(Op::Equal, col_l, l_s, col_r, r_s, len))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this will correctly handle the null masks for a Distinct?

{
child_res.push(item?);
}
// combine the result of each field
let equality = child_res
.iter()
.fold(BooleanBuffer::new_set(len), |acc, x| &acc & x);
Ok(if neg { !&equality } else { equality })
}

/// Perform a potentially vectored `op` on the provided `ArrayOrd`
fn apply<T: ArrayOrd>(
op: Op,
Expand Down Expand Up @@ -544,7 +615,9 @@ impl<'a> ArrayOrd for &'a FixedSizeBinaryArray {
mod tests {
use std::sync::Arc;

use arrow_array::{DictionaryArray, Int32Array, Scalar, StringArray};
use arrow_array::{ArrayRef, DictionaryArray, Int32Array, Scalar, StringArray, StructArray};
use arrow_buffer::Buffer;
use arrow_schema::{DataType, Field};

use super::*;

Expand Down Expand Up @@ -702,4 +775,122 @@ mod tests {

neq(&col.slice(0, col.len() - 1), &col.slice(1, col.len() - 1)).unwrap();
}

#[test]
fn test_struct_equality() {
// test struct('a', 'b') = struct('a', 'b'), the null buffer is 0b0111
let left_a = Arc::new(Int32Array::new(
vec![0, 1, 2, 3].into(),
Some(vec![true, false, true, false].into()),
));
let right_a = Arc::new(Int32Array::new(
vec![0, 1, 2, 3].into(),
my-vegetable-has-exploded marked this conversation as resolved.
Show resolved Hide resolved
Some(vec![true, false, true, false].into()),
));
let left_b = Arc::new(Int32Array::new(
vec![0, 1, 2, 3].into(),
Some(vec![true, true, true, false].into()),
));
let right_b = Arc::new(Int32Array::new(
vec![0, 1, 2, 3].into(),
Some(vec![true, true, true, false].into()),
));
let field_a = Arc::new(Field::new("a", DataType::Int32, true));
let field_b = Arc::new(Field::new("b", DataType::Int32, true));
let left_struct = StructArray::from((
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
let left_struct = StructArray::from((
// [{a: 0, b: 0}, {a: NULL, b: 1}, {a: 2, b: 20}, {a: 3, b: 3}]
let left_struct = StructArray::from((

vec![
(field_a.clone(), left_a.clone() as ArrayRef),
(field_b.clone(), left_b.clone() as ArrayRef),
],
Buffer::from([0b0111]),
));
let right_struct = StructArray::from((
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
let right_struct = StructArray::from((
// right [{a: 0, b: 0}, {a: NULL, b: 1}, {a: 2, b: 2}, {a: 3, b: 3} ]
let right_struct = StructArray::from((

vec![
(field_a.clone(), right_a.clone() as ArrayRef),
(field_b.clone(), right_b.clone() as ArrayRef),
],
Buffer::from([0b0111]),
));
let expected = BooleanArray::new(
vec![true, true, true, true].into(),
Some(vec![true, true, true, false].into()),
);
assert_eq!(eq(&left_struct, &right_struct).unwrap(), expected);
assert_eq!(eq(&right_struct, &left_struct).unwrap(), expected);
let expected = BooleanArray::new(
vec![false, false, false, false].into(),
Some(vec![true, true, true, false].into()),
);
assert_eq!(neq(&left_struct, &right_struct).unwrap(), expected);
assert_eq!(neq(&right_struct, &left_struct).unwrap(), expected);

let sub_struct_fields = left_struct.fields().clone();

// test struct('a', 'b') = struct('a', 'b'), right a[1] is different from left a[2],the null buffer is 0b0111
let right_a2 = Arc::new(Int32Array::new(
vec![0, 2, 2, 3].into(),
Some(vec![true, true, true, false].into()),
));
let right_struct = StructArray::from((
vec![
(field_a.clone(), right_a2.clone() as ArrayRef),
(field_b.clone(), right_b.clone() as ArrayRef),
],
Buffer::from([0b0111]),
));
let expected = BooleanArray::new(
vec![true, false, true, true].into(),
Some(vec![true, true, true, false].into()),
);
assert_eq!(eq(&left_struct, &right_struct).unwrap(), expected);
assert_eq!(eq(&right_struct, &left_struct).unwrap(), expected);
let expected = BooleanArray::new(
vec![false, true, false, false].into(),
Some(vec![true, true, true, false].into()),
);
assert_eq!(neq(&left_struct, &right_struct).unwrap(), expected);
assert_eq!(neq(&right_struct, &left_struct).unwrap(), expected);

// test struct('a' , struct('suba', 'subb')) = struct('a', struct('suba', 'subb')), where the right suba1[1] different from left suba[1],the null buffer is 0b0111
my-vegetable-has-exploded marked this conversation as resolved.
Show resolved Hide resolved
let left_struct = StructArray::from((
vec![
(field_a.clone(), left_a.clone() as ArrayRef),
(
Arc::new(Field::new(
"SubStruct",
DataType::Struct(sub_struct_fields.clone()),
true,
)),
Arc::new(left_struct) as ArrayRef,
),
],
Buffer::from([0b0111]),
));
let right_struct = StructArray::from((
vec![
(field_a.clone(), right_a.clone() as ArrayRef),
(
Arc::new(Field::new(
"SubStruct",
DataType::Struct(sub_struct_fields.clone()),
true,
)),
Arc::new(right_struct) as ArrayRef,
),
],
Buffer::from([0b0111]),
));
let expected = BooleanArray::new(
vec![true, false, true, true].into(),
Some(vec![true, true, true, false].into()),
);
assert_eq!(eq(&left_struct, &right_struct).unwrap(), expected);
assert_eq!(eq(&right_struct, &left_struct).unwrap(), expected);
let expected = BooleanArray::new(
vec![false, true, false, false].into(),
Some(vec![true, true, true, false].into()),
);
assert_eq!(neq(&left_struct, &right_struct).unwrap(), expected);
assert_eq!(neq(&right_struct, &left_struct).unwrap(), expected);
}
}