Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add overflow-checking variants of arithmetic scalar dyn kernels #2713

Merged
merged 4 commits into from
Sep 14, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
199 changes: 178 additions & 21 deletions arrow/src/compute/kernels/arithmetic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
//! `RUSTFLAGS="-C target-feature=+avx2"` for example. See the documentation
//! [here](https://doc.rust-lang.org/stable/core/arch/) for more information.

use std::ops::{Add, Div, Mul, Neg, Rem, Sub};
use std::ops::{Div, Neg, Rem};

use num::{One, Zero};

Expand All @@ -32,7 +32,9 @@ use crate::buffer::Buffer;
use crate::buffer::MutableBuffer;
use crate::compute::kernels::arity::unary;
use crate::compute::util::combine_option_bitmap;
use crate::compute::{binary, binary_opt, try_binary, try_unary, unary_dyn};
use crate::compute::{
binary, binary_opt, try_binary, try_unary, try_unary_dyn, unary_dyn,
};
use crate::datatypes::{
native_op::ArrowNativeTypeOp, ArrowNumericType, DataType, Date32Type, Date64Type,
IntervalDayTimeType, IntervalMonthDayNanoType, IntervalUnit, IntervalYearMonthType,
Expand Down Expand Up @@ -834,12 +836,39 @@ where
/// Add every value in an array by a scalar. If any value in the array is null then the
/// result is also null. The given array must be a `PrimitiveArray` of the type same as
/// the scalar, or a `DictionaryArray` of the value type same as the scalar.
///
/// This doesn't detect overflow. Once overflowing, the result will wrap around.
/// For an overflow-checking variant, use `add_scalar_checked_dyn` instead.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

perhaps explain a bit when it will return Err

///
/// This returns an `Err` when the input array is not supported for adding operation.
pub fn add_scalar_dyn<T>(array: &dyn Array, scalar: T::Native) -> Result<ArrayRef>
where
T: ArrowNumericType,
T::Native: Add<Output = T::Native>,
T::Native: ArrowNativeTypeOp,
{
unary_dyn::<_, T>(array, |value| value.add_wrapping(scalar))
}

/// Add every value in an array by a scalar. If any value in the array is null then the
/// result is also null. The given array must be a `PrimitiveArray` of the type same as
/// the scalar, or a `DictionaryArray` of the value type same as the scalar.
///
/// This detects overflow and returns an `Err` for that. For an non-overflow-checking variant,
/// use `add_scalar_dyn` instead.
///
/// As this kernel has the branching costs and also prevents LLVM from vectorising it correctly,
/// it is usually much slower than non-checking variant.
pub fn add_scalar_checked_dyn<T>(array: &dyn Array, scalar: T::Native) -> Result<ArrayRef>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

curious: do we have benchmark to track how much slower add_scalar_checked_dyn is comparing to add_scalar_dyn?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it is anything like the non-scalar kernels, it is about 10x slower. Aside from the branching costs, it prevents LLVM from vectorising it correctly

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. I wonder if we should point that out in the doc of this method, in case it's not obvious to the users.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea, it will be much slower. As by default (ansi-mode disabled) in our case, non-checked kernels will be used. So most of time users will use faster one, except they have special need to use checked kernels.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea, I'm going to add a few lines mentioning that.

where
T: ArrowNumericType,
T::Native: ArrowNativeTypeOp,
{
unary_dyn::<_, T>(array, |value| value + scalar)
try_unary_dyn::<_, T>(array, |value| {
value.add_checked(scalar).ok_or_else(|| {
ArrowError::CastError(format!("Overflow: adding {:?} to {:?}", scalar, value))
})
})
.map(|a| Arc::new(a) as ArrayRef)
}

/// Perform `left - right` operation on two arrays. If either left or right value is null
Expand Down Expand Up @@ -937,16 +966,40 @@ where
/// Subtract every value in an array by a scalar. If any value in the array is null then the
/// result is also null. The given array must be a `PrimitiveArray` of the type same as
/// the scalar, or a `DictionaryArray` of the value type same as the scalar.
///
/// This doesn't detect overflow. Once overflowing, the result will wrap around.
/// For an overflow-checking variant, use `subtract_scalar_checked_dyn` instead.
pub fn subtract_scalar_dyn<T>(array: &dyn Array, scalar: T::Native) -> Result<ArrayRef>
where
T: datatypes::ArrowNumericType,
T::Native: Add<Output = T::Native>
+ Sub<Output = T::Native>
+ Mul<Output = T::Native>
+ Div<Output = T::Native>
+ Zero,
T: ArrowNumericType,
T::Native: ArrowNativeTypeOp,
{
unary_dyn::<_, T>(array, |value| value.sub_wrapping(scalar))
}

/// Subtract every value in an array by a scalar. If any value in the array is null then the
/// result is also null. The given array must be a `PrimitiveArray` of the type same as
/// the scalar, or a `DictionaryArray` of the value type same as the scalar.
///
/// This detects overflow and returns an `Err` for that. For an non-overflow-checking variant,
/// use `subtract_scalar_dyn` instead.
pub fn subtract_scalar_checked_dyn<T>(
array: &dyn Array,
scalar: T::Native,
) -> Result<ArrayRef>
where
T: ArrowNumericType,
T::Native: ArrowNativeTypeOp,
{
unary_dyn::<_, T>(array, |value| value - scalar)
try_unary_dyn::<_, T>(array, |value| {
value.sub_checked(scalar).ok_or_else(|| {
ArrowError::CastError(format!(
"Overflow: subtracting {:?} from {:?}",
scalar, value
))
})
})
.map(|a| Arc::new(a) as ArrayRef)
}

/// Perform `-` operation on an array. If value is null then the result is also null.
Expand Down Expand Up @@ -1065,18 +1118,40 @@ where
/// Multiply every value in an array by a scalar. If any value in the array is null then the
/// result is also null. The given array must be a `PrimitiveArray` of the type same as
/// the scalar, or a `DictionaryArray` of the value type same as the scalar.
///
/// This doesn't detect overflow. Once overflowing, the result will wrap around.
/// For an overflow-checking variant, use `multiply_scalar_checked_dyn` instead.
pub fn multiply_scalar_dyn<T>(array: &dyn Array, scalar: T::Native) -> Result<ArrayRef>
where
T: ArrowNumericType,
T::Native: Add<Output = T::Native>
+ Sub<Output = T::Native>
+ Mul<Output = T::Native>
+ Div<Output = T::Native>
+ Rem<Output = T::Native>
+ Zero
+ One,
T::Native: ArrowNativeTypeOp,
{
unary_dyn::<_, T>(array, |value| value.mul_wrapping(scalar))
}

/// Subtract every value in an array by a scalar. If any value in the array is null then the
/// result is also null. The given array must be a `PrimitiveArray` of the type same as
/// the scalar, or a `DictionaryArray` of the value type same as the scalar.
///
/// This detects overflow and returns an `Err` for that. For an non-overflow-checking variant,
/// use `multiply_scalar_dyn` instead.
pub fn multiply_scalar_checked_dyn<T>(
array: &dyn Array,
scalar: T::Native,
) -> Result<ArrayRef>
where
T: ArrowNumericType,
T::Native: ArrowNativeTypeOp,
{
unary_dyn::<_, T>(array, |value| value * scalar)
try_unary_dyn::<_, T>(array, |value| {
value.mul_checked(scalar).ok_or_else(|| {
ArrowError::CastError(format!(
"Overflow: multiplying {:?} by {:?}",
value, scalar
))
})
})
.map(|a| Arc::new(a) as ArrayRef)
}

/// Perform `left % right` operation on two arrays. If either left or right value is null
Expand Down Expand Up @@ -1223,15 +1298,48 @@ where
/// result is also null. If the scalar is zero then the result of this operation will be
/// `Err(ArrowError::DivideByZero)`. The given array must be a `PrimitiveArray` of the type
/// same as the scalar, or a `DictionaryArray` of the value type same as the scalar.
///
/// This doesn't detect overflow. Once overflowing, the result will wrap around.
/// For an overflow-checking variant, use `divide_scalar_checked_dyn` instead.
pub fn divide_scalar_dyn<T>(array: &dyn Array, divisor: T::Native) -> Result<ArrayRef>
where
T: ArrowNumericType,
T::Native: Div<Output = T::Native> + Zero,
T::Native: ArrowNativeTypeOp + Zero,
{
if divisor.is_zero() {
return Err(ArrowError::DivideByZero);
}
unary_dyn::<_, T>(array, |value| value.div_wrapping(divisor))
}

/// Divide every value in an array by a scalar. If any value in the array is null then the
/// result is also null. If the scalar is zero then the result of this operation will be
/// `Err(ArrowError::DivideByZero)`. The given array must be a `PrimitiveArray` of the type
/// same as the scalar, or a `DictionaryArray` of the value type same as the scalar.
///
/// This detects overflow and returns an `Err` for that. For an non-overflow-checking variant,
/// use `divide_scalar_dyn` instead.
pub fn divide_scalar_checked_dyn<T>(
array: &dyn Array,
divisor: T::Native,
) -> Result<ArrayRef>
where
T: ArrowNumericType,
T::Native: ArrowNativeTypeOp + Zero,
{
if divisor.is_zero() {
return Err(ArrowError::DivideByZero);
}
unary_dyn::<_, T>(array, |value| value / divisor)

try_unary_dyn::<_, T>(array, |value| {
value.div_checked(divisor).ok_or_else(|| {
ArrowError::CastError(format!(
"Overflow: dividing {:?} by {:?}",
value, divisor
))
})
})
.map(|a| Arc::new(a) as ArrayRef)
}

#[cfg(test)]
Expand Down Expand Up @@ -2222,6 +2330,55 @@ mod tests {
overflow.expect_err("overflow should be detected");
}

#[test]
fn test_primitive_add_scalar_dyn_wrapping_overflow() {
let a = Int32Array::from(vec![i32::MAX, i32::MIN]);

let wrapped = add_scalar_dyn::<Int32Type>(&a, 1).unwrap();
let expected =
Arc::new(Int32Array::from(vec![-2147483648, -2147483647])) as ArrayRef;
assert_eq!(&expected, &wrapped);

let overflow = add_scalar_checked_dyn::<Int32Type>(&a, 1);
overflow.expect_err("overflow should be detected");
}

#[test]
fn test_primitive_subtract_scalar_dyn_wrapping_overflow() {
let a = Int32Array::from(vec![-2]);

let wrapped = subtract_scalar_dyn::<Int32Type>(&a, i32::MAX).unwrap();
let expected = Arc::new(Int32Array::from(vec![i32::MAX])) as ArrayRef;
assert_eq!(&expected, &wrapped);

let overflow = subtract_scalar_checked_dyn::<Int32Type>(&a, i32::MAX);
overflow.expect_err("overflow should be detected");
}

#[test]
fn test_primitive_mul_scalar_dyn_wrapping_overflow() {
let a = Int32Array::from(vec![10]);

let wrapped = multiply_scalar_dyn::<Int32Type>(&a, i32::MAX).unwrap();
let expected = Arc::new(Int32Array::from(vec![-10])) as ArrayRef;
assert_eq!(&expected, &wrapped);

let overflow = multiply_scalar_checked_dyn::<Int32Type>(&a, i32::MAX);
overflow.expect_err("overflow should be detected");
}

#[test]
fn test_primitive_div_scalar_dyn_wrapping_overflow() {
let a = Int32Array::from(vec![i32::MIN]);

let wrapped = divide_scalar_dyn::<Int32Type>(&a, -1).unwrap();
let expected = Arc::new(Int32Array::from(vec![-2147483648])) as ArrayRef;
assert_eq!(&expected, &wrapped);

let overflow = divide_scalar_checked_dyn::<Int32Type>(&a, -1);
overflow.expect_err("overflow should be detected");
}

#[test]
fn test_primitive_div_opt_overflow_division_by_zero() {
let a = Int32Array::from(vec![i32::MIN]);
Expand Down
50 changes: 48 additions & 2 deletions arrow/src/compute/kernels/arity.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ where
Ok(unsafe { build_primitive_array(len, buffer.finish(), null_count, null_buffer) })
}

/// A helper function that applies an unary function to a dictionary array with primitive value type.
/// A helper function that applies an infallible unary function to a dictionary array with primitive value type.
fn unary_dict<K, F, T>(array: &DictionaryArray<K>, op: F) -> Result<ArrayRef>
where
K: ArrowNumericType,
Expand All @@ -138,7 +138,22 @@ where
Ok(Arc::new(new_dict))
}

