Skip to content

Commit

Permalink
Fix Decimal and List ArrayData Validation (#1813) (#1814) (#1816)
Browse files Browse the repository at this point in the history
* Fix DecimalArray validation (#1813)

Fix offset validation for sliced children of list arrays (#1814)

* Update arrow/src/array/data.rs

Co-authored-by: Liang-Chi Hsieh <viirya@gmail.com>

Co-authored-by: Liang-Chi Hsieh <viirya@gmail.com>
  • Loading branch information
tustvold and viirya committed Jun 7, 2022
1 parent a439f7f commit 9c40a87
Showing 1 changed file with 91 additions and 72 deletions.
163 changes: 91 additions & 72 deletions arrow/src/array/data.rs
Expand Up @@ -712,10 +712,10 @@ impl ArrayData {
// Additional Type specific checks
match &self.data_type {
DataType::Utf8 | DataType::Binary => {
self.validate_offsets::<i32>(&self.buffers[0], self.buffers[1].len())?;
self.validate_offsets::<i32>(self.buffers[1].len())?;
}
DataType::LargeUtf8 | DataType::LargeBinary => {
self.validate_offsets::<i64>(&self.buffers[0], self.buffers[1].len())?;
self.validate_offsets::<i64>(self.buffers[1].len())?;
}
DataType::Dictionary(key_type, _value_type) => {
// At the moment, constructing a DictionaryArray will also check this
Expand All @@ -738,40 +738,47 @@ impl ArrayData {
/// entries.
///
/// For an empty array, the `buffer` can also be empty.
fn typed_offsets<'a, T: ArrowNativeType + num::Num + std::fmt::Display>(
&'a self,
buffer: &'a Buffer,
) -> Result<&'a [T]> {
fn typed_offsets<T: ArrowNativeType + num::Num>(&self) -> Result<&[T]> {
// An empty list-like array can have 0 offsets
if buffer.is_empty() && self.len == 0 {
if self.len == 0 && self.buffers[0].is_empty() {
return Ok(&[]);
}

// Validate that there are the correct number of offsets for this array's length
let required_offsets = self.len + self.offset + 1;
self.typed_buffer(0, self.len + 1)
}

/// Returns a reference to the data in `buffers[idx]` as a typed slice after validating
fn typed_buffer<T: ArrowNativeType + num::Num>(
&self,
idx: usize,
len: usize,
) -> Result<&[T]> {
let buffer = &self.buffers[idx];

let required_len = (len + self.offset) * std::mem::size_of::<T>();

if (buffer.len() / std::mem::size_of::<T>()) < required_offsets {
if buffer.len() < required_len {
return Err(ArrowError::InvalidArgumentError(format!(
"Offsets buffer size (bytes): {} isn't large enough for {}. Length {} needs {}",
buffer.len(), self.data_type, self.len, required_offsets
"Buffer {} of {} isn't large enough. Expected {} bytes got {}",
idx,
self.data_type,
required_len,
buffer.len()
)));
}

// Justification: buffer size was validated above
Ok(unsafe {
&(buffer.typed_data::<T>()[self.offset..self.offset + self.len + 1])
})
// SAFETY: Bounds checked above
Ok(unsafe { &(buffer.typed_data::<T>()[self.offset..self.offset + len]) })
}

/// Does a cheap sanity check that the `self.len` values in `buffer` are valid
/// offsets (of type T) into some other buffer of `values_length` bytes long
fn validate_offsets<T: ArrowNativeType + num::Num + std::fmt::Display>(
&self,
buffer: &Buffer,
values_length: usize,
) -> Result<()> {
// Justification: buffer size was validated above
let offsets = self.typed_offsets::<T>(buffer)?;
let offsets = self.typed_offsets::<T>()?;
if offsets.is_empty() {
return Ok(());
}
Expand Down Expand Up @@ -819,12 +826,12 @@ impl ArrayData {
match &self.data_type {
DataType::List(field) | DataType::Map(field, _) => {
let values_data = self.get_single_valid_child_data(field.data_type())?;
self.validate_offsets::<i32>(&self.buffers[0], values_data.len)?;
self.validate_offsets::<i32>(values_data.len)?;
Ok(())
}
DataType::LargeList(field) => {
let values_data = self.get_single_valid_child_data(field.data_type())?;
self.validate_offsets::<i64>(&self.buffers[0], values_data.len)?;
self.validate_offsets::<i64>(values_data.len)?;
Ok(())
}
DataType::FixedSizeList(field, list_size) => {
Expand Down Expand Up @@ -1000,17 +1007,9 @@ impl ArrayData {
pub fn validate_dictionary_offset(&self) -> Result<()> {
match &self.data_type {
DataType::Decimal(p, _) => {
let values_buffer = &self.buffers[0];

for pos in 0..values_buffer.len() {
let raw_val = unsafe {
std::slice::from_raw_parts(
values_buffer.as_ptr().add(pos),
16_usize,
)
};
let value = i128::from_le_bytes(raw_val.try_into().unwrap());
validate_decimal_precision(value, *p)?;
let values_buffer: &[i128] = self.typed_buffer(0, self.len)?;
for value in values_buffer {
validate_decimal_precision(*value, *p)?;
}
Ok(())
}
Expand All @@ -1022,11 +1021,11 @@ impl ArrayData {
}
DataType::List(_) | DataType::Map(_, _) => {
let child = &self.child_data[0];
self.validate_offsets_full::<i32>(child.len + child.offset)
self.validate_offsets_full::<i32>(child.len)
}
DataType::LargeList(_) => {
let child = &self.child_data[0];
self.validate_offsets_full::<i64>(child.len + child.offset)
self.validate_offsets_full::<i64>(child.len)
}
DataType::Union(_, _, _) => {
// Validate Union Array as part of implementing new Union semantics
Expand Down Expand Up @@ -1068,17 +1067,12 @@ impl ArrayData {
///
/// For example, the offsets buffer contained `[1, 2, 4]`, this
/// function would call `validate([1,2])`, and `validate([2,4])`
fn validate_each_offset<T, V>(
&self,
offsets_buffer: &Buffer,
offset_limit: usize,
validate: V,
) -> Result<()>
fn validate_each_offset<T, V>(&self, offset_limit: usize, validate: V) -> Result<()>
where
T: ArrowNativeType + std::convert::TryInto<usize> + num::Num + std::fmt::Display,
T: ArrowNativeType + TryInto<usize> + num::Num + std::fmt::Display,
V: Fn(usize, Range<usize>) -> Result<()>,
{
self.typed_offsets::<T>(offsets_buffer)?
self.typed_offsets::<T>()?
.iter()
.enumerate()
.map(|(i, x)| {
Expand Down Expand Up @@ -1124,50 +1118,39 @@ impl ArrayData {
/// into `buffers[1]` are valid utf8 sequences
fn validate_utf8<T>(&self) -> Result<()>
where
T: ArrowNativeType + std::convert::TryInto<usize> + num::Num + std::fmt::Display,
T: ArrowNativeType + TryInto<usize> + num::Num + std::fmt::Display,
{
let offset_buffer = &self.buffers[0];
let values_buffer = &self.buffers[1].as_slice();

self.validate_each_offset::<T, _>(
offset_buffer,
values_buffer.len(),
|string_index, range| {
std::str::from_utf8(&values_buffer[range.clone()]).map_err(|e| {
ArrowError::InvalidArgumentError(format!(
"Invalid UTF8 sequence at string index {} ({:?}): {}",
string_index, range, e
))
})?;
Ok(())
},
)
self.validate_each_offset::<T, _>(values_buffer.len(), |string_index, range| {
std::str::from_utf8(&values_buffer[range.clone()]).map_err(|e| {
ArrowError::InvalidArgumentError(format!(
"Invalid UTF8 sequence at string index {} ({:?}): {}",
string_index, range, e
))
})?;
Ok(())
})
}

/// Ensures that all offsets in `buffers[0]` into `buffers[1]` are
/// between `0` and `offset_limit`
fn validate_offsets_full<T>(&self, offset_limit: usize) -> Result<()>
where
T: ArrowNativeType + std::convert::TryInto<usize> + num::Num + std::fmt::Display,
T: ArrowNativeType + TryInto<usize> + num::Num + std::fmt::Display,
{
let offset_buffer = &self.buffers[0];

self.validate_each_offset::<T, _>(
offset_buffer,
offset_limit,
|_string_index, _range| {
// No validation applied to each value, but the iteration
// itself applies bounds checking to each range
Ok(())
},
)
self.validate_each_offset::<T, _>(offset_limit, |_string_index, _range| {
// No validation applied to each value, but the iteration
// itself applies bounds checking to each range
Ok(())
})
}

/// Validates that each value in self.buffers (typed as T)
/// is within the range [0, max_value], inclusive
fn check_bounds<T>(&self, max_value: i64) -> Result<()>
where
T: ArrowNativeType + std::convert::TryInto<i64> + num::Num + std::fmt::Display,
T: ArrowNativeType + TryInto<i64> + num::Num + std::fmt::Display,
{
let required_len = self.len + self.offset;
let buffer = &self.buffers[0];
Expand Down Expand Up @@ -1859,7 +1842,7 @@ mod tests {

#[test]
#[should_panic(
expected = "Offsets buffer size (bytes): 4 isn't large enough for LargeUtf8. Length 0 needs 1"
expected = "Buffer 0 of LargeUtf8 isn't large enough. Expected 8 bytes got 4"
)]
fn test_empty_large_utf8_array_with_wrong_type_offsets() {
let data_buffer = Buffer::from(&[]);
Expand All @@ -1877,7 +1860,7 @@ mod tests {

#[test]
#[should_panic(
expected = "Offsets buffer size (bytes): 8 isn't large enough for Utf8. Length 2 needs 3"
expected = "Buffer 0 of Utf8 isn't large enough. Expected 12 bytes got 8"
)]
fn test_validate_offsets_i32() {
let data_buffer = Buffer::from_slice_ref(&"abcdef".as_bytes());
Expand All @@ -1895,7 +1878,7 @@ mod tests {

#[test]
#[should_panic(
expected = "Offsets buffer size (bytes): 16 isn't large enough for LargeUtf8. Length 2 needs 3"
expected = "Buffer 0 of LargeUtf8 isn't large enough. Expected 24 bytes got 16"
)]
fn test_validate_offsets_i64() {
let data_buffer = Buffer::from_slice_ref(&"abcdef".as_bytes());
Expand Down Expand Up @@ -2755,4 +2738,40 @@ mod tests {
error.to_string()
);
}

#[test]
fn test_decimal_validation() {
let mut builder = DecimalBuilder::new(4, 10, 4);
builder.append_value(10000).unwrap();
builder.append_value(20000).unwrap();
let array = builder.finish();

array.data().validate_full().unwrap();
}

#[test]
#[cfg(not(feature = "force_validate"))]
fn test_sliced_array_child() {
let values = Int32Array::from_iter_values([1, 2, 3]);
let values_sliced = values.slice(1, 2);
let offsets = Buffer::from_iter([1_i32, 3_i32]);

let list_field = Field::new("element", DataType::Int32, false);
let data_type = DataType::List(Box::new(list_field));

let data = unsafe {
ArrayData::new_unchecked(
data_type,
1,
None,
None,
0,
vec![offsets],
vec![values_sliced.data().clone()],
)
};

let err = data.validate_dictionary_offset().unwrap_err();
assert_eq!(err.to_string(), "Invalid argument error: Offset invariant failure: offset at position 1 out of bounds: 3 > 2");
}
}

0 comments on commit 9c40a87

Please sign in to comment.