Skip to content

Commit

Permalink
Add overflow-checking variants of arithmetic scalar dyn kernels
Browse files Browse the repository at this point in the history
  • Loading branch information
viirya committed Sep 12, 2022
1 parent be33fb3 commit 54d44d2
Show file tree
Hide file tree
Showing 2 changed files with 212 additions and 23 deletions.
192 changes: 171 additions & 21 deletions arrow/src/compute/kernels/arithmetic.rs
Expand Up @@ -22,7 +22,7 @@
//! `RUSTFLAGS="-C target-feature=+avx2"` for example. See the documentation
//! [here](https://doc.rust-lang.org/stable/core/arch/) for more information.

use std::ops::{Add, Div, Mul, Neg, Rem, Sub};
use std::ops::{Div, Neg, Rem};

use num::{One, Zero};

Expand All @@ -32,7 +32,7 @@ use crate::buffer::Buffer;
use crate::buffer::MutableBuffer;
use crate::compute::kernels::arity::unary;
use crate::compute::util::combine_option_bitmap;
use crate::compute::{binary, try_binary, try_unary, unary_dyn};
use crate::compute::{binary, try_binary, try_unary, try_unary_dyn, unary_dyn};
use crate::datatypes::{
native_op::ArrowNativeTypeOp, ArrowNumericType, DataType, Date32Type, Date64Type,
IntervalDayTimeType, IntervalMonthDayNanoType, IntervalUnit, IntervalYearMonthType,
Expand Down Expand Up @@ -834,12 +834,34 @@ where
/// Add every value in an array by a scalar. If any value in the array is null then the
/// result is also null. The given array must be a `PrimitiveArray` of the type same as
/// the scalar, or a `DictionaryArray` of the value type same as the scalar.
///
/// This doesn't detect overflow. Once overflowing, the result will wrap around.
/// For an overflow-checking variant, use `add_scalar_checked_dyn` instead.
pub fn add_scalar_dyn<T>(array: &dyn Array, scalar: T::Native) -> Result<ArrayRef>
where
T: ArrowNumericType,
T::Native: Add<Output = T::Native>,
T::Native: ArrowNativeTypeOp,
{
unary_dyn::<_, T>(array, |value| value + scalar)
unary_dyn::<_, T>(array, |value| value.add_wrapping(scalar))
}

/// Add every value in an array by a scalar. If any value in the array is null then the
/// result is also null. The given array must be a `PrimitiveArray` of the type same as
/// the scalar, or a `DictionaryArray` of the value type same as the scalar.
///
/// This detects overflow and returns an `Err` for that. For an non-overflow-checking variant,
/// use `add_scalar_dyn` instead.
pub fn add_scalar_checked_dyn<T>(array: &dyn Array, scalar: T::Native) -> Result<ArrayRef>
where
T: ArrowNumericType,
T::Native: ArrowNativeTypeOp,
{
try_unary_dyn::<_, T>(array, |value| {
value.add_checked(scalar).ok_or_else(|| {
ArrowError::CastError(format!("Overflow: adding {:?} to {:?}", scalar, value))
})
})
.map(|a| Arc::new(a) as ArrayRef)
}

/// Perform `left - right` operation on two arrays. If either left or right value is null
Expand Down Expand Up @@ -937,16 +959,40 @@ where
/// Subtract every value in an array by a scalar. If any value in the array is null then the
/// result is also null. The given array must be a `PrimitiveArray` of the type same as
/// the scalar, or a `DictionaryArray` of the value type same as the scalar.
///
/// This doesn't detect overflow. Once overflowing, the result will wrap around.
/// For an overflow-checking variant, use `subtract_scalar_checked_dyn` instead.
pub fn subtract_scalar_dyn<T>(array: &dyn Array, scalar: T::Native) -> Result<ArrayRef>
where
T: datatypes::ArrowNumericType,
T::Native: Add<Output = T::Native>
+ Sub<Output = T::Native>
+ Mul<Output = T::Native>
+ Div<Output = T::Native>
+ Zero,
T: ArrowNumericType,
T::Native: ArrowNativeTypeOp,
{
unary_dyn::<_, T>(array, |value| value - scalar)
unary_dyn::<_, T>(array, |value| value.sub_wrapping(scalar))
}

/// Subtract every value in an array by a scalar. If any value in the array is null then the
/// result is also null. The given array must be a `PrimitiveArray` of the type same as
/// the scalar, or a `DictionaryArray` of the value type same as the scalar.
///
/// This detects overflow and returns an `Err` for that. For an non-overflow-checking variant,
/// use `subtract_scalar_dyn` instead.
pub fn subtract_scalar_checked_dyn<T>(
array: &dyn Array,
scalar: T::Native,
) -> Result<ArrayRef>
where
T: ArrowNumericType,
T::Native: ArrowNativeTypeOp,
{
try_unary_dyn::<_, T>(array, |value| {
value.sub_checked(scalar).ok_or_else(|| {
ArrowError::CastError(format!(
"Overflow: subtracting {:?} from {:?}",
scalar, value
))
})
})
.map(|a| Arc::new(a) as ArrayRef)
}

/// Perform `-` operation on an array. If value is null then the result is also null.
Expand Down Expand Up @@ -1065,18 +1111,40 @@ where
/// Multiply every value in an array by a scalar. If any value in the array is null then the
/// result is also null. The given array must be a `PrimitiveArray` of the type same as
/// the scalar, or a `DictionaryArray` of the value type same as the scalar.
///
/// This doesn't detect overflow. Once overflowing, the result will wrap around.
/// For an overflow-checking variant, use `multiply_scalar_checked_dyn` instead.
pub fn multiply_scalar_dyn<T>(array: &dyn Array, scalar: T::Native) -> Result<ArrayRef>
where
T: ArrowNumericType,
T::Native: Add<Output = T::Native>
+ Sub<Output = T::Native>
+ Mul<Output = T::Native>
+ Div<Output = T::Native>
+ Rem<Output = T::Native>
+ Zero
+ One,
T::Native: ArrowNativeTypeOp,
{
unary_dyn::<_, T>(array, |value| value * scalar)
unary_dyn::<_, T>(array, |value| value.mul_wrapping(scalar))
}

/// Subtract every value in an array by a scalar. If any value in the array is null then the
/// result is also null. The given array must be a `PrimitiveArray` of the type same as
/// the scalar, or a `DictionaryArray` of the value type same as the scalar.
///
/// This detects overflow and returns an `Err` for that. For an non-overflow-checking variant,
/// use `multiply_scalar_dyn` instead.
pub fn multiply_scalar_checked_dyn<T>(
array: &dyn Array,
scalar: T::Native,
) -> Result<ArrayRef>
where
T: ArrowNumericType,
T::Native: ArrowNativeTypeOp,
{
try_unary_dyn::<_, T>(array, |value| {
value.mul_checked(scalar).ok_or_else(|| {
ArrowError::CastError(format!(
"Overflow: multiplying {:?} by {:?}",
value, scalar
))
})
})
.map(|a| Arc::new(a) as ArrayRef)
}

/// Perform `left % right` operation on two arrays. If either left or right value is null
Expand Down Expand Up @@ -1197,15 +1265,48 @@ where
/// result is also null. If the scalar is zero then the result of this operation will be
/// `Err(ArrowError::DivideByZero)`. The given array must be a `PrimitiveArray` of the type
/// same as the scalar, or a `DictionaryArray` of the value type same as the scalar.
///
/// This doesn't detect overflow. Once overflowing, the result will wrap around.
/// For an overflow-checking variant, use `divide_scalar_checked_dyn` instead.
pub fn divide_scalar_dyn<T>(array: &dyn Array, divisor: T::Native) -> Result<ArrayRef>
where
T: ArrowNumericType,
T::Native: Div<Output = T::Native> + Zero,
T::Native: ArrowNativeTypeOp + Zero,
{
if divisor.is_zero() {
return Err(ArrowError::DivideByZero);
}
unary_dyn::<_, T>(array, |value| value / divisor)
unary_dyn::<_, T>(array, |value| value.div_wrapping(divisor))
}

/// Divide every value in an array by a scalar. If any value in the array is null then the
/// result is also null. If the scalar is zero then the result of this operation will be
/// `Err(ArrowError::DivideByZero)`. The given array must be a `PrimitiveArray` of the type
/// same as the scalar, or a `DictionaryArray` of the value type same as the scalar.
///
/// This detects overflow and returns an `Err` for that. For an non-overflow-checking variant,
/// use `divide_scalar_dyn` instead.
pub fn divide_scalar_checked_dyn<T>(
array: &dyn Array,
divisor: T::Native,
) -> Result<ArrayRef>
where
T: ArrowNumericType,
T::Native: ArrowNativeTypeOp + Zero,
{
if divisor.is_zero() {
return Err(ArrowError::DivideByZero);
}

try_unary_dyn::<_, T>(array, |value| {
value.div_checked(divisor).ok_or_else(|| {
ArrowError::CastError(format!(
"Overflow: dividing {:?} by {:?}",
value, divisor
))
})
})
.map(|a| Arc::new(a) as ArrayRef)
}

#[cfg(test)]
Expand Down Expand Up @@ -2195,4 +2296,53 @@ mod tests {
let overflow = multiply_scalar_checked(&a, i32::MAX);
overflow.expect_err("overflow should be detected");
}

#[test]
fn test_primitive_add_scalar_dyn_wrapping_overflow() {
let a = Int32Array::from(vec![i32::MAX, i32::MIN]);

let wrapped = add_scalar_dyn::<Int32Type>(&a, 1).unwrap();
let expected =
Arc::new(Int32Array::from(vec![-2147483648, -2147483647])) as ArrayRef;
assert_eq!(&expected, &wrapped);

let overflow = add_scalar_checked_dyn::<Int32Type>(&a, 1);
overflow.expect_err("overflow should be detected");
}

#[test]
fn test_primitive_subtract_scalar_dyn_wrapping_overflow() {
let a = Int32Array::from(vec![-2]);

let wrapped = subtract_scalar_dyn::<Int32Type>(&a, i32::MAX).unwrap();
let expected = Arc::new(Int32Array::from(vec![i32::MAX])) as ArrayRef;
assert_eq!(&expected, &wrapped);

let overflow = subtract_scalar_checked_dyn::<Int32Type>(&a, i32::MAX);
overflow.expect_err("overflow should be detected");
}

#[test]
fn test_primitive_mul_scalar_dyn_wrapping_overflow() {
let a = Int32Array::from(vec![10]);

let wrapped = multiply_scalar_dyn::<Int32Type>(&a, i32::MAX).unwrap();
let expected = Arc::new(Int32Array::from(vec![-10])) as ArrayRef;
assert_eq!(&expected, &wrapped);

let overflow = multiply_scalar_checked_dyn::<Int32Type>(&a, i32::MAX);
overflow.expect_err("overflow should be detected");
}

#[test]
fn test_primitive_div_scalar_dyn_wrapping_overflow() {
let a = Int32Array::from(vec![i32::MIN]);

let wrapped = divide_scalar_dyn::<Int32Type>(&a, -1).unwrap();
let expected = Arc::new(Int32Array::from(vec![-2147483648])) as ArrayRef;
assert_eq!(&expected, &wrapped);

let overflow = divide_scalar_checked_dyn::<Int32Type>(&a, -1);
overflow.expect_err("overflow should be detected");
}
}
43 changes: 41 additions & 2 deletions arrow/src/compute/kernels/arity.rs
Expand Up @@ -123,7 +123,7 @@ where
Ok(unsafe { build_primitive_array(len, buffer.finish(), null_count, null_buffer) })
}

