Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add sum_dyn to calculate sum for dictionary array #2566

Merged
merged 2 commits into from Aug 24, 2022
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
53 changes: 49 additions & 4 deletions arrow/src/compute/kernels/aggregate.rs
Expand Up @@ -21,10 +21,10 @@ use multiversion::multiversion;
use std::ops::Add;

use crate::array::{
Array, BooleanArray, GenericBinaryArray, GenericStringArray, OffsetSizeTrait,
PrimitiveArray,
as_primitive_array, Array, ArrayAccessor, ArrayIter, BooleanArray,
GenericBinaryArray, GenericStringArray, OffsetSizeTrait, PrimitiveArray,
};
use crate::datatypes::{ArrowNativeType, ArrowNumericType};
use crate::datatypes::{ArrowNativeType, ArrowNumericType, DataType};

/// Generic test for NaN, the optimizer should be able to remove this for integer types.
#[inline]
Expand Down Expand Up @@ -185,6 +185,37 @@ pub fn min_string<T: OffsetSizeTrait>(array: &GenericStringArray<T>) -> Option<&
}

/// Returns the sum of values in the array.
pub fn sum_dyn<T, A: ArrayAccessor<Item = T::Native>>(array: A) -> Option<T::Native>
where
T: ArrowNumericType,
T::Native: Add<Output = T::Native>,
{
match array.data_type() {
DataType::Dictionary(_, _) => {
let null_count = array.null_count();

if null_count == array.len() {
return None;
}

let iter = ArrayIter::new(array);
let sum = iter
.into_iter()
.fold(T::default_value(), |accumulator, value| {
if let Some(value) = value {
accumulator + value
} else {
accumulator
}
});

Some(sum)
}
_ => sum::<T>(as_primitive_array(&array)),
}
}

/// Returns the sum of values in the primitive array.
///
/// Returns `None` if the array is empty or only contains null values.
#[cfg(not(feature = "simd"))]
Expand Down Expand Up @@ -583,7 +614,7 @@ mod simd {
}
}

/// Returns the sum of values in the array.
/// Returns the sum of values in the primitive array.
///
/// Returns `None` if the array is empty or only contains null values.
#[cfg(feature = "simd")]
Expand Down Expand Up @@ -625,6 +656,7 @@ mod tests {
use super::*;
use crate::array::*;
use crate::compute::add;
use crate::datatypes::{Int32Type, Int8Type};

#[test]
fn test_primitive_array_sum() {
Expand Down Expand Up @@ -1003,4 +1035,17 @@ mod tests {
assert_eq!(Some(true), min_boolean(&a));
assert_eq!(Some(true), max_boolean(&a));
}

#[test]
fn test_sum_dyn() {
let values = Int8Array::from_iter_values([10_i8, 11, 12, 13, 14, 15, 16, 17]);
viirya marked this conversation as resolved.
Show resolved Hide resolved
let keys = Int8Array::from_iter_values([2_i8, 3, 4]);

let dict_array = DictionaryArray::try_new(&keys, &values).unwrap();
let array = dict_array.downcast_dict::<Int8Array>().unwrap();
assert_eq!(39, sum_dyn::<Int8Type, _>(array).unwrap());

let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
assert_eq!(15, sum_dyn::<Int32Type, _>(&a).unwrap());
}
}