From 7d64fc6fb652cf7312e34b5f113543bb0714a6bf Mon Sep 17 00:00:00 2001 From: Raphael Taylor-Davies Date: Mon, 13 Jun 2022 17:41:06 +0100 Subject: [PATCH] Mark typed buffer APIs safe (#996) (#1027) --- arrow/src/array/array_union.rs | 8 +++--- arrow/src/array/builder.rs | 4 +-- arrow/src/array/data.rs | 5 ++-- arrow/src/buffer/immutable.rs | 25 ++++++++----------- arrow/src/buffer/mutable.rs | 13 ++++------ arrow/src/buffer/ops.rs | 4 +-- arrow/src/compute/kernels/cast.rs | 2 +- arrow/src/compute/kernels/sort.rs | 6 ++--- arrow/src/compute/kernels/take.rs | 3 +-- parquet/src/arrow/array_reader/byte_array.rs | 4 +-- .../array_reader/byte_array_dictionary.rs | 6 ++--- parquet/src/arrow/array_reader/mod.rs | 8 +++--- parquet/src/arrow/buffer/dictionary_buffer.rs | 2 +- parquet/src/arrow/record_reader/mod.rs | 4 +-- 14 files changed, 39 insertions(+), 55 deletions(-) diff --git a/arrow/src/array/array_union.rs b/arrow/src/array/array_union.rs index 5cfab0bbf85..37a84a9001a 100644 --- a/arrow/src/array/array_union.rs +++ b/arrow/src/array/array_union.rs @@ -185,7 +185,7 @@ impl UnionArray { } // Check the type_ids - let type_id_slice: &[i8] = unsafe { type_ids.typed_data() }; + let type_id_slice: &[i8] = type_ids.typed_data(); let invalid_type_ids = type_id_slice .iter() .filter(|i| *i < &0) @@ -201,7 +201,7 @@ impl UnionArray { // Check the value offsets if provided if let Some(offset_buffer) = &value_offsets { let max_len = type_ids.len() as i32; - let offsets_slice: &[i32] = unsafe { offset_buffer.typed_data() }; + let offsets_slice: &[i32] = offset_buffer.typed_data(); let invalid_offsets = offsets_slice .iter() .filter(|i| *i < &0 || *i > &max_len) @@ -255,9 +255,7 @@ impl UnionArray { pub fn value_offset(&self, index: usize) -> i32 { assert!(index - self.offset() < self.len()); if self.is_dense() { - // safety: reinterpreting is safe since the offset buffer contains `i32` values and is - // properly aligned. - unsafe { self.data().buffers()[1].typed_data::()[index] } + self.data().buffers()[1].typed_data::()[index] } else { index as i32 } diff --git a/arrow/src/array/builder.rs b/arrow/src/array/builder.rs index 041b7a92c33..f8b3ac57cf8 100644 --- a/arrow/src/array/builder.rs +++ b/arrow/src/array/builder.rs @@ -76,7 +76,7 @@ pub(crate) fn builder_to_mutable_buffer( /// builder.append(45); /// let buffer = builder.finish(); /// -/// assert_eq!(unsafe { buffer.typed_data::() }, &[42, 43, 44, 45]); +/// assert_eq!(buffer.typed_data::(), &[42, 43, 44, 45]); /// # Ok(()) /// # } /// ``` @@ -291,7 +291,7 @@ impl BufferBuilder { /// /// let buffer = builder.finish(); /// - /// assert_eq!(unsafe { buffer.typed_data::() }, &[42, 44, 46]); + /// assert_eq!(buffer.typed_data::(), &[42, 44, 46]); /// ``` #[inline] pub fn finish(&mut self) -> Buffer { diff --git a/arrow/src/array/data.rs b/arrow/src/array/data.rs index 0ccbe6a7017..65fbc4df970 100644 --- a/arrow/src/array/data.rs +++ b/arrow/src/array/data.rs @@ -767,8 +767,7 @@ impl ArrayData { ))); } - // SAFETY: Bounds checked above - Ok(unsafe { &(buffer.typed_data::()[self.offset..self.offset + len]) }) + Ok(&buffer.typed_data::()[self.offset..self.offset + len]) } /// Does a cheap sanity check that the `self.len` values in `buffer` are valid @@ -1161,7 +1160,7 @@ impl ArrayData { // Justification: buffer size was validated above let indexes: &[T] = - unsafe { &(buffer.typed_data::()[self.offset..self.offset + self.len]) }; + &buffer.typed_data::()[self.offset..self.offset + self.len]; indexes.iter().enumerate().try_for_each(|(i, &dict_index)| { // Do not check the value is null (value can be arbitrary) diff --git a/arrow/src/buffer/immutable.rs b/arrow/src/buffer/immutable.rs index c34ea101bb3..d563a53adc5 100644 --- a/arrow/src/buffer/immutable.rs +++ b/arrow/src/buffer/immutable.rs @@ -181,19 +181,14 @@ impl Buffer { /// View buffer as typed slice. /// - /// # Safety + /// # Panics /// - /// `ArrowNativeType` is public so that it can be used as a trait bound for other public - /// components, such as the `ToByteSlice` trait. However, this means that it can be - /// implemented by user defined types, which it is not intended for. - pub unsafe fn typed_data(&self) -> &[T] { - // JUSTIFICATION - // Benefit - // Many of the buffers represent specific types, and consumers of `Buffer` often need to re-interpret them. - // Soundness - // * The pointer is non-null by construction - // * alignment asserted below. - let (prefix, offsets, suffix) = self.as_slice().align_to::(); + /// This function panics if the underlying buffer is not aligned + /// correctly for type `T`. + pub fn typed_data(&self) -> &[T] { + // SAFETY + // ArrowNativeType are trivially transmutable, and this method checks alignment + let (prefix, offsets, suffix) = unsafe { self.as_slice().align_to::() }; assert!(prefix.is_empty() && suffix.is_empty()); offsets } @@ -451,7 +446,7 @@ mod tests { macro_rules! check_as_typed_data { ($input: expr, $native_t: ty) => {{ let buffer = Buffer::from_slice_ref($input); - let slice: &[$native_t] = unsafe { buffer.typed_data::<$native_t>() }; + let slice: &[$native_t] = buffer.typed_data::<$native_t>(); assert_eq!($input, slice); }}; } @@ -573,12 +568,12 @@ mod tests { ) }; - let slice = unsafe { buffer.typed_data::() }; + let slice = buffer.typed_data::(); assert_eq!(slice, &[1, 2, 3, 4, 5]); let buffer = buffer.slice(std::mem::size_of::()); - let slice = unsafe { buffer.typed_data::() }; + let slice = buffer.typed_data::(); assert_eq!(slice, &[2, 3, 4, 5]); } } diff --git a/arrow/src/buffer/mutable.rs b/arrow/src/buffer/mutable.rs index 709973b4401..8dd9b4f8f95 100644 --- a/arrow/src/buffer/mutable.rs +++ b/arrow/src/buffer/mutable.rs @@ -275,17 +275,14 @@ impl MutableBuffer { /// View this buffer asa slice of a specific type. /// - /// # Safety - /// - /// This function must only be used with buffers which are treated - /// as type `T` (e.g. extended with items of type `T`). - /// /// # Panics /// /// This function panics if the underlying buffer is not aligned /// correctly for type `T`. - pub unsafe fn typed_data_mut(&mut self) -> &mut [T] { - let (prefix, offsets, suffix) = self.as_slice_mut().align_to_mut::(); + pub fn typed_data_mut(&mut self) -> &mut [T] { + // SAFETY + // ArrowNativeType are trivially transmutable, and this method checks alignment + let (prefix, offsets, suffix) = unsafe { self.as_slice_mut().align_to_mut::() }; assert!(prefix.is_empty() && suffix.is_empty()); offsets } @@ -299,7 +296,7 @@ impl MutableBuffer { /// assert_eq!(buffer.len(), 8) // u32 has 4 bytes /// ``` #[inline] - pub fn extend_from_slice(&mut self, items: &[T]) { + pub fn extend_from_slice(&mut self, items: &[T]) { let len = items.len(); let additional = len * std::mem::size_of::(); self.reserve(additional); diff --git a/arrow/src/buffer/ops.rs b/arrow/src/buffer/ops.rs index b3571d1740b..ea155c8d78e 100644 --- a/arrow/src/buffer/ops.rs +++ b/arrow/src/buffer/ops.rs @@ -68,9 +68,7 @@ where let left_chunks = left.bit_chunks(offset_in_bits, len_in_bits); - // Safety: buffer is always treated as type `u64` in the code - // below. - let result_chunks = unsafe { result.typed_data_mut::().iter_mut() }; + let result_chunks = result.typed_data_mut::().iter_mut(); result_chunks .zip(left_chunks.iter()) diff --git a/arrow/src/compute/kernels/cast.rs b/arrow/src/compute/kernels/cast.rs index 93a8ebcb6b5..9a4638d9773 100644 --- a/arrow/src/compute/kernels/cast.rs +++ b/arrow/src/compute/kernels/cast.rs @@ -2084,7 +2084,7 @@ where let list_data = array.data(); let str_values_buf = str_array.value_data(); - let offsets = unsafe { list_data.buffers()[0].typed_data::() }; + let offsets = list_data.buffers()[0].typed_data::(); let mut offset_builder = BufferBuilder::::new(offsets.len()); offsets.iter().try_for_each::<_, Result<_>>(|offset| { diff --git a/arrow/src/compute/kernels/sort.rs b/arrow/src/compute/kernels/sort.rs index 140a57f33ed..72ee8b68da2 100644 --- a/arrow/src/compute/kernels/sort.rs +++ b/arrow/src/compute/kernels/sort.rs @@ -452,8 +452,7 @@ fn sort_boolean( let mut result = MutableBuffer::new(result_capacity); // sets len to capacity so we can access the whole buffer as a typed slice result.resize(result_capacity, 0); - // Safety: the buffer is always treated as `u32` in the code below - let result_slice: &mut [u32] = unsafe { result.typed_data_mut() }; + let result_slice: &mut [u32] = result.typed_data_mut(); if options.nulls_first { let size = nulls_len.min(len); @@ -565,8 +564,7 @@ where let mut result = MutableBuffer::new(result_capacity); // sets len to capacity so we can access the whole buffer as a typed slice result.resize(result_capacity, 0); - // Safety: the buffer is always treated as `u32` in the code below - let result_slice: &mut [u32] = unsafe { result.typed_data_mut() }; + let result_slice: &mut [u32] = result.typed_data_mut(); if options.nulls_first { let size = nulls_len.min(len); diff --git a/arrow/src/compute/kernels/take.rs b/arrow/src/compute/kernels/take.rs index 567bf5c8ba2..03637ec81dd 100644 --- a/arrow/src/compute/kernels/take.rs +++ b/arrow/src/compute/kernels/take.rs @@ -688,8 +688,7 @@ where let bytes_offset = (data_len + 1) * std::mem::size_of::(); let mut offsets_buffer = MutableBuffer::from_len_zeroed(bytes_offset); - // Safety: the buffer is always treated as as a type of `OffsetSize` in the code below - let offsets = unsafe { offsets_buffer.typed_data_mut() }; + let offsets = offsets_buffer.typed_data_mut(); let mut values = MutableBuffer::new(0); let mut length_so_far = OffsetSize::zero(); offsets[0] = length_so_far; diff --git a/parquet/src/arrow/array_reader/byte_array.rs b/parquet/src/arrow/array_reader/byte_array.rs index 2e29b609474..9e0f83fa945 100644 --- a/parquet/src/arrow/array_reader/byte_array.rs +++ b/parquet/src/arrow/array_reader/byte_array.rs @@ -125,13 +125,13 @@ impl ArrayReader for ByteArrayReader { fn get_def_levels(&self) -> Option<&[i16]> { self.def_levels_buffer .as_ref() - .map(|buf| unsafe { buf.typed_data() }) + .map(|buf| buf.typed_data()) } fn get_rep_levels(&self) -> Option<&[i16]> { self.rep_levels_buffer .as_ref() - .map(|buf| unsafe { buf.typed_data() }) + .map(|buf| buf.typed_data()) } } diff --git a/parquet/src/arrow/array_reader/byte_array_dictionary.rs b/parquet/src/arrow/array_reader/byte_array_dictionary.rs index 0e64f0d25b7..0cd67206f00 100644 --- a/parquet/src/arrow/array_reader/byte_array_dictionary.rs +++ b/parquet/src/arrow/array_reader/byte_array_dictionary.rs @@ -187,13 +187,13 @@ where fn get_def_levels(&self) -> Option<&[i16]> { self.def_levels_buffer .as_ref() - .map(|buf| unsafe { buf.typed_data() }) + .map(|buf| buf.typed_data()) } fn get_rep_levels(&self) -> Option<&[i16]> { self.rep_levels_buffer .as_ref() - .map(|buf| unsafe { buf.typed_data() }) + .map(|buf| buf.typed_data()) } } @@ -356,7 +356,7 @@ where assert_eq!(dict.data_type(), &self.value_type); let dict_buffers = dict.data().buffers(); - let dict_offsets = unsafe { dict_buffers[0].typed_data::() }; + let dict_offsets = dict_buffers[0].typed_data::(); let dict_values = dict_buffers[1].as_slice(); values.extend_from_dictionary( diff --git a/parquet/src/arrow/array_reader/mod.rs b/parquet/src/arrow/array_reader/mod.rs index 21c49b33878..6207b377d13 100644 --- a/parquet/src/arrow/array_reader/mod.rs +++ b/parquet/src/arrow/array_reader/mod.rs @@ -226,13 +226,13 @@ where fn get_def_levels(&self) -> Option<&[i16]> { self.def_levels_buffer .as_ref() - .map(|buf| unsafe { buf.typed_data() }) + .map(|buf| buf.typed_data()) } fn get_rep_levels(&self) -> Option<&[i16]> { self.rep_levels_buffer .as_ref() - .map(|buf| unsafe { buf.typed_data() }) + .map(|buf| buf.typed_data()) } } @@ -447,13 +447,13 @@ where fn get_def_levels(&self) -> Option<&[i16]> { self.def_levels_buffer .as_ref() - .map(|buf| unsafe { buf.typed_data() }) + .map(|buf| buf.typed_data()) } fn get_rep_levels(&self) -> Option<&[i16]> { self.rep_levels_buffer .as_ref() - .map(|buf| unsafe { buf.typed_data() }) + .map(|buf| buf.typed_data()) } } diff --git a/parquet/src/arrow/buffer/dictionary_buffer.rs b/parquet/src/arrow/buffer/dictionary_buffer.rs index 7f445850700..ffa3a4843c5 100644 --- a/parquet/src/arrow/buffer/dictionary_buffer.rs +++ b/parquet/src/arrow/buffer/dictionary_buffer.rs @@ -106,7 +106,7 @@ impl Self::Dict { keys, values } => { let mut spilled = OffsetBuffer::default(); let dict_buffers = values.data().buffers(); - let dict_offsets = unsafe { dict_buffers[0].typed_data::() }; + let dict_offsets = dict_buffers[0].typed_data::(); let dict_values = dict_buffers[1].as_slice(); if values.is_empty() { diff --git a/parquet/src/arrow/record_reader/mod.rs b/parquet/src/arrow/record_reader/mod.rs index 89d782b1aca..023a538a274 100644 --- a/parquet/src/arrow/record_reader/mod.rs +++ b/parquet/src/arrow/record_reader/mod.rs @@ -573,7 +573,7 @@ mod tests { // Verify result record data let actual = record_reader.consume_record_data().unwrap(); - let actual_values = unsafe { actual.typed_data::() }; + let actual_values = actual.typed_data::(); let expected = &[0, 7, 0, 6, 3, 0, 8]; assert_eq!(actual_values.len(), expected.len()); @@ -687,7 +687,7 @@ mod tests { // Verify result record data let actual = record_reader.consume_record_data().unwrap(); - let actual_values = unsafe { actual.typed_data::() }; + let actual_values = actual.typed_data::(); let expected = &[4, 0, 0, 7, 6, 3, 2, 8, 9]; assert_eq!(actual_values.len(), expected.len());