From a6379fe9818de3133a80700b7b8d6d9d59b29e37 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=B6rn=20Horstmann?= Date: Sat, 16 Jul 2022 16:26:08 +0200 Subject: [PATCH] Support nullable indices in boolean take kernel and some optimizations (#2064) * Use iterator in boolean take kernel and support nullable indices * Improve performance of take_boolean kernel by processing validity and value bits separately * Add test for take_boolean with masked out of bounds indices * Test and fix for when only the indices contain null values --- arrow/src/compute/kernels/take.rs | 134 +++++++++++++++++++----------- 1 file changed, 86 insertions(+), 48 deletions(-) diff --git a/arrow/src/compute/kernels/take.rs b/arrow/src/compute/kernels/take.rs index fa907656ae8..4ff7d84d6df 100644 --- a/arrow/src/compute/kernels/take.rs +++ b/arrow/src/compute/kernels/take.rs @@ -599,66 +599,58 @@ where Ok(PrimitiveArray::::from(data)) } -/// `take` implementation for boolean arrays -fn take_boolean( - values: &BooleanArray, +fn take_bits( + values: &Buffer, + values_offset: usize, indices: &PrimitiveArray, -) -> Result +) -> Result where IndexType: ArrowNumericType, IndexType::Native: ToPrimitive, { - let data_len = indices.len(); + let len = indices.len(); + let values_slice = values.as_slice(); + let mut output_buffer = MutableBuffer::new_null(len); + let output_slice = output_buffer.as_slice_mut(); - let num_byte = bit_util::ceil(data_len, 8); - let mut val_buf = MutableBuffer::from_len_zeroed(num_byte); - - let val_slice = val_buf.as_slice_mut(); - - let null_count = values.null_count(); - - let nulls = if null_count == 0 { - (0..data_len).try_for_each::<_, Result<()>>(|i| { - let index = ToPrimitive::to_usize(&indices.value(i)).ok_or_else(|| { - ArrowError::ComputeError("Cast to usize failed".to_string()) - })?; + indices + .iter() + .enumerate() + .try_for_each::<_, Result<()>>(|(i, index)| { + if let Some(index) = index { + let index = ToPrimitive::to_usize(&index).ok_or_else(|| { + ArrowError::ComputeError("Cast to usize failed".to_string()) + })?; - if values.value(index) { - bit_util::set_bit(val_slice, i); + if bit_util::get_bit(values_slice, values_offset + index) { + bit_util::set_bit(output_slice, i); + } } Ok(()) })?; - indices.data_ref().null_buffer().cloned() - } else { - let mut null_buf = MutableBuffer::new(num_byte).with_bitset(num_byte, true); - let null_slice = null_buf.as_slice_mut(); - - (0..data_len).try_for_each::<_, Result<()>>(|i| { - let index = ToPrimitive::to_usize(&indices.value(i)).ok_or_else(|| { - ArrowError::ComputeError("Cast to usize failed".to_string()) - })?; - - if values.is_null(index) { - bit_util::unset_bit(null_slice, i); - } else if values.value(index) { - bit_util::set_bit(val_slice, i); - } - - Ok(()) - })?; + Ok(output_buffer.into()) +} - match indices.data_ref().null_buffer() { - Some(buffer) => Some(buffer_bin_and( - buffer, - indices.offset(), - &null_buf.into(), - 0, - indices.len(), - )), - None => Some(null_buf.into()), +/// `take` implementation for boolean arrays +fn take_boolean( + values: &BooleanArray, + indices: &PrimitiveArray, +) -> Result +where + IndexType: ArrowNumericType, + IndexType::Native: ToPrimitive, +{ + let val_buf = take_bits(values.values(), values.offset(), indices)?; + let null_buf = match values.data().null_buffer() { + Some(buf) if values.null_count() > 0 => { + Some(take_bits(buf, values.offset(), indices)?) } + _ => indices + .data() + .null_buffer() + .map(|b| b.bit_slice(indices.offset(), indices.len())), }; let data = unsafe { @@ -666,9 +658,9 @@ where DataType::Boolean, indices.len(), None, - nulls, + null_buf, 0, - vec![val_buf.into()], + vec![val_buf], vec![], ) }; @@ -1467,6 +1459,52 @@ mod tests { ); } + #[test] + fn test_take_bool_nullable_index() { + // indices where the masked invalid elements would be out of bounds + let index_data = ArrayData::try_new( + DataType::Int32, + 6, + Some(Buffer::from_iter(vec![ + false, true, false, true, false, true, + ])), + 0, + vec![Buffer::from_iter(vec![99, 0, 999, 1, 9999, 2])], + vec![], + ) + .unwrap(); + let index = UInt32Array::from(index_data); + test_take_boolean_arrays( + vec![Some(true), None, Some(false)], + &index, + None, + vec![None, Some(true), None, None, None, Some(false)], + ); + } + + #[test] + fn test_take_bool_nullable_index_nonnull_values() { + // indices where the masked invalid elements would be out of bounds + let index_data = ArrayData::try_new( + DataType::Int32, + 6, + Some(Buffer::from_iter(vec![ + false, true, false, true, false, true, + ])), + 0, + vec![Buffer::from_iter(vec![99, 0, 999, 1, 9999, 2])], + vec![], + ) + .unwrap(); + let index = UInt32Array::from(index_data); + test_take_boolean_arrays( + vec![Some(true), Some(true), Some(false)], + &index, + None, + vec![None, Some(true), None, Some(true), None, Some(false)], + ); + } + #[test] fn test_take_bool_with_offset() { let index =