Skip to content

Commit

Permalink
Simplify ArrowNativeType
Browse files Browse the repository at this point in the history
  • Loading branch information
tustvold committed Oct 7, 2022
1 parent 8dd94a9 commit 098bd14
Show file tree
Hide file tree
Showing 3 changed files with 187 additions and 138 deletions.
9 changes: 4 additions & 5 deletions arrow/src/compute/kernels/aggregate.rs
Expand Up @@ -391,9 +391,8 @@ where
mod simd {
use super::is_nan;
use crate::array::{Array, PrimitiveArray};
use crate::datatypes::ArrowNumericType;
use crate::datatypes::{ArrowNativeTypeOp, ArrowNumericType};
use std::marker::PhantomData;
use std::ops::Add;

pub(super) trait SimdAggregate<T: ArrowNumericType> {
type ScalarAccumulator;
Expand Down Expand Up @@ -434,7 +433,7 @@ mod simd {

impl<T: ArrowNumericType> SimdAggregate<T> for SumAggregate<T>
where
T::Native: Add<Output = T::Native>,
T::Native: ArrowNativeTypeOp,
{
type ScalarAccumulator = T::Native;
type SimdAccumulator = T::Simd;
Expand Down Expand Up @@ -463,7 +462,7 @@ mod simd {
}

fn accumulate_scalar(accumulator: &mut T::Native, value: T::Native) {
*accumulator = *accumulator + value
*accumulator = accumulator.add_wrapping(value)
}

fn reduce(
Expand Down Expand Up @@ -738,7 +737,7 @@ mod simd {
#[cfg(feature = "simd")]
pub fn sum<T: ArrowNumericType>(array: &PrimitiveArray<T>) -> Option<T::Native>
where
T::Native: Add<Output = T::Native>,
T::Native: ArrowNativeTypeOp,
{
use simd::*;

Expand Down
88 changes: 47 additions & 41 deletions arrow/src/compute/kernels/arithmetic.rs
Expand Up @@ -22,9 +22,7 @@
//! `RUSTFLAGS="-C target-feature=+avx2"` for example. See the documentation
//! [here](https://doc.rust-lang.org/stable/core/arch/) for more information.

use std::ops::{Div, Neg};

use num::{One, Zero};
use std::ops::Neg;

use crate::array::*;
#[cfg(feature = "simd")]
Expand Down Expand Up @@ -107,7 +105,6 @@ fn math_checked_divide_op<LT, RT, F>(
where
LT: ArrowNumericType,
RT: ArrowNumericType,
RT::Native: One + Zero,
F: Fn(LT::Native, RT::Native) -> Result<LT::Native>,
{
try_binary(left, right, op)
Expand All @@ -131,7 +128,6 @@ fn math_checked_divide_op_on_iters<T, F>(
) -> Result<PrimitiveArray<T>>
where
T: ArrowNumericType,
T::Native: One + Zero,
F: Fn(T::Native, T::Native) -> Result<T::Native>,
{
let buffer = if null_bit_buffer.is_some() {
Expand Down Expand Up @@ -182,10 +178,10 @@ fn simd_checked_modulus<T: ArrowNumericType>(
right: T::Simd,
) -> Result<T::Simd>
where
T::Native: ArrowNativeTypeOp + One,
T::Native: ArrowNativeTypeOp,
{
let zero = T::init(T::Native::zero());
let one = T::init(T::Native::one());
let zero = T::init(T::Native::ZERO);
let one = T::init(T::Native::ONE);

let right_no_invalid_zeros = match valid_mask {
Some(mask) => {
Expand Down Expand Up @@ -219,10 +215,10 @@ fn simd_checked_divide<T: ArrowNumericType>(
right: T::Simd,
) -> Result<T::Simd>
where
T::Native: One + Zero,
T::Native: ArrowNativeTypeOp,
{
let zero = T::init(T::Native::zero());
let one = T::init(T::Native::one());
let zero = T::init(T::Native::ZERO);
let one = T::init(T::Native::ONE);

let right_no_invalid_zeros = match valid_mask {
Some(mask) => {
Expand Down Expand Up @@ -260,7 +256,7 @@ fn simd_checked_divide_op_remainder<T, F>(
) -> Result<()>
where
T: ArrowNumericType,
T::Native: Zero,
T::Native: ArrowNativeTypeOp,
F: Fn(T::Native, T::Native) -> T::Native,
{
let result_remainder = result_chunks.into_remainder();
Expand All @@ -273,7 +269,7 @@ where
.enumerate()
.try_for_each(|(i, (result_scalar, (left_scalar, right_scalar)))| {
if valid_mask.map(|mask| mask & (1 << i) != 0).unwrap_or(true) {
if *right_scalar == T::Native::zero() {
if right_scalar.is_zero() {
return Err(ArrowError::DivideByZero);
}
*result_scalar = op(*left_scalar, *right_scalar);
Expand Down Expand Up @@ -648,7 +644,6 @@ fn math_divide_checked_op_dict<K, T, F>(
where
K: ArrowNumericType,
T: ArrowNumericType,
T::Native: One + Zero,
F: Fn(T::Native, T::Native) -> Result<T::Native>,
{
if left.len() != right.len() {
Expand Down Expand Up @@ -702,7 +697,6 @@ fn math_divide_safe_op_dict<K, T, F>(
where
K: ArrowNumericType,
T: ArrowNumericType,
T::Native: One + Zero,
F: Fn(T::Native, T::Native) -> Option<T::Native>,
{
let left = left.downcast_dict::<PrimitiveArray<T>>().unwrap();
Expand All @@ -719,7 +713,6 @@ fn math_safe_divide_op<LT, RT, F>(
where
LT: ArrowNumericType,
RT: ArrowNumericType,
RT::Native: One + Zero,
F: Fn(LT::Native, RT::Native) -> Option<LT::Native>,
{
let array: PrimitiveArray<LT> = binary_opt::<_, _, _, LT>(left, right, op)?;
Expand Down Expand Up @@ -1068,8 +1061,8 @@ pub fn subtract_scalar<T>(
scalar: T::Native,
) -> Result<PrimitiveArray<T>>
where
T: datatypes::ArrowNumericType,
T::Native: ArrowNativeTypeOp + Zero,
T: ArrowNumericType,
T::Native: ArrowNativeTypeOp,
{
Ok(unary(array, |value| value.sub_wrapping(scalar)))
}
Expand All @@ -1085,7 +1078,7 @@ pub fn subtract_scalar_checked<T>(
) -> Result<PrimitiveArray<T>>
where
T: ArrowNumericType,
T::Native: ArrowNativeTypeOp + Zero,
T::Native: ArrowNativeTypeOp,
{
try_unary(array, |value| value.sub_checked(scalar))
}
Expand Down Expand Up @@ -1125,7 +1118,7 @@ where
/// Perform `-` operation on an array. If value is null then the result is also null.
pub fn negate<T>(array: &PrimitiveArray<T>) -> Result<PrimitiveArray<T>>
where
T: datatypes::ArrowNumericType,
T: ArrowNumericType,
T::Native: Neg<Output = T::Native>,
{
Ok(unary(array, |x| -x))
Expand Down Expand Up @@ -1239,7 +1232,7 @@ pub fn multiply_scalar<T>(
) -> Result<PrimitiveArray<T>>
where
T: datatypes::ArrowNumericType,
T::Native: ArrowNativeTypeOp + Zero + One,
T::Native: ArrowNativeTypeOp,
{
Ok(unary(array, |value| value.mul_wrapping(scalar)))
}
Expand All @@ -1255,7 +1248,7 @@ pub fn multiply_scalar_checked<T>(
) -> Result<PrimitiveArray<T>>
where
T: ArrowNumericType,
T::Native: ArrowNativeTypeOp + Zero + One,
T::Native: ArrowNativeTypeOp,
{
try_unary(array, |value| value.mul_checked(scalar))
}
Expand Down Expand Up @@ -1301,11 +1294,11 @@ pub fn modulus<T>(
) -> Result<PrimitiveArray<T>>
where
T: ArrowNumericType,
T::Native: ArrowNativeTypeOp + One,
T::Native: ArrowNativeTypeOp,
{
#[cfg(feature = "simd")]
return simd_checked_divide_op(&left, &right, simd_checked_modulus::<T>, |a, b| {
a % b
a.mod_wrapping(b)
});
#[cfg(not(feature = "simd"))]
return try_binary(left, right, |a, b| {
Expand All @@ -1328,11 +1321,13 @@ pub fn divide_checked<T>(
right: &PrimitiveArray<T>,
) -> Result<PrimitiveArray<T>>
where
T: datatypes::ArrowNumericType,
T::Native: ArrowNativeTypeOp + Zero + One,
T: ArrowNumericType,
T::Native: ArrowNativeTypeOp,
{
#[cfg(feature = "simd")]
return simd_checked_divide_op(&left, &right, simd_checked_divide::<T>, |a, b| a / b);
return simd_checked_divide_op(&left, &right, simd_checked_divide::<T>, |a, b| {
a.div_wrapping(b)
});
#[cfg(not(feature = "simd"))]
return math_checked_divide_op(left, right, |a, b| a.div_checked(b));
}
Expand All @@ -1343,16 +1338,21 @@ where
/// If any right hand value is zero, the operation value will be replaced with null in the
/// result.
///
/// Unlike `divide` or `divide_checked`, division by zero will get a null value instead
/// returning an `Err`, this also doesn't check overflowing, overflowing will just wrap
/// the result around.
/// Unlike [`divide`] or [`divide_checked`], division by zero will yield a null value in the
/// result instead of returning an `Err`.
///
/// For floating point types overflow will saturate at INF or -INF
/// preserving the expected sign value.
///
/// For integer types overflow will wrap around.
///
pub fn divide_opt<T>(
left: &PrimitiveArray<T>,
right: &PrimitiveArray<T>,
) -> Result<PrimitiveArray<T>>
where
T: ArrowNumericType,
T::Native: ArrowNativeTypeOp + Zero + One,
T::Native: ArrowNativeTypeOp,
{
binary_opt(left, right, |a, b| {
if b.is_zero() {
Expand Down Expand Up @@ -1480,12 +1480,16 @@ pub fn divide_dyn_opt(left: &dyn Array, right: &dyn Array) -> Result<ArrayRef> {
}
}

/// Perform `left / right` operation on two arrays without checking for division by zero.
/// For floating point types, the result of dividing by zero follows normal floating point
/// rules. For other numeric types, dividing by zero will panic,
/// If either left or right value is null then the result is also null. If any right hand value is zero then the result of this
/// Perform `left / right` operation on two arrays without checking for
/// division by zero or overflow.
///
/// For floating point types, overflow and division by zero follows normal floating point rules
///
/// For integer types overflow will wrap around. Division by zero will currently panic, although
/// this may be subject to change see <https://github.com/apache/arrow-rs/issues/2647>
///
/// If either left or right value is null then the result is also null.
///
/// This doesn't detect overflow. Once overflowing, the result will wrap around.
/// For an overflow-checking variant, use `divide_checked` instead.
pub fn divide<T>(
left: &PrimitiveArray<T>,
Expand All @@ -1495,6 +1499,8 @@ where
T: ArrowNumericType,
T::Native: ArrowNativeTypeOp,
{
// TODO: This is incorrect as div_wrapping has side-effects for integer types
// and so may panic on null values (#2647)
math_op(left, right, |a, b| a.div_wrapping(b))
}

Expand Down Expand Up @@ -1525,12 +1531,12 @@ pub fn divide_scalar<T>(
) -> Result<PrimitiveArray<T>>
where
T: ArrowNumericType,
T::Native: Div<Output = T::Native> + Zero,
T::Native: ArrowNativeTypeOp,
{
if divisor.is_zero() {
return Err(ArrowError::DivideByZero);
}
Ok(unary(array, |a| a / divisor))
Ok(unary(array, |a| a.div_wrapping(divisor)))
}

/// Divide every value in an array by a scalar. If any value in the array is null then the
Expand All @@ -1543,7 +1549,7 @@ where
pub fn divide_scalar_dyn<T>(array: &dyn Array, divisor: T::Native) -> Result<ArrayRef>
where
T: ArrowNumericType,
T::Native: ArrowNativeTypeOp + Zero,
T::Native: ArrowNativeTypeOp,
{
if divisor.is_zero() {
return Err(ArrowError::DivideByZero);
Expand All @@ -1564,7 +1570,7 @@ pub fn divide_scalar_checked_dyn<T>(
) -> Result<ArrayRef>
where
T: ArrowNumericType,
T::Native: ArrowNativeTypeOp + Zero,
T::Native: ArrowNativeTypeOp,
{
if divisor.is_zero() {
return Err(ArrowError::DivideByZero);
Expand All @@ -1587,7 +1593,7 @@ where
pub fn divide_scalar_opt_dyn<T>(array: &dyn Array, divisor: T::Native) -> Result<ArrayRef>
where
T: ArrowNumericType,
T::Native: ArrowNativeTypeOp + Zero,
T::Native: ArrowNativeTypeOp,
{
if divisor.is_zero() {
match array.data_type() {
Expand Down

0 comments on commit 098bd14

Please sign in to comment.