diff --git a/arrow/src/compute/kernels/aggregate.rs b/arrow/src/compute/kernels/aggregate.rs index c215e23953e..083defdde7d 100644 --- a/arrow/src/compute/kernels/aggregate.rs +++ b/arrow/src/compute/kernels/aggregate.rs @@ -17,14 +17,19 @@ //! Defines aggregations over Arrow arrays. +use arrow_data::bit_iterator::try_for_each_valid_idx; +use arrow_schema::ArrowError; use multiversion::multiversion; -use std::ops::Add; +#[allow(unused_imports)] +use std::ops::{Add, Deref}; use crate::array::{ as_primitive_array, Array, ArrayAccessor, ArrayIter, BooleanArray, GenericBinaryArray, GenericStringArray, OffsetSizeTrait, PrimitiveArray, }; +use crate::datatypes::native_op::ArrowNativeTypeOp; use crate::datatypes::{ArrowNativeType, ArrowNumericType, DataType}; +use crate::error::Result; use crate::util::bit_iterator::BitIndexIterator; /// Generic test for NaN, the optimizer should be able to remove this for integer types. @@ -162,10 +167,13 @@ pub fn min_string(array: &GenericStringArray) -> Option<& } /// Returns the sum of values in the array. +/// +/// This doesn't detect overflow. Once overflowing, the result will wrap around. +/// For an overflow-checking variant, use `sum_array_checked` instead. pub fn sum_array>(array: A) -> Option where T: ArrowNumericType, - T::Native: Add, + T::Native: ArrowNativeTypeOp, { match array.data_type() { DataType::Dictionary(_, _) => { @@ -180,7 +188,7 @@ where .into_iter() .fold(T::default_value(), |accumulator, value| { if let Some(value) = value { - accumulator + value + accumulator.add_wrapping(value) } else { accumulator } @@ -192,6 +200,42 @@ where } } +/// Returns the sum of values in the array. +/// +/// This detects overflow and returns an `Err` for that. For an non-overflow-checking variant, +/// use `sum_array` instead. +pub fn sum_array_checked>( + array: A, +) -> Result> +where + T: ArrowNumericType, + T::Native: ArrowNativeTypeOp, +{ + match array.data_type() { + DataType::Dictionary(_, _) => { + let null_count = array.null_count(); + + if null_count == array.len() { + return Ok(None); + } + + let iter = ArrayIter::new(array); + let sum = + iter.into_iter() + .try_fold(T::default_value(), |accumulator, value| { + if let Some(value) = value { + accumulator.add_checked(value) + } else { + Ok(accumulator) + } + })?; + + Ok(Some(sum)) + } + _ => sum_checked::(as_primitive_array(&array)), + } +} + /// Returns the min of values in the array of `ArrowNumericType` type, or dictionary /// array with value of `ArrowNumericType` type. pub fn min_array>(array: A) -> Option @@ -239,11 +283,14 @@ where /// Returns the sum of values in the primitive array. /// /// Returns `None` if the array is empty or only contains null values. +/// +/// This doesn't detect overflow. Once overflowing, the result will wrap around. +/// For an overflow-checking variant, use `sum_checked` instead. #[cfg(not(feature = "simd"))] pub fn sum(array: &PrimitiveArray) -> Option where T: ArrowNumericType, - T::Native: Add, + T::Native: ArrowNativeTypeOp, { let null_count = array.null_count(); @@ -256,7 +303,7 @@ where match array.data().null_buffer() { None => { let sum = data.iter().fold(T::default_value(), |accumulator, value| { - accumulator + *value + accumulator.add_wrapping(*value) }); Some(sum) @@ -274,7 +321,7 @@ where let mut index_mask = 1; chunk.iter().for_each(|value| { if (mask & index_mask) != 0 { - sum = sum + *value; + sum = sum.add_wrapping(*value); } index_mask <<= 1; }); @@ -284,7 +331,7 @@ where remainder.iter().enumerate().for_each(|(i, value)| { if remainder_bits & (1 << i) != 0 { - sum = sum + *value; + sum = sum.add_wrapping(*value); } }); @@ -293,6 +340,54 @@ where } } +/// Returns the sum of values in the primitive array. +/// +/// Returns `Ok(None)` if the array is empty or only contains null values. +/// +/// This detects overflow and returns an `Err` for that. For an non-overflow-checking variant, +/// use `sum` instead. +pub fn sum_checked(array: &PrimitiveArray) -> Result> +where + T: ArrowNumericType, + T::Native: ArrowNativeTypeOp, +{ + let null_count = array.null_count(); + + if null_count == array.len() { + return Ok(None); + } + + let data: &[T::Native] = array.values(); + + match array.data().null_buffer() { + None => { + let sum = data + .iter() + .try_fold(T::default_value(), |accumulator, value| { + accumulator.add_checked(*value) + })?; + + Ok(Some(sum)) + } + Some(buffer) => { + let mut sum = T::default_value(); + + try_for_each_valid_idx( + array.len(), + array.offset(), + null_count, + Some(buffer.deref()), + |idx| { + unsafe { sum = sum.add_checked(array.value_unchecked(idx))? }; + Ok::<_, ArrowError>(()) + }, + )?; + + Ok(Some(sum)) + } + } +} + #[cfg(feature = "simd")] mod simd { use super::is_nan; @@ -638,6 +733,9 @@ mod simd { /// Returns the sum of values in the primitive array. /// /// Returns `None` if the array is empty or only contains null values. +/// +/// This doesn't detect overflow in release mode by default. Once overflowing, the result will +/// wrap around. For an overflow-checking variant, use `sum_checked` instead. #[cfg(feature = "simd")] pub fn sum(array: &PrimitiveArray) -> Option where @@ -1216,4 +1314,21 @@ mod tests { let actual = max_binary(sliced_input); assert_eq!(actual, expected); } + + #[test] + #[cfg(not(feature = "simd"))] + fn test_sum_overflow() { + let a = Int32Array::from(vec![i32::MAX, 1]); + + assert_eq!(sum(&a).unwrap(), -2147483648); + assert_eq!(sum_array::(&a).unwrap(), -2147483648); + } + + #[test] + fn test_sum_checked_overflow() { + let a = Int32Array::from(vec![i32::MAX, 1]); + + sum_checked(&a).expect_err("overflow should be detected"); + sum_array_checked::(&a).expect_err("overflow should be detected"); + } }