From 74f639ca8661c868a1aaa2aa6fe23e01f46f97d8 Mon Sep 17 00:00:00 2001 From: Raphael Taylor-Davies <1781103+tustvold@users.noreply.github.com> Date: Tue, 20 Sep 2022 22:30:27 +0100 Subject: [PATCH] Add dyn_arith_dict feature flag (#2760) * Add dyn_arith_dict feature flag * Document feature flag --- .github/workflows/arrow.yml | 6 +-- arrow/Cargo.toml | 3 ++ arrow/README.md | 1 + arrow/src/compute/kernels/arithmetic.rs | 51 ++++++++++++++++++++----- 4 files changed, 48 insertions(+), 13 deletions(-) diff --git a/.github/workflows/arrow.yml b/.github/workflows/arrow.yml index d81a551a3b4..cdd87ca1639 100644 --- a/.github/workflows/arrow.yml +++ b/.github/workflows/arrow.yml @@ -51,9 +51,9 @@ jobs: - name: Test run: | cargo test -p arrow - - name: Test --features=force_validate,prettyprint,ipc_compression,ffi,dyn_cmp_dict + - name: Test --features=force_validate,prettyprint,ipc_compression,ffi,dyn_cmp_dict,dyn_arith_dict run: | - cargo test -p arrow --features=force_validate,prettyprint,ipc_compression,ffi,dyn_cmp_dict + cargo test -p arrow --features=force_validate,prettyprint,ipc_compression,ffi,dyn_cmp_dict,dyn_arith_dict - name: Run examples run: | # Test arrow examples @@ -177,4 +177,4 @@ jobs: rustup component add clippy - name: Run clippy run: | - cargo clippy -p arrow --features=prettyprint,csv,ipc,test_utils,ffi,ipc_compression,dyn_cmp_dict --all-targets -- -D warnings + cargo clippy -p arrow --features=prettyprint,csv,ipc,test_utils,ffi,ipc_compression,dyn_cmp_dict,dyn_arith_dict --all-targets -- -D warnings diff --git a/arrow/Cargo.toml b/arrow/Cargo.toml index 7391ffcf827..f8dbf1481b5 100644 --- a/arrow/Cargo.toml +++ b/arrow/Cargo.toml @@ -95,6 +95,9 @@ ffi = [] # Enable dyn-comparison of dictionary arrays with other arrays # Note: this does not impact comparison against scalars dyn_cmp_dict = [] +# Enable dyn-arithmetic kernels for dictionary arrays +# Note: this does not impact arithmetic with scalars +dyn_arith_dict = [] [dev-dependencies] rand = { version = "0.8", default-features = false, features = ["std", "std_rng"] } diff --git a/arrow/README.md b/arrow/README.md index a1c0e6279a5..e168d4a09ee 100644 --- a/arrow/README.md +++ b/arrow/README.md @@ -54,6 +54,7 @@ The `arrow` crate provides the following features which may be enabled in your ` - `ffi` - bindings for the Arrow C [C Data Interface](https://arrow.apache.org/docs/format/CDataInterface.html) - `pyarrow` - bindings for pyo3 to call arrow-rs from python - `dyn_cmp_dict` - enables comparison of dictionary arrays within dyn comparison kernels +- `dyn_arith_dict` - enables arithmetic on dictionary arrays within dyn arithmetic kernels ## Arrow Feature Status diff --git a/arrow/src/compute/kernels/arithmetic.rs b/arrow/src/compute/kernels/arithmetic.rs index aa6c8cd6694..b44cb8b947e 100644 --- a/arrow/src/compute/kernels/arithmetic.rs +++ b/arrow/src/compute/kernels/arithmetic.rs @@ -27,11 +27,9 @@ use std::ops::{Div, Neg, Rem}; use num::{One, Zero}; use crate::array::*; -use crate::buffer::Buffer; #[cfg(feature = "simd")] use crate::buffer::MutableBuffer; use crate::compute::kernels::arity::unary; -use crate::compute::util::combine_option_bitmap; use crate::compute::{ binary, binary_opt, try_binary, try_unary, try_unary_dyn, unary_dyn, }; @@ -39,6 +37,7 @@ use crate::datatypes::{ native_op::ArrowNativeTypeOp, ArrowNumericType, DataType, Date32Type, Date64Type, IntervalDayTimeType, IntervalMonthDayNanoType, IntervalUnit, IntervalYearMonthType, }; +#[cfg(feature = "dyn_arith_dict")] use crate::datatypes::{ Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, UInt16Type, UInt32Type, UInt64Type, UInt8Type, @@ -122,12 +121,13 @@ where /// This function errors if: /// * the arrays have different lengths /// * there is an element where both left and right values are valid and the right value is `0` +#[cfg(feature = "dyn_arith_dict")] fn math_checked_divide_op_on_iters( left: impl Iterator>, right: impl Iterator>, op: F, len: usize, - null_bit_buffer: Option, + null_bit_buffer: Option, ) -> Result> where T: ArrowNumericType, @@ -143,7 +143,7 @@ where } }); // Safety: Iterator comes from a PrimitiveArray which reports its size correctly - unsafe { Buffer::try_from_trusted_len_iter(values) } + unsafe { crate::buffer::Buffer::try_from_trusted_len_iter(values) } } else { // no value is null let values = left @@ -151,7 +151,7 @@ where .zip(right.map(|r| r.unwrap())) .map(|(left, right)| op(left, right)); // Safety: Iterator comes from a PrimitiveArray which reports its size correctly - unsafe { Buffer::try_from_trusted_len_iter(values) } + unsafe { crate::buffer::Buffer::try_from_trusted_len_iter(values) } }?; let data = unsafe { @@ -316,8 +316,10 @@ where } // Create the combined `Bitmap` - let null_bit_buffer = - combine_option_bitmap(&[left.data_ref(), right.data_ref()], left.len())?; + let null_bit_buffer = crate::compute::util::combine_option_bitmap( + &[left.data_ref(), right.data_ref()], + left.len(), + )?; let lanes = T::lanes(); let buffer_size = left.len() * std::mem::size_of::(); @@ -425,6 +427,7 @@ where } /// Applies $OP to $LEFT and $RIGHT which are two dictionaries which have (the same) key type $KT +#[cfg(feature = "dyn_arith_dict")] macro_rules! typed_dict_op { ($LEFT: expr, $RIGHT: expr, $OP: expr, $KT: tt, $MATH_OP: ident) => {{ match ($LEFT.value_type(), $RIGHT.value_type()) { @@ -476,6 +479,7 @@ macro_rules! typed_dict_op { }}; } +#[cfg(feature = "dyn_arith_dict")] macro_rules! typed_dict_math_op { // Applies `LEFT OP RIGHT` when `LEFT` and `RIGHT` both are `DictionaryArray` ($LEFT: expr, $RIGHT: expr, $OP: expr, $MATH_OP: ident) => {{ @@ -536,8 +540,20 @@ macro_rules! typed_dict_math_op { }}; } +#[cfg(not(feature = "dyn_arith_dict"))] +macro_rules! typed_dict_math_op { + // Applies `LEFT OP RIGHT` when `LEFT` and `RIGHT` both are `DictionaryArray` + ($LEFT: expr, $RIGHT: expr, $OP: expr, $MATH_OP: ident) => {{ + Err(ArrowError::CastError(format!( + "Arithmetic on arrays of type {} with array of type {} requires \"dyn_arith_dict\" feature", + $LEFT.data_type(), $RIGHT.data_type() + ))) + }}; +} + /// Perform given operation on two `DictionaryArray`s. /// Returns an error if the two arrays have different value type +#[cfg(feature = "dyn_arith_dict")] fn math_op_dict( left: &DictionaryArray, right: &DictionaryArray, @@ -593,6 +609,7 @@ where /// Perform given operation on two `DictionaryArray`s. /// Returns an error if the two arrays have different value type +#[cfg(feature = "dyn_arith_dict")] fn math_checked_op_dict( left: &DictionaryArray, right: &DictionaryArray, @@ -626,6 +643,7 @@ where /// This function errors if: /// * the arrays have different lengths /// * there is an element where both left and right values are valid and the right value is `0` +#[cfg(feature = "dyn_arith_dict")] fn math_divide_checked_op_dict( left: &DictionaryArray, right: &DictionaryArray, @@ -645,8 +663,10 @@ where ))); } - let null_bit_buffer = - combine_option_bitmap(&[left.data_ref(), right.data_ref()], left.len())?; + let null_bit_buffer = crate::compute::util::combine_option_bitmap( + &[left.data_ref(), right.data_ref()], + left.len(), + )?; // Safety justification: Since the inputs are valid Arrow arrays, all values are // valid indexes into the dictionary (which is verified during construction) @@ -1484,7 +1504,7 @@ where mod tests { use super::*; use crate::array::Int32Array; - use crate::datatypes::Date64Type; + use crate::datatypes::{Date64Type, Int32Type, Int8Type}; use chrono::NaiveDate; #[test] @@ -1605,6 +1625,7 @@ mod tests { } #[test] + #[cfg(feature = "dyn_arith_dict")] fn test_primitive_array_add_dyn_dict() { let mut builder = PrimitiveDictionaryBuilder::::new(); builder.append(5).unwrap(); @@ -1683,6 +1704,7 @@ mod tests { } #[test] + #[cfg(feature = "dyn_arith_dict")] fn test_primitive_array_subtract_dyn_dict() { let mut builder = PrimitiveDictionaryBuilder::::new(); builder.append(15).unwrap(); @@ -1761,6 +1783,7 @@ mod tests { } #[test] + #[cfg(feature = "dyn_arith_dict")] fn test_primitive_array_multiply_dyn_dict() { let mut builder = PrimitiveDictionaryBuilder::::new(); builder.append(5).unwrap(); @@ -1801,6 +1824,7 @@ mod tests { } #[test] + #[cfg(feature = "dyn_arith_dict")] fn test_primitive_array_divide_dyn_dict() { let mut builder = PrimitiveDictionaryBuilder::::new(); builder.append(15).unwrap(); @@ -2322,6 +2346,7 @@ mod tests { #[test] #[should_panic(expected = "DivideByZero")] + #[cfg(feature = "dyn_arith_dict")] fn test_int_array_divide_dyn_by_zero_dict() { let mut builder = PrimitiveDictionaryBuilder::::with_capacity(1, 1); @@ -2338,7 +2363,9 @@ mod tests { #[test] #[should_panic(expected = "DivideByZero")] + #[cfg(feature = "dyn_arith_dict")] fn test_f32_dict_array_divide_dyn_by_zero() { + use crate::datatypes::Float32Type; let mut builder = PrimitiveDictionaryBuilder::::with_capacity(1, 1); builder.append(1.5).unwrap(); @@ -2601,6 +2628,7 @@ mod tests { } #[test] + #[cfg(feature = "dyn_arith_dict")] fn test_dictionary_add_dyn_wrapping_overflow() { let mut builder = PrimitiveDictionaryBuilder::::with_capacity(2, 2); @@ -2637,6 +2665,7 @@ mod tests { } #[test] + #[cfg(feature = "dyn_arith_dict")] fn test_dictionary_subtract_dyn_wrapping_overflow() { let mut builder = PrimitiveDictionaryBuilder::::with_capacity(1, 1); @@ -2670,6 +2699,7 @@ mod tests { } #[test] + #[cfg(feature = "dyn_arith_dict")] fn test_dictionary_mul_dyn_wrapping_overflow() { let mut builder = PrimitiveDictionaryBuilder::::with_capacity(1, 1); @@ -2703,6 +2733,7 @@ mod tests { } #[test] + #[cfg(feature = "dyn_arith_dict")] fn test_dictionary_div_dyn_wrapping_overflow() { let mut builder = PrimitiveDictionaryBuilder::::with_capacity(1, 1);