diff --git a/arrow-array/src/array/binary_array.rs b/arrow-array/src/array/binary_array.rs index 851fb60c078..c8407b252ef 100644 --- a/arrow-array/src/array/binary_array.rs +++ b/arrow-array/src/array/binary_array.rs @@ -297,6 +297,8 @@ impl From for GenericBinaryArray>> for BooleanArray { impl From for BooleanArray { fn from(data: ArrayData) -> Self { + assert_eq!( + data.data_type(), + &DataType::Boolean, + "BooleanArray expected ArrayData with type {} got {}", + DataType::Boolean, + data.data_type() + ); assert_eq!( data.buffers().len(), 1, @@ -209,6 +216,8 @@ impl From for BooleanArray { let ptr = data.buffers()[0].as_ptr(); Self { data, + // SAFETY: + // ArrayData must be valid, and validated data type above raw_values: unsafe { RawPtrBox::new(ptr) }, } } @@ -414,4 +423,12 @@ mod tests { }; drop(BooleanArray::from(data)); } + + #[test] + #[should_panic( + expected = "BooleanArray expected ArrayData with type Boolean got Int32" + )] + fn test_from_array_data_validation() { + let _ = BooleanArray::from(ArrayData::new_empty(&DataType::Int32)); + } } diff --git a/arrow-array/src/array/decimal_array.rs b/arrow-array/src/array/decimal_array.rs index 34b424092e4..5ca9b0715cf 100644 --- a/arrow-array/src/array/decimal_array.rs +++ b/arrow-array/src/array/decimal_array.rs @@ -407,13 +407,21 @@ impl From for DecimalArray { "DecimalArray data should contain 1 buffer only (values)" ); let values = data.buffers()[0].as_ptr(); - let (precision, scale) = match (data.data_type(), Self::VALUE_LENGTH) { - (DataType::Decimal128(precision, scale), 16) - | (DataType::Decimal256(precision, scale), 32) => (*precision, *scale), - _ => panic!("Expected data type to be Decimal"), + let (precision, scale) = match (data.data_type(), Self::DEFAULT_TYPE) { + (DataType::Decimal128(precision, scale), DataType::Decimal128(_, _)) + | (DataType::Decimal256(precision, scale), DataType::Decimal256(_, _)) => { + (*precision, *scale) + } + _ => panic!( + "Expected data type to match {} got {}", + Self::DEFAULT_TYPE, + data.data_type() + ), }; Self { data, + // SAFETY: + // ArrayData must be valid, and verified data type above value_data: unsafe { RawPtrBox::new(values) }, precision, scale, @@ -977,4 +985,13 @@ mod tests { array.value(4); } + + #[test] + #[should_panic( + expected = "Expected data type to match Decimal256(76, 10) got Decimal128(38, 10)" + )] + fn test_from_array_data_validation() { + let array = Decimal128Array::from_iter_values(vec![-100, 0, 101].into_iter()); + let _ = Decimal256Array::from(array.into_data()); + } } diff --git a/arrow-array/src/array/dictionary_array.rs b/arrow-array/src/array/dictionary_array.rs index 96e91f729ab..002ee6f4782 100644 --- a/arrow-array/src/array/dictionary_array.rs +++ b/arrow-array/src/array/dictionary_array.rs @@ -408,10 +408,17 @@ impl From for DictionaryArray { ); if let DataType::Dictionary(key_data_type, _) = data.data_type() { - if key_data_type.as_ref() != &T::DATA_TYPE { - panic!("DictionaryArray's data type must match.") - }; + assert_eq!( + &T::DATA_TYPE, + key_data_type.as_ref(), + "DictionaryArray's data type must match, expected {} got {}", + T::DATA_TYPE, + key_data_type + ); + // create a zero-copy of the keys' data + // SAFETY: + // ArrayData is valid and verified type above let keys = PrimitiveArray::::from(unsafe { ArrayData::new_unchecked( T::DATA_TYPE, @@ -925,4 +932,13 @@ mod tests { let keys: Float32Array = [Some(0_f32), None, Some(3_f32)].into_iter().collect(); DictionaryArray::::try_new(&keys, &values).unwrap(); } + + #[test] + #[should_panic( + expected = "DictionaryArray's data type must match, expected Int64 got Int32" + )] + fn test_from_array_data_validation() { + let a = DictionaryArray::::from_iter(["32"]); + let _ = DictionaryArray::::from(a.into_data()); + } } diff --git a/arrow-array/src/array/list_array.rs b/arrow-array/src/array/list_array.rs index 3022db023ab..cdc7531d99f 100644 --- a/arrow-array/src/array/list_array.rs +++ b/arrow-array/src/array/list_array.rs @@ -257,6 +257,8 @@ impl GenericListArray { false => data.buffers()[0].as_ptr(), }; + // SAFETY: + // Verified list type in call to `Self::get_type` let value_offsets = unsafe { RawPtrBox::new(offsets) }; Ok(Self { data, @@ -362,6 +364,7 @@ pub type LargeListArray = GenericListArray; #[cfg(test)] mod tests { use super::*; + use crate::builder::{Int32Builder, ListBuilder}; use crate::types::Int32Type; use crate::Int32Array; use arrow_buffer::{bit_util, Buffer, ToByteSlice}; @@ -820,6 +823,18 @@ mod tests { drop(ListArray::from(list_data)); } + #[test] + #[should_panic( + expected = "[Large]ListArray's datatype must be [Large]ListArray(). It is List" + )] + fn test_from_array_data_validation() { + let mut builder = ListBuilder::new(Int32Builder::new()); + builder.values().append_value(1); + builder.append(true); + let array = builder.finish(); + let _ = LargeListArray::from(array.into_data()); + } + #[test] fn test_list_array_offsets_need_not_start_at_zero() { let value_data = ArrayData::builder(DataType::Int32) diff --git a/arrow-array/src/array/map_array.rs b/arrow-array/src/array/map_array.rs index bfe8d407274..0f3ae2e689a 100644 --- a/arrow-array/src/array/map_array.rs +++ b/arrow-array/src/array/map_array.rs @@ -109,6 +109,12 @@ impl From for ArrayData { impl MapArray { fn try_new_from_array_data(data: ArrayData) -> Result { + assert!( + matches!(data.data_type(), DataType::Map(_, _)), + "MapArray expected ArrayData with DataType::Map got {}", + data.data_type() + ); + if data.buffers().len() != 1 { return Err(ArrowError::InvalidArgumentError( format!("MapArray data should contain a single buffer only (value offsets), had {}", @@ -141,6 +147,8 @@ impl MapArray { let values = make_array(entries); let value_offsets = data.buffers()[0].as_ptr(); + // SAFETY: + // ArrayData is valid, and verified type above let value_offsets = unsafe { RawPtrBox::::new(value_offsets) }; unsafe { if (*value_offsets.as_ptr().offset(0)) != 0 { @@ -467,6 +475,21 @@ mod tests { map_array.value(map_array.len()); } + #[test] + #[should_panic( + expected = "MapArray expected ArrayData with DataType::Map got Dictionary" + )] + fn test_from_array_data_validation() { + // A DictionaryArray has similar buffer layout to a MapArray + // but the meaning of the values differs + let struct_t = DataType::Struct(vec![ + Field::new("keys", DataType::Int32, true), + Field::new("values", DataType::UInt32, true), + ]); + let dict_t = DataType::Dictionary(Box::new(DataType::Int32), Box::new(struct_t)); + let _ = MapArray::from(ArrayData::new_empty(&dict_t)); + } + #[test] fn test_new_from_strings() { let keys = vec!["a", "b", "c", "d", "e", "f", "g", "h"]; diff --git a/arrow-array/src/array/primitive_array.rs b/arrow-array/src/array/primitive_array.rs index 928135463cc..895c80b0753 100644 --- a/arrow-array/src/array/primitive_array.rs +++ b/arrow-array/src/array/primitive_array.rs @@ -818,6 +818,14 @@ impl PrimitiveArray { /// Constructs a `PrimitiveArray` from an array data reference. impl From for PrimitiveArray { fn from(data: ArrayData) -> Self { + // Use discriminant to allow for decimals + assert_eq!( + std::mem::discriminant(&T::DATA_TYPE), + std::mem::discriminant(data.data_type()), + "PrimitiveArray expected ArrayData with type {} got {}", + T::DATA_TYPE, + data.data_type() + ); assert_eq!( data.buffers().len(), 1, @@ -827,6 +835,8 @@ impl From for PrimitiveArray { let ptr = data.buffers()[0].as_ptr(); Self { data, + // SAFETY: + // ArrayData must be valid, and validated data type above raw_values: unsafe { RawPtrBox::new(ptr) }, } } @@ -1352,6 +1362,15 @@ mod tests { array.value(4); } + #[test] + #[should_panic( + expected = "PrimitiveArray expected ArrayData with type Int64 got Int32" + )] + fn test_from_array_data_validation() { + let foo = PrimitiveArray::::from_iter([1, 2, 3]); + let _ = PrimitiveArray::::from(foo.into_data()); + } + #[test] fn test_decimal128() { let values: Vec<_> = vec![0, 1, -1, i128::MIN, i128::MAX]; diff --git a/arrow/src/compute/kernels/cast.rs b/arrow/src/compute/kernels/cast.rs index b573c65d026..49a9b18d85f 100644 --- a/arrow/src/compute/kernels/cast.rs +++ b/arrow/src/compute/kernels/cast.rs @@ -1312,15 +1312,16 @@ pub fn cast_with_options( )), (Timestamp(from_unit, _), Timestamp(to_unit, to_tz)) => { - let time_array = Int64Array::from(array.data().clone()); + let array = cast_with_options(array, &Int64, cast_options)?; + let time_array = as_primitive_array::(array.as_ref()); let from_size = time_unit_multiple(from_unit); let to_size = time_unit_multiple(to_unit); // we either divide or multiply, depending on size of each unit // units are never the same when the types are the same let converted = if from_size >= to_size { - divide_scalar(&time_array, from_size / to_size)? + divide_scalar(time_array, from_size / to_size)? } else { - multiply_scalar(&time_array, to_size / from_size)? + multiply_scalar(time_array, to_size / from_size)? }; Ok(make_timestamp_array( &converted, @@ -1329,10 +1330,10 @@ pub fn cast_with_options( )) } (Timestamp(from_unit, _), Date32) => { - let time_array = Int64Array::from(array.data().clone()); + let array = cast_with_options(array, &Int64, cast_options)?; + let time_array = as_primitive_array::(array.as_ref()); let from_size = time_unit_multiple(from_unit) * SECONDS_IN_DAY; - // Int32Array::from_iter(tim.iter) let mut b = Date32Builder::with_capacity(array.len()); for i in 0..array.len() { diff --git a/arrow/src/compute/kernels/take.rs b/arrow/src/compute/kernels/take.rs index 1aa4473c044..b9cfae516f8 100644 --- a/arrow/src/compute/kernels/take.rs +++ b/arrow/src/compute/kernels/take.rs @@ -1398,7 +1398,7 @@ mod tests { fn test_take_bool_nullable_index() { // indices where the masked invalid elements would be out of bounds let index_data = ArrayData::try_new( - DataType::Int32, + DataType::UInt32, 6, Some(Buffer::from_iter(vec![ false, true, false, true, false, true, @@ -1421,7 +1421,7 @@ mod tests { fn test_take_bool_nullable_index_nonnull_values() { // indices where the masked invalid elements would be out of bounds let index_data = ArrayData::try_new( - DataType::Int32, + DataType::UInt32, 6, Some(Buffer::from_iter(vec![ false, true, false, true, false, true, diff --git a/parquet/src/arrow/array_reader/primitive_array.rs b/parquet/src/arrow/array_reader/primitive_array.rs index d4f96e6a8d6..5fc5e639de9 100644 --- a/parquet/src/arrow/array_reader/primitive_array.rs +++ b/parquet/src/arrow/array_reader/primitive_array.rs @@ -26,7 +26,8 @@ use crate::errors::{ParquetError, Result}; use crate::schema::types::ColumnDescPtr; use arrow::array::{ ArrayDataBuilder, ArrayRef, BooleanArray, BooleanBufferBuilder, Decimal128Array, - Float32Array, Float64Array, Int32Array, Int64Array,TimestampNanosecondArray, TimestampNanosecondBufferBuilder, + Float32Array, Float64Array, Int32Array, Int64Array, TimestampNanosecondArray, + TimestampNanosecondBufferBuilder, UInt32Array, UInt64Array, }; use arrow::buffer::Buffer; use arrow::datatypes::{DataType as ArrowType, TimeUnit}; @@ -169,15 +170,21 @@ where .null_bit_buffer(self.record_reader.consume_bitmap_buffer()); let array_data = unsafe { array_data.build_unchecked() }; - let array = match T::get_physical_type() { - PhysicalType::BOOLEAN => Arc::new(BooleanArray::from(array_data)) as ArrayRef, - PhysicalType::INT32 => Arc::new(Int32Array::from(array_data)) as ArrayRef, - PhysicalType::INT64 => Arc::new(Int64Array::from(array_data)) as ArrayRef, - PhysicalType::FLOAT => Arc::new(Float32Array::from(array_data)) as ArrayRef, - PhysicalType::DOUBLE => Arc::new(Float64Array::from(array_data)) as ArrayRef, - PhysicalType::INT96 => { - Arc::new(TimestampNanosecondArray::from(array_data)) as ArrayRef - } + let array: ArrayRef = match T::get_physical_type() { + PhysicalType::BOOLEAN => Arc::new(BooleanArray::from(array_data)), + PhysicalType::INT32 => match array_data.data_type() { + ArrowType::UInt32 => Arc::new(UInt32Array::from(array_data)), + ArrowType::Int32 => Arc::new(Int32Array::from(array_data)), + _ => unreachable!(), + }, + PhysicalType::INT64 => match array_data.data_type() { + ArrowType::UInt64 => Arc::new(UInt64Array::from(array_data)), + ArrowType::Int64 => Arc::new(Int64Array::from(array_data)), + _ => unreachable!(), + }, + PhysicalType::FLOAT => Arc::new(Float32Array::from(array_data)), + PhysicalType::DOUBLE => Arc::new(Float64Array::from(array_data)), + PhysicalType::INT96 => Arc::new(TimestampNanosecondArray::from(array_data)), PhysicalType::BYTE_ARRAY | PhysicalType::FIXED_LEN_BYTE_ARRAY => { unreachable!( "PrimitiveArrayReaders don't support complex physical types"