Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gate dyn comparison of dictionary arrays behind dyn_cmp_dict #2597

Merged
merged 3 commits into from Aug 27, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 3 additions & 3 deletions .github/workflows/arrow.yml
Expand Up @@ -51,9 +51,9 @@ jobs:
- name: Test
run: |
cargo test -p arrow
- name: Test --features=force_validate,prettyprint,ipc_compression,ffi
- name: Test --features=force_validate,prettyprint,ipc_compression,ffi,dyn_cmp_dict
run: |
cargo test -p arrow --features=force_validate,prettyprint,ipc_compression,ffi
cargo test -p arrow --features=force_validate,prettyprint,ipc_compression,ffi,dyn_cmp_dict
- name: Test --features=nan_ordering
run: |
cargo test -p arrow --features "nan_ordering"
Expand Down Expand Up @@ -175,4 +175,4 @@ jobs:
rustup component add clippy
- name: Run clippy
run: |
cargo clippy -p arrow --features=prettyprint,csv,ipc,test_utils,ffi,ipc_compression --all-targets -- -D warnings
cargo clippy -p arrow --features=prettyprint,csv,ipc,test_utils,ffi,ipc_compression,dyn_cmp_dict --all-targets -- -D warnings
11 changes: 7 additions & 4 deletions arrow/Cargo.toml
Expand Up @@ -38,10 +38,10 @@ path = "src/lib.rs"
bench = false

[target.'cfg(target_arch = "wasm32")'.dependencies]
ahash = { version = "0.8", default-features = false, features=["compile-time-rng"] }
ahash = { version = "0.8", default-features = false, features = ["compile-time-rng"] }

[target.'cfg(not(target_arch = "wasm32"))'.dependencies]
ahash = { version = "0.8", default-features = false, features=["runtime-rng"] }
ahash = { version = "0.8", default-features = false, features = ["runtime-rng"] }

[dependencies]
serde = { version = "1.0", default-features = false }
Expand Down Expand Up @@ -90,6 +90,9 @@ force_validate = []
ffi = []
# Enable NaN-ordering behavior on comparison kernels
nan_ordering = []
# Enable dyn-comparison of dictionary arrays with other arrays
# Note: this does not impact comparison against scalars
dyn_cmp_dict = []

[dev-dependencies]
rand = { version = "0.8", default-features = false, features = ["std", "std_rng"] }
Expand All @@ -102,7 +105,7 @@ tempfile = { version = "3", default-features = false }
[[example]]
name = "dynamic_types"
required-features = ["prettyprint"]
path="./examples/dynamic_types.rs"
path = "./examples/dynamic_types.rs"

[[bench]]
name = "aggregate_kernels"
Expand Down Expand Up @@ -144,7 +147,7 @@ required-features = ["test_utils"]
[[bench]]
name = "comparison_kernels"
harness = false
required-features = ["test_utils"]
required-features = ["test_utils", "dyn_cmp_dict"]

[[bench]]
name = "filter_kernels"
Expand Down
40 changes: 40 additions & 0 deletions arrow/src/compute/kernels/comparison.rs
Expand Up @@ -1855,6 +1855,7 @@ where
compare_op(left_array, right_array, op)
}

