diff --git a/arrow/src/compute/kernels/aggregate.rs b/arrow/src/compute/kernels/aggregate.rs index e6c927c2486..4e726974f66 100644 --- a/arrow/src/compute/kernels/aggregate.rs +++ b/arrow/src/compute/kernels/aggregate.rs @@ -391,9 +391,8 @@ where mod simd { use super::is_nan; use crate::array::{Array, PrimitiveArray}; - use crate::datatypes::ArrowNumericType; + use crate::datatypes::{ArrowNativeTypeOp, ArrowNumericType}; use std::marker::PhantomData; - use std::ops::Add; pub(super) trait SimdAggregate { type ScalarAccumulator; @@ -434,7 +433,7 @@ mod simd { impl SimdAggregate for SumAggregate where - T::Native: Add, + T::Native: ArrowNativeTypeOp, { type ScalarAccumulator = T::Native; type SimdAccumulator = T::Simd; @@ -463,7 +462,7 @@ mod simd { } fn accumulate_scalar(accumulator: &mut T::Native, value: T::Native) { - *accumulator = *accumulator + value + *accumulator = accumulator.add_wrapping(value) } fn reduce( @@ -738,7 +737,7 @@ mod simd { #[cfg(feature = "simd")] pub fn sum(array: &PrimitiveArray) -> Option where - T::Native: Add, + T::Native: ArrowNativeTypeOp, { use simd::*; diff --git a/arrow/src/compute/kernels/arithmetic.rs b/arrow/src/compute/kernels/arithmetic.rs index fe222c3d15d..e0d3077da4c 100644 --- a/arrow/src/compute/kernels/arithmetic.rs +++ b/arrow/src/compute/kernels/arithmetic.rs @@ -22,9 +22,7 @@ //! `RUSTFLAGS="-C target-feature=+avx2"` for example. See the documentation //! [here](https://doc.rust-lang.org/stable/core/arch/) for more information. -use std::ops::{Div, Neg}; - -use num::{One, Zero}; +use std::ops::Neg; use crate::array::*; #[cfg(feature = "simd")] @@ -107,7 +105,6 @@ fn math_checked_divide_op( where LT: ArrowNumericType, RT: ArrowNumericType, - RT::Native: One + Zero, F: Fn(LT::Native, RT::Native) -> Result, { try_binary(left, right, op) @@ -131,7 +128,6 @@ fn math_checked_divide_op_on_iters( ) -> Result> where T: ArrowNumericType, - T::Native: One + Zero, F: Fn(T::Native, T::Native) -> Result, { let buffer = if null_bit_buffer.is_some() { @@ -182,10 +178,10 @@ fn simd_checked_modulus( right: T::Simd, ) -> Result where - T::Native: ArrowNativeTypeOp + One, + T::Native: ArrowNativeTypeOp, { - let zero = T::init(T::Native::zero()); - let one = T::init(T::Native::one()); + let zero = T::init(T::Native::ZERO); + let one = T::init(T::Native::ONE); let right_no_invalid_zeros = match valid_mask { Some(mask) => { @@ -219,10 +215,10 @@ fn simd_checked_divide( right: T::Simd, ) -> Result where - T::Native: One + Zero, + T::Native: ArrowNativeTypeOp, { - let zero = T::init(T::Native::zero()); - let one = T::init(T::Native::one()); + let zero = T::init(T::Native::ZERO); + let one = T::init(T::Native::ONE); let right_no_invalid_zeros = match valid_mask { Some(mask) => { @@ -260,7 +256,7 @@ fn simd_checked_divide_op_remainder( ) -> Result<()> where T: ArrowNumericType, - T::Native: Zero, + T::Native: ArrowNativeTypeOp, F: Fn(T::Native, T::Native) -> T::Native, { let result_remainder = result_chunks.into_remainder(); @@ -273,7 +269,7 @@ where .enumerate() .try_for_each(|(i, (result_scalar, (left_scalar, right_scalar)))| { if valid_mask.map(|mask| mask & (1 << i) != 0).unwrap_or(true) { - if *right_scalar == T::Native::zero() { + if right_scalar.is_zero() { return Err(ArrowError::DivideByZero); } *result_scalar = op(*left_scalar, *right_scalar); @@ -648,7 +644,6 @@ fn math_divide_checked_op_dict( where K: ArrowNumericType, T: ArrowNumericType, - T::Native: One + Zero, F: Fn(T::Native, T::Native) -> Result, { if left.len() != right.len() { @@ -702,7 +697,6 @@ fn math_divide_safe_op_dict( where K: ArrowNumericType, T: ArrowNumericType, - T::Native: One + Zero, F: Fn(T::Native, T::Native) -> Option, { let left = left.downcast_dict::>().unwrap(); @@ -719,7 +713,6 @@ fn math_safe_divide_op( where LT: ArrowNumericType, RT: ArrowNumericType, - RT::Native: One + Zero, F: Fn(LT::Native, RT::Native) -> Option, { let array: PrimitiveArray = binary_opt::<_, _, _, LT>(left, right, op)?; @@ -1068,8 +1061,8 @@ pub fn subtract_scalar( scalar: T::Native, ) -> Result> where - T: datatypes::ArrowNumericType, - T::Native: ArrowNativeTypeOp + Zero, + T: ArrowNumericType, + T::Native: ArrowNativeTypeOp, { Ok(unary(array, |value| value.sub_wrapping(scalar))) } @@ -1085,7 +1078,7 @@ pub fn subtract_scalar_checked( ) -> Result> where T: ArrowNumericType, - T::Native: ArrowNativeTypeOp + Zero, + T::Native: ArrowNativeTypeOp, { try_unary(array, |value| value.sub_checked(scalar)) } @@ -1125,7 +1118,7 @@ where /// Perform `-` operation on an array. If value is null then the result is also null. pub fn negate(array: &PrimitiveArray) -> Result> where - T: datatypes::ArrowNumericType, + T: ArrowNumericType, T::Native: Neg, { Ok(unary(array, |x| -x)) @@ -1239,7 +1232,7 @@ pub fn multiply_scalar( ) -> Result> where T: datatypes::ArrowNumericType, - T::Native: ArrowNativeTypeOp + Zero + One, + T::Native: ArrowNativeTypeOp, { Ok(unary(array, |value| value.mul_wrapping(scalar))) } @@ -1255,7 +1248,7 @@ pub fn multiply_scalar_checked( ) -> Result> where T: ArrowNumericType, - T::Native: ArrowNativeTypeOp + Zero + One, + T::Native: ArrowNativeTypeOp, { try_unary(array, |value| value.mul_checked(scalar)) } @@ -1295,26 +1288,22 @@ where /// Perform `left % right` operation on two arrays. If either left or right value is null /// then the result is also null. If any right hand value is zero then the result of this /// operation will be `Err(ArrowError::DivideByZero)`. +/// +/// When `simd` feature is not enabled. This detects overflow and returns an `Err` for that. pub fn modulus( left: &PrimitiveArray, right: &PrimitiveArray, ) -> Result> where T: ArrowNumericType, - T::Native: ArrowNativeTypeOp + One, + T::Native: ArrowNativeTypeOp, { #[cfg(feature = "simd")] return simd_checked_divide_op(&left, &right, simd_checked_modulus::, |a, b| { - a % b + a.mod_wrapping(b) }); #[cfg(not(feature = "simd"))] - return try_binary(left, right, |a, b| { - if b.is_zero() { - Err(ArrowError::DivideByZero) - } else { - Ok(a.mod_wrapping(b)) - } - }); + return try_binary(left, right, |a, b| a.mod_checked(b)); } /// Perform `left / right` operation on two arrays. If either left or right value is null @@ -1328,11 +1317,13 @@ pub fn divide_checked( right: &PrimitiveArray, ) -> Result> where - T: datatypes::ArrowNumericType, - T::Native: ArrowNativeTypeOp + Zero + One, + T: ArrowNumericType, + T::Native: ArrowNativeTypeOp, { #[cfg(feature = "simd")] - return simd_checked_divide_op(&left, &right, simd_checked_divide::, |a, b| a / b); + return simd_checked_divide_op(&left, &right, simd_checked_divide::, |a, b| { + a.div_wrapping(b) + }); #[cfg(not(feature = "simd"))] return math_checked_divide_op(left, right, |a, b| a.div_checked(b)); } @@ -1343,16 +1334,21 @@ where /// If any right hand value is zero, the operation value will be replaced with null in the /// result. /// -/// Unlike `divide` or `divide_checked`, division by zero will get a null value instead -/// returning an `Err`, this also doesn't check overflowing, overflowing will just wrap -/// the result around. +/// Unlike [`divide`] or [`divide_checked`], division by zero will yield a null value in the +/// result instead of returning an `Err`. +/// +/// For floating point types overflow will saturate at INF or -INF +/// preserving the expected sign value. +/// +/// For integer types overflow will wrap around. +/// pub fn divide_opt( left: &PrimitiveArray, right: &PrimitiveArray, ) -> Result> where T: ArrowNumericType, - T::Native: ArrowNativeTypeOp + Zero + One, + T::Native: ArrowNativeTypeOp, { binary_opt(left, right, |a, b| { if b.is_zero() { @@ -1480,12 +1476,16 @@ pub fn divide_dyn_opt(left: &dyn Array, right: &dyn Array) -> Result { } } -/// Perform `left / right` operation on two arrays without checking for division by zero. -/// For floating point types, the result of dividing by zero follows normal floating point -/// rules. For other numeric types, dividing by zero will panic, -/// If either left or right value is null then the result is also null. If any right hand value is zero then the result of this +/// Perform `left / right` operation on two arrays without checking for +/// division by zero or overflow. +/// +/// For floating point types, overflow and division by zero follows normal floating point rules +/// +/// For integer types overflow will wrap around. Division by zero will currently panic, although +/// this may be subject to change see +/// +/// If either left or right value is null then the result is also null. /// -/// This doesn't detect overflow. Once overflowing, the result will wrap around. /// For an overflow-checking variant, use `divide_checked` instead. pub fn divide( left: &PrimitiveArray, @@ -1495,6 +1495,8 @@ where T: ArrowNumericType, T::Native: ArrowNativeTypeOp, { + // TODO: This is incorrect as div_wrapping has side-effects for integer types + // and so may panic on null values (#2647) math_op(left, right, |a, b| a.div_wrapping(b)) } @@ -1525,12 +1527,12 @@ pub fn divide_scalar( ) -> Result> where T: ArrowNumericType, - T::Native: Div + Zero, + T::Native: ArrowNativeTypeOp, { if divisor.is_zero() { return Err(ArrowError::DivideByZero); } - Ok(unary(array, |a| a / divisor)) + Ok(unary(array, |a| a.div_wrapping(divisor))) } /// Divide every value in an array by a scalar. If any value in the array is null then the @@ -1543,7 +1545,7 @@ where pub fn divide_scalar_dyn(array: &dyn Array, divisor: T::Native) -> Result where T: ArrowNumericType, - T::Native: ArrowNativeTypeOp + Zero, + T::Native: ArrowNativeTypeOp, { if divisor.is_zero() { return Err(ArrowError::DivideByZero); @@ -1564,7 +1566,7 @@ pub fn divide_scalar_checked_dyn( ) -> Result where T: ArrowNumericType, - T::Native: ArrowNativeTypeOp + Zero, + T::Native: ArrowNativeTypeOp, { if divisor.is_zero() { return Err(ArrowError::DivideByZero); @@ -1587,7 +1589,7 @@ where pub fn divide_scalar_opt_dyn(array: &dyn Array, divisor: T::Native) -> Result where T: ArrowNumericType, - T::Native: ArrowNativeTypeOp + Zero, + T::Native: ArrowNativeTypeOp, { if divisor.is_zero() { match array.data_type() { diff --git a/arrow/src/datatypes/native.rs b/arrow/src/datatypes/native.rs index 2ff9574c78a..444ba39e0b6 100644 --- a/arrow/src/datatypes/native.rs +++ b/arrow/src/datatypes/native.rs @@ -19,110 +19,72 @@ use crate::error::{ArrowError, Result}; pub use arrow_array::ArrowPrimitiveType; pub use arrow_buffer::{ArrowNativeType, ToByteSlice}; use half::f16; -use num::Zero; -use std::ops::{Add, Div, Mul, Rem, Sub}; -mod private { - pub trait Sealed {} -} - -/// Trait for ArrowNativeType to provide overflow-checking and non-overflow-checking -/// variants for arithmetic operations. For floating point types, this provides some -/// default implementations. Integer types that need to deal with overflow can implement -/// this trait. +/// Trait for [`ArrowNativeType`] that adds checked and unchecked arithmetic operations, +/// and totally ordered comparison operations /// -/// The APIs with `_wrapping` suffix are the variant of non-overflow-checking. If overflow -/// occurred, they will supposedly wrap around the boundary of the type. +/// The APIs with `_wrapping` suffix do not perform overflow-checking. For integer +/// types they will wrap around the boundary of the type. For floating point types they +/// will overflow to INF or -INF preserving the expected sign value /// -/// The APIs with `_checked` suffix are the variant of overflow-checking which return `None` -/// if overflow occurred. -pub trait ArrowNativeTypeOp: - ArrowNativeType - + Add - + Sub - + Mul - + Div - + Rem - + Zero - + private::Sealed -{ - fn add_checked(self, rhs: Self) -> Result { - Ok(self + rhs) - } - - fn add_wrapping(self, rhs: Self) -> Self { - self + rhs - } - - fn sub_checked(self, rhs: Self) -> Result { - Ok(self - rhs) - } - - fn sub_wrapping(self, rhs: Self) -> Self { - self - rhs - } - - fn mul_checked(self, rhs: Self) -> Result { - Ok(self * rhs) - } - - fn mul_wrapping(self, rhs: Self) -> Self { - self * rhs - } - - fn div_checked(self, rhs: Self) -> Result { - if rhs.is_zero() { - Err(ArrowError::DivideByZero) - } else { - Ok(self / rhs) - } - } +/// Note `div_wrapping` and `mod_wrapping` will panic for integer types if `rhs` is zero +/// although this may be subject to change +/// +/// The APIs with `_checked` suffix perform overflow-checking. For integer types +/// these will return `Err` instead of wrapping. For floating point types they will +/// overflow to INF or -INF preserving the expected sign value +/// +/// Comparison of integer types is as per normal integer comparison rules, floating +/// point values are compared as per IEEE 754's totalOrder predicate see [`f32::total_cmp`] +/// +pub trait ArrowNativeTypeOp: ArrowNativeType { + /// The additive identity + const ZERO: Self; - fn div_wrapping(self, rhs: Self) -> Self { - self / rhs - } + /// The multiplicative identity + const ONE: Self; - fn mod_checked(self, rhs: Self) -> Result { - if rhs.is_zero() { - Err(ArrowError::DivideByZero) - } else { - Ok(self % rhs) - } - } + fn add_checked(self, rhs: Self) -> Result; + + fn add_wrapping(self, rhs: Self) -> Self; + + fn sub_checked(self, rhs: Self) -> Result; + + fn sub_wrapping(self, rhs: Self) -> Self; - fn mod_wrapping(self, rhs: Self) -> Self { - self % rhs - } + fn mul_checked(self, rhs: Self) -> Result; - fn is_eq(self, rhs: Self) -> bool { - self == rhs - } + fn mul_wrapping(self, rhs: Self) -> Self; - fn is_ne(self, rhs: Self) -> bool { - self != rhs - } + fn div_checked(self, rhs: Self) -> Result; - fn is_lt(self, rhs: Self) -> bool { - self < rhs - } + fn div_wrapping(self, rhs: Self) -> Self; - fn is_le(self, rhs: Self) -> bool { - self <= rhs - } + fn mod_checked(self, rhs: Self) -> Result; - fn is_gt(self, rhs: Self) -> bool { - self > rhs - } + fn mod_wrapping(self, rhs: Self) -> Self; - fn is_ge(self, rhs: Self) -> bool { - self >= rhs - } + fn is_zero(self) -> bool; + + fn is_eq(self, rhs: Self) -> bool; + + fn is_ne(self, rhs: Self) -> bool; + + fn is_lt(self, rhs: Self) -> bool; + + fn is_le(self, rhs: Self) -> bool; + + fn is_gt(self, rhs: Self) -> bool; + + fn is_ge(self, rhs: Self) -> bool; } macro_rules! native_type_op { ($t:tt) => { - impl private::Sealed for $t {} impl ArrowNativeTypeOp for $t { + const ZERO: Self = 0; + const ONE: Self = 1; + fn add_checked(self, rhs: Self) -> Result { self.checked_add(rhs).ok_or_else(|| { ArrowError::ComputeError(format!( @@ -195,6 +157,34 @@ macro_rules! native_type_op { fn mod_wrapping(self, rhs: Self) -> Self { self.wrapping_rem(rhs) } + + fn is_zero(self) -> bool { + self == 0 + } + + fn is_eq(self, rhs: Self) -> bool { + self == rhs + } + + fn is_ne(self, rhs: Self) -> bool { + self != rhs + } + + fn is_lt(self, rhs: Self) -> bool { + self < rhs + } + + fn is_le(self, rhs: Self) -> bool { + self <= rhs + } + + fn is_gt(self, rhs: Self) -> bool { + self > rhs + } + + fn is_ge(self, rhs: Self) -> bool { + self >= rhs + } } }; } @@ -210,9 +200,63 @@ native_type_op!(u32); native_type_op!(u64); macro_rules! native_type_float_op { - ($t:tt) => { - impl private::Sealed for $t {} + ($t:tt, $zero:expr, $one:expr) => { impl ArrowNativeTypeOp for $t { + const ZERO: Self = $zero; + const ONE: Self = $one; + + fn add_checked(self, rhs: Self) -> Result { + Ok(self + rhs) + } + + fn add_wrapping(self, rhs: Self) -> Self { + self + rhs + } + + fn sub_checked(self, rhs: Self) -> Result { + Ok(self - rhs) + } + + fn sub_wrapping(self, rhs: Self) -> Self { + self - rhs + } + + fn mul_checked(self, rhs: Self) -> Result { + Ok(self * rhs) + } + + fn mul_wrapping(self, rhs: Self) -> Self { + self * rhs + } + + fn div_checked(self, rhs: Self) -> Result { + if rhs.is_zero() { + Err(ArrowError::DivideByZero) + } else { + Ok(self / rhs) + } + } + + fn div_wrapping(self, rhs: Self) -> Self { + self / rhs + } + + fn mod_checked(self, rhs: Self) -> Result { + if rhs.is_zero() { + Err(ArrowError::DivideByZero) + } else { + Ok(self % rhs) + } + } + + fn mod_wrapping(self, rhs: Self) -> Self { + self % rhs + } + + fn is_zero(self) -> bool { + self == $zero + } + fn is_eq(self, rhs: Self) -> bool { self.total_cmp(&rhs).is_eq() } @@ -240,6 +284,6 @@ macro_rules! native_type_float_op { }; } -native_type_float_op!(f16); -native_type_float_op!(f32); -native_type_float_op!(f64); +native_type_float_op!(f16, f16::ONE, f16::ZERO); +native_type_float_op!(f32, 0., 1.); +native_type_float_op!(f64, 0., 1.);