Skip to content

Commit

Permalink
Mark MutableBuffer::typed_data_mut unsafe (#1029)
Browse files Browse the repository at this point in the history
* Mark `MutableBuffer::typed_data_mut` unsafe

* fmt

* Mark use of `typed_data_but` as unsafe in simd kernels
  • Loading branch information
alamb committed Dec 15, 2021
1 parent ab48e69 commit f21fa54
Show file tree
Hide file tree
Showing 7 changed files with 73 additions and 45 deletions.
21 changes: 13 additions & 8 deletions arrow/src/buffer/mutable.rs
Expand Up @@ -273,15 +273,20 @@ impl MutableBuffer {
}

/// View this buffer asa slice of a specific type.
///
/// # Safety
/// This function must only be used when this buffer was extended with items of type `T`.
/// Failure to do so results in undefined behavior.
pub fn typed_data_mut<T: ArrowNativeType>(&mut self) -> &mut [T] {
unsafe {
let (prefix, offsets, suffix) = self.as_slice_mut().align_to_mut::<T>();
assert!(prefix.is_empty() && suffix.is_empty());
offsets
}
///
/// This function must only be used with buffers which are treated
/// as type `T` (e.g. extended with items of type `T`).
///
/// # Panics
///
/// This function panics if the underlying buffer is not aligned
/// correctly for type `T`.
pub unsafe fn typed_data_mut<T: ArrowNativeType>(&mut self) -> &mut [T] {
let (prefix, offsets, suffix) = self.as_slice_mut().align_to_mut::<T>();
assert!(prefix.is_empty() && suffix.is_empty());
offsets
}

/// Extends this buffer from a slice of items that can be represented in bytes, increasing its capacity if needed.
Expand Down
5 changes: 4 additions & 1 deletion arrow/src/buffer/ops.rs
Expand Up @@ -168,7 +168,10 @@ where
MutableBuffer::new(ceil(len_in_bits, 8)).with_bitset(len_in_bits / 64 * 8, false);

let left_chunks = left.bit_chunks(offset_in_bits, len_in_bits);
let result_chunks = result.typed_data_mut::<u64>().iter_mut();

// Safety: buffer is always treated as type `u64` in the code
// below.
let result_chunks = unsafe { result.typed_data_mut::<u64>().iter_mut() };

result_chunks
.zip(left_chunks.iter())
Expand Down
33 changes: 24 additions & 9 deletions arrow/src/compute/kernels/arithmetic.rs
Expand Up @@ -57,7 +57,8 @@ where
let buffer_size = array.len() * std::mem::size_of::<T::Native>();
let mut result = MutableBuffer::new(buffer_size).with_bitset(buffer_size, false);

let mut result_chunks = result.typed_data_mut().chunks_exact_mut(lanes);
// safety: result is newly created above, always written as a T below
let mut result_chunks = unsafe { result.typed_data_mut().chunks_exact_mut(lanes) };
let mut array_chunks = array.values().chunks_exact(lanes);

result_chunks
Expand Down Expand Up @@ -111,7 +112,8 @@ where

let mut result = MutableBuffer::new(buffer_size).with_bitset(buffer_size, false);

let mut result_chunks = result.typed_data_mut().chunks_exact_mut(lanes);
// safety: result is newly created above, always written as a T below
let mut result_chunks = unsafe { result.typed_data_mut().chunks_exact_mut(lanes) };
let mut array_chunks = array.values().chunks_exact(lanes);

result_chunks
Expand Down Expand Up @@ -398,7 +400,8 @@ where
let buffer_size = left.len() * std::mem::size_of::<T::Native>();
let mut result = MutableBuffer::new(buffer_size).with_bitset(buffer_size, false);

let mut result_chunks = result.typed_data_mut().chunks_exact_mut(lanes);
// safety: result is newly created above, always written as a T below
let mut result_chunks = unsafe { result.typed_data_mut().chunks_exact_mut(lanes) };
let mut left_chunks = left.values().chunks_exact(lanes);
let mut right_chunks = right.values().chunks_exact(lanes);

Expand Down Expand Up @@ -662,7 +665,10 @@ where
let valid_chunks = b.bit_chunks(0, left.len());

// process data in chunks of 64 elements since we also get 64 bits of validity information at a time
let mut result_chunks = result.typed_data_mut().chunks_exact_mut(64);

// safety: result is newly created above, always written as a T below
let mut result_chunks =
unsafe { result.typed_data_mut().chunks_exact_mut(64) };
let mut left_chunks = left.values().chunks_exact(64);
let mut right_chunks = right.values().chunks_exact(64);

Expand Down Expand Up @@ -707,7 +713,9 @@ where
)?;
}
None => {
let mut result_chunks = result.typed_data_mut().chunks_exact_mut(lanes);
// safety: result is newly created above, always written as a T below
let mut result_chunks =
unsafe { result.typed_data_mut().chunks_exact_mut(lanes) };
let mut left_chunks = left.values().chunks_exact(lanes);
let mut right_chunks = right.values().chunks_exact(lanes);

Expand Down Expand Up @@ -784,7 +792,10 @@ where
let valid_chunks = b.bit_chunks(0, left.len());

// process data in chunks of 64 elements since we also get 64 bits of validity information at a time
let mut result_chunks = result.typed_data_mut().chunks_exact_mut(64);

// safety: result is newly created above, always written as a T below
let mut result_chunks =
unsafe { result.typed_data_mut().chunks_exact_mut(64) };
let mut left_chunks = left.values().chunks_exact(64);
let mut right_chunks = right.values().chunks_exact(64);

Expand Down Expand Up @@ -829,7 +840,9 @@ where
)?;
}
None => {
let mut result_chunks = result.typed_data_mut().chunks_exact_mut(lanes);
// safety: result is newly created above, always written as a T below
let mut result_chunks =
unsafe { result.typed_data_mut().chunks_exact_mut(lanes) };
let mut left_chunks = left.values().chunks_exact(lanes);
let mut right_chunks = right.values().chunks_exact(lanes);

Expand Down Expand Up @@ -891,7 +904,8 @@ where
let buffer_size = array.len() * std::mem::size_of::<T::Native>();
let mut result = MutableBuffer::new(buffer_size).with_bitset(buffer_size, false);

let mut result_chunks = result.typed_data_mut().chunks_exact_mut(lanes);
// safety: result is newly created above, always written as a T below
let mut result_chunks = unsafe { result.typed_data_mut().chunks_exact_mut(lanes) };
let mut array_chunks = array.values().chunks_exact(lanes);

result_chunks
Expand Down Expand Up @@ -942,7 +956,8 @@ where
let buffer_size = array.len() * std::mem::size_of::<T::Native>();
let mut result = MutableBuffer::new(buffer_size).with_bitset(buffer_size, false);

let mut result_chunks = result.typed_data_mut().chunks_exact_mut(lanes);
// safety: result is newly created above, always written as a T below
let mut result_chunks = unsafe { result.typed_data_mut().chunks_exact_mut(lanes) };
let mut array_chunks = array.values().chunks_exact(lanes);

result_chunks
Expand Down
47 changes: 24 additions & 23 deletions arrow/src/compute/kernels/comparison.rs
Expand Up @@ -921,23 +921,22 @@ where
let mut left_chunks = left.values().chunks_exact(lanes);
let mut right_chunks = right.values().chunks_exact(lanes);

// safety: result is newly created above, always written as a T below
let result_chunks = unsafe { result.typed_data_mut() };
let result_remainder = left_chunks
.borrow_mut()
.zip(right_chunks.borrow_mut())
.fold(
result.typed_data_mut(),
|result_slice, (left_slice, right_slice)| {
let simd_left = T::load(left_slice);
let simd_right = T::load(right_slice);
let simd_result = simd_op(simd_left, simd_right);
.fold(result_chunks, |result_slice, (left_slice, right_slice)| {
let simd_left = T::load(left_slice);
let simd_right = T::load(right_slice);
let simd_result = simd_op(simd_left, simd_right);

let bitmask = T::mask_to_u64(&simd_result);
let bytes = bitmask.to_le_bytes();
result_slice[0..lanes / 8].copy_from_slice(&bytes[0..lanes / 8]);
let bitmask = T::mask_to_u64(&simd_result);
let bytes = bitmask.to_le_bytes();
result_slice[0..lanes / 8].copy_from_slice(&bytes[0..lanes / 8]);

&mut result_slice[lanes / 8..]
},
);
&mut result_slice[lanes / 8..]
});

let left_remainder = left_chunks.remainder();
let right_remainder = right_chunks.remainder();
Expand Down Expand Up @@ -1005,19 +1004,21 @@ where
let mut left_chunks = left.values().chunks_exact(lanes);
let simd_right = T::init(right);

let result_remainder = left_chunks.borrow_mut().fold(
result.typed_data_mut(),
|result_slice, left_slice| {
let simd_left = T::load(left_slice);
let simd_result = simd_op(simd_left, simd_right);
// safety: result is newly created above, always written as a T below
let result_chunks = unsafe { result.typed_data_mut() };
let result_remainder =
left_chunks
.borrow_mut()
.fold(result_chunks, |result_slice, left_slice| {
let simd_left = T::load(left_slice);
let simd_result = simd_op(simd_left, simd_right);

let bitmask = T::mask_to_u64(&simd_result);
let bytes = bitmask.to_le_bytes();
result_slice[0..lanes / 8].copy_from_slice(&bytes[0..lanes / 8]);
let bitmask = T::mask_to_u64(&simd_result);
let bytes = bitmask.to_le_bytes();
result_slice[0..lanes / 8].copy_from_slice(&bytes[0..lanes / 8]);

&mut result_slice[lanes / 8..]
},
);
&mut result_slice[lanes / 8..]
});

let left_remainder = left_chunks.remainder();

Expand Down
6 changes: 4 additions & 2 deletions arrow/src/compute/kernels/sort.rs
Expand Up @@ -471,7 +471,8 @@ fn sort_boolean(
let mut result = MutableBuffer::new(result_capacity);
// sets len to capacity so we can access the whole buffer as a typed slice
result.resize(result_capacity, 0);
let result_slice: &mut [u32] = result.typed_data_mut();
// Safety: the buffer is always treated as `u32` in the code below
let result_slice: &mut [u32] = unsafe { result.typed_data_mut() };

if options.nulls_first {
let size = nulls_len.min(len);
Expand Down Expand Up @@ -559,7 +560,8 @@ where
let mut result = MutableBuffer::new(result_capacity);
// sets len to capacity so we can access the whole buffer as a typed slice
result.resize(result_capacity, 0);
let result_slice: &mut [u32] = result.typed_data_mut();
// Safety: the buffer is always treated as `u32` in the code below
let result_slice: &mut [u32] = unsafe { result.typed_data_mut() };

if options.nulls_first {
let size = nulls_len.min(len);
Expand Down
3 changes: 2 additions & 1 deletion arrow/src/compute/kernels/take.rs
Expand Up @@ -632,7 +632,8 @@ where
let bytes_offset = (data_len + 1) * std::mem::size_of::<OffsetSize>();
let mut offsets_buffer = MutableBuffer::from_len_zeroed(bytes_offset);

let offsets = offsets_buffer.typed_data_mut();
// Safety: the buffer is always treated as as a type of `OffsetSize` in the code below
let offsets = unsafe { offsets_buffer.typed_data_mut() };
let mut values = MutableBuffer::new(0);
let mut length_so_far = OffsetSize::zero();
offsets[0] = length_so_far;
Expand Down
3 changes: 2 additions & 1 deletion parquet/src/arrow/array_reader.rs
Expand Up @@ -1140,7 +1140,8 @@ impl ArrayReader for StructArrayReader {
let mut def_level_data_buffer = MutableBuffer::new(buffer_size);
def_level_data_buffer.resize(buffer_size, 0);

let def_level_data = def_level_data_buffer.typed_data_mut();
// Safety: the buffer is always treated as `u16` in the code below
let def_level_data = unsafe { def_level_data_buffer.typed_data_mut() };

def_level_data
.iter_mut()
Expand Down

0 comments on commit f21fa54

Please sign in to comment.