From 62053a801e0a8e6b22778314c19a37929e96b76a Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 6 Jul 2022 13:28:10 -0700 Subject: [PATCH] Support DictionaryArray in unary kernel (#1990) * Init * More * Fix clippy * Apply on dictionary values directly in unary_dict. * Fix clippy * Avoid validate when constructing new dictionary array --- arrow/src/compute/kernels/arity.rs | 181 ++++++++++++++++++++++++++++- 1 file changed, 177 insertions(+), 4 deletions(-) diff --git a/arrow/src/compute/kernels/arity.rs b/arrow/src/compute/kernels/arity.rs index 60a0cb77fe2..5135218168f 100644 --- a/arrow/src/compute/kernels/arity.rs +++ b/arrow/src/compute/kernels/arity.rs @@ -17,9 +17,14 @@ //! Defines kernels suitable to perform operations to primitive arrays. -use crate::array::{Array, ArrayData, PrimitiveArray}; +use crate::array::{Array, ArrayData, ArrayRef, DictionaryArray, PrimitiveArray}; use crate::buffer::Buffer; -use crate::datatypes::ArrowPrimitiveType; +use crate::datatypes::{ + ArrowNumericType, ArrowPrimitiveType, DataType, Int16Type, Int32Type, Int64Type, + Int8Type, UInt16Type, UInt32Type, UInt64Type, UInt8Type, +}; +use crate::error::{ArrowError, Result}; +use std::sync::Arc; #[inline] fn into_primitive_array_data( @@ -78,10 +83,128 @@ where PrimitiveArray::::from(data) } +/// A helper function that applies an unary function to a dictionary array with primitive value type. +#[allow(clippy::redundant_closure)] +fn unary_dict(array: &DictionaryArray, op: F) -> Result +where + K: ArrowNumericType, + T: ArrowPrimitiveType, + F: Fn(T::Native) -> T::Native, +{ + let dict_values = array + .values() + .as_any() + .downcast_ref::>() + .unwrap(); + + let values = dict_values + .iter() + .map(|v| v.map(|value| op(value))) + .collect::>(); + + let keys = array.keys(); + + let mut data = ArrayData::builder(array.data_type().clone()) + .len(keys.len()) + .add_buffer(keys.data().buffers()[0].clone()) + .add_child_data(values.data().clone()); + + match keys.data().null_buffer() { + Some(buffer) if keys.data().null_count() > 0 => { + data = data + .null_bit_buffer(Some(buffer.clone())) + .null_count(keys.data().null_count()); + } + _ => data = data.null_count(0), + } + + let new_dict: DictionaryArray = unsafe { data.build_unchecked() }.into(); + Ok(Arc::new(new_dict)) +} + +/// Applies an unary function to an array with primitive values. +pub fn unary_dyn(array: &dyn Array, op: F) -> Result +where + T: ArrowPrimitiveType, + F: Fn(T::Native) -> T::Native, +{ + match array.data_type() { + DataType::Dictionary(key_type, _) => match key_type.as_ref() { + DataType::Int8 => unary_dict::<_, F, T>( + array + .as_any() + .downcast_ref::>() + .unwrap(), + op, + ), + DataType::Int16 => unary_dict::<_, F, T>( + array + .as_any() + .downcast_ref::>() + .unwrap(), + op, + ), + DataType::Int32 => unary_dict::<_, F, T>( + array + .as_any() + .downcast_ref::>() + .unwrap(), + op, + ), + DataType::Int64 => unary_dict::<_, F, T>( + array + .as_any() + .downcast_ref::>() + .unwrap(), + op, + ), + DataType::UInt8 => unary_dict::<_, F, T>( + array + .as_any() + .downcast_ref::>() + .unwrap(), + op, + ), + DataType::UInt16 => unary_dict::<_, F, T>( + array + .as_any() + .downcast_ref::>() + .unwrap(), + op, + ), + DataType::UInt32 => unary_dict::<_, F, T>( + array + .as_any() + .downcast_ref::>() + .unwrap(), + op, + ), + DataType::UInt64 => unary_dict::<_, F, T>( + array + .as_any() + .downcast_ref::>() + .unwrap(), + op, + ), + t => Err(ArrowError::NotYetImplemented(format!( + "Cannot perform unary operation on dictionary array of key type {}.", + t + ))), + }, + _ => Ok(Arc::new(unary::( + array.as_any().downcast_ref::>().unwrap(), + op, + ))), + } +} + #[cfg(test)] mod tests { use super::*; - use crate::array::{as_primitive_array, Float64Array}; + use crate::array::{ + as_primitive_array, Float64Array, PrimitiveBuilder, PrimitiveDictionaryBuilder, + }; + use crate::datatypes::{Float64Type, Int32Type, Int8Type}; #[test] fn test_unary_f64_slice() { @@ -93,6 +216,56 @@ mod tests { assert_eq!( result, Float64Array::from(vec![None, Some(7.0), None, Some(7.0)]) - ) + ); + + let result = unary_dyn::<_, Float64Type>(input_slice, |n| n + 1.0).unwrap(); + + assert_eq!( + result.as_any().downcast_ref::().unwrap(), + &Float64Array::from(vec![None, Some(7.8), None, Some(8.2)]) + ); + } + + #[test] + fn test_unary_dict_and_unary_dyn() { + let key_builder = PrimitiveBuilder::::new(3); + let value_builder = PrimitiveBuilder::::new(2); + let mut builder = PrimitiveDictionaryBuilder::new(key_builder, value_builder); + builder.append(5).unwrap(); + builder.append(6).unwrap(); + builder.append(7).unwrap(); + builder.append(8).unwrap(); + builder.append_null().unwrap(); + builder.append(9).unwrap(); + let dictionary_array = builder.finish(); + + let key_builder = PrimitiveBuilder::::new(3); + let value_builder = PrimitiveBuilder::::new(2); + let mut builder = PrimitiveDictionaryBuilder::new(key_builder, value_builder); + builder.append(6).unwrap(); + builder.append(7).unwrap(); + builder.append(8).unwrap(); + builder.append(9).unwrap(); + builder.append_null().unwrap(); + builder.append(10).unwrap(); + let expected = builder.finish(); + + let result = unary_dict::<_, _, Int32Type>(&dictionary_array, |n| n + 1).unwrap(); + assert_eq!( + result + .as_any() + .downcast_ref::>() + .unwrap(), + &expected + ); + + let result = unary_dyn::<_, Int32Type>(&dictionary_array, |n| n + 1).unwrap(); + assert_eq!( + result + .as_any() + .downcast_ref::>() + .unwrap(), + &expected + ); } }