Skip to content

Commit

Permalink
Add overflow-checking variant for primitive arithmetic kernels and ex…
Browse files Browse the repository at this point in the history
…plicitly define overflow behavior (#2643)

* Add overflow-checking variant for add kernel and explicitly define overflow behavior for add

* For subtract, multiply, divide

* Fix tests

* Fix different error message

* Fix typo

* Rename APIs and add more comments. Print values in error message.

* Add one more test to distinct divide_by_zero behavior on divide.

* Fix clippy

* Update divide doc with dividing by zero behavior for other numeric types.

* Hide ArrowNativeTypeOp

* Fix a typo
  • Loading branch information
viirya committed Sep 4, 2022
1 parent 4c1bb00 commit 6d86472
Show file tree
Hide file tree
Showing 3 changed files with 352 additions and 20 deletions.
4 changes: 2 additions & 2 deletions arrow/benches/arithmetic_kernels.rs
Expand Up @@ -55,13 +55,13 @@ fn bench_multiply(arr_a: &ArrayRef, arr_b: &ArrayRef) {
fn bench_divide(arr_a: &ArrayRef, arr_b: &ArrayRef) {
let arr_a = arr_a.as_any().downcast_ref::<Float32Array>().unwrap();
let arr_b = arr_b.as_any().downcast_ref::<Float32Array>().unwrap();
criterion::black_box(divide(arr_a, arr_b).unwrap());
criterion::black_box(divide_checked(arr_a, arr_b).unwrap());
}

fn bench_divide_unchecked(arr_a: &ArrayRef, arr_b: &ArrayRef) {
let arr_a = arr_a.as_any().downcast_ref::<Float32Array>().unwrap();
let arr_b = arr_b.as_any().downcast_ref::<Float32Array>().unwrap();
criterion::black_box(divide_unchecked(arr_a, arr_b).unwrap());
criterion::black_box(divide(arr_a, arr_b).unwrap());
}

fn bench_divide_scalar(array: &ArrayRef, divisor: f32) {
Expand Down
262 changes: 244 additions & 18 deletions arrow/src/compute/kernels/arithmetic.rs
Expand Up @@ -35,8 +35,9 @@ use crate::compute::unary_dyn;
use crate::compute::util::combine_option_bitmap;
use crate::datatypes;
use crate::datatypes::{
ArrowNumericType, DataType, Date32Type, Date64Type, IntervalDayTimeType,
IntervalMonthDayNanoType, IntervalUnit, IntervalYearMonthType,
native_op::ArrowNativeTypeOp, ArrowNumericType, ArrowPrimitiveType, DataType,
Date32Type, Date64Type, IntervalDayTimeType, IntervalMonthDayNanoType, IntervalUnit,
IntervalYearMonthType,
};
use crate::datatypes::{
Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, UInt16Type,
Expand Down Expand Up @@ -103,6 +104,106 @@ where
Ok(PrimitiveArray::<LT>::from(data))
}

/// This is similar to `math_op` as it performs given operation between two input primitive arrays.
/// But the given operation can return `None` if overflow is detected. For the case, this function
/// returns an `Err`.
fn math_checked_op<LT, RT, F>(
left: &PrimitiveArray<LT>,
right: &PrimitiveArray<RT>,
op: F,
) -> Result<PrimitiveArray<LT>>
where
LT: ArrowNumericType,
RT: ArrowNumericType,
F: Fn(LT::Native, RT::Native) -> Option<LT::Native>,
{
if left.len() != right.len() {
return Err(ArrowError::ComputeError(
"Cannot perform math operation on arrays of different length".to_string(),
));
}

let left_iter = ArrayIter::new(left);
let right_iter = ArrayIter::new(right);

let values: Result<Vec<Option<<LT as ArrowPrimitiveType>::Native>>> = left_iter
.into_iter()
.zip(right_iter.into_iter())
.map(|(l, r)| {
if let (Some(l), Some(r)) = (l, r) {
let result = op(l, r);
if let Some(r) = result {
Ok(Some(r))
} else {
// Overflow
Err(ArrowError::ComputeError(format!(
"Overflow happened on: {:?}, {:?}",
l, r
)))
}
} else {
Ok(None)
}
})
.collect();

let values = values?;

Ok(PrimitiveArray::<LT>::from_iter(values))
}

/// This is similar to `math_checked_op` but just for divide op.
fn math_checked_divide<LT, RT, F>(
left: &PrimitiveArray<LT>,
right: &PrimitiveArray<RT>,
op: F,
) -> Result<PrimitiveArray<LT>>
where
LT: ArrowNumericType,
RT: ArrowNumericType,
RT::Native: One + Zero,
F: Fn(LT::Native, RT::Native) -> Option<LT::Native>,
{
if left.len() != right.len() {
return Err(ArrowError::ComputeError(
"Cannot perform math operation on arrays of different length".to_string(),
));
}

let left_iter = ArrayIter::new(left);
let right_iter = ArrayIter::new(right);

let values: Result<Vec<Option<<LT as ArrowPrimitiveType>::Native>>> = left_iter
.into_iter()
.zip(right_iter.into_iter())
.map(|(l, r)| {
if let (Some(l), Some(r)) = (l, r) {
let result = op(l, r);
if let Some(r) = result {
Ok(Some(r))
} else if r.is_zero() {
Err(ArrowError::ComputeError(format!(
"DivideByZero on: {:?}, {:?}",
l, r
)))
} else {
// Overflow
Err(ArrowError::ComputeError(format!(
"Overflow happened on: {:?}, {:?}",
l, r
)))
}
} else {
Ok(None)
}
})
.collect();

let values = values?;

Ok(PrimitiveArray::<LT>::from_iter(values))
}

/// Helper function for operations where a valid `0` on the right array should
/// result in an [ArrowError::DivideByZero], namely the division and modulo operations
///
Expand Down Expand Up @@ -760,15 +861,34 @@ where

/// Perform `left + right` operation on two arrays. If either left or right value is null
/// then the result is also null.
///
/// This doesn't detect overflow. Once overflowing, the result will wrap around.
/// For an overflow-checking variant, use `add_checked` instead.
pub fn add<T>(
left: &PrimitiveArray<T>,
right: &PrimitiveArray<T>,
) -> Result<PrimitiveArray<T>>
where
T: ArrowNumericType,
T::Native: Add<Output = T::Native>,
T::Native: ArrowNativeTypeOp,
{
math_op(left, right, |a, b| a.add_wrapping(b))
}

/// Perform `left + right` operation on two arrays. If either left or right value is null
/// then the result is also null. Once
///
/// This detects overflow and returns an `Err` for that. For an non-overflow-checking variant,
/// use `add` instead.
pub fn add_checked<T>(
left: &PrimitiveArray<T>,
right: &PrimitiveArray<T>,
) -> Result<PrimitiveArray<T>>
where
T: ArrowNumericType,
T::Native: ArrowNativeTypeOp,
{
math_op(left, right, |a, b| a + b)
math_checked_op(left, right, |a, b| a.add_checked(b))
}

/// Perform `left + right` operation on two arrays. If either left or right value is null
Expand Down Expand Up @@ -856,15 +976,34 @@ where

/// Perform `left - right` operation on two arrays. If either left or right value is null
/// then the result is also null.
///
/// This doesn't detect overflow. Once overflowing, the result will wrap around.
/// For an overflow-checking variant, use `subtract_checked` instead.
pub fn subtract<T>(
left: &PrimitiveArray<T>,
right: &PrimitiveArray<T>,
) -> Result<PrimitiveArray<T>>
where
T: datatypes::ArrowNumericType,
T::Native: Sub<Output = T::Native>,
T::Native: ArrowNativeTypeOp,
{
math_op(left, right, |a, b| a - b)
math_op(left, right, |a, b| a.sub_wrapping(b))
}

/// Perform `left - right` operation on two arrays. If either left or right value is null
/// then the result is also null.
///
/// This detects overflow and returns an `Err` for that. For an non-overflow-checking variant,
/// use `subtract` instead.
pub fn subtract_checked<T>(
left: &PrimitiveArray<T>,
right: &PrimitiveArray<T>,
) -> Result<PrimitiveArray<T>>
where
T: datatypes::ArrowNumericType,
T::Native: ArrowNativeTypeOp,
{
math_checked_op(left, right, |a, b| a.sub_checked(b))
}

/// Perform `left - right` operation on two arrays. If either left or right value is null
Expand Down Expand Up @@ -933,15 +1072,34 @@ where

/// Perform `left * right` operation on two arrays. If either left or right value is null
/// then the result is also null.
///
/// This doesn't detect overflow. Once overflowing, the result will wrap around.
/// For an overflow-checking variant, use `multiply_check` instead.
pub fn multiply<T>(
left: &PrimitiveArray<T>,
right: &PrimitiveArray<T>,
) -> Result<PrimitiveArray<T>>
where
T: datatypes::ArrowNumericType,
T::Native: Mul<Output = T::Native>,
T::Native: ArrowNativeTypeOp,
{
math_op(left, right, |a, b| a * b)
math_op(left, right, |a, b| a.mul_wrapping(b))
}

/// Perform `left * right` operation on two arrays. If either left or right value is null
/// then the result is also null.
///
/// This detects overflow and returns an `Err` for that. For an non-overflow-checking variant,
/// use `multiply` instead.
pub fn multiply_checked<T>(
left: &PrimitiveArray<T>,
right: &PrimitiveArray<T>,
) -> Result<PrimitiveArray<T>>
where
T: datatypes::ArrowNumericType,
T::Native: ArrowNativeTypeOp,
{
math_checked_op(left, right, |a, b| a.mul_checked(b))
}

/// Perform `left * right` operation on two arrays. If either left or right value is null
Expand Down Expand Up @@ -1013,18 +1171,21 @@ where
/// Perform `left / right` operation on two arrays. If either left or right value is null
/// then the result is also null. If any right hand value is zero then the result of this
/// operation will be `Err(ArrowError::DivideByZero)`.
pub fn divide<T>(
///
/// When `simd` feature is not enabled. This detects overflow and returns an `Err` for that.
/// For an non-overflow-checking variant, use `divide` instead.
pub fn divide_checked<T>(
left: &PrimitiveArray<T>,
right: &PrimitiveArray<T>,
) -> Result<PrimitiveArray<T>>
where
T: datatypes::ArrowNumericType,
T::Native: Div<Output = T::Native> + Zero + One,
T::Native: ArrowNativeTypeOp + Zero + One,
{
#[cfg(feature = "simd")]
return simd_checked_divide_op(&left, &right, simd_checked_divide::<T>, |a, b| a / b);
#[cfg(not(feature = "simd"))]
return math_checked_divide_op(left, right, |a, b| a / b);
return math_checked_divide(left, right, |a, b| a.div_checked(b));
}

/// Perform `left / right` operation on two arrays. If either left or right value is null
Expand All @@ -1040,17 +1201,21 @@ pub fn divide_dyn(left: &dyn Array, right: &dyn Array) -> Result<ArrayRef> {
}

/// Perform `left / right` operation on two arrays without checking for division by zero.
/// The result of dividing by zero follows normal floating point rules.
/// For floating point types, the result of dividing by zero follows normal floating point
/// rules. For other numeric types, dividing by zero will panic,
/// If either left or right value is null then the result is also null. If any right hand value is zero then the result of this
pub fn divide_unchecked<T>(
///
/// This doesn't detect overflow. Once overflowing, the result will wrap around.
/// For an overflow-checking variant, use `divide_checked` instead.
pub fn divide<T>(
left: &PrimitiveArray<T>,
right: &PrimitiveArray<T>,
) -> Result<PrimitiveArray<T>>
where
T: datatypes::ArrowFloatNumericType,
T::Native: Div<Output = T::Native>,
T: datatypes::ArrowNumericType,
T::Native: ArrowNativeTypeOp,
{
math_op(left, right, |a, b| a / b)
math_op(left, right, |a, b| a.div_wrapping(b))
}

/// Modulus every value in an array by a scalar. If any value in the array is null then the
Expand Down Expand Up @@ -1769,7 +1934,7 @@ mod tests {
fn test_primitive_array_divide_with_nulls() {
let a = Int32Array::from(vec![Some(15), None, Some(8), Some(1), Some(9), None]);
let b = Int32Array::from(vec![Some(5), Some(6), Some(8), Some(9), None, None]);
let c = divide(&a, &b).unwrap();
let c = divide_checked(&a, &b).unwrap();
assert_eq!(3, c.value(0));
assert!(c.is_null(1));
assert_eq!(1, c.value(2));
Expand Down Expand Up @@ -1854,7 +2019,7 @@ mod tests {
let b = b.slice(8, 6);
let b = b.as_any().downcast_ref::<Int32Array>().unwrap();

let c = divide(a, b).unwrap();
let c = divide_checked(a, b).unwrap();
assert_eq!(6, c.len());
assert_eq!(3, c.value(0));
assert!(c.is_null(1));
Expand Down Expand Up @@ -1919,6 +2084,14 @@ mod tests {

#[test]
#[should_panic(expected = "DivideByZero")]
fn test_primitive_array_divide_by_zero_with_checked() {
let a = Int32Array::from(vec![15]);
let b = Int32Array::from(vec![0]);
divide_checked(&a, &b).unwrap();
}

#[test]
#[should_panic(expected = "attempt to divide by zero")]
fn test_primitive_array_divide_by_zero() {
let a = Int32Array::from(vec![15]);
let b = Int32Array::from(vec![0]);
Expand Down Expand Up @@ -2019,4 +2192,57 @@ mod tests {
let expected = Float64Array::from(vec![Some(1.0), None, Some(9.0)]);
assert_eq!(expected, actual);
}

#[test]
fn test_primitive_add_wrapping_overflow() {
let a = Int32Array::from(vec![i32::MAX, i32::MIN]);
let b = Int32Array::from(vec![1, 1]);

let wrapped = add(&a, &b);
let expected = Int32Array::from(vec![-2147483648, -2147483647]);
assert_eq!(expected, wrapped.unwrap());

let overflow = add_checked(&a, &b);
overflow.expect_err("overflow should be detected");
}

#[test]
fn test_primitive_subtract_wrapping_overflow() {
let a = Int32Array::from(vec![-2]);
let b = Int32Array::from(vec![i32::MAX]);

let wrapped = subtract(&a, &b);
let expected = Int32Array::from(vec![i32::MAX]);
assert_eq!(expected, wrapped.unwrap());

let overflow = subtract_checked(&a, &b);
overflow.expect_err("overflow should be detected");
}

#[test]
fn test_primitive_mul_wrapping_overflow() {
let a = Int32Array::from(vec![10]);
let b = Int32Array::from(vec![i32::MAX]);

let wrapped = multiply(&a, &b);
let expected = Int32Array::from(vec![-10]);
assert_eq!(expected, wrapped.unwrap());

let overflow = multiply_checked(&a, &b);
overflow.expect_err("overflow should be detected");
}

#[test]
#[cfg(not(feature = "simd"))]
fn test_primitive_div_wrapping_overflow() {
let a = Int32Array::from(vec![i32::MIN]);
let b = Int32Array::from(vec![-1]);

let wrapped = divide(&a, &b);
let expected = Int32Array::from(vec![-2147483648]);
assert_eq!(expected, wrapped.unwrap());

let overflow = divide_checked(&a, &b);
overflow.expect_err("overflow should be detected");
}
}

0 comments on commit 6d86472

Please sign in to comment.