From 9c40a87f2a83429afbd5e6b76457cea2895372fd Mon Sep 17 00:00:00 2001 From: Raphael Taylor-Davies <1781103+tustvold@users.noreply.github.com> Date: Tue, 7 Jun 2022 23:06:44 +0100 Subject: [PATCH] Fix Decimal and List ArrayData Validation (#1813) (#1814) (#1816) * Fix DecimalArray validation (#1813) Fix offset validation for sliced children of list arrays (#1814) * Update arrow/src/array/data.rs Co-authored-by: Liang-Chi Hsieh Co-authored-by: Liang-Chi Hsieh --- arrow/src/array/data.rs | 163 ++++++++++++++++++++++------------------ 1 file changed, 91 insertions(+), 72 deletions(-) diff --git a/arrow/src/array/data.rs b/arrow/src/array/data.rs index fcf89f473fd..0ccbe6a7017 100644 --- a/arrow/src/array/data.rs +++ b/arrow/src/array/data.rs @@ -712,10 +712,10 @@ impl ArrayData { // Additional Type specific checks match &self.data_type { DataType::Utf8 | DataType::Binary => { - self.validate_offsets::(&self.buffers[0], self.buffers[1].len())?; + self.validate_offsets::(self.buffers[1].len())?; } DataType::LargeUtf8 | DataType::LargeBinary => { - self.validate_offsets::(&self.buffers[0], self.buffers[1].len())?; + self.validate_offsets::(self.buffers[1].len())?; } DataType::Dictionary(key_type, _value_type) => { // At the moment, constructing a DictionaryArray will also check this @@ -738,40 +738,47 @@ impl ArrayData { /// entries. /// /// For an empty array, the `buffer` can also be empty. - fn typed_offsets<'a, T: ArrowNativeType + num::Num + std::fmt::Display>( - &'a self, - buffer: &'a Buffer, - ) -> Result<&'a [T]> { + fn typed_offsets(&self) -> Result<&[T]> { // An empty list-like array can have 0 offsets - if buffer.is_empty() && self.len == 0 { + if self.len == 0 && self.buffers[0].is_empty() { return Ok(&[]); } - // Validate that there are the correct number of offsets for this array's length - let required_offsets = self.len + self.offset + 1; + self.typed_buffer(0, self.len + 1) + } + + /// Returns a reference to the data in `buffers[idx]` as a typed slice after validating + fn typed_buffer( + &self, + idx: usize, + len: usize, + ) -> Result<&[T]> { + let buffer = &self.buffers[idx]; + + let required_len = (len + self.offset) * std::mem::size_of::(); - if (buffer.len() / std::mem::size_of::()) < required_offsets { + if buffer.len() < required_len { return Err(ArrowError::InvalidArgumentError(format!( - "Offsets buffer size (bytes): {} isn't large enough for {}. Length {} needs {}", - buffer.len(), self.data_type, self.len, required_offsets + "Buffer {} of {} isn't large enough. Expected {} bytes got {}", + idx, + self.data_type, + required_len, + buffer.len() ))); } - // Justification: buffer size was validated above - Ok(unsafe { - &(buffer.typed_data::()[self.offset..self.offset + self.len + 1]) - }) + // SAFETY: Bounds checked above + Ok(unsafe { &(buffer.typed_data::()[self.offset..self.offset + len]) }) } /// Does a cheap sanity check that the `self.len` values in `buffer` are valid /// offsets (of type T) into some other buffer of `values_length` bytes long fn validate_offsets( &self, - buffer: &Buffer, values_length: usize, ) -> Result<()> { // Justification: buffer size was validated above - let offsets = self.typed_offsets::(buffer)?; + let offsets = self.typed_offsets::()?; if offsets.is_empty() { return Ok(()); } @@ -819,12 +826,12 @@ impl ArrayData { match &self.data_type { DataType::List(field) | DataType::Map(field, _) => { let values_data = self.get_single_valid_child_data(field.data_type())?; - self.validate_offsets::(&self.buffers[0], values_data.len)?; + self.validate_offsets::(values_data.len)?; Ok(()) } DataType::LargeList(field) => { let values_data = self.get_single_valid_child_data(field.data_type())?; - self.validate_offsets::(&self.buffers[0], values_data.len)?; + self.validate_offsets::(values_data.len)?; Ok(()) } DataType::FixedSizeList(field, list_size) => { @@ -1000,17 +1007,9 @@ impl ArrayData { pub fn validate_dictionary_offset(&self) -> Result<()> { match &self.data_type { DataType::Decimal(p, _) => { - let values_buffer = &self.buffers[0]; - - for pos in 0..values_buffer.len() { - let raw_val = unsafe { - std::slice::from_raw_parts( - values_buffer.as_ptr().add(pos), - 16_usize, - ) - }; - let value = i128::from_le_bytes(raw_val.try_into().unwrap()); - validate_decimal_precision(value, *p)?; + let values_buffer: &[i128] = self.typed_buffer(0, self.len)?; + for value in values_buffer { + validate_decimal_precision(*value, *p)?; } Ok(()) } @@ -1022,11 +1021,11 @@ impl ArrayData { } DataType::List(_) | DataType::Map(_, _) => { let child = &self.child_data[0]; - self.validate_offsets_full::(child.len + child.offset) + self.validate_offsets_full::(child.len) } DataType::LargeList(_) => { let child = &self.child_data[0]; - self.validate_offsets_full::(child.len + child.offset) + self.validate_offsets_full::(child.len) } DataType::Union(_, _, _) => { // Validate Union Array as part of implementing new Union semantics @@ -1068,17 +1067,12 @@ impl ArrayData { /// /// For example, the offsets buffer contained `[1, 2, 4]`, this /// function would call `validate([1,2])`, and `validate([2,4])` - fn validate_each_offset( - &self, - offsets_buffer: &Buffer, - offset_limit: usize, - validate: V, - ) -> Result<()> + fn validate_each_offset(&self, offset_limit: usize, validate: V) -> Result<()> where - T: ArrowNativeType + std::convert::TryInto + num::Num + std::fmt::Display, + T: ArrowNativeType + TryInto + num::Num + std::fmt::Display, V: Fn(usize, Range) -> Result<()>, { - self.typed_offsets::(offsets_buffer)? + self.typed_offsets::()? .iter() .enumerate() .map(|(i, x)| { @@ -1124,50 +1118,39 @@ impl ArrayData { /// into `buffers[1]` are valid utf8 sequences fn validate_utf8(&self) -> Result<()> where - T: ArrowNativeType + std::convert::TryInto + num::Num + std::fmt::Display, + T: ArrowNativeType + TryInto + num::Num + std::fmt::Display, { - let offset_buffer = &self.buffers[0]; let values_buffer = &self.buffers[1].as_slice(); - self.validate_each_offset::( - offset_buffer, - values_buffer.len(), - |string_index, range| { - std::str::from_utf8(&values_buffer[range.clone()]).map_err(|e| { - ArrowError::InvalidArgumentError(format!( - "Invalid UTF8 sequence at string index {} ({:?}): {}", - string_index, range, e - )) - })?; - Ok(()) - }, - ) + self.validate_each_offset::(values_buffer.len(), |string_index, range| { + std::str::from_utf8(&values_buffer[range.clone()]).map_err(|e| { + ArrowError::InvalidArgumentError(format!( + "Invalid UTF8 sequence at string index {} ({:?}): {}", + string_index, range, e + )) + })?; + Ok(()) + }) } /// Ensures that all offsets in `buffers[0]` into `buffers[1]` are /// between `0` and `offset_limit` fn validate_offsets_full(&self, offset_limit: usize) -> Result<()> where - T: ArrowNativeType + std::convert::TryInto + num::Num + std::fmt::Display, + T: ArrowNativeType + TryInto + num::Num + std::fmt::Display, { - let offset_buffer = &self.buffers[0]; - - self.validate_each_offset::( - offset_buffer, - offset_limit, - |_string_index, _range| { - // No validation applied to each value, but the iteration - // itself applies bounds checking to each range - Ok(()) - }, - ) + self.validate_each_offset::(offset_limit, |_string_index, _range| { + // No validation applied to each value, but the iteration + // itself applies bounds checking to each range + Ok(()) + }) } /// Validates that each value in self.buffers (typed as T) /// is within the range [0, max_value], inclusive fn check_bounds(&self, max_value: i64) -> Result<()> where - T: ArrowNativeType + std::convert::TryInto + num::Num + std::fmt::Display, + T: ArrowNativeType + TryInto + num::Num + std::fmt::Display, { let required_len = self.len + self.offset; let buffer = &self.buffers[0]; @@ -1859,7 +1842,7 @@ mod tests { #[test] #[should_panic( - expected = "Offsets buffer size (bytes): 4 isn't large enough for LargeUtf8. Length 0 needs 1" + expected = "Buffer 0 of LargeUtf8 isn't large enough. Expected 8 bytes got 4" )] fn test_empty_large_utf8_array_with_wrong_type_offsets() { let data_buffer = Buffer::from(&[]); @@ -1877,7 +1860,7 @@ mod tests { #[test] #[should_panic( - expected = "Offsets buffer size (bytes): 8 isn't large enough for Utf8. Length 2 needs 3" + expected = "Buffer 0 of Utf8 isn't large enough. Expected 12 bytes got 8" )] fn test_validate_offsets_i32() { let data_buffer = Buffer::from_slice_ref(&"abcdef".as_bytes()); @@ -1895,7 +1878,7 @@ mod tests { #[test] #[should_panic( - expected = "Offsets buffer size (bytes): 16 isn't large enough for LargeUtf8. Length 2 needs 3" + expected = "Buffer 0 of LargeUtf8 isn't large enough. Expected 24 bytes got 16" )] fn test_validate_offsets_i64() { let data_buffer = Buffer::from_slice_ref(&"abcdef".as_bytes()); @@ -2755,4 +2738,40 @@ mod tests { error.to_string() ); } + + #[test] + fn test_decimal_validation() { + let mut builder = DecimalBuilder::new(4, 10, 4); + builder.append_value(10000).unwrap(); + builder.append_value(20000).unwrap(); + let array = builder.finish(); + + array.data().validate_full().unwrap(); + } + + #[test] + #[cfg(not(feature = "force_validate"))] + fn test_sliced_array_child() { + let values = Int32Array::from_iter_values([1, 2, 3]); + let values_sliced = values.slice(1, 2); + let offsets = Buffer::from_iter([1_i32, 3_i32]); + + let list_field = Field::new("element", DataType::Int32, false); + let data_type = DataType::List(Box::new(list_field)); + + let data = unsafe { + ArrayData::new_unchecked( + data_type, + 1, + None, + None, + 0, + vec![offsets], + vec![values_sliced.data().clone()], + ) + }; + + let err = data.validate_dictionary_offset().unwrap_err(); + assert_eq!(err.to_string(), "Invalid argument error: Offset invariant failure: offset at position 1 out of bounds: 3 > 2"); + } }