Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Validate ArrayData type when converting to Array (#2834) #2835

Merged
merged 10 commits into from Oct 13, 2022
9 changes: 9 additions & 0 deletions arrow-array/src/array/binary_array.rs
Expand Up @@ -290,6 +290,8 @@ impl<OffsetSize: OffsetSizeTrait> From<ArrayData> for GenericBinaryArray<OffsetS
let values = data.buffers()[1].as_ptr();
Self {
data,
// SAFETY:
// ArrayData must be valid, and validated data type above
value_offsets: unsafe { RawPtrBox::new(offsets) },
value_data: unsafe { RawPtrBox::new(values) },
}
Expand Down Expand Up @@ -826,6 +828,13 @@ mod tests {
binary_array.value(4);
}

#[test]
#[should_panic(expected = "[Large]BinaryArray expects Datatype::[Large]Binary")]
fn test_binary_array_validation() {
let array = BinaryArray::from_iter_values(&[&[1, 2]]);
let _ = LargeBinaryArray::from(array.into_data());
}

#[test]
fn test_binary_array_all_null() {
let data = vec![None];
Expand Down
17 changes: 17 additions & 0 deletions arrow-array/src/array/boolean_array.rs
Expand Up @@ -201,6 +201,13 @@ impl From<Vec<Option<bool>>> for BooleanArray {

impl From<ArrayData> for BooleanArray {
fn from(data: ArrayData) -> Self {
assert_eq!(
data.data_type(),
&DataType::Boolean,
"BooleanArray expected ArrayData with type {} got {}",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

DataType::Boolean,
data.data_type()
);
assert_eq!(
data.buffers().len(),
1,
Expand All @@ -209,6 +216,8 @@ impl From<ArrayData> for BooleanArray {
let ptr = data.buffers()[0].as_ptr();
Self {
data,
// SAFETY:
// ArrayData must be valid, and validated data type above
raw_values: unsafe { RawPtrBox::new(ptr) },
}
}
Expand Down Expand Up @@ -414,4 +423,12 @@ mod tests {
};
drop(BooleanArray::from(data));
}

#[test]
#[should_panic(
expected = "BooleanArray expected ArrayData with type Boolean got Int32"
)]
fn test_from_array_data_validation() {
let _ = BooleanArray::from(ArrayData::new_empty(&DataType::Int32));
}
}
25 changes: 21 additions & 4 deletions arrow-array/src/array/decimal_array.rs
Expand Up @@ -407,13 +407,21 @@ impl<T: DecimalType> From<ArrayData> for DecimalArray<T> {
"DecimalArray data should contain 1 buffer only (values)"
);
let values = data.buffers()[0].as_ptr();
let (precision, scale) = match (data.data_type(), Self::VALUE_LENGTH) {
(DataType::Decimal128(precision, scale), 16)
| (DataType::Decimal256(precision, scale), 32) => (*precision, *scale),
_ => panic!("Expected data type to be Decimal"),
let (precision, scale) = match (data.data_type(), Self::DEFAULT_TYPE) {
(DataType::Decimal128(precision, scale), DataType::Decimal128(_, _))
| (DataType::Decimal256(precision, scale), DataType::Decimal256(_, _)) => {
(*precision, *scale)
}
_ => panic!(
"Expected data type to match {} got {}",
Self::DEFAULT_TYPE,
data.data_type()
),
};
Self {
data,
// SAFETY:
// ArrayData must be valid, and verified data type above
value_data: unsafe { RawPtrBox::new(values) },
precision,
scale,
Expand Down Expand Up @@ -977,4 +985,13 @@ mod tests {

array.value(4);
}

#[test]
#[should_panic(
expected = "Expected data type to match Decimal256(76, 10) got Decimal128(38, 10)"
)]
fn test_from_array_data_validation() {
let array = Decimal128Array::from_iter_values(vec![-100, 0, 101].into_iter());
let _ = Decimal256Array::from(array.into_data());
}
}
22 changes: 19 additions & 3 deletions arrow-array/src/array/dictionary_array.rs
Expand Up @@ -408,10 +408,17 @@ impl<T: ArrowPrimitiveType> From<ArrayData> for DictionaryArray<T> {
);

if let DataType::Dictionary(key_data_type, _) = data.data_type() {
if key_data_type.as_ref() != &T::DATA_TYPE {
panic!("DictionaryArray's data type must match.")
};
assert_eq!(
&T::DATA_TYPE,
key_data_type.as_ref(),
"DictionaryArray's data type must match, expected {} got {}",
T::DATA_TYPE,
key_data_type
);

// create a zero-copy of the keys' data
// SAFETY:
// ArrayData is valid and verified type above
let keys = PrimitiveArray::<T>::from(unsafe {
ArrayData::new_unchecked(
T::DATA_TYPE,
Expand Down Expand Up @@ -925,4 +932,13 @@ mod tests {
let keys: Float32Array = [Some(0_f32), None, Some(3_f32)].into_iter().collect();
DictionaryArray::<Float32Type>::try_new(&keys, &values).unwrap();
}

#[test]
#[should_panic(
expected = "DictionaryArray's data type must match, expected Int64 got Int32"
)]
fn test_from_array_data_validation() {
let a = DictionaryArray::<Int32Type>::from_iter(["32"]);
let _ = DictionaryArray::<Int64Type>::from(a.into_data());
}
}
16 changes: 16 additions & 0 deletions arrow-array/src/array/list_array.rs
Expand Up @@ -241,6 +241,9 @@ impl<OffsetSize: OffsetSizeTrait> GenericListArray<OffsetSize> {

let values = make_array(values);
let value_offsets = data.buffers()[0].as_ptr();

// SAFETY:
// Verified list type in call to `Self::get_type`
let value_offsets = unsafe { RawPtrBox::<OffsetSize>::new(value_offsets) };
Ok(Self {
data,
Expand Down Expand Up @@ -346,6 +349,7 @@ pub type LargeListArray = GenericListArray<i64>;
#[cfg(test)]
mod tests {
use super::*;
use crate::builder::{Int32Builder, ListBuilder};
use crate::types::Int32Type;
use crate::Int32Array;
use arrow_buffer::{bit_util, Buffer, ToByteSlice};
Expand Down Expand Up @@ -804,6 +808,18 @@ mod tests {
drop(ListArray::from(list_data));
}

#[test]
#[should_panic(
expected = "[Large]ListArray's datatype must be [Large]ListArray(). It is List"
)]
fn test_from_array_data_validation() {
let mut builder = ListBuilder::new(Int32Builder::new());
builder.values().append_value(1);
builder.append(true);
let array = builder.finish();
let _ = LargeListArray::from(array.into_data());
}

#[test]
fn test_list_array_offsets_need_not_start_at_zero() {
let value_data = ArrayData::builder(DataType::Int32)
Expand Down
23 changes: 23 additions & 0 deletions arrow-array/src/array/map_array.rs
Expand Up @@ -109,6 +109,12 @@ impl From<MapArray> for ArrayData {

impl MapArray {
fn try_new_from_array_data(data: ArrayData) -> Result<Self, ArrowError> {
assert!(
matches!(data.data_type(), DataType::Map(_, _)),
"MapArray expected ArrayData with DataType::Map got {}",
data.data_type()
);

if data.buffers().len() != 1 {
return Err(ArrowError::InvalidArgumentError(
format!("MapArray data should contain a single buffer only (value offsets), had {}",
Expand Down Expand Up @@ -141,6 +147,8 @@ impl MapArray {
let values = make_array(entries);
let value_offsets = data.buffers()[0].as_ptr();

// SAFETY:
// ArrayData is valid, and verified type above
let value_offsets = unsafe { RawPtrBox::<i32>::new(value_offsets) };
unsafe {
if (*value_offsets.as_ptr().offset(0)) != 0 {
Expand Down Expand Up @@ -467,6 +475,21 @@ mod tests {
map_array.value(map_array.len());
}

#[test]
#[should_panic(
expected = "MapArray expected ArrayData with DataType::Map got Dictionary"
)]
fn test_from_array_data_validation() {
// A DictionaryArray has similar buffer layout to a MapArray
// but the meaning of the values differs
let struct_t = DataType::Struct(vec![
Field::new("keys", DataType::Int32, true),
Field::new("values", DataType::UInt32, true),
]);
let dict_t = DataType::Dictionary(Box::new(DataType::Int32), Box::new(struct_t));
let _ = MapArray::from(ArrayData::new_empty(&dict_t));
}

#[test]
fn test_new_from_strings() {
let keys = vec!["a", "b", "c", "d", "e", "f", "g", "h"];
Expand Down
19 changes: 19 additions & 0 deletions arrow-array/src/array/primitive_array.rs
Expand Up @@ -809,6 +809,14 @@ impl<T: ArrowTimestampType> PrimitiveArray<T> {
/// Constructs a `PrimitiveArray` from an array data reference.
impl<T: ArrowPrimitiveType> From<ArrayData> for PrimitiveArray<T> {
fn from(data: ArrayData) -> Self {
// Use discriminant to allow for decimals
assert_eq!(
std::mem::discriminant(&T::DATA_TYPE),
std::mem::discriminant(data.data_type()),
"PrimitiveArray expected ArrayData with type {} got {}",
T::DATA_TYPE,
data.data_type()
);
assert_eq!(
data.buffers().len(),
1,
Expand All @@ -818,6 +826,8 @@ impl<T: ArrowPrimitiveType> From<ArrayData> for PrimitiveArray<T> {
let ptr = data.buffers()[0].as_ptr();
Self {
data,
// SAFETY:
// ArrayData must be valid, and validated data type above
raw_values: unsafe { RawPtrBox::new(ptr) },
}
}
Expand Down Expand Up @@ -1342,4 +1352,13 @@ mod tests {

array.value(4);
}

#[test]
#[should_panic(
expected = "PrimitiveArray expected ArrayData with type Int64 got Int32"
)]
fn test_from_array_data_validation() {
let foo = PrimitiveArray::<Int32Type>::from_iter([1, 2, 3]);
let _ = PrimitiveArray::<Int64Type>::from(foo.into_data());
}
}
23 changes: 11 additions & 12 deletions arrow/src/compute/kernels/cast.rs
Expand Up @@ -43,12 +43,12 @@ use std::str;
use std::sync::Arc;

use crate::buffer::MutableBuffer;
use crate::compute::divide_scalar;
use crate::compute::kernels::arithmetic::{divide, multiply};
use crate::compute::kernels::arity::unary;
use crate::compute::kernels::cast_utils::string_to_timestamp_nanos;
use crate::compute::kernels::temporal::extract_component_from_array;
use crate::compute::kernels::temporal::return_compute_error_with;
use crate::compute::{divide_scalar, multiply_scalar};
use crate::compute::{try_unary, using_chrono_tz_and_utc_naive_date_time};
use crate::datatypes::*;
use crate::error::{ArrowError, Result};
Expand Down Expand Up @@ -1241,14 +1241,14 @@ pub fn cast_with_options(
}
//(Time32(TimeUnit::Second), Time64(_)) => {},
(Time32(from_unit), Time64(to_unit)) => {
let time_array = Int32Array::from(array.data().clone());
let array = cast_with_options(array, &Int32, cast_options)?;
let time_array = as_primitive_array::<Int32Type>(array.as_ref());
// note: (numeric_cast + SIMD multiply) is faster than (cast & multiply)
let c: Int64Array = numeric_cast(&time_array);
let from_size = time_unit_multiple(from_unit);
let to_size = time_unit_multiple(to_unit);
// from is only smaller than to if 64milli/64second don't exist
let mult = Int64Array::from(vec![to_size / from_size; array.len()]);
let converted = multiply(&c, &mult)?;
let converted = multiply_scalar(&c, to_size / from_size)?;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Drive by cleanup

let array_ref = Arc::new(converted) as ArrayRef;
use TimeUnit::*;
match to_unit {
Expand Down Expand Up @@ -1284,7 +1284,8 @@ pub fn cast_with_options(
Ok(Arc::new(values) as ArrayRef)
}
(Time64(from_unit), Time32(to_unit)) => {
let time_array = Int64Array::from(array.data().clone());
let array = cast_with_options(array, &Int64, cast_options)?;
let time_array = as_primitive_array::<Int64Type>(array.as_ref());
let from_size = time_unit_multiple(from_unit);
let to_size = time_unit_multiple(to_unit);
let divisor = from_size / to_size;
Expand Down Expand Up @@ -1321,18 +1322,16 @@ pub fn cast_with_options(
}
}
(Timestamp(from_unit, _), Timestamp(to_unit, _)) => {
let time_array = Int64Array::from(array.data().clone());
let array = cast_with_options(array, &Int64, cast_options)?;
let time_array = as_primitive_array::<Int64Type>(array.as_ref());
let from_size = time_unit_multiple(from_unit);
let to_size = time_unit_multiple(to_unit);
// we either divide or multiply, depending on size of each unit
// units are never the same when the types are the same
let converted = if from_size >= to_size {
divide_scalar(&time_array, from_size / to_size)?
} else {
multiply(
&time_array,
&Int64Array::from(vec![to_size / from_size; array.len()]),
)?
multiply_scalar(&time_array, to_size / from_size)?
};
let array_ref = Arc::new(converted) as ArrayRef;
use TimeUnit::*;
Expand All @@ -1355,10 +1354,10 @@ pub fn cast_with_options(
}
}
(Timestamp(from_unit, _), Date32) => {
let time_array = Int64Array::from(array.data().clone());
let array = cast_with_options(array, &Int64, cast_options)?;
let time_array = as_primitive_array::<Int64Type>(array.as_ref());
let from_size = time_unit_multiple(from_unit) * SECONDS_IN_DAY;

// Int32Array::from_iter(tim.iter)
let mut b = Date32Builder::with_capacity(array.len());

for i in 0..array.len() {
Expand Down
4 changes: 2 additions & 2 deletions arrow/src/compute/kernels/take.rs
Expand Up @@ -1398,7 +1398,7 @@ mod tests {
fn test_take_bool_nullable_index() {
// indices where the masked invalid elements would be out of bounds
let index_data = ArrayData::try_new(
DataType::Int32,
DataType::UInt32,
6,
Some(Buffer::from_iter(vec![
false, true, false, true, false, true,
Expand All @@ -1421,7 +1421,7 @@ mod tests {
fn test_take_bool_nullable_index_nonnull_values() {
// indices where the masked invalid elements would be out of bounds
let index_data = ArrayData::try_new(
DataType::Int32,
DataType::UInt32,
6,
Some(Buffer::from_iter(vec![
false, true, false, true, false, true,
Expand Down