From f85ed977cd6c33a7616986f086e4b18e2ad0fbe4 Mon Sep 17 00:00:00 2001 From: Raphael Taylor-Davies Date: Wed, 28 Sep 2022 12:20:12 +0100 Subject: [PATCH] Add DictionaryArray::with_values (#2797) --- arrow-array/src/array/dictionary_array.rs | 54 +++++++++++++++++++++++ arrow/src/compute/kernels/arity.rs | 14 ++---- 2 files changed, 58 insertions(+), 10 deletions(-) diff --git a/arrow-array/src/array/dictionary_array.rs b/arrow-array/src/array/dictionary_array.rs index 35d243fde9a..557ab65c40a 100644 --- a/arrow-array/src/array/dictionary_array.rs +++ b/arrow-array/src/array/dictionary_array.rs @@ -344,6 +344,60 @@ impl DictionaryArray { values, }) } + + /// Returns a new dictionary with the same keys as the current instance + /// but with a different set of dictionary values + /// + /// This can be used to perform an operation on the values of a dictionary + /// + /// # Panics + /// + /// Panics if `values` has a length less than the current values + /// + /// ``` + /// use arrow_array::builder::PrimitiveDictionaryBuilder; + /// use arrow_array::{Int8Array, Int64Array, ArrayAccessor}; + /// use arrow_array::types::{Int32Type, Int8Type}; + /// + /// // Construct a Dict(Int32, Int8) + /// let mut builder = PrimitiveDictionaryBuilder::::with_capacity(2, 200); + /// for i in 0..100 { + /// builder.append(i % 2).unwrap(); + /// } + /// + /// let dictionary = builder.finish(); + /// + /// // Perform a widening cast of dictionary values + /// let typed_dictionary = dictionary.downcast_dict::().unwrap(); + /// let values: Int64Array = typed_dictionary.values().unary(|x| x as i64); + /// + /// // Create a Dict(Int32, + /// let new = dictionary.with_values(&values); + /// + /// // Verify values are as expected + /// let new_typed = new.downcast_dict::().unwrap(); + /// for i in 0..100 { + /// assert_eq!(new_typed.value(i), (i % 2) as i64) + /// } + /// ``` + /// + pub fn with_values(&self, values: &dyn Array) -> Self { + assert!(values.len() >= self.values.len()); + + let builder = self + .data + .clone() + .into_builder() + .data_type(DataType::Dictionary( + Box::new(K::DATA_TYPE), + Box::new(values.data_type().clone()), + )) + .child_data(vec![values.data().clone()]); + + // SAFETY: + // Offsets were valid before and verified length is greater than or equal + Self::from(unsafe { builder.build_unchecked() }) + } } /// Constructs a `DictionaryArray` from an array data reference. diff --git a/arrow/src/compute/kernels/arity.rs b/arrow/src/compute/kernels/arity.rs index cb5184c0e9d..11ae5a204c5 100644 --- a/arrow/src/compute/kernels/arity.rs +++ b/arrow/src/compute/kernels/arity.rs @@ -76,11 +76,8 @@ where F: Fn(T::Native) -> T::Native, { let dict_values = array.values().as_any().downcast_ref().unwrap(); - let values = unary::(dict_values, op).into_data(); - let data = array.data().clone().into_builder().child_data(vec![values]); - - let new_dict: DictionaryArray = unsafe { data.build_unchecked() }.into(); - Ok(Arc::new(new_dict)) + let values = unary::(dict_values, op); + Ok(Arc::new(array.with_values(&values))) } /// A helper function that applies a fallible unary function to a dictionary array with primitive value type. @@ -98,11 +95,8 @@ where } let dict_values = array.values().as_any().downcast_ref().unwrap(); - let values = try_unary::(dict_values, op)?.into_data(); - let data = array.data().clone().into_builder().child_data(vec![values]); - - let new_dict: DictionaryArray = unsafe { data.build_unchecked() }.into(); - Ok(Arc::new(new_dict)) + let values = try_unary::(dict_values, op)?; + Ok(Arc::new(array.with_values(&values))) } /// Applies an infallible unary function to an array with primitive values.