/// Applies an unary function to an array with primitive values.
/// A helper function that applies a fallible unary function to a dictionary array with primitive value type.
fn try_unary_dict<K, F, T>(array: &DictionaryArray<K>, op: F) -> Result<ArrayRef>
where
K: ArrowNumericType,
T: ArrowPrimitiveType,
F: Fn(T::Native) -> Result<T::Native>,
{
let dict_values = array.values().as_any().downcast_ref().unwrap();
let values = try_unary::<T, F, T>(dict_values, op)?.into_data();
let data = array.data().clone().into_builder().child_data(vec![values]);

let new_dict: DictionaryArray<K> = unsafe { data.build_unchecked() }.into();
Ok(Arc::new(new_dict))
}

/// Applies an infallible unary function to an array with primitive values.
pub fn unary_dyn<F, T>(array: &dyn Array, op: F) -> Result<ArrayRef>
where
T: ArrowPrimitiveType,
Expand All @@ -162,6 +177,37 @@ where
}
}

/// Applies a fallible unary function to an array with primitive values.
pub fn try_unary_dyn<F, T>(array: &dyn Array, op: F) -> Result<ArrayRef>
where
T: ArrowPrimitiveType,
F: Fn(T::Native) -> Result<T::Native>,
{
downcast_dictionary_array! {
array => if array.values().data_type() == &T::DATA_TYPE {
try_unary_dict::<_, F, T>(array, op)
} else {
Err(ArrowError::NotYetImplemented(format!(
"Cannot perform unary operation on dictionary array of type {}",
array.data_type()
)))
},
t => {
if t == &T::DATA_TYPE {
Ok(Arc::new(try_unary::<T, F, T>(
array.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap(),
op,
)?))
} else {
Err(ArrowError::NotYetImplemented(format!(
"Cannot perform unary operation on array of type {}",
t
)))
}
}
}
}

/// Given two arrays of length `len`, calls `op(a[i], b[i])` for `i` in `0..len`, collecting
/// the results in a [`PrimitiveArray`]. If any index is null in either `a` or `b`, the
/// corresponding index in the result will also be null
Expand Down