Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add overflow-checking variant of sum kernel #2822

Merged
merged 3 commits into from Oct 5, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
129 changes: 122 additions & 7 deletions arrow/src/compute/kernels/aggregate.rs
Expand Up @@ -17,14 +17,19 @@

//! Defines aggregations over Arrow arrays.

use arrow_data::bit_iterator::try_for_each_valid_idx;
use arrow_schema::ArrowError;
use multiversion::multiversion;
use std::ops::Add;
#[allow(unused_imports)]
use std::ops::{Add, Deref};

use crate::array::{
as_primitive_array, Array, ArrayAccessor, ArrayIter, BooleanArray,
GenericBinaryArray, GenericStringArray, OffsetSizeTrait, PrimitiveArray,
};
use crate::datatypes::native_op::ArrowNativeTypeOp;
use crate::datatypes::{ArrowNativeType, ArrowNumericType, DataType};
use crate::error::Result;
use crate::util::bit_iterator::BitIndexIterator;

/// Generic test for NaN, the optimizer should be able to remove this for integer types.
Expand Down Expand Up @@ -162,10 +167,13 @@ pub fn min_string<T: OffsetSizeTrait>(array: &GenericStringArray<T>) -> Option<&
}

/// Returns the sum of values in the array.
///
/// This doesn't detect overflow. Once overflowing, the result will wrap around.
/// For an overflow-checking variant, use `sum_array_checked` instead.
pub fn sum_array<T, A: ArrayAccessor<Item = T::Native>>(array: A) -> Option<T::Native>
where
T: ArrowNumericType,
T::Native: Add<Output = T::Native>,
T::Native: ArrowNativeTypeOp,
{
match array.data_type() {
DataType::Dictionary(_, _) => {
Expand All @@ -180,7 +188,7 @@ where
.into_iter()
.fold(T::default_value(), |accumulator, value| {
if let Some(value) = value {
accumulator + value
accumulator.add_wrapping(value)
} else {
accumulator
}
Expand All @@ -192,6 +200,42 @@ where
}
}

/// Returns the sum of values in the array.
///
/// This detects overflow and returns an `Err` for that. For an non-overflow-checking variant,
/// use `sum_array` instead.
pub fn sum_array_checked<T, A: ArrayAccessor<Item = T::Native>>(
array: A,
) -> Result<Option<T::Native>>
where
T: ArrowNumericType,
T::Native: ArrowNativeTypeOp,
{
match array.data_type() {
DataType::Dictionary(_, _) => {
let null_count = array.null_count();

if null_count == array.len() {
return Ok(None);
}

let iter = ArrayIter::new(array);
let sum =
iter.into_iter()
.try_fold(T::default_value(), |accumulator, value| {
if let Some(value) = value {
accumulator.add_checked(value)
} else {
Ok(accumulator)
}
})?;

Ok(Some(sum))
}
_ => sum_checked::<T>(as_primitive_array(&array)),
}
}

/// Returns the min of values in the array of `ArrowNumericType` type, or dictionary
/// array with value of `ArrowNumericType` type.
pub fn min_array<T, A: ArrayAccessor<Item = T::Native>>(array: A) -> Option<T::Native>
Expand Down Expand Up @@ -239,11 +283,14 @@ where
/// Returns the sum of values in the primitive array.
///
/// Returns `None` if the array is empty or only contains null values.
///
/// This doesn't detect overflow. Once overflowing, the result will wrap around.
/// For an overflow-checking variant, use `sum_checked` instead.
#[cfg(not(feature = "simd"))]
pub fn sum<T>(array: &PrimitiveArray<T>) -> Option<T::Native>
where
T: ArrowNumericType,
T::Native: Add<Output = T::Native>,
T::Native: ArrowNativeTypeOp,
{
let null_count = array.null_count();

Expand All @@ -256,7 +303,7 @@ where
match array.data().null_buffer() {
None => {
let sum = data.iter().fold(T::default_value(), |accumulator, value| {
accumulator + *value
accumulator.add_wrapping(*value)
});

Some(sum)
Expand All @@ -274,7 +321,7 @@ where
let mut index_mask = 1;
chunk.iter().for_each(|value| {
if (mask & index_mask) != 0 {
sum = sum + *value;
sum = sum.add_wrapping(*value);
}
index_mask <<= 1;
});
Expand All @@ -284,7 +331,7 @@ where

remainder.iter().enumerate().for_each(|(i, value)| {
if remainder_bits & (1 << i) != 0 {
sum = sum + *value;
sum = sum.add_wrapping(*value);
}
});

Expand All @@ -293,6 +340,54 @@ where
}
}

/// Returns the sum of values in the primitive array.
///
/// Returns `Ok(None)` if the array is empty or only contains null values.
///
/// This detects overflow and returns an `Err` for that. For an non-overflow-checking variant,
/// use `sum` instead.
pub fn sum_checked<T>(array: &PrimitiveArray<T>) -> Result<Option<T::Native>>
where
T: ArrowNumericType,
T::Native: ArrowNativeTypeOp,
{
let null_count = array.null_count();

if null_count == array.len() {
return Ok(None);
}

let data: &[T::Native] = array.values();

match array.data().null_buffer() {
None => {
let sum = data
.iter()
.try_fold(T::default_value(), |accumulator, value| {
accumulator.add_checked(*value)
})?;

Ok(Some(sum))
}
Some(buffer) => {
let mut sum = T::default_value();

try_for_each_valid_idx(
array.len(),
array.offset(),
null_count,
Some(buffer.deref()),
|idx| {
unsafe { sum = sum.add_checked(array.value_unchecked(idx))? };
Ok::<_, ArrowError>(())
},
)?;

Ok(Some(sum))
}
}
}

#[cfg(feature = "simd")]
mod simd {
use super::is_nan;
Expand Down Expand Up @@ -638,6 +733,9 @@ mod simd {
/// Returns the sum of values in the primitive array.
///
/// Returns `None` if the array is empty or only contains null values.
///
/// This doesn't detect overflow in release mode by default. Once overflowing, the result will
/// wrap around. For an overflow-checking variant, use `sum_checked` instead.
#[cfg(feature = "simd")]
pub fn sum<T: ArrowNumericType>(array: &PrimitiveArray<T>) -> Option<T::Native>
where
Expand Down Expand Up @@ -1216,4 +1314,21 @@ mod tests {
let actual = max_binary(sliced_input);
assert_eq!(actual, expected);
}

#[test]
#[cfg(not(feature = "simd"))]
fn test_sum_overflow() {
let a = Int32Array::from(vec![i32::MAX, 1]);

assert_eq!(sum(&a).unwrap(), -2147483648);
assert_eq!(sum_array::<Int32Type, _>(&a).unwrap(), -2147483648);
}

#[test]
fn test_sum_checked_overflow() {
let a = Int32Array::from(vec![i32::MAX, 1]);

sum_checked(&a).expect_err("overflow should be detected");
sum_array_checked::<Int32Type, _>(&a).expect_err("overflow should be detected");
}
}