From 45fb919928e1903be0eb4de3af1966c56c6f6c71 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 24 Aug 2022 14:42:08 -0700 Subject: [PATCH] Add sum_dyn to calculate sum for dictionary array (#2566) * Add sum_dyn * Add null values test case --- arrow/src/compute/kernels/aggregate.rs | 63 ++++++++++++++++++++++++-- 1 file changed, 59 insertions(+), 4 deletions(-) diff --git a/arrow/src/compute/kernels/aggregate.rs b/arrow/src/compute/kernels/aggregate.rs index 12ead669f79..c8d0443c470 100644 --- a/arrow/src/compute/kernels/aggregate.rs +++ b/arrow/src/compute/kernels/aggregate.rs @@ -21,10 +21,10 @@ use multiversion::multiversion; use std::ops::Add; use crate::array::{ - Array, BooleanArray, GenericBinaryArray, GenericStringArray, OffsetSizeTrait, - PrimitiveArray, + as_primitive_array, Array, ArrayAccessor, ArrayIter, BooleanArray, + GenericBinaryArray, GenericStringArray, OffsetSizeTrait, PrimitiveArray, }; -use crate::datatypes::{ArrowNativeType, ArrowNumericType}; +use crate::datatypes::{ArrowNativeType, ArrowNumericType, DataType}; /// Generic test for NaN, the optimizer should be able to remove this for integer types. #[inline] @@ -185,6 +185,37 @@ pub fn min_string(array: &GenericStringArray) -> Option<& } /// Returns the sum of values in the array. +pub fn sum_dyn>(array: A) -> Option +where + T: ArrowNumericType, + T::Native: Add, +{ + match array.data_type() { + DataType::Dictionary(_, _) => { + let null_count = array.null_count(); + + if null_count == array.len() { + return None; + } + + let iter = ArrayIter::new(array); + let sum = iter + .into_iter() + .fold(T::default_value(), |accumulator, value| { + if let Some(value) = value { + accumulator + value + } else { + accumulator + } + }); + + Some(sum) + } + _ => sum::(as_primitive_array(&array)), + } +} + +/// Returns the sum of values in the primitive array. /// /// Returns `None` if the array is empty or only contains null values. #[cfg(not(feature = "simd"))] @@ -583,7 +614,7 @@ mod simd { } } -/// Returns the sum of values in the array. +/// Returns the sum of values in the primitive array. /// /// Returns `None` if the array is empty or only contains null values. #[cfg(feature = "simd")] @@ -625,6 +656,7 @@ mod tests { use super::*; use crate::array::*; use crate::compute::add; + use crate::datatypes::{Int32Type, Int8Type}; #[test] fn test_primitive_array_sum() { @@ -1003,4 +1035,27 @@ mod tests { assert_eq!(Some(true), min_boolean(&a)); assert_eq!(Some(true), max_boolean(&a)); } + + #[test] + fn test_sum_dyn() { + let values = Int8Array::from_iter_values([10_i8, 11, 12, 13, 14, 15, 16, 17]); + let keys = Int8Array::from_iter_values([2_i8, 3, 4]); + + let dict_array = DictionaryArray::try_new(&keys, &values).unwrap(); + let array = dict_array.downcast_dict::().unwrap(); + assert_eq!(39, sum_dyn::(array).unwrap()); + + let a = Int32Array::from(vec![1, 2, 3, 4, 5]); + assert_eq!(15, sum_dyn::(&a).unwrap()); + + let keys = Int8Array::from(vec![Some(2_i8), None, Some(4)]); + let dict_array = DictionaryArray::try_new(&keys, &values).unwrap(); + let array = dict_array.downcast_dict::().unwrap(); + assert_eq!(26, sum_dyn::(array).unwrap()); + + let keys = Int8Array::from(vec![None, None, None]); + let dict_array = DictionaryArray::try_new(&keys, &values).unwrap(); + let array = dict_array.downcast_dict::().unwrap(); + assert!(sum_dyn::(array).is_none()); + } }