diff --git a/arrow-ord/src/cmp.rs b/arrow-ord/src/cmp.rs index bfb1f64e2eb..181b5f8f047 100644 --- a/arrow-ord/src/cmp.rs +++ b/arrow-ord/src/cmp.rs @@ -169,7 +169,6 @@ pub fn not_distinct(lhs: &dyn Datum, rhs: &dyn Datum) -> Result Result { - use arrow_schema::DataType::*; let (l, l_s) = lhs.get(); let (r, r_s) = rhs.get(); @@ -186,10 +185,78 @@ fn compare_op(op: Op, lhs: &dyn Datum, rhs: &dyn Datum) -> Result r_len, false => l_len, }; + Ok(BooleanArray::new( + compare_op_values(op, l, l_s, r, r_s, len)?, + compare_op_nulls(op, l, l_s, r, r_s, len)?, + )) +} - let l_nulls = l.logical_nulls(); - let r_nulls = r.logical_nulls(); +/// get the NullBuffer result of the comparison +fn compare_op_nulls( + op: Op, + l: &dyn Array, + l_s: bool, + r: &dyn Array, + r_s: bool, + len: usize, +) -> Result, ArrowError> { + use arrow_schema::DataType::*; + if matches!(op, Op::Distinct | Op::NotDistinct) { + // for [not]Distinct, the result is never null + return Ok(None); + } + let l_t = l.data_type(); + let r_t = r.data_type(); + let l_nulls = l.logical_nulls().filter(|n| n.null_count() > 0); + let r_nulls = r.logical_nulls().filter(|n| n.null_count() > 0); + let nulls = match (l_nulls, l_s, r_nulls, r_s) { + // Either both sides are scalar or neither side is scalar + (Some(l_nulls), true, Some(r_nulls), true) + | (Some(l_nulls), false, Some(r_nulls), false) => { + NullBuffer::union(Some(&l_nulls), Some(&r_nulls)) + } + // Scalar is null, other side is non-scalar and nullable + (Some(_), true, Some(_), false) | (Some(_), false, Some(_), true) => { + Some(NullBuffer::new_null(len)) + } + // Only one side is nullable + (Some(nulls), is_scalar, None, _) | (None, _, Some(nulls), is_scalar) => match is_scalar { + true => Some(NullBuffer::new_null(len)), + false => Some(nulls), + }, + // Neither side is nullable + (None, _, None, _) => None, + }; + match (l_t, r_t) { + (Struct(_), Struct(_)) => { + // union all nulls from children, because any child in certain slot is null, the struct in the slot is uncomparable + let child_nulls = l + .as_struct() + .columns() + .iter() + .zip(r.as_struct().columns().iter()) + .map(|(l, r)| compare_op_nulls(op, l, l_s, r, r_s, len)) + .collect::, _>>()?; + Ok(child_nulls.iter().fold(nulls, |nulls, child_null| { + NullBuffer::union(nulls.as_ref(), child_null.as_ref()) + })) + } + _ => Ok(nulls), + } +} + +/// Defer computation as may not be necessary +/// get the BooleanBuffer result of the comparison +fn compare_op_values( + op: Op, + l: &dyn Array, + l_s: bool, + r: &dyn Array, + r_s: bool, + len: usize, +) -> Result { + use arrow_schema::DataType::*; let l_v = l.as_any_dictionary_opt(); let l = l_v.map(|x| x.values().as_ref()).unwrap_or(l); let l_t = l.data_type(); @@ -198,15 +265,16 @@ fn compare_op(op: Op, lhs: &dyn Datum, rhs: &dyn Datum) -> Result BooleanBuffer { - let d = downcast_primitive_array! { + let l_nulls = l.logical_nulls().filter(|n| n.null_count() > 0); + let r_nulls = r.logical_nulls().filter(|n| n.null_count() > 0); + let values = || -> Result { + let values = downcast_primitive_array! { (l, r) => apply(op, l.values().as_ref(), l_s, l_v, r.values().as_ref(), r_s, r_v), (Boolean, Boolean) => apply(op, l.as_boolean(), l_s, l_v, r.as_boolean(), r_s, r_v), (Utf8, Utf8) => apply(op, l.as_string::(), l_s, l_v, r.as_string::(), r_s, r_v), @@ -215,46 +283,48 @@ fn compare_op(op: Op, lhs: &dyn Datum, rhs: &dyn Datum) -> Result apply(op, l.as_binary::(), l_s, l_v, r.as_binary::(), r_s, r_v), (FixedSizeBinary(_), FixedSizeBinary(_)) => apply(op, l.as_fixed_size_binary(), l_s, l_v, r.as_fixed_size_binary(), r_s, r_v), (Null, Null) => None, + (Struct(_), Struct(_)) => Some(compare_op_struct_values(op, l, l_s, r, r_s, len)?), _ => unreachable!(), }; - d.unwrap_or_else(|| BooleanBuffer::new_unset(len)) + Ok(values.unwrap_or_else(|| BooleanBuffer::new_unset(len))) }; - - let l_nulls = l_nulls.filter(|n| n.null_count() > 0); - let r_nulls = r_nulls.filter(|n| n.null_count() > 0); Ok(match (l_nulls, l_s, r_nulls, r_s) { - (Some(l), true, Some(r), true) | (Some(l), false, Some(r), false) => { + (Some(l_nulls), true, Some(r_nulls), true) + | (Some(l_nulls), false, Some(r_nulls), false) => { // Either both sides are scalar or neither side is scalar match op { Op::Distinct => { - let values = values(); - let l = l.inner().bit_chunks().iter_padded(); - let r = r.inner().bit_chunks().iter_padded(); + let values = values()?; + let l_nulls = l_nulls.inner().bit_chunks().iter_padded(); + let r_nulls = r_nulls.inner().bit_chunks().iter_padded(); let ne = values.bit_chunks().iter_padded(); - let c = |((l, r), n)| ((l ^ r) | (l & r & n)); - let buffer = l.zip(r).zip(ne).map(c).collect(); - BooleanBuffer::new(buffer, 0, len).into() + let c = + |((l_nulls, r_nulls), n)| ((l_nulls ^ r_nulls) | (l_nulls & r_nulls & n)); + let buffer = l_nulls.zip(r_nulls).zip(ne).map(c).collect(); + BooleanBuffer::new(buffer, 0, len) } Op::NotDistinct => { - let values = values(); - let l = l.inner().bit_chunks().iter_padded(); - let r = r.inner().bit_chunks().iter_padded(); + let values = values()?; + let l_nulls = l_nulls.inner().bit_chunks().iter_padded(); + let r_nulls = r_nulls.inner().bit_chunks().iter_padded(); let e = values.bit_chunks().iter_padded(); - let c = |((l, r), e)| u64::not(l | r) | (l & r & e); - let buffer = l.zip(r).zip(e).map(c).collect(); - BooleanBuffer::new(buffer, 0, len).into() + let c = |((l_nulls, r_nulls), e)| { + u64::not(l_nulls | r_nulls) | (l_nulls & r_nulls & e) + }; + let buffer = l_nulls.zip(r_nulls).zip(e).map(c).collect(); + BooleanBuffer::new(buffer, 0, len) } - _ => BooleanArray::new(values(), NullBuffer::union(Some(&l), Some(&r))), + _ => values()?, } } (Some(_), true, Some(a), false) | (Some(a), false, Some(_), true) => { // Scalar is null, other side is non-scalar and nullable match op { - Op::Distinct => a.into_inner().into(), - Op::NotDistinct => a.into_inner().not().into(), - _ => BooleanArray::new_null(len), + Op::Distinct => a.into_inner(), + Op::NotDistinct => a.into_inner().not(), + _ => BooleanBuffer::new_unset(len), } } (Some(nulls), is_scalar, None, _) | (None, _, Some(nulls), is_scalar) => { @@ -262,29 +332,75 @@ fn compare_op(op: Op, lhs: &dyn Datum, rhs: &dyn Datum) -> Result match op { // Scalar is null, other side is not nullable - Op::Distinct => BooleanBuffer::new_set(len).into(), - Op::NotDistinct => BooleanBuffer::new_unset(len).into(), - _ => BooleanArray::new_null(len), + Op::Distinct => BooleanBuffer::new_set(len), + Op::NotDistinct => BooleanBuffer::new_unset(len), + _ => BooleanBuffer::new_unset(len), }, false => match op { Op::Distinct => { - let values = values(); - let l = nulls.inner().bit_chunks().iter_padded(); + let values = values()?; + let l_nulls = nulls.inner().bit_chunks().iter_padded(); let ne = values.bit_chunks().iter_padded(); - let c = |(l, n)| u64::not(l) | n; - let buffer = l.zip(ne).map(c).collect(); - BooleanBuffer::new(buffer, 0, len).into() + let c = |(l_nulls, n)| u64::not(l_nulls) | n; + let buffer = l_nulls.zip(ne).map(c).collect(); + BooleanBuffer::new(buffer, 0, len) } - Op::NotDistinct => (nulls.inner() & &values()).into(), - _ => BooleanArray::new(values(), Some(nulls)), + Op::NotDistinct => nulls.inner() & &values()?, + _ => values()?, }, } } // Neither side is nullable - (None, _, None, _) => BooleanArray::new(values(), None), + (None, _, None, _) => values()?, }) } +/// recursively compare fields of struct arrays +fn compare_op_struct_values( + op: Op, + l: &dyn Array, + l_s: bool, + r: &dyn Array, + r_s: bool, + len: usize, +) -> Result { + // when one of field is not equal(notdistinct), the result is false for equal(notdistinct) + // so we use neg to reverse the result of equal when handle not equal + let neg = match op { + Op::Equal | Op::NotDistinct => false, + Op::NotEqual | Op::Distinct => true, + _ => { + return Err(ArrowError::InvalidArgumentError(format!( + "Invalid comparison operation: Struct {op} Struct" + ))) + } + }; + + let op = match op { + Op::NotEqual => Op::Equal, + Op::Distinct => Op::NotDistinct, + _ => op, + }; + + let l = l.as_struct(); + let r = r.as_struct(); + + // compare each field of struct + let child_values = l + .columns() + .iter() + .zip(r.columns().iter()) + .map(|(col_l, col_r)| compare_op_values(op, col_l, l_s, col_r, r_s, len)) + .collect::, ArrowError>>()?; + // combine the result of each field + let equality = child_values + .iter() + .fold(BooleanBuffer::new_set(len), |values, child_value| { + &values & child_value + }); + Ok(if neg { !&equality } else { equality }) +} + /// Perform a potentially vectored `op` on the provided `ArrayOrd` fn apply( op: Op, @@ -544,7 +660,9 @@ impl<'a> ArrayOrd for &'a FixedSizeBinaryArray { mod tests { use std::sync::Arc; - use arrow_array::{DictionaryArray, Int32Array, Scalar, StringArray}; + use arrow_array::{ArrayRef, DictionaryArray, Int32Array, Scalar, StringArray, StructArray}; + use arrow_buffer::Buffer; + use arrow_schema::{DataType, Field}; use super::*; @@ -702,4 +820,140 @@ mod tests { neq(&col.slice(0, col.len() - 1), &col.slice(1, col.len() - 1)).unwrap(); } + + #[test] + fn test_struct_uncomparable() { + // test struct('a') == struct('a','b') + let left_a = Arc::new(Int32Array::new( + vec![0, 1, 2, 3].into(), + Some(vec![true, false, true, false].into()), + )); + let right_a = Arc::new(Int32Array::new( + vec![0, 1, 2, 3].into(), + Some(vec![true, false, true, false].into()), + )); + let right_b = Arc::new(Int32Array::new( + vec![0, 1, 2, 3].into(), + Some(vec![true, true, true, false].into()), + )); + let field_a = Arc::new(Field::new("a", DataType::Int32, true)); + let field_b = Arc::new(Field::new("b", DataType::Int32, true)); + let left = StructArray::from(vec![(field_a.clone(), left_a.clone() as ArrayRef)]); + let right = StructArray::from(vec![ + (field_a.clone(), right_a.clone() as ArrayRef), + (field_b.clone(), right_b.clone() as ArrayRef), + ]); + assert_eq!(eq(&left, &right).unwrap_err().to_string(), "Invalid argument error: Invalid comparison operation: Struct([Field { name: \"a\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]) == Struct([Field { name: \"a\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: \"b\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }])"); + + // test struct('a') <= struct('a') + assert_eq!( + lt(&left, &left).unwrap_err().to_string(), + "Invalid argument error: Invalid comparison operation: Struct < Struct" + ); + } + + #[test] + fn test_struct_compare() { + let left_a = Arc::new(Int32Array::new( + vec![0, 1, 2, 3, 4, 5, 6, 7].into(), + Some(vec![true, false, true, true, false, true, true, false].into()), + )); + let right_a = Arc::new(Int32Array::new( + vec![0, 1, 2, 3, 4, 5, 6, 72].into(), + Some(vec![true, false, true, true, false, true, true, false].into()), + )); + let left_b = Arc::new(Int32Array::new( + vec![0, 1, 2, 3, 4, 5, 7, 7].into(), + Some(vec![true, true, true, true, true, true, true, true].into()), + )); + let right_b = Arc::new(Int32Array::new( + vec![0, 1, 20, 13, 72, 6, 6, 7].into(), + Some(vec![true, true, true, true, true, true, false, true].into()), + )); + let field_a = Arc::new(Field::new("a", DataType::Int32, true)); + let field_b = Arc::new(Field::new("b", DataType::Int32, true)); + // left [{a: 0, b: 0}, {a: NULL, b: 1}, {a: 2, b: 2}, NULL({a: 3, b: 3}), {a: NULL, b: 4}, NULL({a: 5, b: 5}), {a:6, b: 7}, {a: NULL, b: 7}] + let left_struct = StructArray::from(( + vec![ + (field_a.clone(), left_a.clone() as ArrayRef), + (field_b.clone(), left_b.clone() as ArrayRef), + ], + Buffer::from([0b11010111]), + )); + // right [{a: 0, b: 0}, {a: NULL, b: 1}, {a: 2, b: 20}, Null({a: 3, b: 13}), {a: NULL, b: 72}, Null({a: 5, b: 6}), {a:6, b: Null}, {a: NULL, b: 7}] + let right_struct = StructArray::from(( + vec![ + (field_a.clone(), right_a.clone() as ArrayRef), + (field_b.clone(), right_b.clone() as ArrayRef), + ], + Buffer::from([0b11010111]), + )); + let expected = BooleanArray::new( + vec![true, true, false, false, false, false, false, false].into(), + Some(vec![true, false, true, false, false, false, false, false].into()), + ); + assert_eq!(eq(&left_struct, &right_struct).unwrap(), expected); + assert_eq!(eq(&right_struct, &left_struct).unwrap(), expected); + let expected = BooleanArray::new( + vec![false, false, true, true, true, true, true, true].into(), + Some(vec![true, false, true, false, false, false, false, false].into()), + ); + assert_eq!(neq(&left_struct, &right_struct).unwrap(), expected); + assert_eq!(neq(&right_struct, &left_struct).unwrap(), expected); + let expected = BooleanArray::new( + vec![false, false, true, false, true, false, true, false].into(), + None, + ); + assert_eq!(distinct(&left_struct, &right_struct).unwrap(), expected); + assert_eq!(distinct(&right_struct, &left_struct).unwrap(), expected); + let expected = BooleanArray::new( + vec![true, true, false, true, false, true, false, true].into(), + None, + ); + assert_eq!(not_distinct(&left_struct, &right_struct).unwrap(), expected); + assert_eq!(not_distinct(&right_struct, &left_struct).unwrap(), expected); + + let sub_struct_fields = left_struct.fields().clone(); + + let left_struct = StructArray::from(( + vec![ + (field_a.clone(), left_a.clone() as ArrayRef), + ( + Arc::new(Field::new( + "SubStruct", + DataType::Struct(sub_struct_fields.clone()), + true, + )), + Arc::new(left_struct) as ArrayRef, + ), + ], + Buffer::from([0b11010111]), + )); + let right_struct = StructArray::from(( + vec![ + (field_a.clone(), right_a.clone() as ArrayRef), + ( + Arc::new(Field::new( + "SubStruct", + DataType::Struct(sub_struct_fields.clone()), + true, + )), + Arc::new(right_struct) as ArrayRef, + ), + ], + Buffer::from([0b11010111]), + )); + let expected = BooleanArray::new( + vec![true, true, false, false, false, false, false, false].into(), + Some(vec![true, false, true, false, false, false, false, false].into()), + ); + assert_eq!(eq(&left_struct, &right_struct).unwrap(), expected); + assert_eq!(eq(&right_struct, &left_struct).unwrap(), expected); + let expected = BooleanArray::new( + vec![false, false, true, true, true, true, true, true].into(), + Some(vec![true, false, true, false, false, false, false, false].into()), + ); + assert_eq!(neq(&left_struct, &right_struct).unwrap(), expected); + assert_eq!(neq(&right_struct, &left_struct).unwrap(), expected); + } }