/// A helper function that applies an unary function to a dictionary array with primitive value type.
/// A helper function that applies an infallible unary function to a dictionary array with primitive value type.
fn unary_dict<K, F, T>(array: &DictionaryArray<K>, op: F) -> Result<ArrayRef>
where
K: ArrowNumericType,
Expand All @@ -138,7 +138,22 @@ where
Ok(Arc::new(new_dict))
}

/// Applies an unary function to an array with primitive values.
/// A helper function that applies a fallible unary function to a dictionary array with primitive value type.
fn try_unary_dict<K, F, T>(array: &DictionaryArray<K>, op: F) -> Result<ArrayRef>
where
K: ArrowNumericType,
T: ArrowPrimitiveType,
F: Fn(T::Native) -> Result<T::Native>,
{
let dict_values = array.values().as_any().downcast_ref().unwrap();
let values = try_unary::<T, F, T>(dict_values, op)?.into_data();
let data = array.data().clone().into_builder().child_data(vec![values]);

let new_dict: DictionaryArray<K> = unsafe { data.build_unchecked() }.into();
Ok(Arc::new(new_dict))
}

/// Applies an infallible unary function to an array with primitive values.
pub fn unary_dyn<F, T>(array: &dyn Array, op: F) -> Result<ArrayRef>
where
T: ArrowPrimitiveType,
Expand All @@ -162,6 +177,30 @@ where
}
}

/// Applies a fallible unary function to an array with primitive values.
pub fn try_unary_dyn<F, T>(array: &dyn Array, op: F) -> Result<ArrayRef>
where
T: ArrowPrimitiveType,
F: Fn(T::Native) -> Result<T::Native>,
{
downcast_dictionary_array! {
array => try_unary_dict::<_, F, T>(array, op),
t => {
if t == &T::DATA_TYPE {
Ok(Arc::new(try_unary::<T, F, T>(
array.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap(),
op,
)?))
} else {
Err(ArrowError::NotYetImplemented(format!(
"Cannot perform unary operation on array of type {}",
t
)))
}
}
}
}

/// Given two arrays of length `len`, calls `op(a[i], b[i])` for `i` in `0..len`, collecting
/// the results in a [`PrimitiveArray`]. If any index is null in either `a` or `b`, the
/// corresponding index in the result will also be null
Expand Down

0 comments on commit 54d44d2

Please sign in to comment.