Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add overflow-checking variants of arithmetic scalar dyn kernels #2713

Merged
merged 4 commits into from
Sep 14, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
192 changes: 171 additions & 21 deletions arrow/src/compute/kernels/arithmetic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
//! `RUSTFLAGS="-C target-feature=+avx2"` for example. See the documentation
//! [here](https://doc.rust-lang.org/stable/core/arch/) for more information.

use std::ops::{Add, Div, Mul, Neg, Rem, Sub};
use std::ops::{Div, Neg, Rem};

use num::{One, Zero};

Expand All @@ -32,7 +32,7 @@ use crate::buffer::Buffer;
use crate::buffer::MutableBuffer;
use crate::compute::kernels::arity::unary;
use crate::compute::util::combine_option_bitmap;
use crate::compute::{binary, try_binary, try_unary, unary_dyn};
use crate::compute::{binary, try_binary, try_unary, try_unary_dyn, unary_dyn};
use crate::datatypes::{
native_op::ArrowNativeTypeOp, ArrowNumericType, DataType, Date32Type, Date64Type,
IntervalDayTimeType, IntervalMonthDayNanoType, IntervalUnit, IntervalYearMonthType,
Expand Down Expand Up @@ -834,12 +834,34 @@ where
/// Add every value in an array by a scalar. If any value in the array is null then the
/// result is also null. The given array must be a `PrimitiveArray` of the type same as
/// the scalar, or a `DictionaryArray` of the value type same as the scalar.
///
/// This doesn't detect overflow. Once overflowing, the result will wrap around.
/// For an overflow-checking variant, use `add_scalar_checked_dyn` instead.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

perhaps explain a bit when it will return Err

pub fn add_scalar_dyn<T>(array: &dyn Array, scalar: T::Native) -> Result<ArrayRef>
where
T: ArrowNumericType,
T::Native: Add<Output = T::Native>,
T::Native: ArrowNativeTypeOp,
{
unary_dyn::<_, T>(array, |value| value + scalar)
unary_dyn::<_, T>(array, |value| value.add_wrapping(scalar))
}

/// Add every value in an array by a scalar. If any value in the array is null then the
/// result is also null. The given array must be a `PrimitiveArray` of the type same as
/// the scalar, or a `DictionaryArray` of the value type same as the scalar.
///
/// This detects overflow and returns an `Err` for that. For an non-overflow-checking variant,
/// use `add_scalar_dyn` instead.
pub fn add_scalar_checked_dyn<T>(array: &dyn Array, scalar: T::Native) -> Result<ArrayRef>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

curious: do we have benchmark to track how much slower add_scalar_checked_dyn is comparing to add_scalar_dyn?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it is anything like the non-scalar kernels, it is about 10x slower. Aside from the branching costs, it prevents LLVM from vectorising it correctly

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. I wonder if we should point that out in the doc of this method, in case it's not obvious to the users.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea, it will be much slower. As by default (ansi-mode disabled) in our case, non-checked kernels will be used. So most of time users will use faster one, except they have special need to use checked kernels.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea, I'm going to add a few lines mentioning that.

where
T: ArrowNumericType,
T::Native: ArrowNativeTypeOp,
{
try_unary_dyn::<_, T>(array, |value| {
value.add_checked(scalar).ok_or_else(|| {
ArrowError::CastError(format!("Overflow: adding {:?} to {:?}", scalar, value))
})
})
.map(|a| Arc::new(a) as ArrayRef)
}

/// Perform `left - right` operation on two arrays. If either left or right value is null
Expand Down Expand Up @@ -937,16 +959,40 @@ where
/// Subtract every value in an array by a scalar. If any value in the array is null then the
/// result is also null. The given array must be a `PrimitiveArray` of the type same as
/// the scalar, or a `DictionaryArray` of the value type same as the scalar.
///
/// This doesn't detect overflow. Once overflowing, the result will wrap around.
/// For an overflow-checking variant, use `subtract_scalar_checked_dyn` instead.
pub fn subtract_scalar_dyn<T>(array: &dyn Array, scalar: T::Native) -> Result<ArrayRef>
where
T: datatypes::ArrowNumericType,
T::Native: Add<Output = T::Native>
+ Sub<Output = T::Native>
+ Mul<Output = T::Native>
+ Div<Output = T::Native>
+ Zero,
T: ArrowNumericType,
T::Native: ArrowNativeTypeOp,
{
unary_dyn::<_, T>(array, |value| value - scalar)
unary_dyn::<_, T>(array, |value| value.sub_wrapping(scalar))
}

/// Subtract every value in an array by a scalar. If any value in the array is null then the
/// result is also null. The given array must be a `PrimitiveArray` of the type same as
/// the scalar, or a `DictionaryArray` of the value type same as the scalar.
///
/// This detects overflow and returns an `Err` for that. For an non-overflow-checking variant,
/// use `subtract_scalar_dyn` instead.
pub fn subtract_scalar_checked_dyn<T>(
array: &dyn Array,
scalar: T::Native,
) -> Result<ArrayRef>
where
T: ArrowNumericType,
T::Native: ArrowNativeTypeOp,
{
try_unary_dyn::<_, T>(array, |value| {
value.sub_checked(scalar).ok_or_else(|| {
ArrowError::CastError(format!(
"Overflow: subtracting {:?} from {:?}",
scalar, value
))
})
})
.map(|a| Arc::new(a) as ArrayRef)
}

/// Perform `-` operation on an array. If value is null then the result is also null.
Expand Down Expand Up @@ -1065,18 +1111,40 @@ where
/// Multiply every value in an array by a scalar. If any value in the array is null then the
/// result is also null. The given array must be a `PrimitiveArray` of the type same as
/// the scalar, or a `DictionaryArray` of the value type same as the scalar.
///
/// This doesn't detect overflow. Once overflowing, the result will wrap around.
/// For an overflow-checking variant, use `multiply_scalar_checked_dyn` instead.
pub fn multiply_scalar_dyn<T>(array: &dyn Array, scalar: T::Native) -> Result<ArrayRef>
where
T: ArrowNumericType,
T::Native: Add<Output = T::Native>
+ Sub<Output = T::Native>
+ Mul<Output = T::Native>
+ Div<Output = T::Native>
+ Rem<Output = T::Native>
+ Zero
+ One,
T::Native: ArrowNativeTypeOp,
{
unary_dyn::<_, T>(array, |value| value * scalar)
unary_dyn::<_, T>(array, |value| value.mul_wrapping(scalar))
}

/// Subtract every value in an array by a scalar. If any value in the array is null then the
/// result is also null. The given array must be a `PrimitiveArray` of the type same as
/// the scalar, or a `DictionaryArray` of the value type same as the scalar.
///
/// This detects overflow and returns an `Err` for that. For an non-overflow-checking variant,
/// use `multiply_scalar_dyn` instead.
pub fn multiply_scalar_checked_dyn<T>(
array: &dyn Array,
scalar: T::Native,
) -> Result<ArrayRef>
where
T: ArrowNumericType,
T::Native: ArrowNativeTypeOp,
{
try_unary_dyn::<_, T>(array, |value| {
value.mul_checked(scalar).ok_or_else(|| {
ArrowError::CastError(format!(
"Overflow: multiplying {:?} by {:?}",
value, scalar
))
})
})
.map(|a| Arc::new(a) as ArrayRef)
}

/// Perform `left % right` operation on two arrays. If either left or right value is null
Expand Down Expand Up @@ -1197,15 +1265,48 @@ where
/// result is also null. If the scalar is zero then the result of this operation will be
/// `Err(ArrowError::DivideByZero)`. The given array must be a `PrimitiveArray` of the type
/// same as the scalar, or a `DictionaryArray` of the value type same as the scalar.
///
/// This doesn't detect overflow. Once overflowing, the result will wrap around.
/// For an overflow-checking variant, use `divide_scalar_checked_dyn` instead.
pub fn divide_scalar_dyn<T>(array: &dyn Array, divisor: T::Native) -> Result<ArrayRef>
where
T: ArrowNumericType,
T::Native: Div<Output = T::Native> + Zero,
T::Native: ArrowNativeTypeOp + Zero,
{
if divisor.is_zero() {
return Err(ArrowError::DivideByZero);
}
unary_dyn::<_, T>(array, |value| value / divisor)
unary_dyn::<_, T>(array, |value| value.div_wrapping(divisor))
}

/// Divide every value in an array by a scalar. If any value in the array is null then the
/// result is also null. If the scalar is zero then the result of this operation will be
/// `Err(ArrowError::DivideByZero)`. The given array must be a `PrimitiveArray` of the type
/// same as the scalar, or a `DictionaryArray` of the value type same as the scalar.
///
/// This detects overflow and returns an `Err` for that. For an non-overflow-checking variant,
/// use `divide_scalar_dyn` instead.
pub fn divide_scalar_checked_dyn<T>(
array: &dyn Array,
divisor: T::Native,
) -> Result<ArrayRef>
where
T: ArrowNumericType,
T::Native: ArrowNativeTypeOp + Zero,
{
if divisor.is_zero() {
return Err(ArrowError::DivideByZero);
}

try_unary_dyn::<_, T>(array, |value| {
value.div_checked(divisor).ok_or_else(|| {
ArrowError::CastError(format!(
"Overflow: dividing {:?} by {:?}",
value, divisor
))
})
})
.map(|a| Arc::new(a) as ArrayRef)
}

#[cfg(test)]
Expand Down Expand Up @@ -2195,4 +2296,53 @@ mod tests {
let overflow = multiply_scalar_checked(&a, i32::MAX);
overflow.expect_err("overflow should be detected");
}

#[test]
fn test_primitive_add_scalar_dyn_wrapping_overflow() {
let a = Int32Array::from(vec![i32::MAX, i32::MIN]);

let wrapped = add_scalar_dyn::<Int32Type>(&a, 1).unwrap();
let expected =
Arc::new(Int32Array::from(vec![-2147483648, -2147483647])) as ArrayRef;
assert_eq!(&expected, &wrapped);

let overflow = add_scalar_checked_dyn::<Int32Type>(&a, 1);
overflow.expect_err("overflow should be detected");
}

#[test]
fn test_primitive_subtract_scalar_dyn_wrapping_overflow() {
let a = Int32Array::from(vec![-2]);

let wrapped = subtract_scalar_dyn::<Int32Type>(&a, i32::MAX).unwrap();
let expected = Arc::new(Int32Array::from(vec![i32::MAX])) as ArrayRef;
assert_eq!(&expected, &wrapped);

let overflow = subtract_scalar_checked_dyn::<Int32Type>(&a, i32::MAX);
overflow.expect_err("overflow should be detected");
}

#[test]
fn test_primitive_mul_scalar_dyn_wrapping_overflow() {
let a = Int32Array::from(vec![10]);

let wrapped = multiply_scalar_dyn::<Int32Type>(&a, i32::MAX).unwrap();
let expected = Arc::new(Int32Array::from(vec![-10])) as ArrayRef;
assert_eq!(&expected, &wrapped);

let overflow = multiply_scalar_checked_dyn::<Int32Type>(&a, i32::MAX);
overflow.expect_err("overflow should be detected");
}

#[test]
fn test_primitive_div_scalar_dyn_wrapping_overflow() {
let a = Int32Array::from(vec![i32::MIN]);

let wrapped = divide_scalar_dyn::<Int32Type>(&a, -1).unwrap();
let expected = Arc::new(Int32Array::from(vec![-2147483648])) as ArrayRef;
assert_eq!(&expected, &wrapped);

let overflow = divide_scalar_checked_dyn::<Int32Type>(&a, -1);
overflow.expect_err("overflow should be detected");
}
}
43 changes: 41 additions & 2 deletions arrow/src/compute/kernels/arity.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ where
Ok(unsafe { build_primitive_array(len, buffer.finish(), null_count, null_buffer) })
}

