Skip to content

Commit

Permalink
Add max_dyn and min_dyn for max/min for dictionary array (#2585)
Browse files Browse the repository at this point in the history
* Add max_dyn and min_dyn

* Add a helper function

* Add NaN handling and test

* Rename to min_array, max_array and sum_array

* Rename min_max_dyn_helper
  • Loading branch information
viirya committed Aug 25, 2022
1 parent 8eea918 commit c64ca4f
Showing 1 changed file with 112 additions and 6 deletions.
118 changes: 112 additions & 6 deletions arrow/src/compute/kernels/aggregate.rs
Expand Up @@ -185,7 +185,7 @@ pub fn min_string<T: OffsetSizeTrait>(array: &GenericStringArray<T>) -> Option<&
}

/// Returns the sum of values in the array.
pub fn sum_dyn<T, A: ArrayAccessor<Item = T::Native>>(array: A) -> Option<T::Native>
pub fn sum_array<T, A: ArrayAccessor<Item = T::Native>>(array: A) -> Option<T::Native>
where
T: ArrowNumericType,
T::Native: Add<Output = T::Native>,
Expand Down Expand Up @@ -215,6 +215,68 @@ where
}
}

/// Returns the min of values in the array.
pub fn min_array<T, A: ArrayAccessor<Item = T::Native>>(array: A) -> Option<T::Native>
where
T: ArrowNumericType,
T::Native: ArrowNativeType,
{
min_max_array_helper::<T, A, _, _>(
array,
|a, b| (!is_nan(*a) & is_nan(*b)) || a < b,
min,
)
}

/// Returns the max of values in the array.
pub fn max_array<T, A: ArrayAccessor<Item = T::Native>>(array: A) -> Option<T::Native>
where
T: ArrowNumericType,
T::Native: ArrowNativeType,
{
min_max_array_helper::<T, A, _, _>(
array,
|a, b| (is_nan(*a) & !is_nan(*b)) || a > b,
max,
)
}

fn min_max_array_helper<T, A: ArrayAccessor<Item = T::Native>, F, M>(
array: A,
cmp: F,
m: M,
) -> Option<T::Native>
where
T: ArrowNumericType,
F: Fn(&T::Native, &T::Native) -> bool,
M: Fn(&PrimitiveArray<T>) -> Option<T::Native>,
{
match array.data_type() {
DataType::Dictionary(_, _) => {
let null_count = array.null_count();

if null_count == array.len() {
return None;
}

let mut has_value = false;
let mut n = T::default_value();
let iter = ArrayIter::new(array);
iter.into_iter().for_each(|value| {
if let Some(value) = value {
if !has_value || cmp(&value, &n) {
has_value = true;
n = value;
}
}
});

Some(n)
}
_ => m(as_primitive_array(&array)),
}
}

/// Returns the sum of values in the primitive array.
///
/// Returns `None` if the array is empty or only contains null values.
Expand Down Expand Up @@ -656,7 +718,7 @@ mod tests {
use super::*;
use crate::array::*;
use crate::compute::add;
use crate::datatypes::{Int32Type, Int8Type};
use crate::datatypes::{Float32Type, Int32Type, Int8Type};

#[test]
fn test_primitive_array_sum() {
Expand Down Expand Up @@ -1043,19 +1105,63 @@ mod tests {

let dict_array = DictionaryArray::try_new(&keys, &values).unwrap();
let array = dict_array.downcast_dict::<Int8Array>().unwrap();
assert_eq!(39, sum_dyn::<Int8Type, _>(array).unwrap());
assert_eq!(39, sum_array::<Int8Type, _>(array).unwrap());

let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
assert_eq!(15, sum_dyn::<Int32Type, _>(&a).unwrap());
assert_eq!(15, sum_array::<Int32Type, _>(&a).unwrap());

let keys = Int8Array::from(vec![Some(2_i8), None, Some(4)]);
let dict_array = DictionaryArray::try_new(&keys, &values).unwrap();
let array = dict_array.downcast_dict::<Int8Array>().unwrap();
assert_eq!(26, sum_dyn::<Int8Type, _>(array).unwrap());
assert_eq!(26, sum_array::<Int8Type, _>(array).unwrap());

let keys = Int8Array::from(vec![None, None, None]);
let dict_array = DictionaryArray::try_new(&keys, &values).unwrap();
let array = dict_array.downcast_dict::<Int8Array>().unwrap();
assert!(sum_array::<Int8Type, _>(array).is_none());
}

#[test]
fn test_max_min_dyn() {
let values = Int8Array::from_iter_values([10_i8, 11, 12, 13, 14, 15, 16, 17]);
let keys = Int8Array::from_iter_values([2_i8, 3, 4]);

let dict_array = DictionaryArray::try_new(&keys, &values).unwrap();
let array = dict_array.downcast_dict::<Int8Array>().unwrap();
assert_eq!(14, max_array::<Int8Type, _>(array).unwrap());

let array = dict_array.downcast_dict::<Int8Array>().unwrap();
assert_eq!(12, min_array::<Int8Type, _>(array).unwrap());

let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
assert_eq!(5, max_array::<Int32Type, _>(&a).unwrap());
assert_eq!(1, min_array::<Int32Type, _>(&a).unwrap());

let keys = Int8Array::from(vec![Some(2_i8), None, Some(7)]);
let dict_array = DictionaryArray::try_new(&keys, &values).unwrap();
let array = dict_array.downcast_dict::<Int8Array>().unwrap();
assert_eq!(17, max_array::<Int8Type, _>(array).unwrap());
let array = dict_array.downcast_dict::<Int8Array>().unwrap();
assert_eq!(12, min_array::<Int8Type, _>(array).unwrap());

let keys = Int8Array::from(vec![None, None, None]);
let dict_array = DictionaryArray::try_new(&keys, &values).unwrap();
let array = dict_array.downcast_dict::<Int8Array>().unwrap();
assert!(sum_dyn::<Int8Type, _>(array).is_none());
assert!(max_array::<Int8Type, _>(array).is_none());
let array = dict_array.downcast_dict::<Int8Array>().unwrap();
assert!(min_array::<Int8Type, _>(array).is_none());
}

#[test]
fn test_max_min_dyn_nan() {
let values = Float32Array::from(vec![5.0_f32, 2.0_f32, f32::NAN]);
let keys = Int8Array::from_iter_values([0_i8, 1, 2]);

let dict_array = DictionaryArray::try_new(&keys, &values).unwrap();
let array = dict_array.downcast_dict::<Float32Array>().unwrap();
assert!(max_array::<Float32Type, _>(array).unwrap().is_nan());

let array = dict_array.downcast_dict::<Float32Array>().unwrap();
assert_eq!(2.0_f32, min_array::<Float32Type, _>(array).unwrap());
}
}

0 comments on commit c64ca4f

Please sign in to comment.