-
Notifications
You must be signed in to change notification settings - Fork 657
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add overflow-checking variants of arithmetic scalar dyn kernels #2713
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -22,7 +22,7 @@ | |
//! `RUSTFLAGS="-C target-feature=+avx2"` for example. See the documentation | ||
//! [here](https://doc.rust-lang.org/stable/core/arch/) for more information. | ||
|
||
use std::ops::{Add, Div, Mul, Neg, Rem, Sub}; | ||
use std::ops::{Div, Neg, Rem}; | ||
|
||
use num::{One, Zero}; | ||
|
||
|
@@ -32,7 +32,9 @@ use crate::buffer::Buffer; | |
use crate::buffer::MutableBuffer; | ||
use crate::compute::kernels::arity::unary; | ||
use crate::compute::util::combine_option_bitmap; | ||
use crate::compute::{binary, binary_opt, try_binary, try_unary, unary_dyn}; | ||
use crate::compute::{ | ||
binary, binary_opt, try_binary, try_unary, try_unary_dyn, unary_dyn, | ||
}; | ||
use crate::datatypes::{ | ||
native_op::ArrowNativeTypeOp, ArrowNumericType, DataType, Date32Type, Date64Type, | ||
IntervalDayTimeType, IntervalMonthDayNanoType, IntervalUnit, IntervalYearMonthType, | ||
|
@@ -834,12 +836,39 @@ where | |
/// Add every value in an array by a scalar. If any value in the array is null then the | ||
/// result is also null. The given array must be a `PrimitiveArray` of the type same as | ||
/// the scalar, or a `DictionaryArray` of the value type same as the scalar. | ||
/// | ||
/// This doesn't detect overflow. Once overflowing, the result will wrap around. | ||
/// For an overflow-checking variant, use `add_scalar_checked_dyn` instead. | ||
/// | ||
/// This returns an `Err` when the input array is not supported for adding operation. | ||
pub fn add_scalar_dyn<T>(array: &dyn Array, scalar: T::Native) -> Result<ArrayRef> | ||
where | ||
T: ArrowNumericType, | ||
T::Native: Add<Output = T::Native>, | ||
T::Native: ArrowNativeTypeOp, | ||
{ | ||
unary_dyn::<_, T>(array, |value| value.add_wrapping(scalar)) | ||
} | ||
|
||
/// Add every value in an array by a scalar. If any value in the array is null then the | ||
/// result is also null. The given array must be a `PrimitiveArray` of the type same as | ||
/// the scalar, or a `DictionaryArray` of the value type same as the scalar. | ||
/// | ||
/// This detects overflow and returns an `Err` for that. For an non-overflow-checking variant, | ||
/// use `add_scalar_dyn` instead. | ||
/// | ||
/// As this kernel has the branching costs and also prevents LLVM from vectorising it correctly, | ||
/// it is usually much slower than non-checking variant. | ||
pub fn add_scalar_checked_dyn<T>(array: &dyn Array, scalar: T::Native) -> Result<ArrayRef> | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. curious: do we have benchmark to track how much slower There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If it is anything like the non-scalar kernels, it is about 10x slower. Aside from the branching costs, it prevents LLVM from vectorising it correctly There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see. I wonder if we should point that out in the doc of this method, in case it's not obvious to the users. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yea, it will be much slower. As by default (ansi-mode disabled) in our case, non-checked kernels will be used. So most of time users will use faster one, except they have special need to use checked kernels. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yea, I'm going to add a few lines mentioning that. |
||
where | ||
T: ArrowNumericType, | ||
T::Native: ArrowNativeTypeOp, | ||
{ | ||
unary_dyn::<_, T>(array, |value| value + scalar) | ||
try_unary_dyn::<_, T>(array, |value| { | ||
value.add_checked(scalar).ok_or_else(|| { | ||
ArrowError::CastError(format!("Overflow: adding {:?} to {:?}", scalar, value)) | ||
}) | ||
}) | ||
.map(|a| Arc::new(a) as ArrayRef) | ||
} | ||
|
||
/// Perform `left - right` operation on two arrays. If either left or right value is null | ||
|
@@ -937,16 +966,40 @@ where | |
/// Subtract every value in an array by a scalar. If any value in the array is null then the | ||
/// result is also null. The given array must be a `PrimitiveArray` of the type same as | ||
/// the scalar, or a `DictionaryArray` of the value type same as the scalar. | ||
/// | ||
/// This doesn't detect overflow. Once overflowing, the result will wrap around. | ||
/// For an overflow-checking variant, use `subtract_scalar_checked_dyn` instead. | ||
pub fn subtract_scalar_dyn<T>(array: &dyn Array, scalar: T::Native) -> Result<ArrayRef> | ||
where | ||
T: datatypes::ArrowNumericType, | ||
T::Native: Add<Output = T::Native> | ||
+ Sub<Output = T::Native> | ||
+ Mul<Output = T::Native> | ||
+ Div<Output = T::Native> | ||
+ Zero, | ||
T: ArrowNumericType, | ||
T::Native: ArrowNativeTypeOp, | ||
{ | ||
unary_dyn::<_, T>(array, |value| value.sub_wrapping(scalar)) | ||
} | ||
|
||
/// Subtract every value in an array by a scalar. If any value in the array is null then the | ||
/// result is also null. The given array must be a `PrimitiveArray` of the type same as | ||
/// the scalar, or a `DictionaryArray` of the value type same as the scalar. | ||
/// | ||
/// This detects overflow and returns an `Err` for that. For an non-overflow-checking variant, | ||
/// use `subtract_scalar_dyn` instead. | ||
pub fn subtract_scalar_checked_dyn<T>( | ||
array: &dyn Array, | ||
scalar: T::Native, | ||
) -> Result<ArrayRef> | ||
where | ||
T: ArrowNumericType, | ||
T::Native: ArrowNativeTypeOp, | ||
{ | ||
unary_dyn::<_, T>(array, |value| value - scalar) | ||
try_unary_dyn::<_, T>(array, |value| { | ||
value.sub_checked(scalar).ok_or_else(|| { | ||
ArrowError::CastError(format!( | ||
"Overflow: subtracting {:?} from {:?}", | ||
scalar, value | ||
)) | ||
}) | ||
}) | ||
.map(|a| Arc::new(a) as ArrayRef) | ||
} | ||
|
||
/// Perform `-` operation on an array. If value is null then the result is also null. | ||
|
@@ -1065,18 +1118,40 @@ where | |
/// Multiply every value in an array by a scalar. If any value in the array is null then the | ||
/// result is also null. The given array must be a `PrimitiveArray` of the type same as | ||
/// the scalar, or a `DictionaryArray` of the value type same as the scalar. | ||
/// | ||
/// This doesn't detect overflow. Once overflowing, the result will wrap around. | ||
/// For an overflow-checking variant, use `multiply_scalar_checked_dyn` instead. | ||
pub fn multiply_scalar_dyn<T>(array: &dyn Array, scalar: T::Native) -> Result<ArrayRef> | ||
where | ||
T: ArrowNumericType, | ||
T::Native: Add<Output = T::Native> | ||
+ Sub<Output = T::Native> | ||
+ Mul<Output = T::Native> | ||
+ Div<Output = T::Native> | ||
+ Rem<Output = T::Native> | ||
+ Zero | ||
+ One, | ||
T::Native: ArrowNativeTypeOp, | ||
{ | ||
unary_dyn::<_, T>(array, |value| value.mul_wrapping(scalar)) | ||
} | ||
|
||
/// Subtract every value in an array by a scalar. If any value in the array is null then the | ||
/// result is also null. The given array must be a `PrimitiveArray` of the type same as | ||
/// the scalar, or a `DictionaryArray` of the value type same as the scalar. | ||
/// | ||
/// This detects overflow and returns an `Err` for that. For an non-overflow-checking variant, | ||
/// use `multiply_scalar_dyn` instead. | ||
pub fn multiply_scalar_checked_dyn<T>( | ||
array: &dyn Array, | ||
scalar: T::Native, | ||
) -> Result<ArrayRef> | ||
where | ||
T: ArrowNumericType, | ||
T::Native: ArrowNativeTypeOp, | ||
{ | ||
unary_dyn::<_, T>(array, |value| value * scalar) | ||
try_unary_dyn::<_, T>(array, |value| { | ||
value.mul_checked(scalar).ok_or_else(|| { | ||
ArrowError::CastError(format!( | ||
"Overflow: multiplying {:?} by {:?}", | ||
value, scalar | ||
)) | ||
}) | ||
}) | ||
.map(|a| Arc::new(a) as ArrayRef) | ||
} | ||
|
||
/// Perform `left % right` operation on two arrays. If either left or right value is null | ||
|
@@ -1223,15 +1298,48 @@ where | |
/// result is also null. If the scalar is zero then the result of this operation will be | ||
/// `Err(ArrowError::DivideByZero)`. The given array must be a `PrimitiveArray` of the type | ||
/// same as the scalar, or a `DictionaryArray` of the value type same as the scalar. | ||
/// | ||
/// This doesn't detect overflow. Once overflowing, the result will wrap around. | ||
/// For an overflow-checking variant, use `divide_scalar_checked_dyn` instead. | ||
pub fn divide_scalar_dyn<T>(array: &dyn Array, divisor: T::Native) -> Result<ArrayRef> | ||
where | ||
T: ArrowNumericType, | ||
T::Native: Div<Output = T::Native> + Zero, | ||
T::Native: ArrowNativeTypeOp + Zero, | ||
{ | ||
if divisor.is_zero() { | ||
return Err(ArrowError::DivideByZero); | ||
} | ||
unary_dyn::<_, T>(array, |value| value.div_wrapping(divisor)) | ||
} | ||
|
||
/// Divide every value in an array by a scalar. If any value in the array is null then the | ||
/// result is also null. If the scalar is zero then the result of this operation will be | ||
/// `Err(ArrowError::DivideByZero)`. The given array must be a `PrimitiveArray` of the type | ||
/// same as the scalar, or a `DictionaryArray` of the value type same as the scalar. | ||
/// | ||
/// This detects overflow and returns an `Err` for that. For an non-overflow-checking variant, | ||
/// use `divide_scalar_dyn` instead. | ||
pub fn divide_scalar_checked_dyn<T>( | ||
array: &dyn Array, | ||
divisor: T::Native, | ||
) -> Result<ArrayRef> | ||
where | ||
T: ArrowNumericType, | ||
T::Native: ArrowNativeTypeOp + Zero, | ||
{ | ||
if divisor.is_zero() { | ||
return Err(ArrowError::DivideByZero); | ||
} | ||
unary_dyn::<_, T>(array, |value| value / divisor) | ||
|
||
try_unary_dyn::<_, T>(array, |value| { | ||
value.div_checked(divisor).ok_or_else(|| { | ||
ArrowError::CastError(format!( | ||
"Overflow: dividing {:?} by {:?}", | ||
value, divisor | ||
)) | ||
}) | ||
}) | ||
.map(|a| Arc::new(a) as ArrayRef) | ||
} | ||
|
||
#[cfg(test)] | ||
|
@@ -2222,6 +2330,55 @@ mod tests { | |
overflow.expect_err("overflow should be detected"); | ||
} | ||
|
||
#[test] | ||
fn test_primitive_add_scalar_dyn_wrapping_overflow() { | ||
let a = Int32Array::from(vec![i32::MAX, i32::MIN]); | ||
|
||
let wrapped = add_scalar_dyn::<Int32Type>(&a, 1).unwrap(); | ||
let expected = | ||
Arc::new(Int32Array::from(vec![-2147483648, -2147483647])) as ArrayRef; | ||
assert_eq!(&expected, &wrapped); | ||
|
||
let overflow = add_scalar_checked_dyn::<Int32Type>(&a, 1); | ||
overflow.expect_err("overflow should be detected"); | ||
} | ||
|
||
#[test] | ||
fn test_primitive_subtract_scalar_dyn_wrapping_overflow() { | ||
let a = Int32Array::from(vec![-2]); | ||
|
||
let wrapped = subtract_scalar_dyn::<Int32Type>(&a, i32::MAX).unwrap(); | ||
let expected = Arc::new(Int32Array::from(vec![i32::MAX])) as ArrayRef; | ||
assert_eq!(&expected, &wrapped); | ||
|
||
let overflow = subtract_scalar_checked_dyn::<Int32Type>(&a, i32::MAX); | ||
overflow.expect_err("overflow should be detected"); | ||
} | ||
|
||
#[test] | ||
fn test_primitive_mul_scalar_dyn_wrapping_overflow() { | ||
let a = Int32Array::from(vec![10]); | ||
|
||
let wrapped = multiply_scalar_dyn::<Int32Type>(&a, i32::MAX).unwrap(); | ||
let expected = Arc::new(Int32Array::from(vec![-10])) as ArrayRef; | ||
assert_eq!(&expected, &wrapped); | ||
|
||
let overflow = multiply_scalar_checked_dyn::<Int32Type>(&a, i32::MAX); | ||
overflow.expect_err("overflow should be detected"); | ||
} | ||
|
||
#[test] | ||
fn test_primitive_div_scalar_dyn_wrapping_overflow() { | ||
let a = Int32Array::from(vec![i32::MIN]); | ||
|
||
let wrapped = divide_scalar_dyn::<Int32Type>(&a, -1).unwrap(); | ||
let expected = Arc::new(Int32Array::from(vec![-2147483648])) as ArrayRef; | ||
assert_eq!(&expected, &wrapped); | ||
|
||
let overflow = divide_scalar_checked_dyn::<Int32Type>(&a, -1); | ||
overflow.expect_err("overflow should be detected"); | ||
} | ||
|
||
#[test] | ||
fn test_primitive_div_opt_overflow_division_by_zero() { | ||
let a = Int32Array::from(vec![i32::MIN]); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
perhaps explain a bit when it will return
Err