diff --git a/arrow/src/compute/kernels/arithmetic.rs b/arrow/src/compute/kernels/arithmetic.rs index 9b860feeec4..3fed10adb1f 100644 --- a/arrow/src/compute/kernels/arithmetic.rs +++ b/arrow/src/compute/kernels/arithmetic.rs @@ -26,6 +26,7 @@ use std::ops::{Add, Div, Mul, Neg, Rem, Sub}; use num::{One, Zero}; +use crate::array::*; use crate::buffer::Buffer; #[cfg(feature = "simd")] use crate::buffer::MutableBuffer; @@ -39,7 +40,6 @@ use crate::datatypes::{ UInt32Type, UInt64Type, UInt8Type, }; use crate::error::{ArrowError, Result}; -use crate::{array::*, util::bit_util}; use num::traits::Pow; use std::any::type_name; #[cfg(feature = "simd")] @@ -126,36 +126,60 @@ where let null_bit_buffer = combine_option_bitmap(&[left.data_ref(), right.data_ref()], left.len())?; - let buffer = if let Some(b) = &null_bit_buffer { - let values = left.values().iter().zip(right.values()).enumerate().map( - |(i, (left, right))| { - let is_valid = unsafe { bit_util::get_bit_raw(b.as_ptr(), i) }; - if is_valid { - if right.is_zero() { - Err(ArrowError::DivideByZero) - } else { - Ok(op(*left, *right)) - } + math_checked_divide_op_on_iters( + left.into_iter(), + right.into_iter(), + op, + left.len(), + null_bit_buffer, + ) +} + +/// Helper function for operations where a valid `0` on the right array should +/// result in an [ArrowError::DivideByZero], namely the division and modulo operations +/// +/// # Errors +/// +/// This function errors if: +/// * the arrays have different lengths +/// * there is an element where both left and right values are valid and the right value is `0` +fn math_checked_divide_op_on_iters( + left: impl Iterator>, + right: impl Iterator>, + op: F, + len: usize, + null_bit_buffer: Option, +) -> Result> +where + T: ArrowNumericType, + T::Native: One + Zero, + F: Fn(T::Native, T::Native) -> T::Native, +{ + let buffer = if null_bit_buffer.is_some() { + let values = left.zip(right).map(|(left, right)| { + if let (Some(l), Some(r)) = (left, right) { + if r.is_zero() { + Err(ArrowError::DivideByZero) } else { - Ok(T::default_value()) + Ok(op(l, r)) } - }, - ); + } else { + Ok(T::default_value()) + } + }); // Safety: Iterator comes from a PrimitiveArray which reports its size correctly unsafe { Buffer::try_from_trusted_len_iter(values) } } else { // no value is null - let values = left - .values() - .iter() - .zip(right.values()) - .map(|(left, right)| { + let values = left.map(|l| l.unwrap()).zip(right.map(|r| r.unwrap())).map( + |(left, right)| { if right.is_zero() { Err(ArrowError::DivideByZero) } else { - Ok(op(*left, *right)) + Ok(op(left, right)) } - }); + }, + ); // Safety: Iterator comes from a PrimitiveArray which reports its size correctly unsafe { Buffer::try_from_trusted_len_iter(values) } }?; @@ -163,7 +187,7 @@ where let data = unsafe { ArrayData::new_unchecked( T::DATA_TYPE, - left.len(), + len, None, null_bit_buffer, 0, @@ -432,46 +456,46 @@ where /// Applies $OP to $LEFT and $RIGHT which are two dictionaries which have (the same) key type $KT macro_rules! typed_dict_op { - ($LEFT: expr, $RIGHT: expr, $OP: expr, $KT: tt) => {{ + ($LEFT: expr, $RIGHT: expr, $OP: expr, $KT: tt, $MATH_OP: ident) => {{ match ($LEFT.value_type(), $RIGHT.value_type()) { (DataType::Int8, DataType::Int8) => { - let array = math_op_dict::<$KT, Int8Type, _>($LEFT, $RIGHT, $OP)?; + let array = $MATH_OP::<$KT, Int8Type, _>($LEFT, $RIGHT, $OP)?; Ok(Arc::new(array)) } (DataType::Int16, DataType::Int16) => { - let array = math_op_dict::<$KT, Int16Type, _>($LEFT, $RIGHT, $OP)?; + let array = $MATH_OP::<$KT, Int16Type, _>($LEFT, $RIGHT, $OP)?; Ok(Arc::new(array)) } (DataType::Int32, DataType::Int32) => { - let array = math_op_dict::<$KT, Int32Type, _>($LEFT, $RIGHT, $OP)?; + let array = $MATH_OP::<$KT, Int32Type, _>($LEFT, $RIGHT, $OP)?; Ok(Arc::new(array)) } (DataType::Int64, DataType::Int64) => { - let array = math_op_dict::<$KT, Int64Type, _>($LEFT, $RIGHT, $OP)?; + let array = $MATH_OP::<$KT, Int64Type, _>($LEFT, $RIGHT, $OP)?; Ok(Arc::new(array)) } (DataType::UInt8, DataType::UInt8) => { - let array = math_op_dict::<$KT, UInt8Type, _>($LEFT, $RIGHT, $OP)?; + let array = $MATH_OP::<$KT, UInt8Type, _>($LEFT, $RIGHT, $OP)?; Ok(Arc::new(array)) } (DataType::UInt16, DataType::UInt16) => { - let array = math_op_dict::<$KT, UInt16Type, _>($LEFT, $RIGHT, $OP)?; + let array = $MATH_OP::<$KT, UInt16Type, _>($LEFT, $RIGHT, $OP)?; Ok(Arc::new(array)) } (DataType::UInt32, DataType::UInt32) => { - let array = math_op_dict::<$KT, UInt32Type, _>($LEFT, $RIGHT, $OP)?; + let array = $MATH_OP::<$KT, UInt32Type, _>($LEFT, $RIGHT, $OP)?; Ok(Arc::new(array)) } (DataType::UInt64, DataType::UInt64) => { - let array = math_op_dict::<$KT, UInt64Type, _>($LEFT, $RIGHT, $OP)?; + let array = $MATH_OP::<$KT, UInt64Type, _>($LEFT, $RIGHT, $OP)?; Ok(Arc::new(array)) } (DataType::Float32, DataType::Float32) => { - let array = math_op_dict::<$KT, Float32Type, _>($LEFT, $RIGHT, $OP)?; + let array = $MATH_OP::<$KT, Float32Type, _>($LEFT, $RIGHT, $OP)?; Ok(Arc::new(array)) } (DataType::Float64, DataType::Float64) => { - let array = math_op_dict::<$KT, Float64Type, _>($LEFT, $RIGHT, $OP)?; + let array = $MATH_OP::<$KT, Float64Type, _>($LEFT, $RIGHT, $OP)?; Ok(Arc::new(array)) } (t1, t2) => Err(ArrowError::CastError(format!( @@ -484,49 +508,49 @@ macro_rules! typed_dict_op { macro_rules! typed_dict_math_op { // Applies `LEFT OP RIGHT` when `LEFT` and `RIGHT` both are `DictionaryArray` - ($LEFT: expr, $RIGHT: expr, $OP: expr) => {{ + ($LEFT: expr, $RIGHT: expr, $OP: expr, $MATH_OP: ident) => {{ match ($LEFT.data_type(), $RIGHT.data_type()) { (DataType::Dictionary(left_key_type, _), DataType::Dictionary(right_key_type, _))=> { match (left_key_type.as_ref(), right_key_type.as_ref()) { (DataType::Int8, DataType::Int8) => { let left = as_dictionary_array::($LEFT); let right = as_dictionary_array::($RIGHT); - typed_dict_op!(left, right, $OP, Int8Type) + typed_dict_op!(left, right, $OP, Int8Type, $MATH_OP) } (DataType::Int16, DataType::Int16) => { let left = as_dictionary_array::($LEFT); let right = as_dictionary_array::($RIGHT); - typed_dict_op!(left, right, $OP, Int16Type) + typed_dict_op!(left, right, $OP, Int16Type, $MATH_OP) } (DataType::Int32, DataType::Int32) => { let left = as_dictionary_array::($LEFT); let right = as_dictionary_array::($RIGHT); - typed_dict_op!(left, right, $OP, Int32Type) + typed_dict_op!(left, right, $OP, Int32Type, $MATH_OP) } (DataType::Int64, DataType::Int64) => { let left = as_dictionary_array::($LEFT); let right = as_dictionary_array::($RIGHT); - typed_dict_op!(left, right, $OP, Int64Type) + typed_dict_op!(left, right, $OP, Int64Type, $MATH_OP) } (DataType::UInt8, DataType::UInt8) => { let left = as_dictionary_array::($LEFT); let right = as_dictionary_array::($RIGHT); - typed_dict_op!(left, right, $OP, UInt8Type) + typed_dict_op!(left, right, $OP, UInt8Type, $MATH_OP) } (DataType::UInt16, DataType::UInt16) => { let left = as_dictionary_array::($LEFT); let right = as_dictionary_array::($RIGHT); - typed_dict_op!(left, right, $OP, UInt16Type) + typed_dict_op!(left, right, $OP, UInt16Type, $MATH_OP) } (DataType::UInt32, DataType::UInt32) => { let left = as_dictionary_array::($LEFT); let right = as_dictionary_array::($RIGHT); - typed_dict_op!(left, right, $OP, UInt32Type) + typed_dict_op!(left, right, $OP, UInt32Type, $MATH_OP) } (DataType::UInt64, DataType::UInt64) => { let left = as_dictionary_array::($LEFT); let right = as_dictionary_array::($RIGHT); - typed_dict_op!(left, right, $OP, UInt64Type) + typed_dict_op!(left, right, $OP, UInt64Type, $MATH_OP) } (t1, t2) => Err(ArrowError::CastError(format!( "Cannot perform arithmetic operation on two dictionary arrays of different key types ({} and {})", @@ -543,7 +567,7 @@ macro_rules! typed_dict_math_op { } macro_rules! typed_op { - ($LEFT: expr, $RIGHT: expr, $T: ident, $OP: expr) => {{ + ($LEFT: expr, $RIGHT: expr, $T: ident, $OP: expr, $MATH_OP: ident) => {{ let left = $LEFT .as_any() .downcast_ref::>() @@ -562,43 +586,43 @@ macro_rules! typed_op { type_name::<$T>(), )) })?; - let array = math_op(left, right, $OP)?; + let array = $MATH_OP(left, right, $OP)?; Ok(Arc::new(array)) }}; } macro_rules! typed_math_op { - ($LEFT: expr, $RIGHT: expr, $OP: expr) => {{ + ($LEFT: expr, $RIGHT: expr, $OP: expr, $MATH_OP: ident) => {{ match $LEFT.data_type() { DataType::Int8 => { - typed_op!($LEFT, $RIGHT, Int8Type, $OP) + typed_op!($LEFT, $RIGHT, Int8Type, $OP, $MATH_OP) } DataType::Int16 => { - typed_op!($LEFT, $RIGHT, Int16Type, $OP) + typed_op!($LEFT, $RIGHT, Int16Type, $OP, $MATH_OP) } DataType::Int32 => { - typed_op!($LEFT, $RIGHT, Int32Type, $OP) + typed_op!($LEFT, $RIGHT, Int32Type, $OP, $MATH_OP) } DataType::Int64 => { - typed_op!($LEFT, $RIGHT, Int64Type, $OP) + typed_op!($LEFT, $RIGHT, Int64Type, $OP, $MATH_OP) } DataType::UInt8 => { - typed_op!($LEFT, $RIGHT, UInt8Type, $OP) + typed_op!($LEFT, $RIGHT, UInt8Type, $OP, $MATH_OP) } DataType::UInt16 => { - typed_op!($LEFT, $RIGHT, UInt16Type, $OP) + typed_op!($LEFT, $RIGHT, UInt16Type, $OP, $MATH_OP) } DataType::UInt32 => { - typed_op!($LEFT, $RIGHT, UInt32Type, $OP) + typed_op!($LEFT, $RIGHT, UInt32Type, $OP, $MATH_OP) } DataType::UInt64 => { - typed_op!($LEFT, $RIGHT, UInt64Type, $OP) + typed_op!($LEFT, $RIGHT, UInt64Type, $OP, $MATH_OP) } DataType::Float32 => { - typed_op!($LEFT, $RIGHT, Float32Type, $OP) + typed_op!($LEFT, $RIGHT, Float32Type, $OP, $MATH_OP) } DataType::Float64 => { - typed_op!($LEFT, $RIGHT, Float64Type, $OP) + typed_op!($LEFT, $RIGHT, Float64Type, $OP, $MATH_OP) } t => Err(ArrowError::CastError(format!( "Cannot perform arithmetic operation on arrays of type {}", @@ -608,7 +632,7 @@ macro_rules! typed_math_op { }}; } -/// Helper function to perform boolean lambda function on values from two dictionary arrays, this +/// Helper function to perform math lambda function on values from two dictionary arrays, this /// version does not attempt to use SIMD explicitly (though the compiler may auto vectorize) macro_rules! math_dict_op { ($left: expr, $right:expr, $op:expr, $value_ty:ty) => {{ @@ -671,6 +695,65 @@ where math_dict_op!(left, right, op, PrimitiveArray) } +/// Helper function for operations where a valid `0` on the right array should +/// result in an [ArrowError::DivideByZero], namely the division and modulo operations +/// +/// # Errors +/// +/// This function errors if: +/// * the arrays have different lengths +/// * there is an element where both left and right values are valid and the right value is `0` +fn math_divide_checked_op_dict( + left: &DictionaryArray, + right: &DictionaryArray, + op: F, +) -> Result> +where + K: ArrowNumericType, + T: ArrowNumericType, + T::Native: One + Zero, + F: Fn(T::Native, T::Native) -> T::Native, +{ + if left.len() != right.len() { + return Err(ArrowError::ComputeError(format!( + "Cannot perform operation on arrays of different length ({}, {})", + left.len(), + right.len() + ))); + } + + let null_bit_buffer = + combine_option_bitmap(&[left.data_ref(), right.data_ref()], left.len())?; + + // Safety justification: Since the inputs are valid Arrow arrays, all values are + // valid indexes into the dictionary (which is verified during construction) + + let left_iter = unsafe { + left.values() + .as_any() + .downcast_ref::>() + .unwrap() + .take_iter_unchecked(left.keys_iter()) + }; + + let right_iter = unsafe { + right + .values() + .as_any() + .downcast_ref::>() + .unwrap() + .take_iter_unchecked(right.keys_iter()) + }; + + math_checked_divide_op_on_iters( + left_iter, + right_iter, + op, + left.len(), + null_bit_buffer, + ) +} + /// Perform `left + right` operation on two arrays. If either left or right value is null /// then the result is also null. pub fn add( @@ -689,9 +772,9 @@ where pub fn add_dyn(left: &dyn Array, right: &dyn Array) -> Result { match left.data_type() { DataType::Dictionary(_, _) => { - typed_dict_math_op!(left, right, |a, b| a + b) + typed_dict_math_op!(left, right, |a, b| a + b, math_op_dict) } - _ => typed_math_op!(left, right, |a, b| a + b), + _ => typed_math_op!(left, right, |a, b| a + b, math_op), } } @@ -737,9 +820,9 @@ where pub fn subtract_dyn(left: &dyn Array, right: &dyn Array) -> Result { match left.data_type() { DataType::Dictionary(_, _) => { - typed_dict_math_op!(left, right, |a, b| a - b) + typed_dict_math_op!(left, right, |a, b| a - b, math_op_dict) } - _ => typed_math_op!(left, right, |a, b| a - b), + _ => typed_math_op!(left, right, |a, b| a - b, math_op), } } @@ -814,9 +897,9 @@ where pub fn multiply_dyn(left: &dyn Array, right: &dyn Array) -> Result { match left.data_type() { DataType::Dictionary(_, _) => { - typed_dict_math_op!(left, right, |a, b| a * b) + typed_dict_math_op!(left, right, |a, b| a * b, math_op_dict) } - _ => typed_math_op!(left, right, |a, b| a * b), + _ => typed_math_op!(left, right, |a, b| a * b, math_op), } } @@ -892,6 +975,18 @@ where return math_checked_divide_op(left, right, |a, b| a / b); } +/// Perform `left / right` operation on two arrays. If either left or right value is null +/// then the result is also null. If any right hand value is zero then the result of this +/// operation will be `Err(ArrowError::DivideByZero)`. +pub fn divide_dyn(left: &dyn Array, right: &dyn Array) -> Result { + match left.data_type() { + DataType::Dictionary(_, _) => { + typed_dict_math_op!(left, right, |a, b| a / b, math_divide_checked_op_dict) + } + _ => typed_math_op!(left, right, |a, b| a / b, math_checked_divide_op), + } +} + /// Perform `left / right` operation on two arrays without checking for division by zero. /// The result of dividing by zero follows normal floating point rules. /// If either left or right value is null then the result is also null. If any right hand value is zero then the result of this @@ -1185,6 +1280,50 @@ mod tests { assert_eq!(90, c.value(4)); } + #[test] + fn test_primitive_array_divide_dyn() { + let a = Int32Array::from(vec![Some(15), Some(6), Some(1), Some(8), Some(9)]); + let b = Int32Array::from(vec![Some(5), Some(3), Some(1), None, Some(3)]); + let c = divide_dyn(&a, &b).unwrap(); + let c = c.as_any().downcast_ref::().unwrap(); + assert_eq!(3, c.value(0)); + assert_eq!(2, c.value(1)); + assert_eq!(1, c.value(2)); + assert!(c.is_null(3)); + assert_eq!(3, c.value(4)); + } + + #[test] + fn test_primitive_array_divide_dyn_dict() { + let key_builder = PrimitiveBuilder::::new(3); + let value_builder = PrimitiveBuilder::::new(2); + let mut builder = PrimitiveDictionaryBuilder::new(key_builder, value_builder); + builder.append(15).unwrap(); + builder.append(6).unwrap(); + builder.append(1).unwrap(); + builder.append(8).unwrap(); + builder.append(9).unwrap(); + let a = builder.finish(); + + let key_builder = PrimitiveBuilder::::new(3); + let value_builder = PrimitiveBuilder::::new(2); + let mut builder = PrimitiveDictionaryBuilder::new(key_builder, value_builder); + builder.append(5).unwrap(); + builder.append(3).unwrap(); + builder.append(1).unwrap(); + builder.append_null().unwrap(); + builder.append(3).unwrap(); + let b = builder.finish(); + + let c = divide_dyn(&a, &b).unwrap(); + let c = c.as_any().downcast_ref::().unwrap(); + assert_eq!(3, c.value(0)); + assert_eq!(2, c.value(1)); + assert_eq!(1, c.value(2)); + assert!(c.is_null(3)); + assert_eq!(3, c.value(4)); + } + #[test] fn test_primitive_array_multiply_scalar_dyn() { let a = Int32Array::from(vec![Some(5), Some(6), Some(7), None, Some(9)]); @@ -1640,6 +1779,32 @@ mod tests { divide(&a, &b).unwrap(); } + #[test] + #[should_panic(expected = "DivideByZero")] + fn test_primitive_array_divide_dyn_by_zero() { + let a = Int32Array::from(vec![15]); + let b = Int32Array::from(vec![0]); + divide_dyn(&a, &b).unwrap(); + } + + #[test] + #[should_panic(expected = "DivideByZero")] + fn test_primitive_array_divide_dyn_by_zero_dict() { + let key_builder = PrimitiveBuilder::::new(1); + let value_builder = PrimitiveBuilder::::new(1); + let mut builder = PrimitiveDictionaryBuilder::new(key_builder, value_builder); + builder.append(15).unwrap(); + let a = builder.finish(); + + let key_builder = PrimitiveBuilder::::new(1); + let value_builder = PrimitiveBuilder::::new(1); + let mut builder = PrimitiveDictionaryBuilder::new(key_builder, value_builder); + builder.append(0).unwrap(); + let b = builder.finish(); + + divide_dyn(&a, &b).unwrap(); + } + #[test] #[should_panic(expected = "DivideByZero")] fn test_primitive_array_modulus_by_zero() {