/// A helper function that applies an unary function to a dictionary array with primitive value type.
/// A helper function that applies an infallible unary function to a dictionary array with primitive value type.
fn unary_dict<K, F, T>(array: &DictionaryArray<K>, op: F) -> Result<ArrayRef>
where
K: ArrowNumericType,
Expand All @@ -138,7 +138,22 @@ where
Ok(Arc::new(new_dict))
}

/// Applies an unary function to an array with primitive values.
/// A helper function that applies a fallible unary function to a dictionary array with primitive value type.
fn try_unary_dict<K, F, T>(array: &DictionaryArray<K>, op: F) -> Result<ArrayRef>
where
K: ArrowNumericType,
T: ArrowPrimitiveType,
F: Fn(T::Native) -> Result<T::Native>,
{
let dict_values = array.values().as_any().downcast_ref().unwrap();
let values = try_unary::<T, F, T>(dict_values, op)?.into_data();
let data = array.data().clone().into_builder().child_data(vec![values]);

let new_dict: DictionaryArray<K> = unsafe { data.build_unchecked() }.into();
Ok(Arc::new(new_dict))
}

/// Applies an infallible unary function to an array with primitive values.
pub fn unary_dyn<F, T>(array: &dyn Array, op: F) -> Result<ArrayRef>
where
T: ArrowPrimitiveType,
Expand All @@ -162,6 +177,30 @@ where
}
}

