From b28ec07ba00cd1921e91b946ca00adcbb974f5ba Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Sat, 11 Dec 2021 08:12:21 -0500 Subject: [PATCH 1/3] Mark `MutableBuffer::typed_data_mut` unsafe --- arrow/src/buffer/mutable.rs | 21 +++++++++++++-------- arrow/src/buffer/ops.rs | 5 ++++- arrow/src/compute/kernels/sort.rs | 6 ++++-- arrow/src/compute/kernels/take.rs | 3 ++- parquet/src/arrow/array_reader.rs | 5 ++++- 5 files changed, 27 insertions(+), 13 deletions(-) diff --git a/arrow/src/buffer/mutable.rs b/arrow/src/buffer/mutable.rs index 61593af704d..7beada96528 100644 --- a/arrow/src/buffer/mutable.rs +++ b/arrow/src/buffer/mutable.rs @@ -273,15 +273,20 @@ impl MutableBuffer { } /// View this buffer asa slice of a specific type. + /// /// # Safety - /// This function must only be used when this buffer was extended with items of type `T`. - /// Failure to do so results in undefined behavior. - pub fn typed_data_mut(&mut self) -> &mut [T] { - unsafe { - let (prefix, offsets, suffix) = self.as_slice_mut().align_to_mut::(); - assert!(prefix.is_empty() && suffix.is_empty()); - offsets - } + /// + /// This function must only be used with buffers which are treated + /// as type `T` (e.g. extended with items of type `T`). + /// + /// # Panics + /// + /// This function panics if the underlying buffer is not aligned + /// correctly for type `T`. + pub unsafe fn typed_data_mut(&mut self) -> &mut [T] { + let (prefix, offsets, suffix) = self.as_slice_mut().align_to_mut::(); + assert!(prefix.is_empty() && suffix.is_empty()); + offsets } /// Extends this buffer from a slice of items that can be represented in bytes, increasing its capacity if needed. diff --git a/arrow/src/buffer/ops.rs b/arrow/src/buffer/ops.rs index c37fd14bd22..14d381199bd 100644 --- a/arrow/src/buffer/ops.rs +++ b/arrow/src/buffer/ops.rs @@ -168,7 +168,10 @@ where MutableBuffer::new(ceil(len_in_bits, 8)).with_bitset(len_in_bits / 64 * 8, false); let left_chunks = left.bit_chunks(offset_in_bits, len_in_bits); - let result_chunks = result.typed_data_mut::().iter_mut(); + + // Safety: buffer is always treated as type `u64` in the code + // below. + let result_chunks = unsafe { result.typed_data_mut::().iter_mut() }; result_chunks .zip(left_chunks.iter()) diff --git a/arrow/src/compute/kernels/sort.rs b/arrow/src/compute/kernels/sort.rs index 6a72224979c..1046853fbb2 100644 --- a/arrow/src/compute/kernels/sort.rs +++ b/arrow/src/compute/kernels/sort.rs @@ -471,7 +471,8 @@ fn sort_boolean( let mut result = MutableBuffer::new(result_capacity); // sets len to capacity so we can access the whole buffer as a typed slice result.resize(result_capacity, 0); - let result_slice: &mut [u32] = result.typed_data_mut(); + // Safety: the buffer is always treated as `u32` in the code below + let result_slice: &mut [u32] = unsafe { result.typed_data_mut() }; if options.nulls_first { let size = nulls_len.min(len); @@ -559,7 +560,8 @@ where let mut result = MutableBuffer::new(result_capacity); // sets len to capacity so we can access the whole buffer as a typed slice result.resize(result_capacity, 0); - let result_slice: &mut [u32] = result.typed_data_mut(); + // Safety: the buffer is always treated as `u32` in the code below + let result_slice: &mut [u32] = unsafe { result.typed_data_mut() }; if options.nulls_first { let size = nulls_len.min(len); diff --git a/arrow/src/compute/kernels/take.rs b/arrow/src/compute/kernels/take.rs index 9fe00ea9a7b..63df3aba83d 100644 --- a/arrow/src/compute/kernels/take.rs +++ b/arrow/src/compute/kernels/take.rs @@ -632,7 +632,8 @@ where let bytes_offset = (data_len + 1) * std::mem::size_of::(); let mut offsets_buffer = MutableBuffer::from_len_zeroed(bytes_offset); - let offsets = offsets_buffer.typed_data_mut(); + // Safety: the buffer is always treated as as a type of `OffsetSize` in the code below + let offsets = unsafe { offsets_buffer.typed_data_mut() }; let mut values = MutableBuffer::new(0); let mut length_so_far = OffsetSize::zero(); offsets[0] = length_so_far; diff --git a/parquet/src/arrow/array_reader.rs b/parquet/src/arrow/array_reader.rs index ae001ed7339..a11befe5416 100644 --- a/parquet/src/arrow/array_reader.rs +++ b/parquet/src/arrow/array_reader.rs @@ -1154,7 +1154,10 @@ impl ArrayReader for StructArrayReader { let mut def_level_data_buffer = MutableBuffer::new(buffer_size); def_level_data_buffer.resize(buffer_size, 0); - let def_level_data = def_level_data_buffer.typed_data_mut(); + // Safety: the buffer is always treated as `u16` in the code below + let def_level_data = unsafe { + def_level_data_buffer.typed_data_mut() + }; def_level_data .iter_mut() From f1f32a6c45138b861aa83b5909f5eb084088ecc1 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Sat, 11 Dec 2021 08:17:08 -0500 Subject: [PATCH 2/3] fmt --- parquet/src/arrow/array_reader.rs | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/parquet/src/arrow/array_reader.rs b/parquet/src/arrow/array_reader.rs index a11befe5416..922d2cb54aa 100644 --- a/parquet/src/arrow/array_reader.rs +++ b/parquet/src/arrow/array_reader.rs @@ -1154,10 +1154,8 @@ impl ArrayReader for StructArrayReader { let mut def_level_data_buffer = MutableBuffer::new(buffer_size); def_level_data_buffer.resize(buffer_size, 0); - // Safety: the buffer is always treated as `u16` in the code below - let def_level_data = unsafe { - def_level_data_buffer.typed_data_mut() - }; + // Safety: the buffer is always treated as `u16` in the code below + let def_level_data = unsafe { def_level_data_buffer.typed_data_mut() }; def_level_data .iter_mut() From 99c1b7d5217fc6425643be5611d810cc28ab5607 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Tue, 14 Dec 2021 14:10:49 -0500 Subject: [PATCH 3/3] Mark use of `typed_data_but` as unsafe in simd kernels --- arrow/src/compute/kernels/arithmetic.rs | 33 ++++++++++++----- arrow/src/compute/kernels/comparison.rs | 47 +++++++++++++------------ 2 files changed, 48 insertions(+), 32 deletions(-) diff --git a/arrow/src/compute/kernels/arithmetic.rs b/arrow/src/compute/kernels/arithmetic.rs index f92888b3796..09d4b9fd6cd 100644 --- a/arrow/src/compute/kernels/arithmetic.rs +++ b/arrow/src/compute/kernels/arithmetic.rs @@ -57,7 +57,8 @@ where let buffer_size = array.len() * std::mem::size_of::(); let mut result = MutableBuffer::new(buffer_size).with_bitset(buffer_size, false); - let mut result_chunks = result.typed_data_mut().chunks_exact_mut(lanes); + // safety: result is newly created above, always written as a T below + let mut result_chunks = unsafe { result.typed_data_mut().chunks_exact_mut(lanes) }; let mut array_chunks = array.values().chunks_exact(lanes); result_chunks @@ -111,7 +112,8 @@ where let mut result = MutableBuffer::new(buffer_size).with_bitset(buffer_size, false); - let mut result_chunks = result.typed_data_mut().chunks_exact_mut(lanes); + // safety: result is newly created above, always written as a T below + let mut result_chunks = unsafe { result.typed_data_mut().chunks_exact_mut(lanes) }; let mut array_chunks = array.values().chunks_exact(lanes); result_chunks @@ -398,7 +400,8 @@ where let buffer_size = left.len() * std::mem::size_of::(); let mut result = MutableBuffer::new(buffer_size).with_bitset(buffer_size, false); - let mut result_chunks = result.typed_data_mut().chunks_exact_mut(lanes); + // safety: result is newly created above, always written as a T below + let mut result_chunks = unsafe { result.typed_data_mut().chunks_exact_mut(lanes) }; let mut left_chunks = left.values().chunks_exact(lanes); let mut right_chunks = right.values().chunks_exact(lanes); @@ -662,7 +665,10 @@ where let valid_chunks = b.bit_chunks(0, left.len()); // process data in chunks of 64 elements since we also get 64 bits of validity information at a time - let mut result_chunks = result.typed_data_mut().chunks_exact_mut(64); + + // safety: result is newly created above, always written as a T below + let mut result_chunks = + unsafe { result.typed_data_mut().chunks_exact_mut(64) }; let mut left_chunks = left.values().chunks_exact(64); let mut right_chunks = right.values().chunks_exact(64); @@ -707,7 +713,9 @@ where )?; } None => { - let mut result_chunks = result.typed_data_mut().chunks_exact_mut(lanes); + // safety: result is newly created above, always written as a T below + let mut result_chunks = + unsafe { result.typed_data_mut().chunks_exact_mut(lanes) }; let mut left_chunks = left.values().chunks_exact(lanes); let mut right_chunks = right.values().chunks_exact(lanes); @@ -784,7 +792,10 @@ where let valid_chunks = b.bit_chunks(0, left.len()); // process data in chunks of 64 elements since we also get 64 bits of validity information at a time - let mut result_chunks = result.typed_data_mut().chunks_exact_mut(64); + + // safety: result is newly created above, always written as a T below + let mut result_chunks = + unsafe { result.typed_data_mut().chunks_exact_mut(64) }; let mut left_chunks = left.values().chunks_exact(64); let mut right_chunks = right.values().chunks_exact(64); @@ -829,7 +840,9 @@ where )?; } None => { - let mut result_chunks = result.typed_data_mut().chunks_exact_mut(lanes); + // safety: result is newly created above, always written as a T below + let mut result_chunks = + unsafe { result.typed_data_mut().chunks_exact_mut(lanes) }; let mut left_chunks = left.values().chunks_exact(lanes); let mut right_chunks = right.values().chunks_exact(lanes); @@ -891,7 +904,8 @@ where let buffer_size = array.len() * std::mem::size_of::(); let mut result = MutableBuffer::new(buffer_size).with_bitset(buffer_size, false); - let mut result_chunks = result.typed_data_mut().chunks_exact_mut(lanes); + // safety: result is newly created above, always written as a T below + let mut result_chunks = unsafe { result.typed_data_mut().chunks_exact_mut(lanes) }; let mut array_chunks = array.values().chunks_exact(lanes); result_chunks @@ -942,7 +956,8 @@ where let buffer_size = array.len() * std::mem::size_of::(); let mut result = MutableBuffer::new(buffer_size).with_bitset(buffer_size, false); - let mut result_chunks = result.typed_data_mut().chunks_exact_mut(lanes); + // safety: result is newly created above, always written as a T below + let mut result_chunks = unsafe { result.typed_data_mut().chunks_exact_mut(lanes) }; let mut array_chunks = array.values().chunks_exact(lanes); result_chunks diff --git a/arrow/src/compute/kernels/comparison.rs b/arrow/src/compute/kernels/comparison.rs index 125be912caf..33644c4ff2b 100644 --- a/arrow/src/compute/kernels/comparison.rs +++ b/arrow/src/compute/kernels/comparison.rs @@ -921,23 +921,22 @@ where let mut left_chunks = left.values().chunks_exact(lanes); let mut right_chunks = right.values().chunks_exact(lanes); + // safety: result is newly created above, always written as a T below + let result_chunks = unsafe { result.typed_data_mut() }; let result_remainder = left_chunks .borrow_mut() .zip(right_chunks.borrow_mut()) - .fold( - result.typed_data_mut(), - |result_slice, (left_slice, right_slice)| { - let simd_left = T::load(left_slice); - let simd_right = T::load(right_slice); - let simd_result = simd_op(simd_left, simd_right); + .fold(result_chunks, |result_slice, (left_slice, right_slice)| { + let simd_left = T::load(left_slice); + let simd_right = T::load(right_slice); + let simd_result = simd_op(simd_left, simd_right); - let bitmask = T::mask_to_u64(&simd_result); - let bytes = bitmask.to_le_bytes(); - result_slice[0..lanes / 8].copy_from_slice(&bytes[0..lanes / 8]); + let bitmask = T::mask_to_u64(&simd_result); + let bytes = bitmask.to_le_bytes(); + result_slice[0..lanes / 8].copy_from_slice(&bytes[0..lanes / 8]); - &mut result_slice[lanes / 8..] - }, - ); + &mut result_slice[lanes / 8..] + }); let left_remainder = left_chunks.remainder(); let right_remainder = right_chunks.remainder(); @@ -1005,19 +1004,21 @@ where let mut left_chunks = left.values().chunks_exact(lanes); let simd_right = T::init(right); - let result_remainder = left_chunks.borrow_mut().fold( - result.typed_data_mut(), - |result_slice, left_slice| { - let simd_left = T::load(left_slice); - let simd_result = simd_op(simd_left, simd_right); + // safety: result is newly created above, always written as a T below + let result_chunks = unsafe { result.typed_data_mut() }; + let result_remainder = + left_chunks + .borrow_mut() + .fold(result_chunks, |result_slice, left_slice| { + let simd_left = T::load(left_slice); + let simd_result = simd_op(simd_left, simd_right); - let bitmask = T::mask_to_u64(&simd_result); - let bytes = bitmask.to_le_bytes(); - result_slice[0..lanes / 8].copy_from_slice(&bytes[0..lanes / 8]); + let bitmask = T::mask_to_u64(&simd_result); + let bytes = bitmask.to_le_bytes(); + result_slice[0..lanes / 8].copy_from_slice(&bytes[0..lanes / 8]); - &mut result_slice[lanes / 8..] - }, - ); + &mut result_slice[lanes / 8..] + }); let left_remainder = left_chunks.remainder();