Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add overflow-checking variant for primitive arithmetic kernels and explicitly define overflow behavior #2643

Merged
merged 11 commits into from Sep 4, 2022
4 changes: 2 additions & 2 deletions arrow/benches/arithmetic_kernels.rs
Expand Up @@ -55,13 +55,13 @@ fn bench_multiply(arr_a: &ArrayRef, arr_b: &ArrayRef) {
fn bench_divide(arr_a: &ArrayRef, arr_b: &ArrayRef) {
let arr_a = arr_a.as_any().downcast_ref::<Float32Array>().unwrap();
let arr_b = arr_b.as_any().downcast_ref::<Float32Array>().unwrap();
criterion::black_box(divide(arr_a, arr_b).unwrap());
criterion::black_box(divide_checked(arr_a, arr_b).unwrap());
}

fn bench_divide_unchecked(arr_a: &ArrayRef, arr_b: &ArrayRef) {
let arr_a = arr_a.as_any().downcast_ref::<Float32Array>().unwrap();
let arr_b = arr_b.as_any().downcast_ref::<Float32Array>().unwrap();
criterion::black_box(divide_unchecked(arr_a, arr_b).unwrap());
criterion::black_box(divide(arr_a, arr_b).unwrap());
}

fn bench_divide_scalar(array: &ArrayRef, divisor: f32) {
Expand Down
246 changes: 228 additions & 18 deletions arrow/src/compute/kernels/arithmetic.rs
Expand Up @@ -35,8 +35,9 @@ use crate::compute::unary_dyn;
use crate::compute::util::combine_option_bitmap;
use crate::datatypes;
use crate::datatypes::{
ArrowNumericType, DataType, Date32Type, Date64Type, IntervalDayTimeType,
IntervalMonthDayNanoType, IntervalUnit, IntervalYearMonthType,
ArrowNativeTypeOp, ArrowNumericType, ArrowPrimitiveType, DataType, Date32Type,
Date64Type, IntervalDayTimeType, IntervalMonthDayNanoType, IntervalUnit,
IntervalYearMonthType,
};
use crate::datatypes::{
Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, UInt16Type,
Expand Down Expand Up @@ -103,6 +104,99 @@ where
Ok(PrimitiveArray::<LT>::from(data))
}

/// This is similar to `math_op` as it performs given operation between two input primitive arrays.
/// But the given operation can return `None` if overflow is detected. For the case, this function
/// returns an `Err`.
fn math_checked_op<LT, RT, F>(
left: &PrimitiveArray<LT>,
right: &PrimitiveArray<RT>,
op: F,
) -> Result<PrimitiveArray<LT>>
where
LT: ArrowNumericType,
RT: ArrowNumericType,
F: Fn(LT::Native, RT::Native) -> Option<LT::Native>,
{
if left.len() != right.len() {
return Err(ArrowError::ComputeError(
"Cannot perform math operation on arrays of different length".to_string(),
));
}

let left_iter = ArrayIter::new(left);
let right_iter = ArrayIter::new(right);

let values: Result<Vec<Option<<LT as ArrowPrimitiveType>::Native>>> = left_iter
.into_iter()
.zip(right_iter.into_iter())
.map(|(l, r)| {
if let (Some(l), Some(r)) = (l, r) {
let result = op(l, r);
if let Some(r) = result {
Ok(Some(r))
} else {
// Overflow
Err(ArrowError::ComputeError("Overflow happened".to_string()))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps we could print the problematic values

}
} else {
Ok(None)
}
})
.collect();

let values = values?;

Ok(PrimitiveArray::<LT>::from_iter(values))
}

/// This is similar to `math_checked_op` but just for divide op.
fn math_checked_divide<LT, RT, F>(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the difference between this function and math_checked_divide_op and why do we need both of them?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Finally I hope we can just have one. Currently math_checked_divide_op is used by divide_dyn and I want to limit the range of change to primitive kernels only.

left: &PrimitiveArray<LT>,
right: &PrimitiveArray<RT>,
op: F,
) -> Result<PrimitiveArray<LT>>
where
LT: ArrowNumericType,
RT: ArrowNumericType,
RT::Native: One + Zero,
F: Fn(LT::Native, RT::Native) -> Option<LT::Native>,
{
if left.len() != right.len() {
return Err(ArrowError::ComputeError(
"Cannot perform math operation on arrays of different length".to_string(),
));
}

let left_iter = ArrayIter::new(left);
let right_iter = ArrayIter::new(right);

let values: Result<Vec<Option<<LT as ArrowPrimitiveType>::Native>>> = left_iter
.into_iter()
.zip(right_iter.into_iter())
.map(|(l, r)| {
if let (Some(l), Some(r)) = (l, r) {
let result = op(l, r);
if let Some(r) = result {
Ok(Some(r))
} else {
if r.is_zero() {
Err(ArrowError::ComputeError("DivideByZero".to_string()))
} else {
// Overflow
Err(ArrowError::ComputeError("Overflow happened".to_string()))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here

}
}
} else {
Ok(None)
}
})
.collect();

let values = values?;

Ok(PrimitiveArray::<LT>::from_iter(values))
}

/// Helper function for operations where a valid `0` on the right array should
/// result in an [ArrowError::DivideByZero], namely the division and modulo operations
///
Expand Down Expand Up @@ -760,15 +854,34 @@ where

/// Perform `left + right` operation on two arrays. If either left or right value is null
/// then the result is also null.
///
/// This doesn't detect overflow. Once overflowing, the result will wrap around.
/// For an overflow-checking variant, use `add_checked` instead.
pub fn add<T>(
left: &PrimitiveArray<T>,
right: &PrimitiveArray<T>,
) -> Result<PrimitiveArray<T>>
where
T: ArrowNumericType,
T::Native: Add<Output = T::Native>,
T::Native: ArrowNativeTypeOp,
{
math_op(left, right, |a, b| a.wrapping_add_if_applied(b))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👌

}

/// Perform `left + right` operation on two arrays. If either left or right value is null
/// then the result is also null. Once
///
/// This detects overflow and returns an `Err` for that. For an non-overflow-checking variant,
/// use `add` instead.
pub fn add_checked<T>(
left: &PrimitiveArray<T>,
right: &PrimitiveArray<T>,
) -> Result<PrimitiveArray<T>>
where
T: ArrowNumericType,
T::Native: ArrowNativeTypeOp,
{
math_op(left, right, |a, b| a + b)
math_checked_op(left, right, |a, b| a.checked_add_if_applied(b))
}

/// Perform `left + right` operation on two arrays. If either left or right value is null
Expand Down Expand Up @@ -856,15 +969,34 @@ where

/// Perform `left - right` operation on two arrays. If either left or right value is null
/// then the result is also null.
///
/// This doesn't detect overflow. Once overflowing, the result will wrap around.
/// For an overflow-checking variant, use `subtract_checked` instead.
pub fn subtract<T>(
left: &PrimitiveArray<T>,
right: &PrimitiveArray<T>,
) -> Result<PrimitiveArray<T>>
where
T: datatypes::ArrowNumericType,
T::Native: Sub<Output = T::Native>,
T::Native: ArrowNativeTypeOp,
{
math_op(left, right, |a, b| a.wrapping_sub_if_applied(b))
}

/// Perform `left - right` operation on two arrays. If either left or right value is null
/// then the result is also null.
///
/// This detects overflow and returns an `Err` for that. For an non-overflow-checking variant,
/// use `subtract` instead.
pub fn subtract_checked<T>(
left: &PrimitiveArray<T>,
right: &PrimitiveArray<T>,
) -> Result<PrimitiveArray<T>>
where
T: datatypes::ArrowNumericType,
T::Native: ArrowNativeTypeOp,
{
math_op(left, right, |a, b| a - b)
math_checked_op(left, right, |a, b| a.checked_sub_if_applied(b))
}

/// Perform `left - right` operation on two arrays. If either left or right value is null
Expand Down Expand Up @@ -933,15 +1065,34 @@ where

/// Perform `left * right` operation on two arrays. If either left or right value is null
/// then the result is also null.
///
/// This doesn't detect overflow. Once overflowing, the result will wrap around.
/// For an overflow-checking variant, use `multiply_check` instead.
pub fn multiply<T>(
left: &PrimitiveArray<T>,
right: &PrimitiveArray<T>,
) -> Result<PrimitiveArray<T>>
where
T: datatypes::ArrowNumericType,
T::Native: Mul<Output = T::Native>,
T::Native: ArrowNativeTypeOp,
{
math_op(left, right, |a, b| a.wrapping_mul_if_applied(b))
}

/// Perform `left * right` operation on two arrays. If either left or right value is null
/// then the result is also null.
///
/// This detects overflow and returns an `Err` for that. For an non-overflow-checking variant,
/// use `multiply` instead.
pub fn multiply_check<T>(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
pub fn multiply_check<T>(
pub fn multiply_checked<T>(

left: &PrimitiveArray<T>,
right: &PrimitiveArray<T>,
) -> Result<PrimitiveArray<T>>
where
T: datatypes::ArrowNumericType,
T::Native: ArrowNativeTypeOp,
{
math_op(left, right, |a, b| a * b)
math_checked_op(left, right, |a, b| a.checked_mul_if_applied(b))
}

/// Perform `left * right` operation on two arrays. If either left or right value is null
Expand Down Expand Up @@ -1013,18 +1164,21 @@ where
/// Perform `left / right` operation on two arrays. If either left or right value is null
/// then the result is also null. If any right hand value is zero then the result of this
/// operation will be `Err(ArrowError::DivideByZero)`.
pub fn divide<T>(
///
/// When `simd` feature is not enabled. This detects overflow and returns an `Err` for that.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens when SIMD is enabled?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got signal: 8, SIGFPE: erroneous arithmetic operation. This is original behavior.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting, that would imply rust division is always checked 🤔

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup - and LLVM cannot vectorize it correctly - https://rust.godbolt.org/z/T8eTGM8zn

/// For an non-overflow-checking variant, use `divide` instead.
pub fn divide_checked<T>(
left: &PrimitiveArray<T>,
right: &PrimitiveArray<T>,
) -> Result<PrimitiveArray<T>>
where
T: datatypes::ArrowNumericType,
T::Native: Div<Output = T::Native> + Zero + One,
T::Native: ArrowNativeTypeOp + Zero + One,
{
#[cfg(feature = "simd")]
return simd_checked_divide_op(&left, &right, simd_checked_divide::<T>, |a, b| a / b);
#[cfg(not(feature = "simd"))]
return math_checked_divide_op(left, right, |a, b| a / b);
return math_checked_divide(left, right, |a, b| a.checked_div_if_applied(b));
}

/// Perform `left / right` operation on two arrays. If either left or right value is null
Expand All @@ -1042,15 +1196,18 @@ pub fn divide_dyn(left: &dyn Array, right: &dyn Array) -> Result<ArrayRef> {
/// Perform `left / right` operation on two arrays without checking for division by zero.
/// The result of dividing by zero follows normal floating point rules.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should update the doc

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean "The result of dividing by zero follows normal floating point rules"? I think this is not changed? It will panic as usual.

Copy link
Contributor

@HaoYang670 HaoYang670 Sep 4, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean "The result of dividing by zero follows normal floating point rules"?

Yes. But why follows normal floating point rules here ? It seems like the function has supported other numeric types. (T: datatypes::ArrowNumericType)

I think this is not changed? It will panic as usual.

Nope, but float will never panic. Divide by zero in float type gives inf or nan. https://play.rust-lang.org/?version=stable&mode=debug&edition=2021&gist=972d301e807f9a6cfd2ba644b763b86c

Maybe it is better to doc the different behaviour between float and other types

for float, dividing by zero follows the normal floating point rules,
for other numeric types, dividing be zero will panic,
...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I see. Yea, let me update the doc.

/// If either left or right value is null then the result is also null. If any right hand value is zero then the result of this
pub fn divide_unchecked<T>(
///
/// This doesn't detect overflow. Once overflowing, the result will wrap around.
/// For an overflow-checking variant, use `divide_checked` instead.
pub fn divide<T>(
left: &PrimitiveArray<T>,
right: &PrimitiveArray<T>,
) -> Result<PrimitiveArray<T>>
where
T: datatypes::ArrowFloatNumericType,
T::Native: Div<Output = T::Native>,
T: datatypes::ArrowNumericType,
T::Native: ArrowNativeTypeOp,
{
math_op(left, right, |a, b| a / b)
math_op(left, right, |a, b| a.wrapping_div_if_applied(b))
}

/// Modulus every value in an array by a scalar. If any value in the array is null then the
Expand Down Expand Up @@ -1769,7 +1926,7 @@ mod tests {
fn test_primitive_array_divide_with_nulls() {
let a = Int32Array::from(vec![Some(15), None, Some(8), Some(1), Some(9), None]);
let b = Int32Array::from(vec![Some(5), Some(6), Some(8), Some(9), None, None]);
let c = divide(&a, &b).unwrap();
let c = divide_checked(&a, &b).unwrap();
assert_eq!(3, c.value(0));
assert!(c.is_null(1));
assert_eq!(1, c.value(2));
Expand Down Expand Up @@ -1854,7 +2011,7 @@ mod tests {
let b = b.slice(8, 6);
let b = b.as_any().downcast_ref::<Int32Array>().unwrap();

let c = divide(a, b).unwrap();
let c = divide_checked(a, b).unwrap();
assert_eq!(6, c.len());
assert_eq!(3, c.value(0));
assert!(c.is_null(1));
Expand Down Expand Up @@ -1922,7 +2079,7 @@ mod tests {
fn test_primitive_array_divide_by_zero() {
let a = Int32Array::from(vec![15]);
let b = Int32Array::from(vec![0]);
divide(&a, &b).unwrap();
divide_checked(&a, &b).unwrap();
}

#[test]
Expand Down Expand Up @@ -2019,4 +2176,57 @@ mod tests {
let expected = Float64Array::from(vec![Some(1.0), None, Some(9.0)]);
assert_eq!(expected, actual);
}

#[test]
fn test_primitive_add_wrapping_overflow() {
let a = Int32Array::from(vec![i32::MAX, i32::MIN]);
let b = Int32Array::from(vec![1, 1]);

let wrapped = add(&a, &b);
let expected = Int32Array::from(vec![-2147483648, -2147483647]);
assert_eq!(expected, wrapped.unwrap());

let overflow = add_checked(&a, &b);
overflow.expect_err("overflow should be detected");
}

#[test]
fn test_primitive_subtract_wrapping_overflow() {
let a = Int32Array::from(vec![-2]);
let b = Int32Array::from(vec![i32::MAX]);

let wrapped = subtract(&a, &b);
let expected = Int32Array::from(vec![i32::MAX]);
assert_eq!(expected, wrapped.unwrap());

let overflow = subtract_checked(&a, &b);
overflow.expect_err("overflow should be detected");
}

#[test]
fn test_primitive_mul_wrapping_overflow() {
let a = Int32Array::from(vec![10]);
let b = Int32Array::from(vec![i32::MAX]);

let wrapped = multiply(&a, &b);
let expected = Int32Array::from(vec![-10]);
assert_eq!(expected, wrapped.unwrap());

let overflow = multiply_check(&a, &b);
overflow.expect_err("overflow should be detected");
}

#[test]
#[cfg(not(feature = "simd"))]
fn test_primitive_div_wrapping_overflow() {
let a = Int32Array::from(vec![i32::MIN]);
let b = Int32Array::from(vec![-1]);

let wrapped = divide(&a, &b);
let expected = Int32Array::from(vec![-2147483648]);
assert_eq!(expected, wrapped.unwrap());

let overflow = divide_checked(&a, &b);
overflow.expect_err("overflow should be detected");
}
}