Skip to content

Commit

Permalink
Add overflow-checking variant for add kernel and explicitly define ov…
Browse files Browse the repository at this point in the history
…erflow behavior for add
  • Loading branch information
viirya committed Sep 3, 2022
1 parent 2b2c15b commit 154f8a5
Show file tree
Hide file tree
Showing 2 changed files with 121 additions and 4 deletions.
86 changes: 82 additions & 4 deletions arrow/src/compute/kernels/arithmetic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,9 @@ use crate::compute::unary_dyn;
use crate::compute::util::combine_option_bitmap;
use crate::datatypes;
use crate::datatypes::{
ArrowNumericType, DataType, Date32Type, Date64Type, IntervalDayTimeType,
IntervalMonthDayNanoType, IntervalUnit, IntervalYearMonthType,
ArrowNativeTypeOp, ArrowNumericType, ArrowPrimitiveType, DataType, Date32Type,
Date64Type, IntervalDayTimeType, IntervalMonthDayNanoType, IntervalUnit,
IntervalYearMonthType,
};
use crate::datatypes::{
Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, UInt16Type,
Expand Down Expand Up @@ -103,6 +104,51 @@ where
Ok(PrimitiveArray::<LT>::from(data))
}

/// This is similar to `math_op` as it performs given operation between two input primitive arrays.
/// But the given can return `None` if overflow is detected. For the case, this function returns an
/// `Err`.
fn math_checked_op<LT, RT, F>(
left: &PrimitiveArray<LT>,
right: &PrimitiveArray<RT>,
op: F,
) -> Result<PrimitiveArray<LT>>
where
LT: ArrowNumericType,
RT: ArrowNumericType,
F: Fn(LT::Native, RT::Native) -> Option<LT::Native>,
{
if left.len() != right.len() {
return Err(ArrowError::ComputeError(
"Cannot perform math operation on arrays of different length".to_string(),
));
}

let left_iter = ArrayIter::new(left);
let right_iter = ArrayIter::new(right);

let values: Result<Vec<Option<<LT as ArrowPrimitiveType>::Native>>> = left_iter
.into_iter()
.zip(right_iter.into_iter())
.map(|(l, r)| {
if let (Some(l), Some(r)) = (l, r) {
let result = op(l, r);
if let Some(r) = result {
Ok(Some(r))
} else {
// Overflow
Err(ArrowError::ComputeError("Overflow happened".to_string()))
}
} else {
Ok(None)
}
})
.collect();

let values = values?;

Ok(PrimitiveArray::<LT>::from_iter(values))
}

/// Helper function for operations where a valid `0` on the right array should
/// result in an [ArrowError::DivideByZero], namely the division and modulo operations
///
Expand Down Expand Up @@ -760,15 +806,34 @@ where

/// Perform `left + right` operation on two arrays. If either left or right value is null
/// then the result is also null.
///
/// This doesn't detect overflow. Once overflowing, the result will wrap around.
/// For an overflow-checking variant, use `add_checked` instead.
pub fn add<T>(
left: &PrimitiveArray<T>,
right: &PrimitiveArray<T>,
) -> Result<PrimitiveArray<T>>
where
T: ArrowNumericType,
T::Native: Add<Output = T::Native>,
T::Native: ArrowNativeTypeOp,
{
math_op(left, right, |a, b| a.wrapping_add_if_applied(b))
}

/// Perform `left + right` operation on two arrays. If either left or right value is null
/// then the result is also null. Once
///
/// This detects overflow and returns an `Err` for that. For an non-overflow-checking variant,
/// use `add` instead.
pub fn add_checked<T>(
left: &PrimitiveArray<T>,
right: &PrimitiveArray<T>,
) -> Result<PrimitiveArray<T>>
where
T: ArrowNumericType,
T::Native: ArrowNativeTypeOp,
{
math_op(left, right, |a, b| a + b)
math_checked_op(left, right, |a, b| a.checked_add_if_applied(b))
}

/// Perform `left + right` operation on two arrays. If either left or right value is null
Expand Down Expand Up @@ -2019,4 +2084,17 @@ mod tests {
let expected = Float64Array::from(vec![Some(1.0), None, Some(9.0)]);
assert_eq!(expected, actual);
}

#[test]
fn test_primitive_add_wrapping_overflow() {
let a = Int32Array::from(vec![i32::MAX, i32::MIN]);
let b = Int32Array::from(vec![1, 1]);

let wrapped = add(&a, &b);
let expected = Int32Array::from(vec![-2147483648, -2147483647]);
assert_eq!(expected, wrapped.unwrap());

let overflow = add_checked(&a, &b);
overflow.expect_err("overflow should be detected");
}
}
39 changes: 39 additions & 0 deletions arrow/src/datatypes/native.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

use super::DataType;
use half::f16;
use std::ops::Add;

mod private {
pub trait Sealed {}
Expand Down Expand Up @@ -114,6 +115,44 @@ pub trait ArrowPrimitiveType: 'static {
}
}

/// Trait for ArrowNativeType to provide overflow-aware operations.
pub trait ArrowNativeTypeOp: ArrowNativeType + Add<Output = Self> {
fn checked_add_if_applied(self, rhs: Self) -> Option<Self> {
Some(self + rhs)
}

fn wrapping_add_if_applied(self, rhs: Self) -> Self {
self + rhs
}
}

macro_rules! native_type_op {
($t:tt) => {
impl ArrowNativeTypeOp for $t {
fn checked_add_if_applied(self, rhs: Self) -> Option<Self> {
self.checked_add(rhs)
}

fn wrapping_add_if_applied(self, rhs: Self) -> Self {
self.wrapping_add(rhs)
}
}
};
}

native_type_op!(i8);
native_type_op!(i16);
native_type_op!(i32);
native_type_op!(i64);
native_type_op!(u8);
native_type_op!(u16);
native_type_op!(u32);
native_type_op!(u64);

impl ArrowNativeTypeOp for f16 {}
impl ArrowNativeTypeOp for f32 {}
impl ArrowNativeTypeOp for f64 {}

impl private::Sealed for i8 {}
impl ArrowNativeType for i8 {
#[inline]
Expand Down

0 comments on commit 154f8a5

Please sign in to comment.