Skip to content

Commit

Permalink
Add typed dictionary (#2136) (#2297)
Browse files Browse the repository at this point in the history
* Add typed dictionary (#2136)

* Review feedback
  • Loading branch information
tustvold committed Aug 5, 2022
1 parent b6eaf22 commit 4a3919b
Show file tree
Hide file tree
Showing 2 changed files with 120 additions and 7 deletions.
125 changes: 119 additions & 6 deletions arrow/src/array/array_dictionary.rs
Expand Up @@ -15,6 +15,7 @@
// specific language governing permissions and limitations
// under the License.

use crate::array::{ArrayAccessor, ArrayIter};
use std::any::Any;
use std::fmt;
use std::iter::IntoIterator;
Expand Down Expand Up @@ -234,6 +235,28 @@ impl<K: ArrowPrimitiveType> DictionaryArray<K> {
.expect("Dictionary index not usize")
})
}

/// Downcast this dictionary to a [`TypedDictionaryArray`]
///
/// ```
/// use arrow::array::{Array, ArrayAccessor, DictionaryArray, StringArray};
/// use arrow::datatypes::Int32Type;
///
/// let orig = [Some("a"), Some("b"), None];
/// let dictionary = DictionaryArray::<Int32Type>::from_iter(orig);
/// let typed = dictionary.downcast_dict::<StringArray>().unwrap();
/// assert_eq!(typed.value(0), "a");
/// assert_eq!(typed.value(1), "b");
/// assert!(typed.is_null(2));
/// ```
///
pub fn downcast_dict<V: 'static>(&self) -> Option<TypedDictionaryArray<'_, K, V>> {
let values = self.values.as_any().downcast_ref()?;
Some(TypedDictionaryArray {
dictionary: self,
values,
})
}
}

/// Constructs a `DictionaryArray` from an array data reference.
Expand Down Expand Up @@ -302,9 +325,7 @@ impl<T: ArrowPrimitiveType> From<DictionaryArray<T>> for ArrayData {
/// format!("{:?}", array)
/// );
/// ```
impl<'a, T: ArrowPrimitiveType + ArrowDictionaryKeyType> FromIterator<Option<&'a str>>
for DictionaryArray<T>
{
impl<'a, T: ArrowDictionaryKeyType> FromIterator<Option<&'a str>> for DictionaryArray<T> {
fn from_iter<I: IntoIterator<Item = Option<&'a str>>>(iter: I) -> Self {
let it = iter.into_iter();
let (lower, _) = it.size_hint();
Expand Down Expand Up @@ -342,9 +363,7 @@ impl<'a, T: ArrowPrimitiveType + ArrowDictionaryKeyType> FromIterator<Option<&'a
/// format!("{:?}", array)
/// );
/// ```
impl<'a, T: ArrowPrimitiveType + ArrowDictionaryKeyType> FromIterator<&'a str>
for DictionaryArray<T>
{
impl<'a, T: ArrowDictionaryKeyType> FromIterator<&'a str> for DictionaryArray<T> {
fn from_iter<I: IntoIterator<Item = &'a str>>(iter: I) -> Self {
let it = iter.into_iter();
let (lower, _) = it.size_hint();
Expand Down Expand Up @@ -385,6 +404,100 @@ impl<T: ArrowPrimitiveType> fmt::Debug for DictionaryArray<T> {
}
}

/// A strongly-typed wrapper around a [`DictionaryArray`] that implements [`ArrayAccessor`]
/// allowing fast access to its elements
///
/// ```
/// use arrow::array::{ArrayIter, DictionaryArray, StringArray};
/// use arrow::datatypes::Int32Type;
///
/// let orig = ["a", "b", "a", "b"];
/// let dictionary = DictionaryArray::<Int32Type>::from_iter(orig);
///
/// // `TypedDictionaryArray` allows you to access the values directly
/// let typed = dictionary.downcast_dict::<StringArray>().unwrap();
///
/// for (maybe_val, orig) in typed.into_iter().zip(orig) {
/// assert_eq!(maybe_val.unwrap(), orig)
/// }
/// ```
#[derive(Copy, Clone)]
pub struct TypedDictionaryArray<'a, K: ArrowPrimitiveType, V> {
/// The dictionary array
dictionary: &'a DictionaryArray<K>,
/// The values of the dictionary
values: &'a V,
}

impl<'a, K: ArrowPrimitiveType, V> fmt::Debug for TypedDictionaryArray<'a, K, V> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
writeln!(f, "TypedDictionaryArray({:?})", self.dictionary)
}
}

impl<'a, K: ArrowPrimitiveType, V> TypedDictionaryArray<'a, K, V> {
/// Returns the keys of this [`TypedDictionaryArray`]
pub fn keys(&self) -> &'a PrimitiveArray<K> {
self.dictionary.keys()
}

/// Returns the values of this [`TypedDictionaryArray`]
pub fn values(&self) -> &'a V {
self.values
}
}

impl<'a, K: ArrowPrimitiveType, V: Sync> Array for TypedDictionaryArray<'a, K, V> {
fn as_any(&self) -> &dyn Any {
self.dictionary
}

fn data(&self) -> &ArrayData {
&self.dictionary.data
}

fn into_data(self) -> ArrayData {
self.dictionary.into_data()
}
}

impl<'a, K, V> IntoIterator for TypedDictionaryArray<'a, K, V>
where
K: ArrowPrimitiveType,
V: Sync + Send,
&'a V: ArrayAccessor,
{
type Item = Option<<Self as ArrayAccessor>::Item>;
type IntoIter = ArrayIter<Self>;

fn into_iter(self) -> Self::IntoIter {
ArrayIter::new(self)
}
}

impl<'a, K, V> ArrayAccessor for TypedDictionaryArray<'a, K, V>
where
K: ArrowPrimitiveType,
V: Sync + Send,
&'a V: ArrayAccessor,
{
type Item = <&'a V as ArrayAccessor>::Item;

fn value(&self, index: usize) -> Self::Item {
assert!(self.dictionary.is_valid(index), "{}", index);
let value_idx = self.dictionary.keys.value(index).to_usize().unwrap();
// Dictionary indexes should be valid
unsafe { self.values.value_unchecked(value_idx) }
}

unsafe fn value_unchecked(&self, index: usize) -> Self::Item {
let val = self.dictionary.keys.value_unchecked(index);
let value_idx = val.to_usize().unwrap();
// Dictionary indexes should be valid
self.values.value_unchecked(value_idx)
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
2 changes: 1 addition & 1 deletion arrow/src/array/mod.rs
Expand Up @@ -208,7 +208,7 @@ pub use self::array_fixed_size_list::FixedSizeListArray;
#[deprecated(note = "Please use `Decimal128Array` instead")]
pub type DecimalArray = Decimal128Array;

pub use self::array_dictionary::DictionaryArray;
pub use self::array_dictionary::{DictionaryArray, TypedDictionaryArray};
pub use self::array_list::LargeListArray;
pub use self::array_list::ListArray;
pub use self::array_map::MapArray;
Expand Down

0 comments on commit 4a3919b

Please sign in to comment.