diff --git a/arrow/src/compute/kernels/arithmetic.rs b/arrow/src/compute/kernels/arithmetic.rs index 04865e15bca..1f812b67e98 100644 --- a/arrow/src/compute/kernels/arithmetic.rs +++ b/arrow/src/compute/kernels/arithmetic.rs @@ -32,14 +32,20 @@ use crate::buffer::MutableBuffer; use crate::compute::kernels::arity::unary; use crate::compute::util::combine_option_bitmap; use crate::datatypes; -use crate::datatypes::ArrowNumericType; +use crate::datatypes::{ArrowNumericType, DataType}; +use crate::datatypes::{ + Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, UInt16Type, + UInt32Type, UInt64Type, UInt8Type, +}; use crate::error::{ArrowError, Result}; use crate::{array::*, util::bit_util}; use num::traits::Pow; +use std::any::type_name; #[cfg(feature = "simd")] use std::borrow::BorrowMut; #[cfg(feature = "simd")] use std::slice::{ChunksExact, ChunksExactMut}; +use std::sync::Arc; /// Helper function to perform math lambda function on values from two arrays. If either /// left or right value is null then the output value is also null, so `1 + null` is @@ -423,6 +429,247 @@ where Ok(PrimitiveArray::::from(data)) } +/// Applies $OP to $LEFT and $RIGHT which are two dictionaries which have (the same) key type $KT +macro_rules! typed_dict_op { + ($LEFT: expr, $RIGHT: expr, $OP: expr, $KT: tt) => {{ + match ($LEFT.value_type(), $RIGHT.value_type()) { + (DataType::Int8, DataType::Int8) => { + let array = math_op_dict::<$KT, Int8Type, _>($LEFT, $RIGHT, $OP)?; + Ok(Arc::new(array)) + } + (DataType::Int16, DataType::Int16) => { + let array = math_op_dict::<$KT, Int16Type, _>($LEFT, $RIGHT, $OP)?; + Ok(Arc::new(array)) + } + (DataType::Int32, DataType::Int32) => { + let array = math_op_dict::<$KT, Int32Type, _>($LEFT, $RIGHT, $OP)?; + Ok(Arc::new(array)) + } + (DataType::Int64, DataType::Int64) => { + let array = math_op_dict::<$KT, Int64Type, _>($LEFT, $RIGHT, $OP)?; + Ok(Arc::new(array)) + } + (DataType::UInt8, DataType::UInt8) => { + let array = math_op_dict::<$KT, UInt8Type, _>($LEFT, $RIGHT, $OP)?; + Ok(Arc::new(array)) + } + (DataType::UInt16, DataType::UInt16) => { + let array = math_op_dict::<$KT, UInt16Type, _>($LEFT, $RIGHT, $OP)?; + Ok(Arc::new(array)) + } + (DataType::UInt32, DataType::UInt32) => { + let array = math_op_dict::<$KT, UInt32Type, _>($LEFT, $RIGHT, $OP)?; + Ok(Arc::new(array)) + } + (DataType::UInt64, DataType::UInt64) => { + let array = math_op_dict::<$KT, UInt64Type, _>($LEFT, $RIGHT, $OP)?; + Ok(Arc::new(array)) + } + (DataType::Float32, DataType::Float32) => { + let array = math_op_dict::<$KT, Float32Type, _>($LEFT, $RIGHT, $OP)?; + Ok(Arc::new(array)) + } + (DataType::Float64, DataType::Float64) => { + let array = math_op_dict::<$KT, Float64Type, _>($LEFT, $RIGHT, $OP)?; + Ok(Arc::new(array)) + } + (t1, t2) => Err(ArrowError::CastError(format!( + "Cannot perform arithmetic operation on two dictionary arrays of different value types ({} and {})", + t1, t2 + ))), + } + }}; +} + +macro_rules! typed_dict_math_op { + // Applies `LEFT OP RIGHT` when `LEFT` and `RIGHT` both are `DictionaryArray` + ($LEFT: expr, $RIGHT: expr, $OP: expr) => {{ + match ($LEFT.data_type(), $RIGHT.data_type()) { + (DataType::Dictionary(left_key_type, _), DataType::Dictionary(right_key_type, _))=> { + match (left_key_type.as_ref(), right_key_type.as_ref()) { + (DataType::Int8, DataType::Int8) => { + let left = as_dictionary_array::($LEFT); + let right = as_dictionary_array::($RIGHT); + typed_dict_op!(left, right, $OP, Int8Type) + } + (DataType::Int16, DataType::Int16) => { + let left = as_dictionary_array::($LEFT); + let right = as_dictionary_array::($RIGHT); + typed_dict_op!(left, right, $OP, Int16Type) + } + (DataType::Int32, DataType::Int32) => { + let left = as_dictionary_array::($LEFT); + let right = as_dictionary_array::($RIGHT); + typed_dict_op!(left, right, $OP, Int32Type) + } + (DataType::Int64, DataType::Int64) => { + let left = as_dictionary_array::($LEFT); + let right = as_dictionary_array::($RIGHT); + typed_dict_op!(left, right, $OP, Int64Type) + } + (DataType::UInt8, DataType::UInt8) => { + let left = as_dictionary_array::($LEFT); + let right = as_dictionary_array::($RIGHT); + typed_dict_op!(left, right, $OP, UInt8Type) + } + (DataType::UInt16, DataType::UInt16) => { + let left = as_dictionary_array::($LEFT); + let right = as_dictionary_array::($RIGHT); + typed_dict_op!(left, right, $OP, UInt16Type) + } + (DataType::UInt32, DataType::UInt32) => { + let left = as_dictionary_array::($LEFT); + let right = as_dictionary_array::($RIGHT); + typed_dict_op!(left, right, $OP, UInt32Type) + } + (DataType::UInt64, DataType::UInt64) => { + let left = as_dictionary_array::($LEFT); + let right = as_dictionary_array::($RIGHT); + typed_dict_op!(left, right, $OP, UInt64Type) + } + (t1, t2) => Err(ArrowError::CastError(format!( + "Cannot perform arithmetic operation on two dictionary arrays of different key types ({} and {})", + t1, t2 + ))), + } + } + (t1, t2) => Err(ArrowError::CastError(format!( + "Cannot perform arithmetic operation on dictionary array with non-dictionary array ({} and {})", + t1, t2 + ))), + } + }}; +} + +macro_rules! typed_op { + ($LEFT: expr, $RIGHT: expr, $T: ident, $OP: expr) => {{ + let left = $LEFT + .as_any() + .downcast_ref::>() + .ok_or_else(|| { + ArrowError::CastError(format!( + "Left array cannot be cast to {}", + type_name::<$T>() + )) + })?; + let right = $RIGHT + .as_any() + .downcast_ref::>() + .ok_or_else(|| { + ArrowError::CastError(format!( + "Right array cannot be cast to {}", + type_name::<$T>(), + )) + })?; + let array = math_op(left, right, $OP)?; + Ok(Arc::new(array)) + }}; +} + +macro_rules! typed_math_op { + ($LEFT: expr, $RIGHT: expr, $OP: expr) => {{ + match $LEFT.data_type() { + DataType::Int8 => { + typed_op!($LEFT, $RIGHT, Int8Type, $OP) + } + DataType::Int16 => { + typed_op!($LEFT, $RIGHT, Int16Type, $OP) + } + DataType::Int32 => { + typed_op!($LEFT, $RIGHT, Int32Type, $OP) + } + DataType::Int64 => { + typed_op!($LEFT, $RIGHT, Int64Type, $OP) + } + DataType::UInt8 => { + typed_op!($LEFT, $RIGHT, UInt8Type, $OP) + } + DataType::UInt16 => { + typed_op!($LEFT, $RIGHT, UInt16Type, $OP) + } + DataType::UInt32 => { + typed_op!($LEFT, $RIGHT, UInt32Type, $OP) + } + DataType::UInt64 => { + typed_op!($LEFT, $RIGHT, UInt64Type, $OP) + } + DataType::Float32 => { + typed_op!($LEFT, $RIGHT, Float32Type, $OP) + } + DataType::Float64 => { + typed_op!($LEFT, $RIGHT, Float64Type, $OP) + } + t => Err(ArrowError::CastError(format!( + "Cannot perform arithmetic operation on arrays of type {}", + t + ))), + } + }}; +} + +/// Helper function to perform boolean lambda function on values from two dictionary arrays, this +/// version does not attempt to use SIMD explicitly (though the compiler may auto vectorize) +macro_rules! math_dict_op { + ($left: expr, $right:expr, $op:expr, $value_ty:ty) => {{ + if $left.len() != $right.len() { + return Err(ArrowError::ComputeError(format!( + "Cannot perform operation on arrays of different length ({}, {})", + $left.len(), + $right.len() + ))); + } + + // Safety justification: Since the inputs are valid Arrow arrays, all values are + // valid indexes into the dictionary (which is verified during construction) + + let left_iter = unsafe { + $left + .values() + .as_any() + .downcast_ref::<$value_ty>() + .unwrap() + .take_iter_unchecked($left.keys_iter()) + }; + + let right_iter = unsafe { + $right + .values() + .as_any() + .downcast_ref::<$value_ty>() + .unwrap() + .take_iter_unchecked($right.keys_iter()) + }; + + let result = left_iter + .zip(right_iter) + .map(|(left_value, right_value)| { + if let (Some(left), Some(right)) = (left_value, right_value) { + Some($op(left, right)) + } else { + None + } + }) + .collect(); + + Ok(result) + }}; +} + +/// Perform given operation on two `DictionaryArray`s. +/// Returns an error if the two arrays have different value type +fn math_op_dict( + left: &DictionaryArray, + right: &DictionaryArray, + op: F, +) -> Result> +where + K: ArrowNumericType, + T: ArrowNumericType, + F: Fn(T::Native, T::Native) -> T::Native, +{ + math_dict_op!(left, right, op, PrimitiveArray) +} + /// Perform `left + right` operation on two arrays. If either left or right value is null /// then the result is also null. pub fn add( @@ -436,6 +683,17 @@ where math_op(left, right, |a, b| a + b) } +/// Perform `left + right` operation on two arrays. If either left or right value is null +/// then the result is also null. +pub fn add_dyn(left: &dyn Array, right: &dyn Array) -> Result { + match left.data_type() { + DataType::Dictionary(_, _) => { + typed_dict_math_op!(left, right, |a, b| a + b) + } + _ => typed_math_op!(left, right, |a, b| a + b), + } +} + /// Add every value in an array by a scalar. If any value in the array is null then the /// result is also null. pub fn add_scalar( @@ -634,6 +892,50 @@ mod tests { assert_eq!(17, c.value(4)); } + #[test] + fn test_primitive_array_add_dyn() { + let a = Int32Array::from(vec![Some(5), Some(6), Some(7), Some(8), Some(9)]); + let b = Int32Array::from(vec![Some(6), Some(7), Some(8), None, Some(8)]); + let c = add_dyn(&a, &b).unwrap(); + let c = c.as_any().downcast_ref::().unwrap(); + assert_eq!(11, c.value(0)); + assert_eq!(13, c.value(1)); + assert_eq!(15, c.value(2)); + assert!(c.is_null(3)); + assert_eq!(17, c.value(4)); + } + + #[test] + fn test_primitive_array_add_dyn_dict() { + let key_builder = PrimitiveBuilder::::new(3); + let value_builder = PrimitiveBuilder::::new(2); + let mut builder = PrimitiveDictionaryBuilder::new(key_builder, value_builder); + builder.append(5).unwrap(); + builder.append(6).unwrap(); + builder.append(7).unwrap(); + builder.append(8).unwrap(); + builder.append(9).unwrap(); + let a = builder.finish(); + + let key_builder = PrimitiveBuilder::::new(3); + let value_builder = PrimitiveBuilder::::new(2); + let mut builder = PrimitiveDictionaryBuilder::new(key_builder, value_builder); + builder.append(6).unwrap(); + builder.append(7).unwrap(); + builder.append(8).unwrap(); + builder.append_null().unwrap(); + builder.append(10).unwrap(); + let b = builder.finish(); + + let c = add_dyn(&a, &b).unwrap(); + let c = c.as_any().downcast_ref::().unwrap(); + assert_eq!(11, c.value(0)); + assert_eq!(13, c.value(1)); + assert_eq!(15, c.value(2)); + assert!(c.is_null(3)); + assert_eq!(19, c.value(4)); + } + #[test] fn test_primitive_array_add_sliced() { let a = Int32Array::from(vec![0, 0, 0, 5, 6, 7, 8, 9, 0]);