diff --git a/arrow/src/array/data.rs b/arrow/src/array/data.rs index c38107b2587..17b51cd3b7b 100644 --- a/arrow/src/array/data.rs +++ b/arrow/src/array/data.rs @@ -23,6 +23,7 @@ use crate::datatypes::{ UnionMode, }; use crate::error::{ArrowError, Result}; +use crate::util::bit_iterator::BitSliceIterator; use crate::{bitmap::Bitmap, datatypes::ArrowNativeType}; use crate::{ buffer::{Buffer, MutableBuffer}, @@ -37,6 +38,21 @@ use std::sync::Arc; use super::equal::equal; +#[inline] +pub(crate) fn contains_nulls( + null_bit_buffer: Option<&Buffer>, + offset: usize, + len: usize, +) -> bool { + match null_bit_buffer { + Some(buffer) => match BitSliceIterator::new(buffer, offset, len).next() { + Some((start, end)) => start != 0 || end != len, + None => len != 0, // No non-null values + }, + None => false, // No null buffer + } +} + #[inline] pub(crate) fn count_nulls( null_bit_buffer: Option<&Buffer>, @@ -2865,4 +2881,15 @@ mod tests { let err = data.validate_values().unwrap_err(); assert_eq!(err.to_string(), "Invalid argument error: Offset invariant failure: offset at position 1 out of bounds: 3 > 2"); } + + #[test] + fn test_contains_nulls() { + let buffer: Buffer = + MutableBuffer::from_iter([false, false, false, true, true, false]).into(); + + assert!(contains_nulls(Some(&buffer), 0, 6)); + assert!(contains_nulls(Some(&buffer), 0, 3)); + assert!(!contains_nulls(Some(&buffer), 3, 2)); + assert!(!contains_nulls(Some(&buffer), 0, 0)); + } } diff --git a/arrow/src/array/equal/boolean.rs b/arrow/src/array/equal/boolean.rs index 1a7179fa858..fddf21b963a 100644 --- a/arrow/src/array/equal/boolean.rs +++ b/arrow/src/array/equal/boolean.rs @@ -15,11 +15,14 @@ // specific language governing permissions and limitations // under the License. -use crate::array::{data::count_nulls, ArrayData}; +use crate::array::{data::contains_nulls, ArrayData}; +use crate::util::bit_iterator::BitIndexIterator; use crate::util::bit_util::get_bit; use super::utils::{equal_bits, equal_len}; +/// Returns true if the value data for the arrays is equal, assuming the null masks have +/// already been checked for equality pub(super) fn boolean_equal( lhs: &ArrayData, rhs: &ArrayData, @@ -30,10 +33,9 @@ pub(super) fn boolean_equal( let lhs_values = lhs.buffers()[0].as_slice(); let rhs_values = rhs.buffers()[0].as_slice(); - let lhs_null_count = count_nulls(lhs.null_buffer(), lhs_start + lhs.offset(), len); - let rhs_null_count = count_nulls(rhs.null_buffer(), rhs_start + rhs.offset(), len); + let contains_nulls = contains_nulls(lhs.null_buffer(), lhs_start + lhs.offset(), len); - if lhs_null_count == 0 && rhs_null_count == 0 { + if !contains_nulls { // Optimize performance for starting offset at u8 boundary. if lhs_start % 8 == 0 && rhs_start % 8 == 0 @@ -75,20 +77,14 @@ pub(super) fn boolean_equal( } else { // get a ref of the null buffer bytes, to use in testing for nullness let lhs_null_bytes = lhs.null_buffer().as_ref().unwrap().as_slice(); - let rhs_null_bytes = rhs.null_buffer().as_ref().unwrap().as_slice(); let lhs_start = lhs.offset() + lhs_start; let rhs_start = rhs.offset() + rhs_start; - (0..len).all(|i| { + BitIndexIterator::new(lhs_null_bytes, lhs_start, len).all(|i| { let lhs_pos = lhs_start + i; let rhs_pos = rhs_start + i; - let lhs_is_null = !get_bit(lhs_null_bytes, lhs_pos); - let rhs_is_null = !get_bit(rhs_null_bytes, rhs_pos); - - lhs_is_null - || (lhs_is_null == rhs_is_null) - && equal_bits(lhs_values, rhs_values, lhs_pos, rhs_pos, 1) + get_bit(lhs_values, lhs_pos) == get_bit(rhs_values, rhs_pos) }) } } @@ -109,4 +105,13 @@ mod tests { let slice = array.slice(8, 24); assert_eq!(slice.data(), slice.data()); } + + #[test] + fn test_sliced_nullable_boolean_array() { + let a = BooleanArray::from(vec![None; 32]); + let b = BooleanArray::from(vec![true; 32]); + let slice_a = a.slice(1, 12); + let slice_b = b.slice(1, 12); + assert_ne!(slice_a.data(), slice_b.data()); + } } diff --git a/arrow/src/array/equal/utils.rs b/arrow/src/array/equal/utils.rs index fed3933a089..449055d366e 100644 --- a/arrow/src/array/equal/utils.rs +++ b/arrow/src/array/equal/utils.rs @@ -15,9 +15,10 @@ // specific language governing permissions and limitations // under the License. -use crate::array::{data::count_nulls, ArrayData}; +use crate::array::data::contains_nulls; +use crate::array::ArrayData; use crate::datatypes::DataType; -use crate::util::bit_util; +use crate::util::bit_chunk_iterator::BitChunks; // whether bits along the positions are equal // `lhs_start`, `rhs_start` and `len` are _measured in bits_. @@ -29,10 +30,16 @@ pub(super) fn equal_bits( rhs_start: usize, len: usize, ) -> bool { - (0..len).all(|i| { - bit_util::get_bit(lhs_values, lhs_start + i) - == bit_util::get_bit(rhs_values, rhs_start + i) - }) + let lhs = BitChunks::new(lhs_values, lhs_start, len); + let rhs = BitChunks::new(rhs_values, rhs_start, len); + + for (a, b) in lhs.iter().zip(rhs.iter()) { + if a != b { + return false; + } + } + + lhs.remainder_bits() == rhs.remainder_bits() } #[inline] @@ -43,25 +50,16 @@ pub(super) fn equal_nulls( rhs_start: usize, len: usize, ) -> bool { - let lhs_null_count = count_nulls(lhs.null_buffer(), lhs_start + lhs.offset(), len); - let rhs_null_count = count_nulls(rhs.null_buffer(), rhs_start + rhs.offset(), len); + let lhs_offset = lhs_start + lhs.offset(); + let rhs_offset = rhs_start + rhs.offset(); - if lhs_null_count != rhs_null_count { - return false; - } - - if lhs_null_count > 0 || rhs_null_count > 0 { - let lhs_values = lhs.null_buffer().unwrap().as_slice(); - let rhs_values = rhs.null_buffer().unwrap().as_slice(); - equal_bits( - lhs_values, - rhs_values, - lhs_start + lhs.offset(), - rhs_start + rhs.offset(), - len, - ) - } else { - true + match (lhs.null_buffer(), rhs.null_buffer()) { + (Some(lhs), Some(rhs)) => { + equal_bits(lhs.as_slice(), rhs.as_slice(), lhs_offset, rhs_offset, len) + } + (Some(lhs), None) => !contains_nulls(Some(lhs), lhs_offset, len), + (None, Some(rhs)) => !contains_nulls(Some(rhs), rhs_offset, len), + (None, None) => true, } }