Skip to content

Commit

Permalink
Overflow-checking variant of arithmetic scalar kernels (#2650)
Browse files Browse the repository at this point in the history
* Overflow-checking variant of arithmetic scalar kernels

* Remove division scalar change for now.
  • Loading branch information
viirya committed Sep 12, 2022
1 parent e646ae8 commit e1f8ed8
Showing 1 changed file with 118 additions and 17 deletions.
135 changes: 118 additions & 17 deletions arrow/src/compute/kernels/arithmetic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ use crate::buffer::Buffer;
use crate::buffer::MutableBuffer;
use crate::compute::kernels::arity::unary;
use crate::compute::util::combine_option_bitmap;
use crate::compute::{binary, try_binary, unary_dyn};
use crate::compute::{binary, try_binary, try_unary, unary_dyn};
use crate::datatypes::{
native_op::ArrowNativeTypeOp, ArrowNumericType, DataType, Date32Type, Date64Type,
IntervalDayTimeType, IntervalMonthDayNanoType, IntervalUnit, IntervalYearMonthType,
Expand Down Expand Up @@ -797,15 +797,38 @@ pub fn add_dyn(left: &dyn Array, right: &dyn Array) -> Result<ArrayRef> {

/// Add every value in an array by a scalar. If any value in the array is null then the
/// result is also null.
///
/// This doesn't detect overflow. Once overflowing, the result will wrap around.
/// For an overflow-checking variant, use `add_scalar_checked` instead.
pub fn add_scalar<T>(
array: &PrimitiveArray<T>,
scalar: T::Native,
) -> Result<PrimitiveArray<T>>
where
T: ArrowNumericType,
T::Native: Add<Output = T::Native>,
T::Native: ArrowNativeTypeOp,
{
Ok(unary(array, |value| value + scalar))
Ok(unary(array, |value| value.add_wrapping(scalar)))
}

/// Add every value in an array by a scalar. If any value in the array is null then the
/// result is also null.
///
/// This detects overflow and returns an `Err` for that. For an non-overflow-checking variant,
/// use `add_scalar` instead.
pub fn add_scalar_checked<T>(
array: &PrimitiveArray<T>,
scalar: T::Native,
) -> Result<PrimitiveArray<T>>
where
T: ArrowNumericType,
T::Native: ArrowNativeTypeOp,
{
try_unary(array, |value| {
value.add_checked(scalar).ok_or_else(|| {
ArrowError::CastError(format!("Overflow: adding {:?} to {:?}", scalar, value))
})
})
}

/// Add every value in an array by a scalar. If any value in the array is null then the
Expand Down Expand Up @@ -874,19 +897,41 @@ pub fn subtract_dyn(left: &dyn Array, right: &dyn Array) -> Result<ArrayRef> {

/// Subtract every value in an array by a scalar. If any value in the array is null then the
/// result is also null.
///
/// This doesn't detect overflow. Once overflowing, the result will wrap around.
/// For an overflow-checking variant, use `subtract_scalar_checked` instead.
pub fn subtract_scalar<T>(
array: &PrimitiveArray<T>,
scalar: T::Native,
) -> Result<PrimitiveArray<T>>
where
T: datatypes::ArrowNumericType,
T::Native: Add<Output = T::Native>
+ Sub<Output = T::Native>
+ Mul<Output = T::Native>
+ Div<Output = T::Native>
+ Zero,
T::Native: ArrowNativeTypeOp + Zero,
{
Ok(unary(array, |value| value - scalar))
Ok(unary(array, |value| value.sub_wrapping(scalar)))
}

/// Subtract every value in an array by a scalar. If any value in the array is null then the
/// result is also null.
///
/// This detects overflow and returns an `Err` for that. For an non-overflow-checking variant,
/// use `subtract_scalar` instead.
pub fn subtract_scalar_checked<T>(
array: &PrimitiveArray<T>,
scalar: T::Native,
) -> Result<PrimitiveArray<T>>
where
T: ArrowNumericType,
T::Native: ArrowNativeTypeOp + Zero,
{
try_unary(array, |value| {
value.sub_checked(scalar).ok_or_else(|| {
ArrowError::CastError(format!(
"Overflow: subtracting {:?} from {:?}",
scalar, value
))
})
})
}

/// Subtract every value in an array by a scalar. If any value in the array is null then the
Expand Down Expand Up @@ -980,21 +1025,41 @@ pub fn multiply_dyn(left: &dyn Array, right: &dyn Array) -> Result<ArrayRef> {

/// Multiply every value in an array by a scalar. If any value in the array is null then the
/// result is also null.
///
/// This doesn't detect overflow. Once overflowing, the result will wrap around.
/// For an overflow-checking variant, use `multiply_scalar_checked` instead.
pub fn multiply_scalar<T>(
array: &PrimitiveArray<T>,
scalar: T::Native,
) -> Result<PrimitiveArray<T>>
where
T: datatypes::ArrowNumericType,
T::Native: Add<Output = T::Native>
+ Sub<Output = T::Native>
+ Mul<Output = T::Native>
+ Div<Output = T::Native>
+ Rem<Output = T::Native>
+ Zero
+ One,
T::Native: ArrowNativeTypeOp + Zero + One,
{
Ok(unary(array, |value| value * scalar))
Ok(unary(array, |value| value.mul_wrapping(scalar)))
}

/// Multiply every value in an array by a scalar. If any value in the array is null then the
/// result is also null.
///
/// This detects overflow and returns an `Err` for that. For an non-overflow-checking variant,
/// use `multiply_scalar` instead.
pub fn multiply_scalar_checked<T>(
array: &PrimitiveArray<T>,
scalar: T::Native,
) -> Result<PrimitiveArray<T>>
where
T: ArrowNumericType,
T::Native: ArrowNativeTypeOp + Zero + One,
{
try_unary(array, |value| {
value.mul_checked(scalar).ok_or_else(|| {
ArrowError::CastError(format!(
"Overflow: multiplying {:?} by {:?}",
value, scalar,
))
})
})
}

/// Multiply every value in an array by a scalar. If any value in the array is null then the
Expand Down Expand Up @@ -2094,4 +2159,40 @@ mod tests {
let overflow = divide_checked(&a, &b);
overflow.expect_err("overflow should be detected");
}

#[test]
fn test_primitive_add_scalar_wrapping_overflow() {
let a = Int32Array::from(vec![i32::MAX, i32::MIN]);

let wrapped = add_scalar(&a, 1);
let expected = Int32Array::from(vec![-2147483648, -2147483647]);
assert_eq!(expected, wrapped.unwrap());

let overflow = add_scalar_checked(&a, 1);
overflow.expect_err("overflow should be detected");
}

#[test]
fn test_primitive_subtract_scalar_wrapping_overflow() {
let a = Int32Array::from(vec![-2]);

let wrapped = subtract_scalar(&a, i32::MAX);
let expected = Int32Array::from(vec![i32::MAX]);
assert_eq!(expected, wrapped.unwrap());

let overflow = subtract_scalar_checked(&a, i32::MAX);
overflow.expect_err("overflow should be detected");
}

#[test]
fn test_primitive_mul_scalar_wrapping_overflow() {
let a = Int32Array::from(vec![10]);

let wrapped = multiply_scalar(&a, i32::MAX);
let expected = Int32Array::from(vec![-10]);
assert_eq!(expected, wrapped.unwrap());

let overflow = multiply_scalar_checked(&a, i32::MAX);
overflow.expect_err("overflow should be detected");
}
}

0 comments on commit e1f8ed8

Please sign in to comment.