-
Notifications
You must be signed in to change notification settings - Fork 658
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add overflow-checking variants of arithmetic scalar dyn kernels #2713
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -22,7 +22,7 @@ | |
//! `RUSTFLAGS="-C target-feature=+avx2"` for example. See the documentation | ||
//! [here](https://doc.rust-lang.org/stable/core/arch/) for more information. | ||
|
||
use std::ops::{Add, Div, Mul, Neg, Rem, Sub}; | ||
use std::ops::{Div, Neg, Rem}; | ||
|
||
use num::{One, Zero}; | ||
|
||
|
@@ -32,7 +32,7 @@ use crate::buffer::Buffer; | |
use crate::buffer::MutableBuffer; | ||
use crate::compute::kernels::arity::unary; | ||
use crate::compute::util::combine_option_bitmap; | ||
use crate::compute::{binary, try_binary, try_unary, unary_dyn}; | ||
use crate::compute::{binary, try_binary, try_unary, try_unary_dyn, unary_dyn}; | ||
use crate::datatypes::{ | ||
native_op::ArrowNativeTypeOp, ArrowNumericType, DataType, Date32Type, Date64Type, | ||
IntervalDayTimeType, IntervalMonthDayNanoType, IntervalUnit, IntervalYearMonthType, | ||
|
@@ -834,12 +834,34 @@ where | |
/// Add every value in an array by a scalar. If any value in the array is null then the | ||
/// result is also null. The given array must be a `PrimitiveArray` of the type same as | ||
/// the scalar, or a `DictionaryArray` of the value type same as the scalar. | ||
/// | ||
/// This doesn't detect overflow. Once overflowing, the result will wrap around. | ||
/// For an overflow-checking variant, use `add_scalar_checked_dyn` instead. | ||
pub fn add_scalar_dyn<T>(array: &dyn Array, scalar: T::Native) -> Result<ArrayRef> | ||
where | ||
T: ArrowNumericType, | ||
T::Native: Add<Output = T::Native>, | ||
T::Native: ArrowNativeTypeOp, | ||
{ | ||
unary_dyn::<_, T>(array, |value| value + scalar) | ||
unary_dyn::<_, T>(array, |value| value.add_wrapping(scalar)) | ||
} | ||
|
||
/// Add every value in an array by a scalar. If any value in the array is null then the | ||
/// result is also null. The given array must be a `PrimitiveArray` of the type same as | ||
/// the scalar, or a `DictionaryArray` of the value type same as the scalar. | ||
/// | ||
/// This detects overflow and returns an `Err` for that. For an non-overflow-checking variant, | ||
/// use `add_scalar_dyn` instead. | ||
pub fn add_scalar_checked_dyn<T>(array: &dyn Array, scalar: T::Native) -> Result<ArrayRef> | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. curious: do we have benchmark to track how much slower There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If it is anything like the non-scalar kernels, it is about 10x slower. Aside from the branching costs, it prevents LLVM from vectorising it correctly There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see. I wonder if we should point that out in the doc of this method, in case it's not obvious to the users. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yea, it will be much slower. As by default (ansi-mode disabled) in our case, non-checked kernels will be used. So most of time users will use faster one, except they have special need to use checked kernels. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yea, I'm going to add a few lines mentioning that. |
||
where | ||
T: ArrowNumericType, | ||
T::Native: ArrowNativeTypeOp, | ||
{ | ||
try_unary_dyn::<_, T>(array, |value| { | ||
value.add_checked(scalar).ok_or_else(|| { | ||
ArrowError::CastError(format!("Overflow: adding {:?} to {:?}", scalar, value)) | ||
}) | ||
}) | ||
.map(|a| Arc::new(a) as ArrayRef) | ||
} | ||
|
||
/// Perform `left - right` operation on two arrays. If either left or right value is null | ||
|
@@ -937,16 +959,40 @@ where | |
/// Subtract every value in an array by a scalar. If any value in the array is null then the | ||
/// result is also null. The given array must be a `PrimitiveArray` of the type same as | ||
/// the scalar, or a `DictionaryArray` of the value type same as the scalar. | ||
/// | ||
/// This doesn't detect overflow. Once overflowing, the result will wrap around. | ||
/// For an overflow-checking variant, use `subtract_scalar_checked_dyn` instead. | ||
pub fn subtract_scalar_dyn<T>(array: &dyn Array, scalar: T::Native) -> Result<ArrayRef> | ||
where | ||
T: datatypes::ArrowNumericType, | ||
T::Native: Add<Output = T::Native> | ||
+ Sub<Output = T::Native> | ||
+ Mul<Output = T::Native> | ||
+ Div<Output = T::Native> | ||
+ Zero, | ||
T: ArrowNumericType, | ||
T::Native: ArrowNativeTypeOp, | ||
{ | ||
unary_dyn::<_, T>(array, |value| value - scalar) | ||
unary_dyn::<_, T>(array, |value| value.sub_wrapping(scalar)) | ||
} | ||
|
||
/// Subtract every value in an array by a scalar. If any value in the array is null then the | ||
/// result is also null. The given array must be a `PrimitiveArray` of the type same as | ||
/// the scalar, or a `DictionaryArray` of the value type same as the scalar. | ||
/// | ||
/// This detects overflow and returns an `Err` for that. For an non-overflow-checking variant, | ||
/// use `subtract_scalar_dyn` instead. | ||
pub fn subtract_scalar_checked_dyn<T>( | ||
array: &dyn Array, | ||
scalar: T::Native, | ||
) -> Result<ArrayRef> | ||
where | ||
T: ArrowNumericType, | ||
T::Native: ArrowNativeTypeOp, | ||
{ | ||
try_unary_dyn::<_, T>(array, |value| { | ||
value.sub_checked(scalar).ok_or_else(|| { | ||
ArrowError::CastError(format!( | ||
"Overflow: subtracting {:?} from {:?}", | ||
scalar, value | ||
)) | ||
}) | ||
}) | ||
.map(|a| Arc::new(a) as ArrayRef) | ||
} | ||
|
||
/// Perform `-` operation on an array. If value is null then the result is also null. | ||
|
@@ -1065,18 +1111,40 @@ where | |
/// Multiply every value in an array by a scalar. If any value in the array is null then the | ||
/// result is also null. The given array must be a `PrimitiveArray` of the type same as | ||
/// the scalar, or a `DictionaryArray` of the value type same as the scalar. | ||
/// | ||
/// This doesn't detect overflow. Once overflowing, the result will wrap around. | ||
/// For an overflow-checking variant, use `multiply_scalar_checked_dyn` instead. | ||
pub fn multiply_scalar_dyn<T>(array: &dyn Array, scalar: T::Native) -> Result<ArrayRef> | ||
where | ||
T: ArrowNumericType, | ||
T::Native: Add<Output = T::Native> | ||
+ Sub<Output = T::Native> | ||
+ Mul<Output = T::Native> | ||
+ Div<Output = T::Native> | ||
+ Rem<Output = T::Native> | ||
+ Zero | ||
+ One, | ||
T::Native: ArrowNativeTypeOp, | ||
{ | ||
unary_dyn::<_, T>(array, |value| value * scalar) | ||
unary_dyn::<_, T>(array, |value| value.mul_wrapping(scalar)) | ||
} | ||
|
||
/// Subtract every value in an array by a scalar. If any value in the array is null then the | ||
/// result is also null. The given array must be a `PrimitiveArray` of the type same as | ||
/// the scalar, or a `DictionaryArray` of the value type same as the scalar. | ||
/// | ||
/// This detects overflow and returns an `Err` for that. For an non-overflow-checking variant, | ||
/// use `multiply_scalar_dyn` instead. | ||
pub fn multiply_scalar_checked_dyn<T>( | ||
array: &dyn Array, | ||
scalar: T::Native, | ||
) -> Result<ArrayRef> | ||
where | ||
T: ArrowNumericType, | ||
T::Native: ArrowNativeTypeOp, | ||
{ | ||
try_unary_dyn::<_, T>(array, |value| { | ||
value.mul_checked(scalar).ok_or_else(|| { | ||
ArrowError::CastError(format!( | ||
"Overflow: multiplying {:?} by {:?}", | ||
value, scalar | ||
)) | ||
}) | ||
}) | ||
.map(|a| Arc::new(a) as ArrayRef) | ||
} | ||
|
||
/// Perform `left % right` operation on two arrays. If either left or right value is null | ||
|
@@ -1197,15 +1265,48 @@ where | |
/// result is also null. If the scalar is zero then the result of this operation will be | ||
/// `Err(ArrowError::DivideByZero)`. The given array must be a `PrimitiveArray` of the type | ||
/// same as the scalar, or a `DictionaryArray` of the value type same as the scalar. | ||
/// | ||
/// This doesn't detect overflow. Once overflowing, the result will wrap around. | ||
/// For an overflow-checking variant, use `divide_scalar_checked_dyn` instead. | ||
pub fn divide_scalar_dyn<T>(array: &dyn Array, divisor: T::Native) -> Result<ArrayRef> | ||
where | ||
T: ArrowNumericType, | ||
T::Native: Div<Output = T::Native> + Zero, | ||
T::Native: ArrowNativeTypeOp + Zero, | ||
{ | ||
if divisor.is_zero() { | ||
return Err(ArrowError::DivideByZero); | ||
} | ||
unary_dyn::<_, T>(array, |value| value / divisor) | ||
unary_dyn::<_, T>(array, |value| value.div_wrapping(divisor)) | ||
} | ||
|
||
/// Divide every value in an array by a scalar. If any value in the array is null then the | ||
/// result is also null. If the scalar is zero then the result of this operation will be | ||
/// `Err(ArrowError::DivideByZero)`. The given array must be a `PrimitiveArray` of the type | ||
/// same as the scalar, or a `DictionaryArray` of the value type same as the scalar. | ||
/// | ||
/// This detects overflow and returns an `Err` for that. For an non-overflow-checking variant, | ||
/// use `divide_scalar_dyn` instead. | ||
pub fn divide_scalar_checked_dyn<T>( | ||
array: &dyn Array, | ||
divisor: T::Native, | ||
) -> Result<ArrayRef> | ||
where | ||
T: ArrowNumericType, | ||
T::Native: ArrowNativeTypeOp + Zero, | ||
{ | ||
if divisor.is_zero() { | ||
return Err(ArrowError::DivideByZero); | ||
} | ||
|
||
try_unary_dyn::<_, T>(array, |value| { | ||
value.div_checked(divisor).ok_or_else(|| { | ||
ArrowError::CastError(format!( | ||
"Overflow: dividing {:?} by {:?}", | ||
value, divisor | ||
)) | ||
}) | ||
}) | ||
.map(|a| Arc::new(a) as ArrayRef) | ||
} | ||
|
||
#[cfg(test)] | ||
|
@@ -2195,4 +2296,53 @@ mod tests { | |
let overflow = multiply_scalar_checked(&a, i32::MAX); | ||
overflow.expect_err("overflow should be detected"); | ||
} | ||
|
||
#[test] | ||
fn test_primitive_add_scalar_dyn_wrapping_overflow() { | ||
let a = Int32Array::from(vec![i32::MAX, i32::MIN]); | ||
|
||
let wrapped = add_scalar_dyn::<Int32Type>(&a, 1).unwrap(); | ||
let expected = | ||
Arc::new(Int32Array::from(vec![-2147483648, -2147483647])) as ArrayRef; | ||
assert_eq!(&expected, &wrapped); | ||
|
||
let overflow = add_scalar_checked_dyn::<Int32Type>(&a, 1); | ||
overflow.expect_err("overflow should be detected"); | ||
} | ||
|
||
#[test] | ||
fn test_primitive_subtract_scalar_dyn_wrapping_overflow() { | ||
let a = Int32Array::from(vec![-2]); | ||
|
||
let wrapped = subtract_scalar_dyn::<Int32Type>(&a, i32::MAX).unwrap(); | ||
let expected = Arc::new(Int32Array::from(vec![i32::MAX])) as ArrayRef; | ||
assert_eq!(&expected, &wrapped); | ||
|
||
let overflow = subtract_scalar_checked_dyn::<Int32Type>(&a, i32::MAX); | ||
overflow.expect_err("overflow should be detected"); | ||
} | ||
|
||
#[test] | ||
fn test_primitive_mul_scalar_dyn_wrapping_overflow() { | ||
let a = Int32Array::from(vec![10]); | ||
|
||
let wrapped = multiply_scalar_dyn::<Int32Type>(&a, i32::MAX).unwrap(); | ||
let expected = Arc::new(Int32Array::from(vec![-10])) as ArrayRef; | ||
assert_eq!(&expected, &wrapped); | ||
|
||
let overflow = multiply_scalar_checked_dyn::<Int32Type>(&a, i32::MAX); | ||
overflow.expect_err("overflow should be detected"); | ||
} | ||
|
||
#[test] | ||
fn test_primitive_div_scalar_dyn_wrapping_overflow() { | ||
let a = Int32Array::from(vec![i32::MIN]); | ||
|
||
let wrapped = divide_scalar_dyn::<Int32Type>(&a, -1).unwrap(); | ||
let expected = Arc::new(Int32Array::from(vec![-2147483648])) as ArrayRef; | ||
assert_eq!(&expected, &wrapped); | ||
|
||
let overflow = divide_scalar_checked_dyn::<Int32Type>(&a, -1); | ||
overflow.expect_err("overflow should be detected"); | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -123,7 +123,7 @@ where | |
Ok(unsafe { build_primitive_array(len, buffer.finish(), null_count, null_buffer) }) | ||
} | ||
|
||
/// A helper function that applies an unary function to a dictionary array with primitive value type. | ||
/// A helper function that applies an infallible unary function to a dictionary array with primitive value type. | ||
fn unary_dict<K, F, T>(array: &DictionaryArray<K>, op: F) -> Result<ArrayRef> | ||
where | ||
K: ArrowNumericType, | ||
|
@@ -138,7 +138,22 @@ where | |
Ok(Arc::new(new_dict)) | ||
} | ||
|
||
/// Applies an unary function to an array with primitive values. | ||
/// A helper function that applies a fallible unary function to a dictionary array with primitive value type. | ||
fn try_unary_dict<K, F, T>(array: &DictionaryArray<K>, op: F) -> Result<ArrayRef> | ||
where | ||
K: ArrowNumericType, | ||
T: ArrowPrimitiveType, | ||
F: Fn(T::Native) -> Result<T::Native>, | ||
{ | ||
let dict_values = array.values().as_any().downcast_ref().unwrap(); | ||
let values = try_unary::<T, F, T>(dict_values, op)?.into_data(); | ||
let data = array.data().clone().into_builder().child_data(vec![values]); | ||
|
||
let new_dict: DictionaryArray<K> = unsafe { data.build_unchecked() }.into(); | ||
Ok(Arc::new(new_dict)) | ||
} | ||
|
||
/// Applies an infallible unary function to an array with primitive values. | ||
pub fn unary_dyn<F, T>(array: &dyn Array, op: F) -> Result<ArrayRef> | ||
where | ||
T: ArrowPrimitiveType, | ||
|
@@ -162,6 +177,30 @@ where | |
} | ||
} | ||
|
||
/// Applies a fallible unary function to an array with primitive values. | ||
pub fn try_unary_dyn<F, T>(array: &dyn Array, op: F) -> Result<ArrayRef> | ||
where | ||
T: ArrowPrimitiveType, | ||
F: Fn(T::Native) -> Result<T::Native>, | ||
{ | ||
downcast_dictionary_array! { | ||
array => try_unary_dict::<_, F, T>(array, op), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. hmm how do we know the dictionary value type matches There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Indeed, there is no type-bound for the dictionary value type. Just do a simple test. At runtime There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Normally as the op is provided by users, I suppose that users know dictionary value is same type as the scalar. But it is good to return a meaningful Err instead of runtime panic. I will do it in a follow-up. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Follow up sounds fine to me. Perhaps we can just check the type here: downcast_dictionary_array! {
array => if array.values().data_type() == &T::DATA_TYPE {
try_unary_dict::<_, F, T>(array, op)
} else {
// throw error
},
t => { There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh, right, actually I thought to handle it at |
||
t => { | ||
if t == &T::DATA_TYPE { | ||
Ok(Arc::new(try_unary::<T, F, T>( | ||
array.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap(), | ||
op, | ||
)?)) | ||
} else { | ||
Err(ArrowError::NotYetImplemented(format!( | ||
"Cannot perform unary operation on array of type {}", | ||
t | ||
))) | ||
} | ||
} | ||
} | ||
} | ||
|
||
/// Given two arrays of length `len`, calls `op(a[i], b[i])` for `i` in `0..len`, collecting | ||
/// the results in a [`PrimitiveArray`]. If any index is null in either `a` or `b`, the | ||
/// corresponding index in the result will also be null | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
perhaps explain a bit when it will return
Err