Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify ArrowNativeType #2841

Merged
merged 4 commits into from Oct 8, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
9 changes: 4 additions & 5 deletions arrow/src/compute/kernels/aggregate.rs
Expand Up @@ -391,9 +391,8 @@ where
mod simd {
use super::is_nan;
use crate::array::{Array, PrimitiveArray};
use crate::datatypes::ArrowNumericType;
use crate::datatypes::{ArrowNativeTypeOp, ArrowNumericType};
use std::marker::PhantomData;
use std::ops::Add;

pub(super) trait SimdAggregate<T: ArrowNumericType> {
type ScalarAccumulator;
Expand Down Expand Up @@ -434,7 +433,7 @@ mod simd {

impl<T: ArrowNumericType> SimdAggregate<T> for SumAggregate<T>
where
T::Native: Add<Output = T::Native>,
T::Native: ArrowNativeTypeOp,
{
type ScalarAccumulator = T::Native;
type SimdAccumulator = T::Simd;
Expand Down Expand Up @@ -463,7 +462,7 @@ mod simd {
}

fn accumulate_scalar(accumulator: &mut T::Native, value: T::Native) {
*accumulator = *accumulator + value
*accumulator = accumulator.add_wrapping(value)
}

fn reduce(
Expand Down Expand Up @@ -738,7 +737,7 @@ mod simd {
#[cfg(feature = "simd")]
pub fn sum<T: ArrowNumericType>(array: &PrimitiveArray<T>) -> Option<T::Native>
where
T::Native: Add<Output = T::Native>,
T::Native: ArrowNativeTypeOp,
{
use simd::*;

Expand Down
88 changes: 47 additions & 41 deletions arrow/src/compute/kernels/arithmetic.rs
Expand Up @@ -22,9 +22,7 @@
//! `RUSTFLAGS="-C target-feature=+avx2"` for example. See the documentation
//! [here](https://doc.rust-lang.org/stable/core/arch/) for more information.

use std::ops::{Div, Neg};

use num::{One, Zero};
use std::ops::Neg;

use crate::array::*;
#[cfg(feature = "simd")]
Expand Down Expand Up @@ -107,7 +105,6 @@ fn math_checked_divide_op<LT, RT, F>(
where
LT: ArrowNumericType,
RT: ArrowNumericType,
RT::Native: One + Zero,
F: Fn(LT::Native, RT::Native) -> Result<LT::Native>,
{
try_binary(left, right, op)
Expand All @@ -131,7 +128,6 @@ fn math_checked_divide_op_on_iters<T, F>(
) -> Result<PrimitiveArray<T>>
where
T: ArrowNumericType,
T::Native: One + Zero,
F: Fn(T::Native, T::Native) -> Result<T::Native>,
{
let buffer = if null_bit_buffer.is_some() {
Expand Down Expand Up @@ -182,10 +178,10 @@ fn simd_checked_modulus<T: ArrowNumericType>(
right: T::Simd,
) -> Result<T::Simd>
where
T::Native: ArrowNativeTypeOp + One,
T::Native: ArrowNativeTypeOp,
{
let zero = T::init(T::Native::zero());
let one = T::init(T::Native::one());
let zero = T::init(T::Native::ZERO);
let one = T::init(T::Native::ONE);

let right_no_invalid_zeros = match valid_mask {
Some(mask) => {
Expand Down Expand Up @@ -219,10 +215,10 @@ fn simd_checked_divide<T: ArrowNumericType>(
right: T::Simd,
) -> Result<T::Simd>
where
T::Native: One + Zero,
T::Native: ArrowNativeTypeOp,
{
let zero = T::init(T::Native::zero());
let one = T::init(T::Native::one());
let zero = T::init(T::Native::ZERO);
let one = T::init(T::Native::ONE);

let right_no_invalid_zeros = match valid_mask {
Some(mask) => {
Expand Down Expand Up @@ -260,7 +256,7 @@ fn simd_checked_divide_op_remainder<T, F>(
) -> Result<()>
where
T: ArrowNumericType,
T::Native: Zero,
T::Native: ArrowNativeTypeOp,
F: Fn(T::Native, T::Native) -> T::Native,
{
let result_remainder = result_chunks.into_remainder();
Expand All @@ -273,7 +269,7 @@ where
.enumerate()
.try_for_each(|(i, (result_scalar, (left_scalar, right_scalar)))| {
if valid_mask.map(|mask| mask & (1 << i) != 0).unwrap_or(true) {
if *right_scalar == T::Native::zero() {
if right_scalar.is_zero() {
return Err(ArrowError::DivideByZero);
}
*result_scalar = op(*left_scalar, *right_scalar);
Expand Down Expand Up @@ -648,7 +644,6 @@ fn math_divide_checked_op_dict<K, T, F>(
where
K: ArrowNumericType,
T: ArrowNumericType,
T::Native: One + Zero,
F: Fn(T::Native, T::Native) -> Result<T::Native>,
{
if left.len() != right.len() {
Expand Down Expand Up @@ -702,7 +697,6 @@ fn math_divide_safe_op_dict<K, T, F>(
where
K: ArrowNumericType,
T: ArrowNumericType,
T::Native: One + Zero,
F: Fn(T::Native, T::Native) -> Option<T::Native>,
{
let left = left.downcast_dict::<PrimitiveArray<T>>().unwrap();
Expand All @@ -719,7 +713,6 @@ fn math_safe_divide_op<LT, RT, F>(
where
LT: ArrowNumericType,
RT: ArrowNumericType,
RT::Native: One + Zero,
F: Fn(LT::Native, RT::Native) -> Option<LT::Native>,
{
let array: PrimitiveArray<LT> = binary_opt::<_, _, _, LT>(left, right, op)?;
Expand Down Expand Up @@ -1068,8 +1061,8 @@ pub fn subtract_scalar<T>(
scalar: T::Native,
) -> Result<PrimitiveArray<T>>
where
T: datatypes::ArrowNumericType,
T::Native: ArrowNativeTypeOp + Zero,
T: ArrowNumericType,
T::Native: ArrowNativeTypeOp,
{
Ok(unary(array, |value| value.sub_wrapping(scalar)))
}
Expand All @@ -1085,7 +1078,7 @@ pub fn subtract_scalar_checked<T>(
) -> Result<PrimitiveArray<T>>
where
T: ArrowNumericType,
T::Native: ArrowNativeTypeOp + Zero,
T::Native: ArrowNativeTypeOp,
{
try_unary(array, |value| value.sub_checked(scalar))
}
Expand Down Expand Up @@ -1125,7 +1118,7 @@ where
/// Perform `-` operation on an array. If value is null then the result is also null.
pub fn negate<T>(array: &PrimitiveArray<T>) -> Result<PrimitiveArray<T>>
where
T: datatypes::ArrowNumericType,
T: ArrowNumericType,
T::Native: Neg<Output = T::Native>,
{
Ok(unary(array, |x| -x))
Expand Down Expand Up @@ -1239,7 +1232,7 @@ pub fn multiply_scalar<T>(
) -> Result<PrimitiveArray<T>>
where
T: datatypes::ArrowNumericType,
T::Native: ArrowNativeTypeOp + Zero + One,
T::Native: ArrowNativeTypeOp,
{
Ok(unary(array, |value| value.mul_wrapping(scalar)))
}
Expand All @@ -1255,7 +1248,7 @@ pub fn multiply_scalar_checked<T>(
) -> Result<PrimitiveArray<T>>
where
T: ArrowNumericType,
T::Native: ArrowNativeTypeOp + Zero + One,
T::Native: ArrowNativeTypeOp,
{
try_unary(array, |value| value.mul_checked(scalar))
}
Expand Down Expand Up @@ -1301,11 +1294,11 @@ pub fn modulus<T>(
) -> Result<PrimitiveArray<T>>
where
T: ArrowNumericType,
T::Native: ArrowNativeTypeOp + One,
T::Native: ArrowNativeTypeOp,
{
#[cfg(feature = "simd")]
return simd_checked_divide_op(&left, &right, simd_checked_modulus::<T>, |a, b| {
a % b
a.mod_wrapping(b)
});
#[cfg(not(feature = "simd"))]
return try_binary(left, right, |a, b| {
Expand All @@ -1328,11 +1321,13 @@ pub fn divide_checked<T>(
right: &PrimitiveArray<T>,
) -> Result<PrimitiveArray<T>>
where
T: datatypes::ArrowNumericType,
T::Native: ArrowNativeTypeOp + Zero + One,
T: ArrowNumericType,
T::Native: ArrowNativeTypeOp,
{
#[cfg(feature = "simd")]
return simd_checked_divide_op(&left, &right, simd_checked_divide::<T>, |a, b| a / b);
return simd_checked_divide_op(&left, &right, simd_checked_divide::<T>, |a, b| {
a.div_wrapping(b)
});
#[cfg(not(feature = "simd"))]
return math_checked_divide_op(left, right, |a, b| a.div_checked(b));
}
Expand All @@ -1343,16 +1338,21 @@ where
/// If any right hand value is zero, the operation value will be replaced with null in the
/// result.
///
/// Unlike `divide` or `divide_checked`, division by zero will get a null value instead
/// returning an `Err`, this also doesn't check overflowing, overflowing will just wrap
/// the result around.
/// Unlike [`divide`] or [`divide_checked`], division by zero will yield a null value in the
/// result instead of returning an `Err`.
///
/// For floating point types overflow will saturate at INF or -INF
/// preserving the expected sign value.
///
/// For integer types overflow will wrap around.
///
pub fn divide_opt<T>(
left: &PrimitiveArray<T>,
right: &PrimitiveArray<T>,
) -> Result<PrimitiveArray<T>>
where
T: ArrowNumericType,
T::Native: ArrowNativeTypeOp + Zero + One,
T::Native: ArrowNativeTypeOp,
{
binary_opt(left, right, |a, b| {
if b.is_zero() {
Expand Down Expand Up @@ -1480,12 +1480,16 @@ pub fn divide_dyn_opt(left: &dyn Array, right: &dyn Array) -> Result<ArrayRef> {
}
}

/// Perform `left / right` operation on two arrays without checking for division by zero.
/// For floating point types, the result of dividing by zero follows normal floating point
/// rules. For other numeric types, dividing by zero will panic,
/// If either left or right value is null then the result is also null. If any right hand value is zero then the result of this
/// Perform `left / right` operation on two arrays without checking for
/// division by zero or overflow.
///
/// For floating point types, overflow and division by zero follows normal floating point rules
///
/// For integer types overflow will wrap around. Division by zero will currently panic, although
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I plan to address this in a follow up, the current kernel is effectively useless for integer types

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Useless for integer types? How?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because nulls will have the value zero which then panic

/// this may be subject to change see <https://github.com/apache/arrow-rs/issues/2647>
///
/// If either left or right value is null then the result is also null.
///
/// This doesn't detect overflow. Once overflowing, the result will wrap around.
/// For an overflow-checking variant, use `divide_checked` instead.
pub fn divide<T>(
left: &PrimitiveArray<T>,
Expand All @@ -1495,6 +1499,8 @@ where
T: ArrowNumericType,
T::Native: ArrowNativeTypeOp,
{
// TODO: This is incorrect as div_wrapping has side-effects for integer types
// and so may panic on null values (#2647)
math_op(left, right, |a, b| a.div_wrapping(b))
}

Expand Down Expand Up @@ -1525,12 +1531,12 @@ pub fn divide_scalar<T>(
) -> Result<PrimitiveArray<T>>
where
T: ArrowNumericType,
T::Native: Div<Output = T::Native> + Zero,
T::Native: ArrowNativeTypeOp,
{
if divisor.is_zero() {
return Err(ArrowError::DivideByZero);
}
Ok(unary(array, |a| a / divisor))
Ok(unary(array, |a| a.div_wrapping(divisor)))
}

/// Divide every value in an array by a scalar. If any value in the array is null then the
Expand All @@ -1543,7 +1549,7 @@ where
pub fn divide_scalar_dyn<T>(array: &dyn Array, divisor: T::Native) -> Result<ArrayRef>
where
T: ArrowNumericType,
T::Native: ArrowNativeTypeOp + Zero,
T::Native: ArrowNativeTypeOp,
{
if divisor.is_zero() {
return Err(ArrowError::DivideByZero);
Expand All @@ -1564,7 +1570,7 @@ pub fn divide_scalar_checked_dyn<T>(
) -> Result<ArrayRef>
where
T: ArrowNumericType,
T::Native: ArrowNativeTypeOp + Zero,
T::Native: ArrowNativeTypeOp,
{
if divisor.is_zero() {
return Err(ArrowError::DivideByZero);
Expand All @@ -1587,7 +1593,7 @@ where
pub fn divide_scalar_opt_dyn<T>(array: &dyn Array, divisor: T::Native) -> Result<ArrayRef>
where
T: ArrowNumericType,
T::Native: ArrowNativeTypeOp + Zero,
T::Native: ArrowNativeTypeOp,
{
if divisor.is_zero() {
match array.data_type() {
Expand Down