/// Applies a fallible unary function to an array with primitive values.
pub fn try_unary_dyn<F, T>(array: &dyn Array, op: F) -> Result<ArrayRef>
where
T: ArrowPrimitiveType,
F: Fn(T::Native) -> Result<T::Native>,
{
downcast_dictionary_array! {
array => try_unary_dict::<_, F, T>(array, op),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm how do we know the dictionary value type matches T?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed, there is no type-bound for the dictionary value type. Just do a simple test. At runtime downcast_ref will fail in unary_dict. I will address it in other PR.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Normally as the op is provided by users, I suppose that users know dictionary value is same type as the scalar. But it is good to return a meaningful Err instead of runtime panic. I will do it in a follow-up.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Follow up sounds fine to me. Perhaps we can just check the type here:

    downcast_dictionary_array! {
        array => if array.values().data_type() == &T::DATA_TYPE {
            try_unary_dict::<_, F, T>(array, op)
        } else {
            // throw error
        },
        t => {

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, right, actually I thought to handle it at try_unary_dict. But this fix looks okay as try_unary_dict is currently used here and not public. I may fix at try_unary_dict at another followup.

t => {
if t == &T::DATA_TYPE {
Ok(Arc::new(try_unary::<T, F, T>(
array.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap(),
op,
)?))
} else {
Err(ArrowError::NotYetImplemented(format!(
"Cannot perform unary operation on array of type {}",
t
)))
}
}
}
}

/// Given two arrays of length `len`, calls `op(a[i], b[i])` for `i` in `0..len`, collecting
/// the results in a [`PrimitiveArray`]. If any index is null in either `a` or `b`, the
/// corresponding index in the result will also be null
Expand Down