Skip to content

Commit

Permalink
Add overflow-checking variant of sum kernel (#2822)
Browse files Browse the repository at this point in the history
* Define overflow-checking behavior of sum kernels

* Add sum_checked.

* Add sum_array_checked.
  • Loading branch information
viirya committed Oct 5, 2022
1 parent 4eb9908 commit e79ba40
Showing 1 changed file with 122 additions and 7 deletions.
129 changes: 122 additions & 7 deletions arrow/src/compute/kernels/aggregate.rs
Expand Up @@ -17,14 +17,19 @@

//! Defines aggregations over Arrow arrays.

use arrow_data::bit_iterator::try_for_each_valid_idx;
use arrow_schema::ArrowError;
use multiversion::multiversion;
use std::ops::Add;
#[allow(unused_imports)]
use std::ops::{Add, Deref};

use crate::array::{
as_primitive_array, Array, ArrayAccessor, ArrayIter, BooleanArray,
GenericBinaryArray, GenericStringArray, OffsetSizeTrait, PrimitiveArray,
};
use crate::datatypes::native_op::ArrowNativeTypeOp;
use crate::datatypes::{ArrowNativeType, ArrowNumericType, DataType};
use crate::error::Result;
use crate::util::bit_iterator::BitIndexIterator;

/// Generic test for NaN, the optimizer should be able to remove this for integer types.
Expand Down Expand Up @@ -162,10 +167,13 @@ pub fn min_string<T: OffsetSizeTrait>(array: &GenericStringArray<T>) -> Option<&
}

/// Returns the sum of values in the array.
///
/// This doesn't detect overflow. Once overflowing, the result will wrap around.
/// For an overflow-checking variant, use `sum_array_checked` instead.
pub fn sum_array<T, A: ArrayAccessor<Item = T::Native>>(array: A) -> Option<T::Native>
where
T: ArrowNumericType,
T::Native: Add<Output = T::Native>,
T::Native: ArrowNativeTypeOp,
{
match array.data_type() {
DataType::Dictionary(_, _) => {
Expand All @@ -180,7 +188,7 @@ where
.into_iter()
.fold(T::default_value(), |accumulator, value| {
if let Some(value) = value {
accumulator + value
accumulator.add_wrapping(value)
} else {
accumulator
}
Expand All @@ -192,6 +200,42 @@ where
}
}

/// Returns the sum of values in the array.
///
/// This detects overflow and returns an `Err` for that. For an non-overflow-checking variant,
/// use `sum_array` instead.
pub fn sum_array_checked<T, A: ArrayAccessor<Item = T::Native>>(
array: A,
) -> Result<Option<T::Native>>
where
T: ArrowNumericType,
T::Native: ArrowNativeTypeOp,
{
match array.data_type() {
DataType::Dictionary(_, _) => {
let null_count = array.null_count();

if null_count == array.len() {
return Ok(None);
}

let iter = ArrayIter::new(array);
let sum =
iter.into_iter()
.try_fold(T::default_value(), |accumulator, value| {
if let Some(value) = value {
accumulator.add_checked(value)
} else {
Ok(accumulator)
}
})?;

Ok(Some(sum))
}
_ => sum_checked::<T>(as_primitive_array(&array)),
}
}

/// Returns the min of values in the array of `ArrowNumericType` type, or dictionary
/// array with value of `ArrowNumericType` type.
pub fn min_array<T, A: ArrayAccessor<Item = T::Native>>(array: A) -> Option<T::Native>
Expand Down Expand Up @@ -239,11 +283,14 @@ where
/// Returns the sum of values in the primitive array.
///
/// Returns `None` if the array is empty or only contains null values.
///
/// This doesn't detect overflow. Once overflowing, the result will wrap around.
/// For an overflow-checking variant, use `sum_checked` instead.
#[cfg(not(feature = "simd"))]
pub fn sum<T>(array: &PrimitiveArray<T>) -> Option<T::Native>
where
T: ArrowNumericType,
T::Native: Add<Output = T::Native>,
T::Native: ArrowNativeTypeOp,
{
let null_count = array.null_count();

Expand All @@ -256,7 +303,7 @@ where
match array.data().null_buffer() {
None => {
let sum = data.iter().fold(T::default_value(), |accumulator, value| {
accumulator + *value
accumulator.add_wrapping(*value)
});

Some(sum)
Expand All @@ -274,7 +321,7 @@ where
let mut index_mask = 1;
chunk.iter().for_each(|value| {
if (mask & index_mask) != 0 {
sum = sum + *value;
sum = sum.add_wrapping(*value);
}
index_mask <<= 1;
});
Expand All @@ -284,7 +331,7 @@ where

remainder.iter().enumerate().for_each(|(i, value)| {
if remainder_bits & (1 << i) != 0 {
sum = sum + *value;
sum = sum.add_wrapping(*value);
}
});

Expand All @@ -293,6 +340,54 @@ where
}
}

/// Returns the sum of values in the primitive array.
///
/// Returns `Ok(None)` if the array is empty or only contains null values.
///
/// This detects overflow and returns an `Err` for that. For an non-overflow-checking variant,
/// use `sum` instead.
pub fn sum_checked<T>(array: &PrimitiveArray<T>) -> Result<Option<T::Native>>
where
T: ArrowNumericType,
T::Native: ArrowNativeTypeOp,
{
let null_count = array.null_count();

if null_count == array.len() {
return Ok(None);
}

let data: &[T::Native] = array.values();

match array.data().null_buffer() {
None => {
let sum = data
.iter()
.try_fold(T::default_value(), |accumulator, value| {
accumulator.add_checked(*value)
})?;

Ok(Some(sum))
}
Some(buffer) => {
let mut sum = T::default_value();

try_for_each_valid_idx(
array.len(),
array.offset(),
null_count,
Some(buffer.deref()),
|idx| {
unsafe { sum = sum.add_checked(array.value_unchecked(idx))? };
Ok::<_, ArrowError>(())
},
)?;

Ok(Some(sum))
}
}
}

#[cfg(feature = "simd")]
mod simd {
use super::is_nan;
Expand Down Expand Up @@ -638,6 +733,9 @@ mod simd {
/// Returns the sum of values in the primitive array.
///
/// Returns `None` if the array is empty or only contains null values.
///
/// This doesn't detect overflow in release mode by default. Once overflowing, the result will
/// wrap around. For an overflow-checking variant, use `sum_checked` instead.
#[cfg(feature = "simd")]
pub fn sum<T: ArrowNumericType>(array: &PrimitiveArray<T>) -> Option<T::Native>
where
Expand Down Expand Up @@ -1216,4 +1314,21 @@ mod tests {
let actual = max_binary(sliced_input);
assert_eq!(actual, expected);
}

#[test]
#[cfg(not(feature = "simd"))]
fn test_sum_overflow() {
let a = Int32Array::from(vec![i32::MAX, 1]);

assert_eq!(sum(&a).unwrap(), -2147483648);
assert_eq!(sum_array::<Int32Type, _>(&a).unwrap(), -2147483648);
}

#[test]
fn test_sum_checked_overflow() {
let a = Int32Array::from(vec![i32::MAX, 1]);

sum_checked(&a).expect_err("overflow should be detected");
sum_array_checked::<Int32Type, _>(&a).expect_err("overflow should be detected");
}
}

0 comments on commit e79ba40

Please sign in to comment.