diff --git a/arrow/src/compute/kernels/arithmetic.rs b/arrow/src/compute/kernels/arithmetic.rs index 775fc53611e..3660fbf52e7 100644 --- a/arrow/src/compute/kernels/arithmetic.rs +++ b/arrow/src/compute/kernels/arithmetic.rs @@ -921,15 +921,34 @@ where /// Perform `left - right` operation on two arrays. If either left or right value is null /// then the result is also null. +/// +/// This doesn't detect overflow. Once overflowing, the result will wrap around. +/// For an overflow-checking variant, use `subtract_checked` instead. pub fn subtract( left: &PrimitiveArray, right: &PrimitiveArray, ) -> Result> where T: datatypes::ArrowNumericType, - T::Native: Sub, + T::Native: ArrowNativeTypeOp, +{ + math_op(left, right, |a, b| a.wrapping_sub_if_applied(b)) +} + +/// Perform `left - right` operation on two arrays. If either left or right value is null +/// then the result is also null. +/// +/// This detects overflow and returns an `Err` for that. For an non-overflow-checking variant, +/// use `subtract` instead. +pub fn subtract_checked( + left: &PrimitiveArray, + right: &PrimitiveArray, +) -> Result> +where + T: datatypes::ArrowNumericType, + T::Native: ArrowNativeTypeOp, { - math_op(left, right, |a, b| a - b) + math_checked_op(left, right, |a, b| a.checked_sub_if_applied(b)) } /// Perform `left - right` operation on two arrays. If either left or right value is null @@ -998,15 +1017,34 @@ where /// Perform `left * right` operation on two arrays. If either left or right value is null /// then the result is also null. +/// +/// This doesn't detect overflow. Once overflowing, the result will wrap around. +/// For an overflow-checking variant, use `multiply_check` instead. pub fn multiply( left: &PrimitiveArray, right: &PrimitiveArray, ) -> Result> where T: datatypes::ArrowNumericType, - T::Native: Mul, + T::Native: ArrowNativeTypeOp, { - math_op(left, right, |a, b| a * b) + math_op(left, right, |a, b| a.wrapping_mul_if_applied(b)) +} + +/// Perform `left * right` operation on two arrays. If either left or right value is null +/// then the result is also null. +/// +/// This detects overflow and returns an `Err` for that. For an non-overflow-checking variant, +/// use `multiply` instead. +pub fn multiply_check( + left: &PrimitiveArray, + right: &PrimitiveArray, +) -> Result> +where + T: datatypes::ArrowNumericType, + T::Native: ArrowNativeTypeOp, +{ + math_checked_op(left, right, |a, b| a.checked_mul_if_applied(b)) } /// Perform `left * right` operation on two arrays. If either left or right value is null @@ -1078,18 +1116,21 @@ where /// Perform `left / right` operation on two arrays. If either left or right value is null /// then the result is also null. If any right hand value is zero then the result of this /// operation will be `Err(ArrowError::DivideByZero)`. -pub fn divide( +/// +/// When `simd` feature is not enabled. This detects overflow and returns an `Err` for that. +/// For an non-overflow-checking variant, use `divide` instead. +pub fn divide_checked( left: &PrimitiveArray, right: &PrimitiveArray, ) -> Result> where T: datatypes::ArrowNumericType, - T::Native: Div + Zero + One, + T::Native: ArrowNativeTypeOp + Zero + One, { #[cfg(feature = "simd")] return simd_checked_divide_op(&left, &right, simd_checked_divide::, |a, b| a / b); #[cfg(not(feature = "simd"))] - return math_checked_divide_op(left, right, |a, b| a / b); + return math_checked_op(left, right, |a, b| a.checked_div_if_applied(b)); } /// Perform `left / right` operation on two arrays. If either left or right value is null @@ -1107,15 +1148,18 @@ pub fn divide_dyn(left: &dyn Array, right: &dyn Array) -> Result { /// Perform `left / right` operation on two arrays without checking for division by zero. /// The result of dividing by zero follows normal floating point rules. /// If either left or right value is null then the result is also null. If any right hand value is zero then the result of this -pub fn divide_unchecked( +/// +/// This doesn't detect overflow. Once overflowing, the result will wrap around. +/// For an overflow-checking variant, use `divide_checked` instead. +pub fn divide( left: &PrimitiveArray, right: &PrimitiveArray, ) -> Result> where - T: datatypes::ArrowFloatNumericType, - T::Native: Div, + T: datatypes::ArrowNumericType, + T::Native: ArrowNativeTypeOp, { - math_op(left, right, |a, b| a / b) + math_op(left, right, |a, b| a.wrapping_div_if_applied(b)) } /// Modulus every value in an array by a scalar. If any value in the array is null then the @@ -2097,4 +2141,43 @@ mod tests { let overflow = add_checked(&a, &b); overflow.expect_err("overflow should be detected"); } + + #[test] + fn test_primitive_subtract_wrapping_overflow() { + let a = Int32Array::from(vec![-2]); + let b = Int32Array::from(vec![i32::MAX]); + + let wrapped = subtract(&a, &b); + let expected = Int32Array::from(vec![i32::MAX]); + assert_eq!(expected, wrapped.unwrap()); + + let overflow = subtract_checked(&a, &b); + overflow.expect_err("overflow should be detected"); + } + + #[test] + fn test_primitive_mul_wrapping_overflow() { + let a = Int32Array::from(vec![10]); + let b = Int32Array::from(vec![i32::MAX]); + + let wrapped = multiply(&a, &b); + let expected = Int32Array::from(vec![-10]); + assert_eq!(expected, wrapped.unwrap()); + + let overflow = multiply_check(&a, &b); + overflow.expect_err("overflow should be detected"); + } + + #[test] + fn test_primitive_div_wrapping_overflow() { + let a = Int32Array::from(vec![i32::MIN]); + let b = Int32Array::from(vec![-1]); + + let wrapped = divide(&a, &b); + let expected = Int32Array::from(vec![-2147483648]); + assert_eq!(expected, wrapped.unwrap()); + + let overflow = divide_checked(&a, &b); + overflow.expect_err("overflow should be detected"); + } } diff --git a/arrow/src/datatypes/native.rs b/arrow/src/datatypes/native.rs index 2163184137a..e31998e5ebb 100644 --- a/arrow/src/datatypes/native.rs +++ b/arrow/src/datatypes/native.rs @@ -17,7 +17,7 @@ use super::DataType; use half::f16; -use std::ops::Add; +use std::ops::{Add, Div, Mul, Sub}; mod private { pub trait Sealed {} @@ -116,7 +116,13 @@ pub trait ArrowPrimitiveType: 'static { } /// Trait for ArrowNativeType to provide overflow-aware operations. -pub trait ArrowNativeTypeOp: ArrowNativeType + Add { +pub trait ArrowNativeTypeOp: + ArrowNativeType + + Add + + Sub + + Mul + + Div +{ fn checked_add_if_applied(self, rhs: Self) -> Option { Some(self + rhs) } @@ -124,6 +130,30 @@ pub trait ArrowNativeTypeOp: ArrowNativeType + Add { fn wrapping_add_if_applied(self, rhs: Self) -> Self { self + rhs } + + fn checked_sub_if_applied(self, rhs: Self) -> Option { + Some(self + rhs) + } + + fn wrapping_sub_if_applied(self, rhs: Self) -> Self { + self + rhs + } + + fn checked_mul_if_applied(self, rhs: Self) -> Option { + Some(self * rhs) + } + + fn wrapping_mul_if_applied(self, rhs: Self) -> Self { + self * rhs + } + + fn checked_div_if_applied(self, rhs: Self) -> Option { + Some(self / rhs) + } + + fn wrapping_div_if_applied(self, rhs: Self) -> Self { + self / rhs + } } macro_rules! native_type_op { @@ -136,6 +166,30 @@ macro_rules! native_type_op { fn wrapping_add_if_applied(self, rhs: Self) -> Self { self.wrapping_add(rhs) } + + fn checked_sub_if_applied(self, rhs: Self) -> Option { + self.checked_sub(rhs) + } + + fn wrapping_sub_if_applied(self, rhs: Self) -> Self { + self.wrapping_sub(rhs) + } + + fn checked_mul_if_applied(self, rhs: Self) -> Option { + self.checked_mul(rhs) + } + + fn wrapping_mul_if_applied(self, rhs: Self) -> Self { + self.wrapping_mul(rhs) + } + + fn checked_div_if_applied(self, rhs: Self) -> Option { + self.checked_div(rhs) + } + + fn wrapping_div_if_applied(self, rhs: Self) -> Self { + self.wrapping_div(rhs) + } } }; }