diff --git a/arrow-arith/src/arity.rs b/arrow-arith/src/arity.rs index e89fe7b914a..3e7a8186292 100644 --- a/arrow-arith/src/arity.rs +++ b/arrow-arith/src/arity.rs @@ -114,9 +114,7 @@ where T: ArrowPrimitiveType, F: Fn(T::Native) -> Result, { - if std::mem::discriminant(&array.value_type()) - != std::mem::discriminant(&T::DATA_TYPE) - { + if !PrimitiveArray::::is_compatible(&array.value_type()) { return Err(ArrowError::CastError(format!( "Cannot perform the unary operation of type {} on dictionary array of value type {}", T::DATA_TYPE, @@ -138,7 +136,7 @@ where downcast_dictionary_array! { array => unary_dict::<_, F, T>(array, op), t => { - if std::mem::discriminant(t) == std::mem::discriminant(&T::DATA_TYPE) { + if PrimitiveArray::::is_compatible(t) { Ok(Arc::new(unary::( array.as_any().downcast_ref::>().unwrap(), op, @@ -170,7 +168,7 @@ where ))) }, t => { - if std::mem::discriminant(t) == std::mem::discriminant(&T::DATA_TYPE) { + if PrimitiveArray::::is_compatible(t) { Ok(Arc::new(try_unary::( array.as_any().downcast_ref::>().unwrap(), op, diff --git a/arrow-array/src/array/primitive_array.rs b/arrow-array/src/array/primitive_array.rs index 4ff0ed4d93e..01eda724ba4 100644 --- a/arrow-array/src/array/primitive_array.rs +++ b/arrow-array/src/array/primitive_array.rs @@ -297,6 +297,21 @@ impl PrimitiveArray { PrimitiveBuilder::::with_capacity(capacity) } + /// Returns if this [`PrimitiveArray`] is compatible with the provided [`DataType`] + /// + /// This is equivalent to `data_type == T::DATA_TYPE`, however ignores timestamp + /// timezones and decimal precision and scale + pub fn is_compatible(data_type: &DataType) -> bool { + match T::DATA_TYPE { + DataType::Timestamp(t1, _) => { + matches!(data_type, DataType::Timestamp(t2, _) if &t1 == t2) + } + DataType::Decimal128(_, _) => matches!(data_type, DataType::Decimal128(_, _)), + DataType::Decimal256(_, _) => matches!(data_type, DataType::Decimal256(_, _)), + _ => T::DATA_TYPE.eq(data_type), + } + } + /// Returns the primitive value at index `i`. /// /// # Safety @@ -1042,10 +1057,8 @@ impl PrimitiveArray { /// Constructs a `PrimitiveArray` from an array data reference. impl From for PrimitiveArray { fn from(data: ArrayData) -> Self { - // Use discriminant to allow for decimals - assert_eq!( - std::mem::discriminant(&T::DATA_TYPE), - std::mem::discriminant(data.data_type()), + assert!( + Self::is_compatible(data.data_type()), "PrimitiveArray expected ArrayData with type {} got {}", T::DATA_TYPE, data.data_type() @@ -2205,4 +2218,13 @@ mod tests { let c = array.unary_mut(|x| x * 2 + 1).unwrap(); assert_eq!(c, Int32Array::from(vec![Some(11), Some(15), None])); } + + #[test] + #[should_panic( + expected = "PrimitiveArray expected ArrayData with type Interval(MonthDayNano) got Interval(DayTime)" + )] + fn test_invalid_interval_type() { + let array = IntervalDayTimeArray::from(vec![1, 2, 3]); + let _ = IntervalMonthDayNanoArray::from(array.into_data()); + } } diff --git a/arrow-row/src/dictionary.rs b/arrow-row/src/dictionary.rs index 0da6c68d168..e332e11316f 100644 --- a/arrow-row/src/dictionary.rs +++ b/arrow-row/src/dictionary.rs @@ -270,10 +270,7 @@ fn decode_primitive( where T::Native: FixedLengthEncoding, { - assert_eq!( - std::mem::discriminant(&T::DATA_TYPE), - std::mem::discriminant(&data_type), - ); + assert!(PrimitiveArray::::is_compatible(&data_type)); // SAFETY: // Validated data type above diff --git a/arrow-row/src/fixed.rs b/arrow-row/src/fixed.rs index 159eba9adf1..d4b82c2a398 100644 --- a/arrow-row/src/fixed.rs +++ b/arrow-row/src/fixed.rs @@ -343,10 +343,7 @@ pub fn decode_primitive( where T::Native: FixedLengthEncoding, { - assert_eq!( - std::mem::discriminant(&T::DATA_TYPE), - std::mem::discriminant(&data_type), - ); + assert!(PrimitiveArray::::is_compatible(&data_type)); // SAFETY: // Validated data type above unsafe { decode_fixed::(rows, data_type, options).into() }