Skip to content

Commit

Permalink
Add DictionaryArray::with_values (apache#2797)
Browse files Browse the repository at this point in the history
  • Loading branch information
tustvold committed Sep 28, 2022
1 parent 7639f28 commit f85ed97
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 10 deletions.
54 changes: 54 additions & 0 deletions arrow-array/src/array/dictionary_array.rs
Expand Up @@ -344,6 +344,60 @@ impl<K: ArrowPrimitiveType> DictionaryArray<K> {
values,
})
}

/// Returns a new dictionary with the same keys as the current instance
/// but with a different set of dictionary values
///
/// This can be used to perform an operation on the values of a dictionary
///
/// # Panics
///
/// Panics if `values` has a length less than the current values
///
/// ```
/// use arrow_array::builder::PrimitiveDictionaryBuilder;
/// use arrow_array::{Int8Array, Int64Array, ArrayAccessor};
/// use arrow_array::types::{Int32Type, Int8Type};
///
/// // Construct a Dict(Int32, Int8)
/// let mut builder = PrimitiveDictionaryBuilder::<Int32Type, Int8Type>::with_capacity(2, 200);
/// for i in 0..100 {
/// builder.append(i % 2).unwrap();
/// }
///
/// let dictionary = builder.finish();
///
/// // Perform a widening cast of dictionary values
/// let typed_dictionary = dictionary.downcast_dict::<Int8Array>().unwrap();
/// let values: Int64Array = typed_dictionary.values().unary(|x| x as i64);
///
/// // Create a Dict(Int32,
/// let new = dictionary.with_values(&values);
///
/// // Verify values are as expected
/// let new_typed = new.downcast_dict::<Int64Array>().unwrap();
/// for i in 0..100 {
/// assert_eq!(new_typed.value(i), (i % 2) as i64)
/// }
/// ```
///
pub fn with_values(&self, values: &dyn Array) -> Self {
assert!(values.len() >= self.values.len());

let builder = self
.data
.clone()
.into_builder()
.data_type(DataType::Dictionary(
Box::new(K::DATA_TYPE),
Box::new(values.data_type().clone()),
))
.child_data(vec![values.data().clone()]);

// SAFETY:
// Offsets were valid before and verified length is greater than or equal
Self::from(unsafe { builder.build_unchecked() })
}
}

/// Constructs a `DictionaryArray` from an array data reference.
Expand Down
14 changes: 4 additions & 10 deletions arrow/src/compute/kernels/arity.rs
Expand Up @@ -76,11 +76,8 @@ where
F: Fn(T::Native) -> T::Native,
{
let dict_values = array.values().as_any().downcast_ref().unwrap();
let values = unary::<T, F, T>(dict_values, op).into_data();
let data = array.data().clone().into_builder().child_data(vec![values]);

let new_dict: DictionaryArray<K> = unsafe { data.build_unchecked() }.into();
Ok(Arc::new(new_dict))
let values = unary::<T, F, T>(dict_values, op);
Ok(Arc::new(array.with_values(&values)))
}

/// A helper function that applies a fallible unary function to a dictionary array with primitive value type.
Expand All @@ -98,11 +95,8 @@ where
}

let dict_values = array.values().as_any().downcast_ref().unwrap();
let values = try_unary::<T, F, T>(dict_values, op)?.into_data();
let data = array.data().clone().into_builder().child_data(vec![values]);

let new_dict: DictionaryArray<K> = unsafe { data.build_unchecked() }.into();
Ok(Arc::new(new_dict))
let values = try_unary::<T, F, T>(dict_values, op)?;
Ok(Arc::new(array.with_values(&values)))
}

/// Applies an infallible unary function to an array with primitive values.
Expand Down

0 comments on commit f85ed97

Please sign in to comment.