Skip to content

Commit

Permalink
Add dyn_arith_dict feature flag (#2760)
Browse files Browse the repository at this point in the history
* Add dyn_arith_dict feature flag

* Document feature flag
  • Loading branch information
tustvold committed Sep 20, 2022
1 parent 5b601b3 commit 74f639c
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 13 deletions.
6 changes: 3 additions & 3 deletions .github/workflows/arrow.yml
Expand Up @@ -51,9 +51,9 @@ jobs:
- name: Test
run: |
cargo test -p arrow
- name: Test --features=force_validate,prettyprint,ipc_compression,ffi,dyn_cmp_dict
- name: Test --features=force_validate,prettyprint,ipc_compression,ffi,dyn_cmp_dict,dyn_arith_dict
run: |
cargo test -p arrow --features=force_validate,prettyprint,ipc_compression,ffi,dyn_cmp_dict
cargo test -p arrow --features=force_validate,prettyprint,ipc_compression,ffi,dyn_cmp_dict,dyn_arith_dict
- name: Run examples
run: |
# Test arrow examples
Expand Down Expand Up @@ -177,4 +177,4 @@ jobs:
rustup component add clippy
- name: Run clippy
run: |
cargo clippy -p arrow --features=prettyprint,csv,ipc,test_utils,ffi,ipc_compression,dyn_cmp_dict --all-targets -- -D warnings
cargo clippy -p arrow --features=prettyprint,csv,ipc,test_utils,ffi,ipc_compression,dyn_cmp_dict,dyn_arith_dict --all-targets -- -D warnings
3 changes: 3 additions & 0 deletions arrow/Cargo.toml
Expand Up @@ -95,6 +95,9 @@ ffi = []
# Enable dyn-comparison of dictionary arrays with other arrays
# Note: this does not impact comparison against scalars
dyn_cmp_dict = []
# Enable dyn-arithmetic kernels for dictionary arrays
# Note: this does not impact arithmetic with scalars
dyn_arith_dict = []

[dev-dependencies]
rand = { version = "0.8", default-features = false, features = ["std", "std_rng"] }
Expand Down
1 change: 1 addition & 0 deletions arrow/README.md
Expand Up @@ -54,6 +54,7 @@ The `arrow` crate provides the following features which may be enabled in your `
- `ffi` - bindings for the Arrow C [C Data Interface](https://arrow.apache.org/docs/format/CDataInterface.html)
- `pyarrow` - bindings for pyo3 to call arrow-rs from python
- `dyn_cmp_dict` - enables comparison of dictionary arrays within dyn comparison kernels
- `dyn_arith_dict` - enables arithmetic on dictionary arrays within dyn arithmetic kernels

## Arrow Feature Status

Expand Down
51 changes: 41 additions & 10 deletions arrow/src/compute/kernels/arithmetic.rs
Expand Up @@ -27,18 +27,17 @@ use std::ops::{Div, Neg, Rem};
use num::{One, Zero};

use crate::array::*;
use crate::buffer::Buffer;
#[cfg(feature = "simd")]
use crate::buffer::MutableBuffer;
use crate::compute::kernels::arity::unary;
use crate::compute::util::combine_option_bitmap;
use crate::compute::{
binary, binary_opt, try_binary, try_unary, try_unary_dyn, unary_dyn,
};
use crate::datatypes::{
native_op::ArrowNativeTypeOp, ArrowNumericType, DataType, Date32Type, Date64Type,
IntervalDayTimeType, IntervalMonthDayNanoType, IntervalUnit, IntervalYearMonthType,
};
#[cfg(feature = "dyn_arith_dict")]
use crate::datatypes::{
Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, UInt16Type,
UInt32Type, UInt64Type, UInt8Type,
Expand Down Expand Up @@ -122,12 +121,13 @@ where
/// This function errors if:
/// * the arrays have different lengths
/// * there is an element where both left and right values are valid and the right value is `0`
#[cfg(feature = "dyn_arith_dict")]
fn math_checked_divide_op_on_iters<T, F>(
left: impl Iterator<Item = Option<T::Native>>,
right: impl Iterator<Item = Option<T::Native>>,
op: F,
len: usize,
null_bit_buffer: Option<Buffer>,
null_bit_buffer: Option<crate::buffer::Buffer>,
) -> Result<PrimitiveArray<T>>
where
T: ArrowNumericType,
Expand All @@ -143,15 +143,15 @@ where
}
});
// Safety: Iterator comes from a PrimitiveArray which reports its size correctly
unsafe { Buffer::try_from_trusted_len_iter(values) }
unsafe { crate::buffer::Buffer::try_from_trusted_len_iter(values) }
} else {
// no value is null
let values = left
.map(|l| l.unwrap())
.zip(right.map(|r| r.unwrap()))
.map(|(left, right)| op(left, right));
// Safety: Iterator comes from a PrimitiveArray which reports its size correctly
unsafe { Buffer::try_from_trusted_len_iter(values) }
unsafe { crate::buffer::Buffer::try_from_trusted_len_iter(values) }
}?;

let data = unsafe {
Expand Down Expand Up @@ -316,8 +316,10 @@ where
}

// Create the combined `Bitmap`
let null_bit_buffer =
combine_option_bitmap(&[left.data_ref(), right.data_ref()], left.len())?;
let null_bit_buffer = crate::compute::util::combine_option_bitmap(
&[left.data_ref(), right.data_ref()],
left.len(),
)?;

let lanes = T::lanes();
let buffer_size = left.len() * std::mem::size_of::<T::Native>();
Expand Down Expand Up @@ -425,6 +427,7 @@ where
}

/// Applies $OP to $LEFT and $RIGHT which are two dictionaries which have (the same) key type $KT
#[cfg(feature = "dyn_arith_dict")]
macro_rules! typed_dict_op {
($LEFT: expr, $RIGHT: expr, $OP: expr, $KT: tt, $MATH_OP: ident) => {{
match ($LEFT.value_type(), $RIGHT.value_type()) {
Expand Down Expand Up @@ -476,6 +479,7 @@ macro_rules! typed_dict_op {
}};
}

#[cfg(feature = "dyn_arith_dict")]
macro_rules! typed_dict_math_op {
// Applies `LEFT OP RIGHT` when `LEFT` and `RIGHT` both are `DictionaryArray`
($LEFT: expr, $RIGHT: expr, $OP: expr, $MATH_OP: ident) => {{
Expand Down Expand Up @@ -536,8 +540,20 @@ macro_rules! typed_dict_math_op {
}};
}

#[cfg(not(feature = "dyn_arith_dict"))]
macro_rules! typed_dict_math_op {
// Applies `LEFT OP RIGHT` when `LEFT` and `RIGHT` both are `DictionaryArray`
($LEFT: expr, $RIGHT: expr, $OP: expr, $MATH_OP: ident) => {{
Err(ArrowError::CastError(format!(
"Arithmetic on arrays of type {} with array of type {} requires \"dyn_arith_dict\" feature",
$LEFT.data_type(), $RIGHT.data_type()
)))
}};
}

/// Perform given operation on two `DictionaryArray`s.
/// Returns an error if the two arrays have different value type
#[cfg(feature = "dyn_arith_dict")]
fn math_op_dict<K, T, F>(
left: &DictionaryArray<K>,
right: &DictionaryArray<K>,
Expand Down Expand Up @@ -593,6 +609,7 @@ where

/// Perform given operation on two `DictionaryArray`s.
/// Returns an error if the two arrays have different value type
#[cfg(feature = "dyn_arith_dict")]
fn math_checked_op_dict<K, T, F>(
left: &DictionaryArray<K>,
right: &DictionaryArray<K>,
Expand Down Expand Up @@ -626,6 +643,7 @@ where
/// This function errors if:
/// * the arrays have different lengths
/// * there is an element where both left and right values are valid and the right value is `0`
#[cfg(feature = "dyn_arith_dict")]
fn math_divide_checked_op_dict<K, T, F>(
left: &DictionaryArray<K>,
right: &DictionaryArray<K>,
Expand All @@ -645,8 +663,10 @@ where
)));
}

let null_bit_buffer =
combine_option_bitmap(&[left.data_ref(), right.data_ref()], left.len())?;
let null_bit_buffer = crate::compute::util::combine_option_bitmap(
&[left.data_ref(), right.data_ref()],
left.len(),
)?;

// Safety justification: Since the inputs are valid Arrow arrays, all values are
// valid indexes into the dictionary (which is verified during construction)
Expand Down Expand Up @@ -1484,7 +1504,7 @@ where
mod tests {
use super::*;
use crate::array::Int32Array;
use crate::datatypes::Date64Type;
use crate::datatypes::{Date64Type, Int32Type, Int8Type};
use chrono::NaiveDate;

#[test]
Expand Down Expand Up @@ -1605,6 +1625,7 @@ mod tests {
}

#[test]
#[cfg(feature = "dyn_arith_dict")]
fn test_primitive_array_add_dyn_dict() {
let mut builder = PrimitiveDictionaryBuilder::<Int8Type, Int32Type>::new();
builder.append(5).unwrap();
Expand Down Expand Up @@ -1683,6 +1704,7 @@ mod tests {
}

#[test]
#[cfg(feature = "dyn_arith_dict")]
fn test_primitive_array_subtract_dyn_dict() {
let mut builder = PrimitiveDictionaryBuilder::<Int8Type, Int32Type>::new();
builder.append(15).unwrap();
Expand Down Expand Up @@ -1761,6 +1783,7 @@ mod tests {
}

#[test]
#[cfg(feature = "dyn_arith_dict")]
fn test_primitive_array_multiply_dyn_dict() {
let mut builder = PrimitiveDictionaryBuilder::<Int8Type, Int32Type>::new();
builder.append(5).unwrap();
Expand Down Expand Up @@ -1801,6 +1824,7 @@ mod tests {
}

#[test]
#[cfg(feature = "dyn_arith_dict")]
fn test_primitive_array_divide_dyn_dict() {
let mut builder = PrimitiveDictionaryBuilder::<Int8Type, Int32Type>::new();
builder.append(15).unwrap();
Expand Down Expand Up @@ -2322,6 +2346,7 @@ mod tests {

#[test]
#[should_panic(expected = "DivideByZero")]
#[cfg(feature = "dyn_arith_dict")]
fn test_int_array_divide_dyn_by_zero_dict() {
let mut builder =
PrimitiveDictionaryBuilder::<Int8Type, Int32Type>::with_capacity(1, 1);
Expand All @@ -2338,7 +2363,9 @@ mod tests {

#[test]
#[should_panic(expected = "DivideByZero")]
#[cfg(feature = "dyn_arith_dict")]
fn test_f32_dict_array_divide_dyn_by_zero() {
use crate::datatypes::Float32Type;
let mut builder =
PrimitiveDictionaryBuilder::<Int8Type, Float32Type>::with_capacity(1, 1);
builder.append(1.5).unwrap();
Expand Down Expand Up @@ -2601,6 +2628,7 @@ mod tests {
}

#[test]
#[cfg(feature = "dyn_arith_dict")]
fn test_dictionary_add_dyn_wrapping_overflow() {
let mut builder =
PrimitiveDictionaryBuilder::<Int8Type, Int32Type>::with_capacity(2, 2);
Expand Down Expand Up @@ -2637,6 +2665,7 @@ mod tests {
}

#[test]
#[cfg(feature = "dyn_arith_dict")]
fn test_dictionary_subtract_dyn_wrapping_overflow() {
let mut builder =
PrimitiveDictionaryBuilder::<Int8Type, Int32Type>::with_capacity(1, 1);
Expand Down Expand Up @@ -2670,6 +2699,7 @@ mod tests {
}

#[test]
#[cfg(feature = "dyn_arith_dict")]
fn test_dictionary_mul_dyn_wrapping_overflow() {
let mut builder =
PrimitiveDictionaryBuilder::<Int8Type, Int32Type>::with_capacity(1, 1);
Expand Down Expand Up @@ -2703,6 +2733,7 @@ mod tests {
}

#[test]
#[cfg(feature = "dyn_arith_dict")]
fn test_dictionary_div_dyn_wrapping_overflow() {
let mut builder =
PrimitiveDictionaryBuilder::<Int8Type, Int32Type>::with_capacity(1, 1);
Expand Down

0 comments on commit 74f639c

Please sign in to comment.