From 0fdbfff6df1df963cf07fa2d57ec8ce8e1693daf Mon Sep 17 00:00:00 2001 From: Raphael Taylor-Davies Date: Fri, 2 Sep 2022 12:08:47 +0100 Subject: [PATCH 1/3] Add downcast macros (#2635) --- arrow/src/array/cast.rs | 285 ++++++++++++++++++++++++++++ arrow/src/compute/kernels/filter.rs | 125 +----------- arrow/src/compute/kernels/take.rs | 101 ++-------- 3 files changed, 303 insertions(+), 208 deletions(-) diff --git a/arrow/src/array/cast.rs b/arrow/src/array/cast.rs index fd57133bc84..bac97c23c9c 100644 --- a/arrow/src/array/cast.rs +++ b/arrow/src/array/cast.rs @@ -20,6 +20,207 @@ use crate::array::*; use crate::datatypes::*; +/// Downcast an [`Array`] to a [`PrimitiveArray`] based on its [`DataType`], accepts +/// a number of subsequent patterns to match the data type +/// +/// ``` +/// # use arrow::downcast_primitive_array; +/// # use arrow::array::Array; +/// # use arrow::datatypes::DataType; +/// # use arrow::array::as_string_array; +/// +/// fn print_primitive(array: &dyn Array) { +/// downcast_primitive_array!( +/// array => { +/// for v in array { +/// println!("{:?}", v); +/// } +/// } +/// DataType::Utf8 => { +/// for v in as_string_array(array) { +/// println!("{:?}", v); +/// } +/// } +/// t => println!("Unsupported datatype {}", t) +/// ) +/// } +/// ``` +/// +#[macro_export] +macro_rules! downcast_primitive_array { + ($values:ident => $e:expr, $($p:pat => $fallback:expr)*) => { + downcast_primitive_array!($values => {$e} $($p => $fallback)*) + }; + + ($values:ident => $e:block $($p:pat => $fallback:expr)*) => { + match $values.data_type() { + $crate::datatypes::DataType::Int8 => { + let $values = $crate::array::as_primitive_array::< + $crate::datatypes::Int8Type, + >($values); + $e + } + $crate::datatypes::DataType::Int16 => { + let $values = $crate::array::as_primitive_array::< + $crate::datatypes::Int16Type, + >($values); + $e + } + $crate::datatypes::DataType::Int32 => { + let $values = $crate::array::as_primitive_array::< + $crate::datatypes::Int32Type, + >($values); + $e + } + $crate::datatypes::DataType::Int64 => { + let $values = $crate::array::as_primitive_array::< + $crate::datatypes::Int64Type, + >($values); + $e + } + $crate::datatypes::DataType::UInt8 => { + let $values = $crate::array::as_primitive_array::< + $crate::datatypes::UInt8Type, + >($values); + $e + } + $crate::datatypes::DataType::UInt16 => { + let $values = $crate::array::as_primitive_array::< + $crate::datatypes::UInt16Type, + >($values); + $e + } + $crate::datatypes::DataType::UInt32 => { + let $values = $crate::array::as_primitive_array::< + $crate::datatypes::UInt32Type, + >($values); + $e + } + $crate::datatypes::DataType::UInt64 => { + let $values = $crate::array::as_primitive_array::< + $crate::datatypes::UInt64Type, + >($values); + $e + } + $crate::datatypes::DataType::Float32 => { + let $values = $crate::array::as_primitive_array::< + $crate::datatypes::Float32Type, + >($values); + $e + } + $crate::datatypes::DataType::Float64 => { + let $values = $crate::array::as_primitive_array::< + $crate::datatypes::Float64Type, + >($values); + $e + } + $crate::datatypes::DataType::Date32 => { + let $values = $crate::array::as_primitive_array::< + $crate::datatypes::Date32Type, + >($values); + $e + } + $crate::datatypes::DataType::Date64 => { + let $values = $crate::array::as_primitive_array::< + $crate::datatypes::Date64Type, + >($values); + $e + } + $crate::datatypes::DataType::Time32($crate::datatypes::TimeUnit::Second) => { + let $values = $crate::array::as_primitive_array::< + $crate::datatypes::Time32SecondType, + >($values); + $e + } + $crate::datatypes::DataType::Time32($crate::datatypes::TimeUnit::Millisecond) => { + let $values = $crate::array::as_primitive_array::< + $crate::datatypes::Time32MillisecondType, + >($values); + $e + } + $crate::datatypes::DataType::Time64($crate::datatypes::TimeUnit::Microsecond) => { + let $values = $crate::array::as_primitive_array::< + $crate::datatypes::Time64MicrosecondType, + >($values); + $e + } + $crate::datatypes::DataType::Time64($crate::datatypes::TimeUnit::Nanosecond) => { + let $values = $crate::array::as_primitive_array::< + $crate::datatypes::Time64NanosecondType, + >($values); + $e + } + $crate::datatypes::DataType::Timestamp($crate::datatypes::TimeUnit::Second, _) => { + let $values = $crate::array::as_primitive_array::< + $crate::datatypes::TimestampSecondType, + >($values); + $e + } + $crate::datatypes::DataType::Timestamp($crate::datatypes::TimeUnit::Millisecond, _) => { + let $values = $crate::array::as_primitive_array::< + $crate::datatypes::TimestampMillisecondType, + >($values); + $e + } + $crate::datatypes::DataType::Timestamp($crate::datatypes::TimeUnit::Microsecond, _) => { + let $values = $crate::array::as_primitive_array::< + $crate::datatypes::TimestampMicrosecondType, + >($values); + $e + } + $crate::datatypes::DataType::Timestamp($crate::datatypes::TimeUnit::Nanosecond, _) => { + let $values = $crate::array::as_primitive_array::< + $crate::datatypes::TimestampNanosecondType, + >($values); + $e + } + $crate::datatypes::DataType::Interval($crate::datatypes::IntervalUnit::YearMonth) => { + let $values = $crate::array::as_primitive_array::< + $crate::datatypes::IntervalYearMonthType, + >($values); + $e + } + $crate::datatypes::DataType::Interval($crate::datatypes::IntervalUnit::DayTime) => { + let $values = $crate::array::as_primitive_array::< + $crate::datatypes::IntervalDayTimeType, + >($values); + $e + } + $crate::datatypes::DataType::Interval($crate::datatypes::IntervalUnit::MonthDayNano) => { + let $values = $crate::array::as_primitive_array::< + $crate::datatypes::IntervalMonthDayNanoType, + >($values); + $e + } + $crate::datatypes::DataType::Duration($crate::datatypes::TimeUnit::Second) => { + let $values = $crate::array::as_primitive_array::< + $crate::datatypes::DurationSecondType, + >($values); + $e + } + $crate::datatypes::DataType::Duration($crate::datatypes::TimeUnit::Millisecond) => { + let $values = $crate::array::as_primitive_array::< + $crate::datatypes::DurationMillisecondType, + >($values); + $e + } + $crate::datatypes::DataType::Duration($crate::datatypes::TimeUnit::Microsecond) => { + let $values = $crate::array::as_primitive_array::< + $crate::datatypes::DurationMicrosecondType, + >($values); + $e + } + $crate::datatypes::DataType::Duration($crate::datatypes::TimeUnit::Nanosecond) => { + let $values = $crate::array::as_primitive_array::< + $crate::datatypes::DurationNanosecondType, + >($values); + $e + } + $($p => $fallback,)* + } + }; +} + /// Force downcast of an [`Array`], such as an [`ArrayRef`], to /// [`PrimitiveArray`], panic'ing on failure. /// @@ -53,6 +254,90 @@ where .expect("Unable to downcast to primitive array") } +/// Downcast an [`Array`] to a [`DictionaryArray`] based on its [`DataType`], accepts +/// a number of subsequent patterns to match the data type +/// +/// ``` +/// # use arrow::downcast_dict_array; +/// # use arrow::array::Array; +/// # use arrow::datatypes::DataType; +/// # use arrow::array::as_string_array; +/// +/// fn print_keys(array: &dyn Array) { +/// downcast_dict_array!( +/// array => { +/// for v in array.keys() { +/// println!("{:?}", v); +/// } +/// } +/// t => println!("Unsupported datatype {}", t) +/// ) +/// } +/// ``` +#[macro_export] +macro_rules! downcast_dict_array { + ($values:ident => $e:expr, $($p:pat => $fallback:expr)*) => { + downcast_dict_array!($values => {$e} $($p => $fallback)*) + }; + + ($values:ident => $e:block $($p:pat => $fallback:expr)*) => { + match $values.data_type() { + $crate::datatypes::DataType::Dictionary(k, _) => match k.as_ref() { + $crate::datatypes::DataType::Int8 => { + let $values = $crate::array::as_dictionary_array::< + $crate::datatypes::Int8Type, + >($values); + $e + }, + $crate::datatypes::DataType::Int16 => { + let $values = $crate::array::as_dictionary_array::< + $crate::datatypes::Int16Type, + >($values); + $e + }, + $crate::datatypes::DataType::Int32 => { + let $values = $crate::array::as_dictionary_array::< + $crate::datatypes::Int32Type, + >($values); + $e + }, + $crate::datatypes::DataType::Int64 => { + let $values = $crate::array::as_dictionary_array::< + $crate::datatypes::Int64Type, + >($values); + $e + }, + $crate::datatypes::DataType::UInt8 => { + let $values = $crate::array::as_dictionary_array::< + $crate::datatypes::UInt8Type, + >($values); + $e + }, + $crate::datatypes::DataType::UInt16 => { + let $values = $crate::array::as_dictionary_array::< + $crate::datatypes::UInt16Type, + >($values); + $e + }, + $crate::datatypes::DataType::UInt32 => { + let $values = $crate::array::as_dictionary_array::< + $crate::datatypes::UInt32Type, + >($values); + $e + }, + $crate::datatypes::DataType::UInt64 => { + let $values = $crate::array::as_dictionary_array::< + $crate::datatypes::UInt64Type, + >($values); + $e + }, + k => unreachable!("unsupported dictionary key type: {}", k) + } + $($p => $fallback,)* + } + } +} + /// Force downcast of an [`Array`], such as an [`ArrayRef`] to /// [`DictionaryArray`], panic'ing on failure. /// diff --git a/arrow/src/compute/kernels/filter.rs b/arrow/src/compute/kernels/filter.rs index 81be3a1d172..12d527bcde8 100644 --- a/arrow/src/compute/kernels/filter.rs +++ b/arrow/src/compute/kernels/filter.rs @@ -22,8 +22,6 @@ use std::sync::Arc; use num::Zero; -use TimeUnit::*; - use crate::array::*; use crate::buffer::{buffer_bin_and, Buffer, MutableBuffer}; use crate::datatypes::*; @@ -31,6 +29,7 @@ use crate::error::{ArrowError, Result}; use crate::record_batch::RecordBatch; use crate::util::bit_iterator::{BitIndexIterator, BitSliceIterator}; use crate::util::bit_util; +use crate::{downcast_dict_array, downcast_primitive_array}; /// If the filter selects more than this fraction of rows, use /// [`SlicesIterator`] to copy ranges of values. Otherwise iterate @@ -40,27 +39,6 @@ use crate::util::bit_util; /// const FILTER_SLICES_SELECTIVITY_THRESHOLD: f64 = 0.8; -macro_rules! downcast_filter { - ($type: ty, $values: expr, $filter: expr) => {{ - let values = $values - .as_any() - .downcast_ref::>() - .expect("Unable to downcast to a primitive array"); - - Ok(Arc::new(filter_primitive::<$type>(&values, $filter))) - }}; -} - -macro_rules! downcast_dict_filter { - ($type: ty, $values: expr, $filter: expr) => {{ - let values = $values - .as_any() - .downcast_ref::>() - .expect("Unable to downcast to a dictionary array"); - Ok(Arc::new(filter_dict::<$type>(values, $filter))) - }}; -} - /// An iterator of `(usize, usize)` each representing an interval /// `[start, end)` whose slots of a [BooleanArray] are true. Each /// interval corresponds to a contiguous region of memory to be @@ -358,92 +336,12 @@ fn filter_array(values: &dyn Array, predicate: &FilterPredicate) -> Result Ok(new_empty_array(values.data_type())), IterationStrategy::All => Ok(make_array(values.data().slice(0, predicate.count))), // actually filter - _ => match values.data_type() { + _ => downcast_primitive_array! { + values => Ok(Arc::new(filter_primitive(values, predicate))), DataType::Boolean => { let values = values.as_any().downcast_ref::().unwrap(); Ok(Arc::new(filter_boolean(values, predicate))) } - DataType::Int8 => { - downcast_filter!(Int8Type, values, predicate) - } - DataType::Int16 => { - downcast_filter!(Int16Type, values, predicate) - } - DataType::Int32 => { - downcast_filter!(Int32Type, values, predicate) - } - DataType::Int64 => { - downcast_filter!(Int64Type, values, predicate) - } - DataType::UInt8 => { - downcast_filter!(UInt8Type, values, predicate) - } - DataType::UInt16 => { - downcast_filter!(UInt16Type, values, predicate) - } - DataType::UInt32 => { - downcast_filter!(UInt32Type, values, predicate) - } - DataType::UInt64 => { - downcast_filter!(UInt64Type, values, predicate) - } - DataType::Float32 => { - downcast_filter!(Float32Type, values, predicate) - } - DataType::Float64 => { - downcast_filter!(Float64Type, values, predicate) - } - DataType::Date32 => { - downcast_filter!(Date32Type, values, predicate) - } - DataType::Date64 => { - downcast_filter!(Date64Type, values, predicate) - } - DataType::Time32(Second) => { - downcast_filter!(Time32SecondType, values, predicate) - } - DataType::Time32(Millisecond) => { - downcast_filter!(Time32MillisecondType, values, predicate) - } - DataType::Time64(Microsecond) => { - downcast_filter!(Time64MicrosecondType, values, predicate) - } - DataType::Time64(Nanosecond) => { - downcast_filter!(Time64NanosecondType, values, predicate) - } - DataType::Timestamp(Second, _) => { - downcast_filter!(TimestampSecondType, values, predicate) - } - DataType::Timestamp(Millisecond, _) => { - downcast_filter!(TimestampMillisecondType, values, predicate) - } - DataType::Timestamp(Microsecond, _) => { - downcast_filter!(TimestampMicrosecondType, values, predicate) - } - DataType::Timestamp(Nanosecond, _) => { - downcast_filter!(TimestampNanosecondType, values, predicate) - } - DataType::Interval(IntervalUnit::YearMonth) => { - downcast_filter!(IntervalYearMonthType, values, predicate) - } - DataType::Interval(IntervalUnit::DayTime) => { - downcast_filter!(IntervalDayTimeType, values, predicate) - } - DataType::Interval(IntervalUnit::MonthDayNano) => { - downcast_filter!(IntervalMonthDayNanoType, values, predicate) - } - DataType::Duration(TimeUnit::Second) => { - downcast_filter!(DurationSecondType, values, predicate) - } - DataType::Duration(TimeUnit::Millisecond) => { - downcast_filter!(DurationMillisecondType, values, predicate) - } - DataType::Duration(TimeUnit::Microsecond) => { - downcast_filter!(DurationMicrosecondType, values, predicate) - } - DataType::Duration(TimeUnit::Nanosecond) => { - downcast_filter!(DurationNanosecondType, values, predicate) - } DataType::Utf8 => { let values = values .as_any() @@ -458,19 +356,10 @@ fn filter_array(values: &dyn Array, predicate: &FilterPredicate) -> Result(values, predicate))) } - DataType::Dictionary(key_type, _) => match key_type.as_ref() { - DataType::Int8 => downcast_dict_filter!(Int8Type, values, predicate), - DataType::Int16 => downcast_dict_filter!(Int16Type, values, predicate), - DataType::Int32 => downcast_dict_filter!(Int32Type, values, predicate), - DataType::Int64 => downcast_dict_filter!(Int64Type, values, predicate), - DataType::UInt8 => downcast_dict_filter!(UInt8Type, values, predicate), - DataType::UInt16 => downcast_dict_filter!(UInt16Type, values, predicate), - DataType::UInt32 => downcast_dict_filter!(UInt32Type, values, predicate), - DataType::UInt64 => downcast_dict_filter!(UInt64Type, values, predicate), - t => { - unimplemented!("Filter not supported for dictionary key type {:?}", t) - } - }, + DataType::Dictionary(_, _) => downcast_dict_array! { + values => Ok(Arc::new(filter_dict(values, predicate))), + t => unimplemented!("Filter not supported for dictionary type {:?}", t) + } _ => { // fallback to using MutableArrayData let mut mutable = MutableArrayData::new( diff --git a/arrow/src/compute/kernels/take.rs b/arrow/src/compute/kernels/take.rs index 3272c84549f..4a2f7a8bb88 100644 --- a/arrow/src/compute/kernels/take.rs +++ b/arrow/src/compute/kernels/take.rs @@ -26,30 +26,11 @@ use crate::compute::util::{ use crate::datatypes::*; use crate::error::{ArrowError, Result}; use crate::util::bit_util; -use crate::{array::*, buffer::buffer_bin_and}; +use crate::{ + array::*, buffer::buffer_bin_and, downcast_dict_array, downcast_primitive_array, +}; use num::{ToPrimitive, Zero}; -use TimeUnit::*; - -macro_rules! downcast_take { - ($type: ty, $values: expr, $indices: expr) => {{ - let values = $values - .as_any() - .downcast_ref::>() - .expect("Unable to downcast to a primitive array"); - Ok(Arc::new(take_primitive::<$type, _>(&values, $indices)?)) - }}; -} - -macro_rules! downcast_dict_take { - ($type: ty, $values: expr, $indices: expr) => {{ - let values = $values - .as_any() - .downcast_ref::>() - .expect("Unable to downcast to a dictionary array"); - Ok(Arc::new(take_dict::<$type, _>(values, $indices)?)) - }}; -} /// Take elements by index from [Array], creating a new [Array] from those indexes. /// @@ -141,7 +122,9 @@ where })? } } - match values.data_type() { + + downcast_primitive_array! { + values => Ok(Arc::new(take_primitive(values, indices)?)), DataType::Boolean => { let values = values.as_any().downcast_ref::().unwrap(); Ok(Arc::new(take_boolean(values, indices)?)) @@ -151,61 +134,6 @@ where values.as_any().downcast_ref::().unwrap(); Ok(Arc::new(take_decimal128(decimal_values, indices)?)) } - DataType::Int8 => downcast_take!(Int8Type, values, indices), - DataType::Int16 => downcast_take!(Int16Type, values, indices), - DataType::Int32 => downcast_take!(Int32Type, values, indices), - DataType::Int64 => downcast_take!(Int64Type, values, indices), - DataType::UInt8 => downcast_take!(UInt8Type, values, indices), - DataType::UInt16 => downcast_take!(UInt16Type, values, indices), - DataType::UInt32 => downcast_take!(UInt32Type, values, indices), - DataType::UInt64 => downcast_take!(UInt64Type, values, indices), - DataType::Float32 => downcast_take!(Float32Type, values, indices), - DataType::Float64 => downcast_take!(Float64Type, values, indices), - DataType::Date32 => downcast_take!(Date32Type, values, indices), - DataType::Date64 => downcast_take!(Date64Type, values, indices), - DataType::Time32(Second) => downcast_take!(Time32SecondType, values, indices), - DataType::Time32(Millisecond) => { - downcast_take!(Time32MillisecondType, values, indices) - } - DataType::Time64(Microsecond) => { - downcast_take!(Time64MicrosecondType, values, indices) - } - DataType::Time64(Nanosecond) => { - downcast_take!(Time64NanosecondType, values, indices) - } - DataType::Timestamp(Second, _) => { - downcast_take!(TimestampSecondType, values, indices) - } - DataType::Timestamp(Millisecond, _) => { - downcast_take!(TimestampMillisecondType, values, indices) - } - DataType::Timestamp(Microsecond, _) => { - downcast_take!(TimestampMicrosecondType, values, indices) - } - DataType::Timestamp(Nanosecond, _) => { - downcast_take!(TimestampNanosecondType, values, indices) - } - DataType::Interval(IntervalUnit::YearMonth) => { - downcast_take!(IntervalYearMonthType, values, indices) - } - DataType::Interval(IntervalUnit::DayTime) => { - downcast_take!(IntervalDayTimeType, values, indices) - } - DataType::Interval(IntervalUnit::MonthDayNano) => { - downcast_take!(IntervalMonthDayNanoType, values, indices) - } - DataType::Duration(TimeUnit::Second) => { - downcast_take!(DurationSecondType, values, indices) - } - DataType::Duration(TimeUnit::Millisecond) => { - downcast_take!(DurationMillisecondType, values, indices) - } - DataType::Duration(TimeUnit::Microsecond) => { - downcast_take!(DurationMicrosecondType, values, indices) - } - DataType::Duration(TimeUnit::Nanosecond) => { - downcast_take!(DurationNanosecondType, values, indices) - } DataType::Utf8 => { let values = values .as_any() @@ -271,17 +199,10 @@ where Ok(Arc::new(StructArray::from((fields, is_valid))) as ArrayRef) } - DataType::Dictionary(key_type, _) => match key_type.as_ref() { - DataType::Int8 => downcast_dict_take!(Int8Type, values, indices), - DataType::Int16 => downcast_dict_take!(Int16Type, values, indices), - DataType::Int32 => downcast_dict_take!(Int32Type, values, indices), - DataType::Int64 => downcast_dict_take!(Int64Type, values, indices), - DataType::UInt8 => downcast_dict_take!(UInt8Type, values, indices), - DataType::UInt16 => downcast_dict_take!(UInt16Type, values, indices), - DataType::UInt32 => downcast_dict_take!(UInt32Type, values, indices), - DataType::UInt64 => downcast_dict_take!(UInt64Type, values, indices), - t => unimplemented!("Take not supported for dictionary key type {:?}", t), - }, + DataType::Dictionary(_, _) => downcast_dict_array! { + values => Ok(Arc::new(take_dict(values, indices)?)), + t => unimplemented!("Take not supported for dictionary type {:?}", t) + } DataType::Binary => { let values = values .as_any() @@ -314,7 +235,7 @@ where Ok(new_null_array(&DataType::Null, indices.len())) } } - t => unimplemented!("Take not supported for data type {:?}", t), + t => unimplemented!("Take not supported for data type {:?}", t) } } From 370d712d0b7304765976951f9de31732b6a08b44 Mon Sep 17 00:00:00 2001 From: Raphael Taylor-Davies Date: Fri, 2 Sep 2022 14:35:08 +0100 Subject: [PATCH 2/3] Add Float16 support and trailing commas --- arrow/src/array/cast.rs | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/arrow/src/array/cast.rs b/arrow/src/array/cast.rs index bac97c23c9c..05b4a683a5c 100644 --- a/arrow/src/array/cast.rs +++ b/arrow/src/array/cast.rs @@ -48,11 +48,11 @@ use crate::datatypes::*; /// #[macro_export] macro_rules! downcast_primitive_array { - ($values:ident => $e:expr, $($p:pat => $fallback:expr)*) => { + ($values:ident => $e:expr, $($p:pat => $fallback:expr $(,)*)*) => { downcast_primitive_array!($values => {$e} $($p => $fallback)*) }; - ($values:ident => $e:block $($p:pat => $fallback:expr)*) => { + ($values:ident => $e:block $($p:pat => $fallback:expr $(,)*)*) => { match $values.data_type() { $crate::datatypes::DataType::Int8 => { let $values = $crate::array::as_primitive_array::< @@ -102,6 +102,12 @@ macro_rules! downcast_primitive_array { >($values); $e } + $crate::datatypes::DataType::Float16 => { + let $values = $crate::array::as_primitive_array::< + $crate::datatypes::Float16Type, + >($values); + $e + } $crate::datatypes::DataType::Float32 => { let $values = $crate::array::as_primitive_array::< $crate::datatypes::Float32Type, @@ -276,11 +282,11 @@ where /// ``` #[macro_export] macro_rules! downcast_dict_array { - ($values:ident => $e:expr, $($p:pat => $fallback:expr)*) => { + ($values:ident => $e:expr, $($p:pat => $fallback:expr $(,)*)*) => { downcast_dict_array!($values => {$e} $($p => $fallback)*) }; - ($values:ident => $e:block $($p:pat => $fallback:expr)*) => { + ($values:ident => $e:block $($p:pat => $fallback:expr $(,)*)*) => { match $values.data_type() { $crate::datatypes::DataType::Dictionary(k, _) => match k.as_ref() { $crate::datatypes::DataType::Int8 => { From b6992d88d01d915704d321346fe1ee9ff6b1fbdb Mon Sep 17 00:00:00 2001 From: Raphael Taylor-Davies Date: Fri, 2 Sep 2022 15:04:01 +0100 Subject: [PATCH 3/3] Review feedback --- arrow/src/array/cast.rs | 24 ++++++++++++++++-------- arrow/src/compute/kernels/filter.rs | 4 ++-- arrow/src/compute/kernels/take.rs | 4 ++-- 3 files changed, 20 insertions(+), 12 deletions(-) diff --git a/arrow/src/array/cast.rs b/arrow/src/array/cast.rs index 05b4a683a5c..2b68cbbe642 100644 --- a/arrow/src/array/cast.rs +++ b/arrow/src/array/cast.rs @@ -264,15 +264,23 @@ where /// a number of subsequent patterns to match the data type /// /// ``` -/// # use arrow::downcast_dict_array; -/// # use arrow::array::Array; +/// # use arrow::downcast_dictionary_array; +/// # use arrow::array::{Array, StringArray}; /// # use arrow::datatypes::DataType; /// # use arrow::array::as_string_array; /// -/// fn print_keys(array: &dyn Array) { -/// downcast_dict_array!( -/// array => { -/// for v in array.keys() { +/// fn print_strings(array: &dyn Array) { +/// downcast_dictionary_array!( +/// array => match array.values().data_type() { +/// DataType::Utf8 => { +/// for v in array.downcast_dict::().unwrap() { +/// println!("{:?}", v); +/// } +/// } +/// t => println!("Unsupported dictionary value type {}", t), +/// }, +/// DataType::Utf8 => { +/// for v in as_string_array(array) { /// println!("{:?}", v); /// } /// } @@ -281,9 +289,9 @@ where /// } /// ``` #[macro_export] -macro_rules! downcast_dict_array { +macro_rules! downcast_dictionary_array { ($values:ident => $e:expr, $($p:pat => $fallback:expr $(,)*)*) => { - downcast_dict_array!($values => {$e} $($p => $fallback)*) + downcast_dictionary_array!($values => {$e} $($p => $fallback)*) }; ($values:ident => $e:block $($p:pat => $fallback:expr $(,)*)*) => { diff --git a/arrow/src/compute/kernels/filter.rs b/arrow/src/compute/kernels/filter.rs index 12d527bcde8..52664a17544 100644 --- a/arrow/src/compute/kernels/filter.rs +++ b/arrow/src/compute/kernels/filter.rs @@ -29,7 +29,7 @@ use crate::error::{ArrowError, Result}; use crate::record_batch::RecordBatch; use crate::util::bit_iterator::{BitIndexIterator, BitSliceIterator}; use crate::util::bit_util; -use crate::{downcast_dict_array, downcast_primitive_array}; +use crate::{downcast_dictionary_array, downcast_primitive_array}; /// If the filter selects more than this fraction of rows, use /// [`SlicesIterator`] to copy ranges of values. Otherwise iterate @@ -356,7 +356,7 @@ fn filter_array(values: &dyn Array, predicate: &FilterPredicate) -> Result(values, predicate))) } - DataType::Dictionary(_, _) => downcast_dict_array! { + DataType::Dictionary(_, _) => downcast_dictionary_array! { values => Ok(Arc::new(filter_dict(values, predicate))), t => unimplemented!("Filter not supported for dictionary type {:?}", t) } diff --git a/arrow/src/compute/kernels/take.rs b/arrow/src/compute/kernels/take.rs index 4a2f7a8bb88..19eb1b17ca2 100644 --- a/arrow/src/compute/kernels/take.rs +++ b/arrow/src/compute/kernels/take.rs @@ -27,7 +27,7 @@ use crate::datatypes::*; use crate::error::{ArrowError, Result}; use crate::util::bit_util; use crate::{ - array::*, buffer::buffer_bin_and, downcast_dict_array, downcast_primitive_array, + array::*, buffer::buffer_bin_and, downcast_dictionary_array, downcast_primitive_array, }; use num::{ToPrimitive, Zero}; @@ -199,7 +199,7 @@ where Ok(Arc::new(StructArray::from((fields, is_valid))) as ArrayRef) } - DataType::Dictionary(_, _) => downcast_dict_array! { + DataType::Dictionary(_, _) => downcast_dictionary_array! { values => Ok(Arc::new(take_dict(values, indices)?)), t => unimplemented!("Take not supported for dictionary type {:?}", t) }