diff --git a/arrow/src/array/array_dictionary.rs b/arrow/src/array/array_dictionary.rs index 4f7d5f9c147..2afc7a69e0b 100644 --- a/arrow/src/array/array_dictionary.rs +++ b/arrow/src/array/array_dictionary.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +use crate::array::{ArrayAccessor, ArrayIter}; use std::any::Any; use std::fmt; use std::iter::IntoIterator; @@ -234,6 +235,28 @@ impl DictionaryArray { .expect("Dictionary index not usize") }) } + + /// Downcast this dictionary to a [`TypedDictionaryArray`] + /// + /// ``` + /// use arrow::array::{Array, ArrayAccessor, DictionaryArray, StringArray}; + /// use arrow::datatypes::Int32Type; + /// + /// let orig = [Some("a"), Some("b"), None]; + /// let dictionary = DictionaryArray::::from_iter(orig); + /// let typed = dictionary.downcast_dict::().unwrap(); + /// assert_eq!(typed.value(0), "a"); + /// assert_eq!(typed.value(1), "b"); + /// assert!(typed.is_null(2)); + /// ``` + /// + pub fn downcast_dict(&self) -> Option> { + let values = self.values.as_any().downcast_ref()?; + Some(TypedDictionaryArray { + dictionary: self, + values, + }) + } } /// Constructs a `DictionaryArray` from an array data reference. @@ -302,9 +325,7 @@ impl From> for ArrayData { /// format!("{:?}", array) /// ); /// ``` -impl<'a, T: ArrowPrimitiveType + ArrowDictionaryKeyType> FromIterator> - for DictionaryArray -{ +impl<'a, T: ArrowDictionaryKeyType> FromIterator> for DictionaryArray { fn from_iter>>(iter: I) -> Self { let it = iter.into_iter(); let (lower, _) = it.size_hint(); @@ -342,9 +363,7 @@ impl<'a, T: ArrowPrimitiveType + ArrowDictionaryKeyType> FromIterator FromIterator<&'a str> - for DictionaryArray -{ +impl<'a, T: ArrowDictionaryKeyType> FromIterator<&'a str> for DictionaryArray { fn from_iter>(iter: I) -> Self { let it = iter.into_iter(); let (lower, _) = it.size_hint(); @@ -385,6 +404,100 @@ impl fmt::Debug for DictionaryArray { } } +/// A strongly-typed wrapper around a [`DictionaryArray`] that implements [`ArrayAccessor`] +/// allowing fast access to its elements +/// +/// ``` +/// use arrow::array::{ArrayIter, DictionaryArray, StringArray}; +/// use arrow::datatypes::Int32Type; +/// +/// let orig = ["a", "b", "a", "b"]; +/// let dictionary = DictionaryArray::::from_iter(orig); +/// +/// // `TypedDictionaryArray` allows you to access the values directly +/// let typed = dictionary.downcast_dict::().unwrap(); +/// +/// for (maybe_val, orig) in typed.into_iter().zip(orig) { +/// assert_eq!(maybe_val.unwrap(), orig) +/// } +/// ``` +#[derive(Copy, Clone)] +pub struct TypedDictionaryArray<'a, K: ArrowPrimitiveType, V> { + /// The dictionary array + dictionary: &'a DictionaryArray, + /// The values of the dictionary + values: &'a V, +} + +impl<'a, K: ArrowPrimitiveType, V> fmt::Debug for TypedDictionaryArray<'a, K, V> { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + writeln!(f, "TypedDictionaryArray({:?})", self.dictionary) + } +} + +impl<'a, K: ArrowPrimitiveType, V> TypedDictionaryArray<'a, K, V> { + /// Returns the keys of this [`TypedDictionaryArray`] + pub fn keys(&self) -> &'a PrimitiveArray { + self.dictionary.keys() + } + + /// Returns the values of this [`TypedDictionaryArray`] + pub fn values(&self) -> &'a V { + self.values + } +} + +impl<'a, K: ArrowPrimitiveType, V: Sync> Array for TypedDictionaryArray<'a, K, V> { + fn as_any(&self) -> &dyn Any { + self.dictionary + } + + fn data(&self) -> &ArrayData { + &self.dictionary.data + } + + fn into_data(self) -> ArrayData { + self.dictionary.into_data() + } +} + +impl<'a, K, V> IntoIterator for TypedDictionaryArray<'a, K, V> +where + K: ArrowPrimitiveType, + V: Sync + Send, + &'a V: ArrayAccessor, +{ + type Item = Option<::Item>; + type IntoIter = ArrayIter; + + fn into_iter(self) -> Self::IntoIter { + ArrayIter::new(self) + } +} + +impl<'a, K, V> ArrayAccessor for TypedDictionaryArray<'a, K, V> +where + K: ArrowPrimitiveType, + V: Sync + Send, + &'a V: ArrayAccessor, +{ + type Item = <&'a V as ArrayAccessor>::Item; + + fn value(&self, index: usize) -> Self::Item { + assert!(self.dictionary.is_valid(index), "{}", index); + let value_idx = self.dictionary.keys.value(index).to_usize().unwrap(); + // Dictionary indexes should be valid + unsafe { self.values.value_unchecked(value_idx) } + } + + unsafe fn value_unchecked(&self, index: usize) -> Self::Item { + let val = self.dictionary.keys.value_unchecked(index); + let value_idx = val.to_usize().unwrap(); + // Dictionary indexes should be valid + self.values.value_unchecked(value_idx) + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/arrow/src/array/mod.rs b/arrow/src/array/mod.rs index 3785af85aff..4a766774159 100644 --- a/arrow/src/array/mod.rs +++ b/arrow/src/array/mod.rs @@ -208,7 +208,7 @@ pub use self::array_fixed_size_list::FixedSizeListArray; #[deprecated(note = "Please use `Decimal128Array` instead")] pub type DecimalArray = Decimal128Array; -pub use self::array_dictionary::DictionaryArray; +pub use self::array_dictionary::{DictionaryArray, TypedDictionaryArray}; pub use self::array_list::LargeListArray; pub use self::array_list::ListArray; pub use self::array_map::MapArray;