#[cfg(feature = "dyn_cmp_dict")]
macro_rules! typed_dict_non_dict_cmp {
($LEFT: expr, $RIGHT: expr, $LEFT_KEY_TYPE: expr, $RIGHT_TYPE: tt, $OP_BOOL: expr, $OP: expr) => {{
match $LEFT_KEY_TYPE {
Expand Down Expand Up @@ -1898,6 +1899,7 @@ macro_rules! typed_dict_non_dict_cmp {
}};
}

#[cfg(feature = "dyn_cmp_dict")]
macro_rules! typed_cmp_dict_non_dict {
($LEFT: expr, $RIGHT: expr, $OP_BOOL: expr, $OP: expr) => {{
match ($LEFT.data_type(), $RIGHT.data_type()) {
Expand Down Expand Up @@ -1948,6 +1950,16 @@ macro_rules! typed_cmp_dict_non_dict {
}};
}

#[cfg(not(feature = "dyn_cmp_dict"))]
macro_rules! typed_cmp_dict_non_dict {
($LEFT: expr, $RIGHT: expr, $OP_BOOL: expr, $OP: expr) => {{
Err(ArrowError::CastError(format!(
"Comparing dictionary array of type {} with array of type {} requires \"dyn_cmp_dict\" feature",
$LEFT.data_type(), $RIGHT.data_type()
)))
}}
}

macro_rules! typed_compares {
($LEFT: expr, $RIGHT: expr, $OP_BOOL: expr, $OP: expr, $OP_FLOAT: expr) => {{
match ($LEFT.data_type(), $RIGHT.data_type()) {
Expand Down Expand Up @@ -2064,6 +2076,7 @@ macro_rules! typed_compares {
}

/// Applies $OP to $LEFT and $RIGHT which are two dictionaries which have (the same) key type $KT
#[cfg(feature = "dyn_cmp_dict")]
macro_rules! typed_dict_cmp {
($LEFT: expr, $RIGHT: expr, $OP: expr, $OP_FLOAT: expr, $OP_BOOL: expr, $KT: tt) => {{
match ($LEFT.value_type(), $RIGHT.value_type()) {
Expand Down Expand Up @@ -2196,6 +2209,7 @@ macro_rules! typed_dict_cmp {
}};
}

#[cfg(feature = "dyn_cmp_dict")]
macro_rules! typed_dict_compares {
// Applies `LEFT OP RIGHT` when `LEFT` and `RIGHT` both are `DictionaryArray`
($LEFT: expr, $RIGHT: expr, $OP: expr, $OP_FLOAT: expr, $OP_BOOL: expr) => {{
Expand Down Expand Up @@ -2260,8 +2274,19 @@ macro_rules! typed_dict_compares {
}};
}

#[cfg(not(feature = "dyn_cmp_dict"))]
macro_rules! typed_dict_compares {
($LEFT: expr, $RIGHT: expr, $OP: expr, $OP_FLOAT: expr, $OP_BOOL: expr) => {{
Err(ArrowError::CastError(format!(
"Comparing array of type {} with array of type {} requires \"dyn_cmp_dict\" feature",
$LEFT.data_type(), $RIGHT.data_type()
)))
}}
}

/// Perform given operation on `DictionaryArray` and `PrimitiveArray`. The value
/// type of `DictionaryArray` is same as `PrimitiveArray`'s type.
#[cfg(feature = "dyn_cmp_dict")]
fn cmp_dict_primitive<K, T, F>(
left: &DictionaryArray<K>,
right: &dyn Array,
Expand All @@ -2282,6 +2307,7 @@ where
/// Perform given operation on two `DictionaryArray`s which value type is
/// primitive type. Returns an error if the two arrays have different value
/// type
#[cfg(feature = "dyn_cmp_dict")]
pub fn cmp_dict<K, T, F>(
left: &DictionaryArray<K>,
right: &DictionaryArray<K>,
Expand All @@ -2301,6 +2327,7 @@ where

/// Perform the given operation on two `DictionaryArray`s which value type is
/// `DataType::Boolean`.
#[cfg(feature = "dyn_cmp_dict")]
pub fn cmp_dict_bool<K, F>(
left: &DictionaryArray<K>,
right: &DictionaryArray<K>,
Expand All @@ -2319,6 +2346,7 @@ where

/// Perform the given operation on two `DictionaryArray`s which value type is
/// `DataType::Utf8` or `DataType::LargeUtf8`.
#[cfg(feature = "dyn_cmp_dict")]
pub fn cmp_dict_utf8<K, OffsetSize: OffsetSizeTrait, F>(
left: &DictionaryArray<K>,
right: &DictionaryArray<K>,
Expand All @@ -2340,6 +2368,7 @@ where

/// Perform the given operation on two `DictionaryArray`s which value type is
/// `DataType::Binary` or `DataType::LargeBinary`.
#[cfg(feature = "dyn_cmp_dict")]
pub fn cmp_dict_binary<K, OffsetSize: OffsetSizeTrait, F>(
left: &DictionaryArray<K>,
right: &DictionaryArray<K>,
Expand Down Expand Up @@ -5242,6 +5271,7 @@ mod tests {
}

#[test]
#[cfg(feature = "dyn_cmp_dict")]
fn test_eq_dyn_neq_dyn_dictionary_i8_array() {
// Construct a value array
let values = Int8Array::from_iter_values([10_i8, 11, 12, 13, 14, 15, 16, 17]);
Expand All @@ -5262,6 +5292,7 @@ mod tests {
}

#[test]
#[cfg(feature = "dyn_cmp_dict")]
fn test_eq_dyn_neq_dyn_dictionary_u64_array() {
let values = UInt64Array::from_iter_values([10_u64, 11, 12, 13, 14, 15, 16, 17]);

Expand All @@ -5283,6 +5314,7 @@ mod tests {
}

#[test]
#[cfg(feature = "dyn_cmp_dict")]
fn test_eq_dyn_neq_dyn_dictionary_utf8_array() {
let test1 = vec!["a", "a", "b", "c"];
let test2 = vec!["a", "b", "b", "c"];
Expand Down Expand Up @@ -5310,6 +5342,7 @@ mod tests {
}

#[test]
#[cfg(feature = "dyn_cmp_dict")]
fn test_eq_dyn_neq_dyn_dictionary_binary_array() {
let values: BinaryArray = ["hello", "", "parquet"]
.into_iter()
Expand All @@ -5334,6 +5367,7 @@ mod tests {
}

#[test]
#[cfg(feature = "dyn_cmp_dict")]
fn test_eq_dyn_neq_dyn_dictionary_interval_array() {
let values = IntervalDayTimeArray::from(vec![1, 6, 10, 2, 3, 5]);

Expand All @@ -5355,6 +5389,7 @@ mod tests {
}

#[test]
#[cfg(feature = "dyn_cmp_dict")]
fn test_eq_dyn_neq_dyn_dictionary_date_array() {
let values = Date32Array::from(vec![1, 6, 10, 2, 3, 5]);

Expand All @@ -5376,6 +5411,7 @@ mod tests {
}

#[test]
#[cfg(feature = "dyn_cmp_dict")]
fn test_eq_dyn_neq_dyn_dictionary_bool_array() {
let values = BooleanArray::from(vec![true, false]);

Expand All @@ -5397,6 +5433,7 @@ mod tests {
}

#[test]
#[cfg(feature = "dyn_cmp_dict")]
fn test_lt_dyn_gt_dyn_dictionary_i8_array() {
// Construct a value array
let values = Int8Array::from_iter_values([10_i8, 11, 12, 13, 14, 15, 16, 17]);
Expand Down Expand Up @@ -5426,6 +5463,7 @@ mod tests {
}

#[test]
#[cfg(feature = "dyn_cmp_dict")]
fn test_lt_dyn_gt_dyn_dictionary_bool_array() {
let values = BooleanArray::from(vec![true, false]);

Expand Down Expand Up @@ -5468,6 +5506,7 @@ mod tests {
}

#[test]
#[cfg(feature = "dyn_cmp_dict")]
fn test_eq_dyn_neq_dyn_dictionary_i8_i8_array() {
let values = Int8Array::from_iter_values([10_i8, 11, 12, 13, 14, 15, 16, 17]);
let keys = Int8Array::from_iter_values([2_i8, 3, 4]);
Expand Down Expand Up @@ -5502,6 +5541,7 @@ mod tests {
}

#[test]
#[cfg(feature = "dyn_cmp_dict")]
fn test_lt_dyn_lt_eq_dyn_gt_dyn_gt_eq_dyn_dictionary_i8_i8_array() {
let values = Int8Array::from_iter_values([10_i8, 11, 12, 13, 14, 15, 16, 17]);
let keys = Int8Array::from_iter_values([2_i8, 3, 4]);
Expand Down