Skip to content

Commit

Permalink
Validate ArrayData type when converting to Array (#2834) (#2835)
Browse files Browse the repository at this point in the history
* Validate ArrayData type when converting to Array (#2834)

* Fix cast kernel and take kernel tests

* Clippy

* Fix parquet

* Clippy
  • Loading branch information
tustvold committed Oct 13, 2022
1 parent 1397fb4 commit 8adebca
Show file tree
Hide file tree
Showing 10 changed files with 148 additions and 24 deletions.
9 changes: 9 additions & 0 deletions arrow-array/src/array/binary_array.rs
Expand Up @@ -297,6 +297,8 @@ impl<OffsetSize: OffsetSizeTrait> From<ArrayData> for GenericBinaryArray<OffsetS
let values = data.buffers()[1].as_ptr();
Self {
data,
// SAFETY:
// ArrayData must be valid, and validated data type above
value_offsets: unsafe { RawPtrBox::new(offsets) },
value_data: unsafe { RawPtrBox::new(values) },
}
Expand Down Expand Up @@ -833,6 +835,13 @@ mod tests {
binary_array.value(4);
}

#[test]
#[should_panic(expected = "[Large]BinaryArray expects Datatype::[Large]Binary")]
fn test_binary_array_validation() {
let array = BinaryArray::from_iter_values(&[&[1, 2]]);
let _ = LargeBinaryArray::from(array.into_data());
}

#[test]
fn test_binary_array_all_null() {
let data = vec![None];
Expand Down
17 changes: 17 additions & 0 deletions arrow-array/src/array/boolean_array.rs
Expand Up @@ -201,6 +201,13 @@ impl From<Vec<Option<bool>>> for BooleanArray {

impl From<ArrayData> for BooleanArray {
fn from(data: ArrayData) -> Self {
assert_eq!(
data.data_type(),
&DataType::Boolean,
"BooleanArray expected ArrayData with type {} got {}",
DataType::Boolean,
data.data_type()
);
assert_eq!(
data.buffers().len(),
1,
Expand All @@ -209,6 +216,8 @@ impl From<ArrayData> for BooleanArray {
let ptr = data.buffers()[0].as_ptr();
Self {
data,
// SAFETY:
// ArrayData must be valid, and validated data type above
raw_values: unsafe { RawPtrBox::new(ptr) },
}
}
Expand Down Expand Up @@ -414,4 +423,12 @@ mod tests {
};
drop(BooleanArray::from(data));
}

#[test]
#[should_panic(
expected = "BooleanArray expected ArrayData with type Boolean got Int32"
)]
fn test_from_array_data_validation() {
let _ = BooleanArray::from(ArrayData::new_empty(&DataType::Int32));
}
}
25 changes: 21 additions & 4 deletions arrow-array/src/array/decimal_array.rs
Expand Up @@ -407,13 +407,21 @@ impl<T: DecimalType> From<ArrayData> for DecimalArray<T> {
"DecimalArray data should contain 1 buffer only (values)"
);
let values = data.buffers()[0].as_ptr();
let (precision, scale) = match (data.data_type(), Self::VALUE_LENGTH) {
(DataType::Decimal128(precision, scale), 16)
| (DataType::Decimal256(precision, scale), 32) => (*precision, *scale),
_ => panic!("Expected data type to be Decimal"),
let (precision, scale) = match (data.data_type(), Self::DEFAULT_TYPE) {
(DataType::Decimal128(precision, scale), DataType::Decimal128(_, _))
| (DataType::Decimal256(precision, scale), DataType::Decimal256(_, _)) => {
(*precision, *scale)
}
_ => panic!(
"Expected data type to match {} got {}",
Self::DEFAULT_TYPE,
data.data_type()
),
};
Self {
data,
// SAFETY:
// ArrayData must be valid, and verified data type above
value_data: unsafe { RawPtrBox::new(values) },
precision,
scale,
Expand Down Expand Up @@ -977,4 +985,13 @@ mod tests {

array.value(4);
}

#[test]
#[should_panic(
expected = "Expected data type to match Decimal256(76, 10) got Decimal128(38, 10)"
)]
fn test_from_array_data_validation() {
let array = Decimal128Array::from_iter_values(vec![-100, 0, 101].into_iter());
let _ = Decimal256Array::from(array.into_data());
}
}
22 changes: 19 additions & 3 deletions arrow-array/src/array/dictionary_array.rs
Expand Up @@ -408,10 +408,17 @@ impl<T: ArrowPrimitiveType> From<ArrayData> for DictionaryArray<T> {
);

if let DataType::Dictionary(key_data_type, _) = data.data_type() {
if key_data_type.as_ref() != &T::DATA_TYPE {
panic!("DictionaryArray's data type must match.")
};
assert_eq!(
&T::DATA_TYPE,
key_data_type.as_ref(),
"DictionaryArray's data type must match, expected {} got {}",
T::DATA_TYPE,
key_data_type
);

// create a zero-copy of the keys' data
// SAFETY:
// ArrayData is valid and verified type above
let keys = PrimitiveArray::<T>::from(unsafe {
ArrayData::new_unchecked(
T::DATA_TYPE,
Expand Down Expand Up @@ -925,4 +932,13 @@ mod tests {
let keys: Float32Array = [Some(0_f32), None, Some(3_f32)].into_iter().collect();
DictionaryArray::<Float32Type>::try_new(&keys, &values).unwrap();
}

#[test]
#[should_panic(
expected = "DictionaryArray's data type must match, expected Int64 got Int32"
)]
fn test_from_array_data_validation() {
let a = DictionaryArray::<Int32Type>::from_iter(["32"]);
let _ = DictionaryArray::<Int64Type>::from(a.into_data());
}
}
15 changes: 15 additions & 0 deletions arrow-array/src/array/list_array.rs
Expand Up @@ -257,6 +257,8 @@ impl<OffsetSize: OffsetSizeTrait> GenericListArray<OffsetSize> {
false => data.buffers()[0].as_ptr(),
};

// SAFETY:
// Verified list type in call to `Self::get_type`
let value_offsets = unsafe { RawPtrBox::new(offsets) };
Ok(Self {
data,
Expand Down Expand Up @@ -362,6 +364,7 @@ pub type LargeListArray = GenericListArray<i64>;
#[cfg(test)]
mod tests {
use super::*;
use crate::builder::{Int32Builder, ListBuilder};
use crate::types::Int32Type;
use crate::Int32Array;
use arrow_buffer::{bit_util, Buffer, ToByteSlice};
Expand Down Expand Up @@ -820,6 +823,18 @@ mod tests {
drop(ListArray::from(list_data));
}

#[test]
#[should_panic(
expected = "[Large]ListArray's datatype must be [Large]ListArray(). It is List"
)]
fn test_from_array_data_validation() {
let mut builder = ListBuilder::new(Int32Builder::new());
builder.values().append_value(1);
builder.append(true);
let array = builder.finish();
let _ = LargeListArray::from(array.into_data());
}

#[test]
fn test_list_array_offsets_need_not_start_at_zero() {
let value_data = ArrayData::builder(DataType::Int32)
Expand Down
23 changes: 23 additions & 0 deletions arrow-array/src/array/map_array.rs
Expand Up @@ -109,6 +109,12 @@ impl From<MapArray> for ArrayData {

impl MapArray {
fn try_new_from_array_data(data: ArrayData) -> Result<Self, ArrowError> {
assert!(
matches!(data.data_type(), DataType::Map(_, _)),
"MapArray expected ArrayData with DataType::Map got {}",
data.data_type()
);

if data.buffers().len() != 1 {
return Err(ArrowError::InvalidArgumentError(
format!("MapArray data should contain a single buffer only (value offsets), had {}",
Expand Down Expand Up @@ -141,6 +147,8 @@ impl MapArray {
let values = make_array(entries);
let value_offsets = data.buffers()[0].as_ptr();

// SAFETY:
// ArrayData is valid, and verified type above
let value_offsets = unsafe { RawPtrBox::<i32>::new(value_offsets) };
unsafe {
if (*value_offsets.as_ptr().offset(0)) != 0 {
Expand Down Expand Up @@ -467,6 +475,21 @@ mod tests {
map_array.value(map_array.len());
}

#[test]
#[should_panic(
expected = "MapArray expected ArrayData with DataType::Map got Dictionary"
)]
fn test_from_array_data_validation() {
// A DictionaryArray has similar buffer layout to a MapArray
// but the meaning of the values differs
let struct_t = DataType::Struct(vec![
Field::new("keys", DataType::Int32, true),
Field::new("values", DataType::UInt32, true),
]);
let dict_t = DataType::Dictionary(Box::new(DataType::Int32), Box::new(struct_t));
let _ = MapArray::from(ArrayData::new_empty(&dict_t));
}

#[test]
fn test_new_from_strings() {
let keys = vec!["a", "b", "c", "d", "e", "f", "g", "h"];
Expand Down
19 changes: 19 additions & 0 deletions arrow-array/src/array/primitive_array.rs
Expand Up @@ -818,6 +818,14 @@ impl<T: ArrowTimestampType> PrimitiveArray<T> {
/// Constructs a `PrimitiveArray` from an array data reference.
impl<T: ArrowPrimitiveType> From<ArrayData> for PrimitiveArray<T> {
fn from(data: ArrayData) -> Self {
// Use discriminant to allow for decimals
assert_eq!(
std::mem::discriminant(&T::DATA_TYPE),
std::mem::discriminant(data.data_type()),
"PrimitiveArray expected ArrayData with type {} got {}",
T::DATA_TYPE,
data.data_type()
);
assert_eq!(
data.buffers().len(),
1,
Expand All @@ -827,6 +835,8 @@ impl<T: ArrowPrimitiveType> From<ArrayData> for PrimitiveArray<T> {
let ptr = data.buffers()[0].as_ptr();
Self {
data,
// SAFETY:
// ArrayData must be valid, and validated data type above
raw_values: unsafe { RawPtrBox::new(ptr) },
}
}
Expand Down Expand Up @@ -1352,6 +1362,15 @@ mod tests {
array.value(4);
}

#[test]
#[should_panic(
expected = "PrimitiveArray expected ArrayData with type Int64 got Int32"
)]
fn test_from_array_data_validation() {
let foo = PrimitiveArray::<Int32Type>::from_iter([1, 2, 3]);
let _ = PrimitiveArray::<Int64Type>::from(foo.into_data());
}

#[test]
fn test_decimal128() {
let values: Vec<_> = vec![0, 1, -1, i128::MIN, i128::MAX];
Expand Down
11 changes: 6 additions & 5 deletions arrow/src/compute/kernels/cast.rs
Expand Up @@ -1312,15 +1312,16 @@ pub fn cast_with_options(
)),

(Timestamp(from_unit, _), Timestamp(to_unit, to_tz)) => {
let time_array = Int64Array::from(array.data().clone());
let array = cast_with_options(array, &Int64, cast_options)?;
let time_array = as_primitive_array::<Int64Type>(array.as_ref());
let from_size = time_unit_multiple(from_unit);
let to_size = time_unit_multiple(to_unit);
// we either divide or multiply, depending on size of each unit
// units are never the same when the types are the same
let converted = if from_size >= to_size {
divide_scalar(&time_array, from_size / to_size)?
divide_scalar(time_array, from_size / to_size)?
} else {
multiply_scalar(&time_array, to_size / from_size)?
multiply_scalar(time_array, to_size / from_size)?
};
Ok(make_timestamp_array(
&converted,
Expand All @@ -1329,10 +1330,10 @@ pub fn cast_with_options(
))
}
(Timestamp(from_unit, _), Date32) => {
let time_array = Int64Array::from(array.data().clone());
let array = cast_with_options(array, &Int64, cast_options)?;
let time_array = as_primitive_array::<Int64Type>(array.as_ref());
let from_size = time_unit_multiple(from_unit) * SECONDS_IN_DAY;

// Int32Array::from_iter(tim.iter)
let mut b = Date32Builder::with_capacity(array.len());

for i in 0..array.len() {
Expand Down
4 changes: 2 additions & 2 deletions arrow/src/compute/kernels/take.rs
Expand Up @@ -1398,7 +1398,7 @@ mod tests {
fn test_take_bool_nullable_index() {
// indices where the masked invalid elements would be out of bounds
let index_data = ArrayData::try_new(
DataType::Int32,
DataType::UInt32,
6,
Some(Buffer::from_iter(vec![
false, true, false, true, false, true,
Expand All @@ -1421,7 +1421,7 @@ mod tests {
fn test_take_bool_nullable_index_nonnull_values() {
// indices where the masked invalid elements would be out of bounds
let index_data = ArrayData::try_new(
DataType::Int32,
DataType::UInt32,
6,
Some(Buffer::from_iter(vec![
false, true, false, true, false, true,
Expand Down
27 changes: 17 additions & 10 deletions parquet/src/arrow/array_reader/primitive_array.rs
Expand Up @@ -26,7 +26,8 @@ use crate::errors::{ParquetError, Result};
use crate::schema::types::ColumnDescPtr;
use arrow::array::{
ArrayDataBuilder, ArrayRef, BooleanArray, BooleanBufferBuilder, Decimal128Array,
Float32Array, Float64Array, Int32Array, Int64Array,TimestampNanosecondArray, TimestampNanosecondBufferBuilder,
Float32Array, Float64Array, Int32Array, Int64Array, TimestampNanosecondArray,
TimestampNanosecondBufferBuilder, UInt32Array, UInt64Array,
};
use arrow::buffer::Buffer;
use arrow::datatypes::{DataType as ArrowType, TimeUnit};
Expand Down Expand Up @@ -169,15 +170,21 @@ where
.null_bit_buffer(self.record_reader.consume_bitmap_buffer());

let array_data = unsafe { array_data.build_unchecked() };
let array = match T::get_physical_type() {
PhysicalType::BOOLEAN => Arc::new(BooleanArray::from(array_data)) as ArrayRef,
PhysicalType::INT32 => Arc::new(Int32Array::from(array_data)) as ArrayRef,
PhysicalType::INT64 => Arc::new(Int64Array::from(array_data)) as ArrayRef,
PhysicalType::FLOAT => Arc::new(Float32Array::from(array_data)) as ArrayRef,
PhysicalType::DOUBLE => Arc::new(Float64Array::from(array_data)) as ArrayRef,
PhysicalType::INT96 => {
Arc::new(TimestampNanosecondArray::from(array_data)) as ArrayRef
}
let array: ArrayRef = match T::get_physical_type() {
PhysicalType::BOOLEAN => Arc::new(BooleanArray::from(array_data)),
PhysicalType::INT32 => match array_data.data_type() {
ArrowType::UInt32 => Arc::new(UInt32Array::from(array_data)),
ArrowType::Int32 => Arc::new(Int32Array::from(array_data)),
_ => unreachable!(),
},
PhysicalType::INT64 => match array_data.data_type() {
ArrowType::UInt64 => Arc::new(UInt64Array::from(array_data)),
ArrowType::Int64 => Arc::new(Int64Array::from(array_data)),
_ => unreachable!(),
},
PhysicalType::FLOAT => Arc::new(Float32Array::from(array_data)),
PhysicalType::DOUBLE => Arc::new(Float64Array::from(array_data)),
PhysicalType::INT96 => Arc::new(TimestampNanosecondArray::from(array_data)),
PhysicalType::BYTE_ARRAY | PhysicalType::FIXED_LEN_BYTE_ARRAY => {
unreachable!(
"PrimitiveArrayReaders don't support complex physical types"
Expand Down

0 comments on commit 8adebca

Please sign in to comment.