diff --git a/arrow/src/array/cast.rs b/arrow/src/array/cast.rs index fd57133bc84..2b68cbbe642 100644 --- a/arrow/src/array/cast.rs +++ b/arrow/src/array/cast.rs @@ -20,6 +20,213 @@ use crate::array::*; use crate::datatypes::*; +/// Downcast an [`Array`] to a [`PrimitiveArray`] based on its [`DataType`], accepts +/// a number of subsequent patterns to match the data type +/// +/// ``` +/// # use arrow::downcast_primitive_array; +/// # use arrow::array::Array; +/// # use arrow::datatypes::DataType; +/// # use arrow::array::as_string_array; +/// +/// fn print_primitive(array: &dyn Array) { +/// downcast_primitive_array!( +/// array => { +/// for v in array { +/// println!("{:?}", v); +/// } +/// } +/// DataType::Utf8 => { +/// for v in as_string_array(array) { +/// println!("{:?}", v); +/// } +/// } +/// t => println!("Unsupported datatype {}", t) +/// ) +/// } +/// ``` +/// +#[macro_export] +macro_rules! downcast_primitive_array { + ($values:ident => $e:expr, $($p:pat => $fallback:expr $(,)*)*) => { + downcast_primitive_array!($values => {$e} $($p => $fallback)*) + }; + + ($values:ident => $e:block $($p:pat => $fallback:expr $(,)*)*) => { + match $values.data_type() { + $crate::datatypes::DataType::Int8 => { + let $values = $crate::array::as_primitive_array::< + $crate::datatypes::Int8Type, + >($values); + $e + } + $crate::datatypes::DataType::Int16 => { + let $values = $crate::array::as_primitive_array::< + $crate::datatypes::Int16Type, + >($values); + $e + } + $crate::datatypes::DataType::Int32 => { + let $values = $crate::array::as_primitive_array::< + $crate::datatypes::Int32Type, + >($values); + $e + } + $crate::datatypes::DataType::Int64 => { + let $values = $crate::array::as_primitive_array::< + $crate::datatypes::Int64Type, + >($values); + $e + } + $crate::datatypes::DataType::UInt8 => { + let $values = $crate::array::as_primitive_array::< + $crate::datatypes::UInt8Type, + >($values); + $e + } + $crate::datatypes::DataType::UInt16 => { + let $values = $crate::array::as_primitive_array::< + $crate::datatypes::UInt16Type, + >($values); + $e + } + $crate::datatypes::DataType::UInt32 => { + let $values = $crate::array::as_primitive_array::< + $crate::datatypes::UInt32Type, + >($values); + $e + } + $crate::datatypes::DataType::UInt64 => { + let $values = $crate::array::as_primitive_array::< + $crate::datatypes::UInt64Type, + >($values); + $e + } + $crate::datatypes::DataType::Float16 => { + let $values = $crate::array::as_primitive_array::< + $crate::datatypes::Float16Type, + >($values); + $e + } + $crate::datatypes::DataType::Float32 => { + let $values = $crate::array::as_primitive_array::< + $crate::datatypes::Float32Type, + >($values); + $e + } + $crate::datatypes::DataType::Float64 => { + let $values = $crate::array::as_primitive_array::< + $crate::datatypes::Float64Type, + >($values); + $e + } + $crate::datatypes::DataType::Date32 => { + let $values = $crate::array::as_primitive_array::< + $crate::datatypes::Date32Type, + >($values); + $e + } + $crate::datatypes::DataType::Date64 => { + let $values = $crate::array::as_primitive_array::< + $crate::datatypes::Date64Type, + >($values); + $e + } + $crate::datatypes::DataType::Time32($crate::datatypes::TimeUnit::Second) => { + let $values = $crate::array::as_primitive_array::< + $crate::datatypes::Time32SecondType, + >($values); + $e + } + $crate::datatypes::DataType::Time32($crate::datatypes::TimeUnit::Millisecond) => { + let $values = $crate::array::as_primitive_array::< + $crate::datatypes::Time32MillisecondType, + >($values); + $e + } + $crate::datatypes::DataType::Time64($crate::datatypes::TimeUnit::Microsecond) => { + let $values = $crate::array::as_primitive_array::< + $crate::datatypes::Time64MicrosecondType, + >($values); + $e + } + $crate::datatypes::DataType::Time64($crate::datatypes::TimeUnit::Nanosecond) => { + let $values = $crate::array::as_primitive_array::< + $crate::datatypes::Time64NanosecondType, + >($values); + $e + } + $crate::datatypes::DataType::Timestamp($crate::datatypes::TimeUnit::Second, _) => { + let $values = $crate::array::as_primitive_array::< + $crate::datatypes::TimestampSecondType, + >($values); + $e + } + $crate::datatypes::DataType::Timestamp($crate::datatypes::TimeUnit::Millisecond, _) => { + let $values = $crate::array::as_primitive_array::< + $crate::datatypes::TimestampMillisecondType, + >($values); + $e + } + $crate::datatypes::DataType::Timestamp($crate::datatypes::TimeUnit::Microsecond, _) => { + let $values = $crate::array::as_primitive_array::< + $crate::datatypes::TimestampMicrosecondType, + >($values); + $e + } + $crate::datatypes::DataType::Timestamp($crate::datatypes::TimeUnit::Nanosecond, _) => { + let $values = $crate::array::as_primitive_array::< + $crate::datatypes::TimestampNanosecondType, + >($values); + $e + } + $crate::datatypes::DataType::Interval($crate::datatypes::IntervalUnit::YearMonth) => { + let $values = $crate::array::as_primitive_array::< + $crate::datatypes::IntervalYearMonthType, + >($values); + $e + } + $crate::datatypes::DataType::Interval($crate::datatypes::IntervalUnit::DayTime) => { + let $values = $crate::array::as_primitive_array::< + $crate::datatypes::IntervalDayTimeType, + >($values); + $e + } + $crate::datatypes::DataType::Interval($crate::datatypes::IntervalUnit::MonthDayNano) => { + let $values = $crate::array::as_primitive_array::< + $crate::datatypes::IntervalMonthDayNanoType, + >($values); + $e + } + $crate::datatypes::DataType::Duration($crate::datatypes::TimeUnit::Second) => { + let $values = $crate::array::as_primitive_array::< + $crate::datatypes::DurationSecondType, + >($values); + $e + } + $crate::datatypes::DataType::Duration($crate::datatypes::TimeUnit::Millisecond) => { + let $values = $crate::array::as_primitive_array::< + $crate::datatypes::DurationMillisecondType, + >($values); + $e + } + $crate::datatypes::DataType::Duration($crate::datatypes::TimeUnit::Microsecond) => { + let $values = $crate::array::as_primitive_array::< + $crate::datatypes::DurationMicrosecondType, + >($values); + $e + } + $crate::datatypes::DataType::Duration($crate::datatypes::TimeUnit::Nanosecond) => { + let $values = $crate::array::as_primitive_array::< + $crate::datatypes::DurationNanosecondType, + >($values); + $e + } + $($p => $fallback,)* + } + }; +} + /// Force downcast of an [`Array`], such as an [`ArrayRef`], to /// [`PrimitiveArray`], panic'ing on failure. /// @@ -53,6 +260,98 @@ where .expect("Unable to downcast to primitive array") } +/// Downcast an [`Array`] to a [`DictionaryArray`] based on its [`DataType`], accepts +/// a number of subsequent patterns to match the data type +/// +/// ``` +/// # use arrow::downcast_dictionary_array; +/// # use arrow::array::{Array, StringArray}; +/// # use arrow::datatypes::DataType; +/// # use arrow::array::as_string_array; +/// +/// fn print_strings(array: &dyn Array) { +/// downcast_dictionary_array!( +/// array => match array.values().data_type() { +/// DataType::Utf8 => { +/// for v in array.downcast_dict::().unwrap() { +/// println!("{:?}", v); +/// } +/// } +/// t => println!("Unsupported dictionary value type {}", t), +/// }, +/// DataType::Utf8 => { +/// for v in as_string_array(array) { +/// println!("{:?}", v); +/// } +/// } +/// t => println!("Unsupported datatype {}", t) +/// ) +/// } +/// ``` +#[macro_export] +macro_rules! downcast_dictionary_array { + ($values:ident => $e:expr, $($p:pat => $fallback:expr $(,)*)*) => { + downcast_dictionary_array!($values => {$e} $($p => $fallback)*) + }; + + ($values:ident => $e:block $($p:pat => $fallback:expr $(,)*)*) => { + match $values.data_type() { + $crate::datatypes::DataType::Dictionary(k, _) => match k.as_ref() { + $crate::datatypes::DataType::Int8 => { + let $values = $crate::array::as_dictionary_array::< + $crate::datatypes::Int8Type, + >($values); + $e + }, + $crate::datatypes::DataType::Int16 => { + let $values = $crate::array::as_dictionary_array::< + $crate::datatypes::Int16Type, + >($values); + $e + }, + $crate::datatypes::DataType::Int32 => { + let $values = $crate::array::as_dictionary_array::< + $crate::datatypes::Int32Type, + >($values); + $e + }, + $crate::datatypes::DataType::Int64 => { + let $values = $crate::array::as_dictionary_array::< + $crate::datatypes::Int64Type, + >($values); + $e + }, + $crate::datatypes::DataType::UInt8 => { + let $values = $crate::array::as_dictionary_array::< + $crate::datatypes::UInt8Type, + >($values); + $e + }, + $crate::datatypes::DataType::UInt16 => { + let $values = $crate::array::as_dictionary_array::< + $crate::datatypes::UInt16Type, + >($values); + $e + }, + $crate::datatypes::DataType::UInt32 => { + let $values = $crate::array::as_dictionary_array::< + $crate::datatypes::UInt32Type, + >($values); + $e + }, + $crate::datatypes::DataType::UInt64 => { + let $values = $crate::array::as_dictionary_array::< + $crate::datatypes::UInt64Type, + >($values); + $e + }, + k => unreachable!("unsupported dictionary key type: {}", k) + } + $($p => $fallback,)* + } + } +} + /// Force downcast of an [`Array`], such as an [`ArrayRef`] to /// [`DictionaryArray`], panic'ing on failure. /// diff --git a/arrow/src/compute/kernels/filter.rs b/arrow/src/compute/kernels/filter.rs index 81be3a1d172..52664a17544 100644 --- a/arrow/src/compute/kernels/filter.rs +++ b/arrow/src/compute/kernels/filter.rs @@ -22,8 +22,6 @@ use std::sync::Arc; use num::Zero; -use TimeUnit::*; - use crate::array::*; use crate::buffer::{buffer_bin_and, Buffer, MutableBuffer}; use crate::datatypes::*; @@ -31,6 +29,7 @@ use crate::error::{ArrowError, Result}; use crate::record_batch::RecordBatch; use crate::util::bit_iterator::{BitIndexIterator, BitSliceIterator}; use crate::util::bit_util; +use crate::{downcast_dictionary_array, downcast_primitive_array}; /// If the filter selects more than this fraction of rows, use /// [`SlicesIterator`] to copy ranges of values. Otherwise iterate @@ -40,27 +39,6 @@ use crate::util::bit_util; /// const FILTER_SLICES_SELECTIVITY_THRESHOLD: f64 = 0.8; -macro_rules! downcast_filter { - ($type: ty, $values: expr, $filter: expr) => {{ - let values = $values - .as_any() - .downcast_ref::>() - .expect("Unable to downcast to a primitive array"); - - Ok(Arc::new(filter_primitive::<$type>(&values, $filter))) - }}; -} - -macro_rules! downcast_dict_filter { - ($type: ty, $values: expr, $filter: expr) => {{ - let values = $values - .as_any() - .downcast_ref::>() - .expect("Unable to downcast to a dictionary array"); - Ok(Arc::new(filter_dict::<$type>(values, $filter))) - }}; -} - /// An iterator of `(usize, usize)` each representing an interval /// `[start, end)` whose slots of a [BooleanArray] are true. Each /// interval corresponds to a contiguous region of memory to be @@ -358,92 +336,12 @@ fn filter_array(values: &dyn Array, predicate: &FilterPredicate) -> Result Ok(new_empty_array(values.data_type())), IterationStrategy::All => Ok(make_array(values.data().slice(0, predicate.count))), // actually filter - _ => match values.data_type() { + _ => downcast_primitive_array! { + values => Ok(Arc::new(filter_primitive(values, predicate))), DataType::Boolean => { let values = values.as_any().downcast_ref::().unwrap(); Ok(Arc::new(filter_boolean(values, predicate))) } - DataType::Int8 => { - downcast_filter!(Int8Type, values, predicate) - } - DataType::Int16 => { - downcast_filter!(Int16Type, values, predicate) - } - DataType::Int32 => { - downcast_filter!(Int32Type, values, predicate) - } - DataType::Int64 => { - downcast_filter!(Int64Type, values, predicate) - } - DataType::UInt8 => { - downcast_filter!(UInt8Type, values, predicate) - } - DataType::UInt16 => { - downcast_filter!(UInt16Type, values, predicate) - } - DataType::UInt32 => { - downcast_filter!(UInt32Type, values, predicate) - } - DataType::UInt64 => { - downcast_filter!(UInt64Type, values, predicate) - } - DataType::Float32 => { - downcast_filter!(Float32Type, values, predicate) - } - DataType::Float64 => { - downcast_filter!(Float64Type, values, predicate) - } - DataType::Date32 => { - downcast_filter!(Date32Type, values, predicate) - } - DataType::Date64 => { - downcast_filter!(Date64Type, values, predicate) - } - DataType::Time32(Second) => { - downcast_filter!(Time32SecondType, values, predicate) - } - DataType::Time32(Millisecond) => { - downcast_filter!(Time32MillisecondType, values, predicate) - } - DataType::Time64(Microsecond) => { - downcast_filter!(Time64MicrosecondType, values, predicate) - } - DataType::Time64(Nanosecond) => { - downcast_filter!(Time64NanosecondType, values, predicate) - } - DataType::Timestamp(Second, _) => { - downcast_filter!(TimestampSecondType, values, predicate) - } - DataType::Timestamp(Millisecond, _) => { - downcast_filter!(TimestampMillisecondType, values, predicate) - } - DataType::Timestamp(Microsecond, _) => { - downcast_filter!(TimestampMicrosecondType, values, predicate) - } - DataType::Timestamp(Nanosecond, _) => { - downcast_filter!(TimestampNanosecondType, values, predicate) - } - DataType::Interval(IntervalUnit::YearMonth) => { - downcast_filter!(IntervalYearMonthType, values, predicate) - } - DataType::Interval(IntervalUnit::DayTime) => { - downcast_filter!(IntervalDayTimeType, values, predicate) - } - DataType::Interval(IntervalUnit::MonthDayNano) => { - downcast_filter!(IntervalMonthDayNanoType, values, predicate) - } - DataType::Duration(TimeUnit::Second) => { - downcast_filter!(DurationSecondType, values, predicate) - } - DataType::Duration(TimeUnit::Millisecond) => { - downcast_filter!(DurationMillisecondType, values, predicate) - } - DataType::Duration(TimeUnit::Microsecond) => { - downcast_filter!(DurationMicrosecondType, values, predicate) - } - DataType::Duration(TimeUnit::Nanosecond) => { - downcast_filter!(DurationNanosecondType, values, predicate) - } DataType::Utf8 => { let values = values .as_any() @@ -458,19 +356,10 @@ fn filter_array(values: &dyn Array, predicate: &FilterPredicate) -> Result(values, predicate))) } - DataType::Dictionary(key_type, _) => match key_type.as_ref() { - DataType::Int8 => downcast_dict_filter!(Int8Type, values, predicate), - DataType::Int16 => downcast_dict_filter!(Int16Type, values, predicate), - DataType::Int32 => downcast_dict_filter!(Int32Type, values, predicate), - DataType::Int64 => downcast_dict_filter!(Int64Type, values, predicate), - DataType::UInt8 => downcast_dict_filter!(UInt8Type, values, predicate), - DataType::UInt16 => downcast_dict_filter!(UInt16Type, values, predicate), - DataType::UInt32 => downcast_dict_filter!(UInt32Type, values, predicate), - DataType::UInt64 => downcast_dict_filter!(UInt64Type, values, predicate), - t => { - unimplemented!("Filter not supported for dictionary key type {:?}", t) - } - }, + DataType::Dictionary(_, _) => downcast_dictionary_array! { + values => Ok(Arc::new(filter_dict(values, predicate))), + t => unimplemented!("Filter not supported for dictionary type {:?}", t) + } _ => { // fallback to using MutableArrayData let mut mutable = MutableArrayData::new( diff --git a/arrow/src/compute/kernels/take.rs b/arrow/src/compute/kernels/take.rs index 3272c84549f..19eb1b17ca2 100644 --- a/arrow/src/compute/kernels/take.rs +++ b/arrow/src/compute/kernels/take.rs @@ -26,30 +26,11 @@ use crate::compute::util::{ use crate::datatypes::*; use crate::error::{ArrowError, Result}; use crate::util::bit_util; -use crate::{array::*, buffer::buffer_bin_and}; +use crate::{ + array::*, buffer::buffer_bin_and, downcast_dictionary_array, downcast_primitive_array, +}; use num::{ToPrimitive, Zero}; -use TimeUnit::*; - -macro_rules! downcast_take { - ($type: ty, $values: expr, $indices: expr) => {{ - let values = $values - .as_any() - .downcast_ref::>() - .expect("Unable to downcast to a primitive array"); - Ok(Arc::new(take_primitive::<$type, _>(&values, $indices)?)) - }}; -} - -macro_rules! downcast_dict_take { - ($type: ty, $values: expr, $indices: expr) => {{ - let values = $values - .as_any() - .downcast_ref::>() - .expect("Unable to downcast to a dictionary array"); - Ok(Arc::new(take_dict::<$type, _>(values, $indices)?)) - }}; -} /// Take elements by index from [Array], creating a new [Array] from those indexes. /// @@ -141,7 +122,9 @@ where })? } } - match values.data_type() { + + downcast_primitive_array! { + values => Ok(Arc::new(take_primitive(values, indices)?)), DataType::Boolean => { let values = values.as_any().downcast_ref::().unwrap(); Ok(Arc::new(take_boolean(values, indices)?)) @@ -151,61 +134,6 @@ where values.as_any().downcast_ref::().unwrap(); Ok(Arc::new(take_decimal128(decimal_values, indices)?)) } - DataType::Int8 => downcast_take!(Int8Type, values, indices), - DataType::Int16 => downcast_take!(Int16Type, values, indices), - DataType::Int32 => downcast_take!(Int32Type, values, indices), - DataType::Int64 => downcast_take!(Int64Type, values, indices), - DataType::UInt8 => downcast_take!(UInt8Type, values, indices), - DataType::UInt16 => downcast_take!(UInt16Type, values, indices), - DataType::UInt32 => downcast_take!(UInt32Type, values, indices), - DataType::UInt64 => downcast_take!(UInt64Type, values, indices), - DataType::Float32 => downcast_take!(Float32Type, values, indices), - DataType::Float64 => downcast_take!(Float64Type, values, indices), - DataType::Date32 => downcast_take!(Date32Type, values, indices), - DataType::Date64 => downcast_take!(Date64Type, values, indices), - DataType::Time32(Second) => downcast_take!(Time32SecondType, values, indices), - DataType::Time32(Millisecond) => { - downcast_take!(Time32MillisecondType, values, indices) - } - DataType::Time64(Microsecond) => { - downcast_take!(Time64MicrosecondType, values, indices) - } - DataType::Time64(Nanosecond) => { - downcast_take!(Time64NanosecondType, values, indices) - } - DataType::Timestamp(Second, _) => { - downcast_take!(TimestampSecondType, values, indices) - } - DataType::Timestamp(Millisecond, _) => { - downcast_take!(TimestampMillisecondType, values, indices) - } - DataType::Timestamp(Microsecond, _) => { - downcast_take!(TimestampMicrosecondType, values, indices) - } - DataType::Timestamp(Nanosecond, _) => { - downcast_take!(TimestampNanosecondType, values, indices) - } - DataType::Interval(IntervalUnit::YearMonth) => { - downcast_take!(IntervalYearMonthType, values, indices) - } - DataType::Interval(IntervalUnit::DayTime) => { - downcast_take!(IntervalDayTimeType, values, indices) - } - DataType::Interval(IntervalUnit::MonthDayNano) => { - downcast_take!(IntervalMonthDayNanoType, values, indices) - } - DataType::Duration(TimeUnit::Second) => { - downcast_take!(DurationSecondType, values, indices) - } - DataType::Duration(TimeUnit::Millisecond) => { - downcast_take!(DurationMillisecondType, values, indices) - } - DataType::Duration(TimeUnit::Microsecond) => { - downcast_take!(DurationMicrosecondType, values, indices) - } - DataType::Duration(TimeUnit::Nanosecond) => { - downcast_take!(DurationNanosecondType, values, indices) - } DataType::Utf8 => { let values = values .as_any() @@ -271,17 +199,10 @@ where Ok(Arc::new(StructArray::from((fields, is_valid))) as ArrayRef) } - DataType::Dictionary(key_type, _) => match key_type.as_ref() { - DataType::Int8 => downcast_dict_take!(Int8Type, values, indices), - DataType::Int16 => downcast_dict_take!(Int16Type, values, indices), - DataType::Int32 => downcast_dict_take!(Int32Type, values, indices), - DataType::Int64 => downcast_dict_take!(Int64Type, values, indices), - DataType::UInt8 => downcast_dict_take!(UInt8Type, values, indices), - DataType::UInt16 => downcast_dict_take!(UInt16Type, values, indices), - DataType::UInt32 => downcast_dict_take!(UInt32Type, values, indices), - DataType::UInt64 => downcast_dict_take!(UInt64Type, values, indices), - t => unimplemented!("Take not supported for dictionary key type {:?}", t), - }, + DataType::Dictionary(_, _) => downcast_dictionary_array! { + values => Ok(Arc::new(take_dict(values, indices)?)), + t => unimplemented!("Take not supported for dictionary type {:?}", t) + } DataType::Binary => { let values = values .as_any() @@ -314,7 +235,7 @@ where Ok(new_null_array(&DataType::Null, indices.len())) } } - t => unimplemented!("Take not supported for data type {:?}", t), + t => unimplemented!("Take not supported for data type {:?}", t) } }