Skip to content

Commit

Permalink
Add sum_dyn to calculate sum for dictionary array (#2566)
Browse files Browse the repository at this point in the history
* Add sum_dyn

* Add null values test case
  • Loading branch information
viirya committed Aug 24, 2022
1 parent b34adcc commit 45fb919
Showing 1 changed file with 59 additions and 4 deletions.
63 changes: 59 additions & 4 deletions arrow/src/compute/kernels/aggregate.rs
Expand Up @@ -21,10 +21,10 @@ use multiversion::multiversion;
use std::ops::Add;

use crate::array::{
Array, BooleanArray, GenericBinaryArray, GenericStringArray, OffsetSizeTrait,
PrimitiveArray,
as_primitive_array, Array, ArrayAccessor, ArrayIter, BooleanArray,
GenericBinaryArray, GenericStringArray, OffsetSizeTrait, PrimitiveArray,
};
use crate::datatypes::{ArrowNativeType, ArrowNumericType};
use crate::datatypes::{ArrowNativeType, ArrowNumericType, DataType};

/// Generic test for NaN, the optimizer should be able to remove this for integer types.
#[inline]
Expand Down Expand Up @@ -185,6 +185,37 @@ pub fn min_string<T: OffsetSizeTrait>(array: &GenericStringArray<T>) -> Option<&
}

/// Returns the sum of values in the array.
pub fn sum_dyn<T, A: ArrayAccessor<Item = T::Native>>(array: A) -> Option<T::Native>
where
T: ArrowNumericType,
T::Native: Add<Output = T::Native>,
{
match array.data_type() {
DataType::Dictionary(_, _) => {
let null_count = array.null_count();

if null_count == array.len() {
return None;
}

let iter = ArrayIter::new(array);
let sum = iter
.into_iter()
.fold(T::default_value(), |accumulator, value| {
if let Some(value) = value {
accumulator + value
} else {
accumulator
}
});

Some(sum)
}
_ => sum::<T>(as_primitive_array(&array)),
}
}

/// Returns the sum of values in the primitive array.
///
/// Returns `None` if the array is empty or only contains null values.
#[cfg(not(feature = "simd"))]
Expand Down Expand Up @@ -583,7 +614,7 @@ mod simd {
}
}

/// Returns the sum of values in the array.
/// Returns the sum of values in the primitive array.
///
/// Returns `None` if the array is empty or only contains null values.
#[cfg(feature = "simd")]
Expand Down Expand Up @@ -625,6 +656,7 @@ mod tests {
use super::*;
use crate::array::*;
use crate::compute::add;
use crate::datatypes::{Int32Type, Int8Type};

#[test]
fn test_primitive_array_sum() {
Expand Down Expand Up @@ -1003,4 +1035,27 @@ mod tests {
assert_eq!(Some(true), min_boolean(&a));
assert_eq!(Some(true), max_boolean(&a));
}

#[test]
fn test_sum_dyn() {
let values = Int8Array::from_iter_values([10_i8, 11, 12, 13, 14, 15, 16, 17]);
let keys = Int8Array::from_iter_values([2_i8, 3, 4]);

let dict_array = DictionaryArray::try_new(&keys, &values).unwrap();
let array = dict_array.downcast_dict::<Int8Array>().unwrap();
assert_eq!(39, sum_dyn::<Int8Type, _>(array).unwrap());

let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
assert_eq!(15, sum_dyn::<Int32Type, _>(&a).unwrap());

let keys = Int8Array::from(vec![Some(2_i8), None, Some(4)]);
let dict_array = DictionaryArray::try_new(&keys, &values).unwrap();
let array = dict_array.downcast_dict::<Int8Array>().unwrap();
assert_eq!(26, sum_dyn::<Int8Type, _>(array).unwrap());

let keys = Int8Array::from(vec![None, None, None]);
let dict_array = DictionaryArray::try_new(&keys, &values).unwrap();
let array = dict_array.downcast_dict::<Int8Array>().unwrap();
assert!(sum_dyn::<Int8Type, _>(array).is_none());
}
}

0 comments on commit 45fb919

Please sign in to comment.