diff --git a/arrow/src/compute/kernels/arithmetic.rs b/arrow/src/compute/kernels/arithmetic.rs index 1c28c989524..2e6c2876f58 100644 --- a/arrow/src/compute/kernels/arithmetic.rs +++ b/arrow/src/compute/kernels/arithmetic.rs @@ -22,7 +22,7 @@ //! `RUSTFLAGS="-C target-feature=+avx2"` for example. See the documentation //! [here](https://doc.rust-lang.org/stable/core/arch/) for more information. -use std::ops::{Div, Neg, Rem}; +use std::ops::{Div, Neg}; use num::{One, Zero}; @@ -182,7 +182,7 @@ fn simd_checked_modulus( right: T::Simd, ) -> Result where - T::Native: One + Zero, + T::Native: ArrowNativeTypeOp + One, { let zero = T::init(T::Native::zero()); let one = T::init(T::Native::one()); @@ -305,7 +305,7 @@ fn simd_checked_divide_op( ) -> Result> where T: ArrowNumericType, - T::Native: One + Zero, + T::Native: ArrowNativeTypeOp, SI: Fn(Option, T::Simd, T::Simd) -> Result, SC: Fn(T::Native, T::Native) -> T::Native, { @@ -1305,7 +1305,7 @@ pub fn modulus( ) -> Result> where T: ArrowNumericType, - T::Native: Rem + Zero + One, + T::Native: ArrowNativeTypeOp + One, { #[cfg(feature = "simd")] return simd_checked_divide_op(&left, &right, simd_checked_modulus::, |a, b| { @@ -1316,7 +1316,7 @@ where if b.is_zero() { Err(ArrowError::DivideByZero) } else { - Ok(a % b) + Ok(a.mod_wrapping(b)) } }); } @@ -1511,13 +1511,13 @@ pub fn modulus_scalar( ) -> Result> where T: ArrowNumericType, - T::Native: Rem + Zero, + T::Native: ArrowNativeTypeOp, { if modulo.is_zero() { return Err(ArrowError::DivideByZero); } - Ok(unary(array, |a| a % modulo)) + Ok(unary(array, |a| a.mod_wrapping(modulo))) } /// Divide every value in an array by a scalar. If any value in the array is null then the @@ -2120,7 +2120,7 @@ mod tests { } #[test] - fn test_primitive_array_modulus() { + fn test_int_array_modulus() { let a = Int32Array::from(vec![15, 15, 8, 1, 9]); let b = Int32Array::from(vec![5, 6, 8, 9, 1]); let c = modulus(&a, &b).unwrap(); @@ -2131,6 +2131,34 @@ mod tests { assert_eq!(0, c.value(4)); } + #[test] + #[should_panic( + expected = "called `Result::unwrap()` on an `Err` value: DivideByZero" + )] + fn test_int_array_modulus_divide_by_zero() { + let a = Int32Array::from(vec![1]); + let b = Int32Array::from(vec![0]); + modulus(&a, &b).unwrap(); + } + + #[test] + #[cfg(not(feature = "simd"))] + fn test_int_array_modulus_overflow_wrapping() { + let a = Int32Array::from(vec![i32::MIN]); + let b = Int32Array::from(vec![-1]); + let result = modulus(&a, &b).unwrap(); + assert_eq!(0, result.value(0)) + } + + #[test] + #[cfg(feature = "simd")] + #[should_panic(expected = "attempt to calculate the remainder with overflow")] + fn test_int_array_modulus_overflow_panic() { + let a = Int32Array::from(vec![i32::MIN]); + let b = Int32Array::from(vec![-1]); + let _ = modulus(&a, &b).unwrap(); + } + #[test] fn test_primitive_array_divide_scalar() { let a = Int32Array::from(vec![15, 14, 9, 8, 1]); @@ -2193,7 +2221,7 @@ mod tests { } #[test] - fn test_primitive_array_modulus_scalar() { + fn test_int_array_modulus_scalar() { let a = Int32Array::from(vec![15, 14, 9, 8, 1]); let b = 3; let c = modulus_scalar(&a, b).unwrap(); @@ -2202,7 +2230,7 @@ mod tests { } #[test] - fn test_primitive_array_modulus_scalar_sliced() { + fn test_int_array_modulus_scalar_sliced() { let a = Int32Array::from(vec![Some(15), None, Some(9), Some(8), None]); let a = a.slice(1, 4); let a = as_primitive_array(&a); @@ -2211,6 +2239,22 @@ mod tests { assert_eq!(actual, expected); } + #[test] + #[should_panic( + expected = "called `Result::unwrap()` on an `Err` value: DivideByZero" + )] + fn test_int_array_modulus_scalar_divide_by_zero() { + let a = Int32Array::from(vec![1]); + modulus_scalar(&a, 0).unwrap(); + } + + #[test] + fn test_int_array_modulus_scalar_overflow_wrapping() { + let a = Int32Array::from(vec![i32::MIN]); + let result = modulus_scalar(&a, -1).unwrap(); + assert_eq!(0, result.value(0)) + } + #[test] fn test_primitive_array_divide_sliced() { let a = Int32Array::from(vec![0, 0, 0, 15, 15, 8, 1, 9, 0]); diff --git a/arrow/src/datatypes/native.rs b/arrow/src/datatypes/native.rs index 6ab82688e52..654b939500a 100644 --- a/arrow/src/datatypes/native.rs +++ b/arrow/src/datatypes/native.rs @@ -26,7 +26,7 @@ pub(crate) mod native_op { use super::ArrowNativeType; use crate::error::{ArrowError, Result}; use num::Zero; - use std::ops::{Add, Div, Mul, Sub}; + use std::ops::{Add, Div, Mul, Rem, Sub}; /// Trait for ArrowNativeType to provide overflow-checking and non-overflow-checking /// variants for arithmetic operations. For floating point types, this provides some @@ -44,6 +44,7 @@ pub(crate) mod native_op { + Sub + Mul + Div + + Rem + Zero { fn add_checked(self, rhs: Self) -> Result { @@ -81,6 +82,18 @@ pub(crate) mod native_op { fn div_wrapping(self, rhs: Self) -> Self { self / rhs } + + fn mod_checked(self, rhs: Self) -> Result { + if rhs.is_zero() { + Err(ArrowError::DivideByZero) + } else { + Ok(self % rhs) + } + } + + fn mod_wrapping(self, rhs: Self) -> Self { + self % rhs + } } } @@ -142,6 +155,23 @@ macro_rules! native_type_op { fn div_wrapping(self, rhs: Self) -> Self { self.wrapping_div(rhs) } + + fn mod_checked(self, rhs: Self) -> Result { + if rhs.is_zero() { + Err(ArrowError::DivideByZero) + } else { + self.checked_rem(rhs).ok_or_else(|| { + ArrowError::ComputeError(format!( + "Overflow happened on: {:?} % {:?}", + self, rhs + )) + }) + } + } + + fn mod_wrapping(self, rhs: Self) -> Self { + self.wrapping_rem(rhs) + } } }; }