From c59ccec4f6e54a80d6dab37f14c28cf454a190a2 Mon Sep 17 00:00:00 2001 From: Raphael Taylor-Davies Date: Thu, 8 Dec 2022 15:37:41 +0000 Subject: [PATCH] Split out arrow-ord (#2594) --- .github/workflows/arrow.yml | 34 +- Cargo.toml | 1 + arrow-array/Cargo.toml | 4 + arrow-array/src/lib.rs | 3 + .../datatypes => arrow-array/src}/numeric.rs | 10 +- arrow-ord/Cargo.toml | 53 ++ .../kernels => arrow-ord/src}/comparison.rs | 554 +++++++++--------- arrow-ord/src/lib.rs | 23 + {arrow/src/array => arrow-ord/src}/ord.rs | 89 ++- .../kernels => arrow-ord/src}/partition.rs | 96 ++- .../compute/kernels => arrow-ord/src}/sort.rs | 82 +-- arrow-string/src/regexp.rs | 115 ++++ arrow/Cargo.toml | 5 +- arrow/src/array/mod.rs | 3 +- arrow/src/compute/kernels/mod.rs | 11 +- arrow/src/datatypes/mod.rs | 7 +- 16 files changed, 642 insertions(+), 448 deletions(-) rename {arrow/src/datatypes => arrow-array/src}/numeric.rs (98%) create mode 100644 arrow-ord/Cargo.toml rename {arrow/src/compute/kernels => arrow-ord/src}/comparison.rs (94%) create mode 100644 arrow-ord/src/lib.rs rename {arrow/src/array => arrow-ord/src}/ord.rs (90%) rename {arrow/src/compute/kernels => arrow-ord/src}/partition.rs (83%) rename {arrow/src/compute/kernels => arrow-ord/src}/sort.rs (98%) diff --git a/.github/workflows/arrow.yml b/.github/workflows/arrow.yml index 0b47f02566c..c2f3b0c8435 100644 --- a/.github/workflows/arrow.yml +++ b/.github/workflows/arrow.yml @@ -25,18 +25,20 @@ on: - master pull_request: paths: - - arrow/** + - .github/** - arrow-array/** - arrow-buffer/** - arrow-cast/** + - arrow-csv/** - arrow-data/** - - arrow-schema/** - - arrow-select/** - arrow-integration-test/** - arrow-ipc/** - - arrow-csv/** - arrow-json/** - - .github/** + - arrow-ord/** + - arrow-schema/** + - arrow-select/** + - arrow-string/** + - arrow/** jobs: @@ -58,8 +60,8 @@ jobs: run: cargo test -p arrow-data --all-features - name: Test arrow-schema with all features run: cargo test -p arrow-schema --all-features - - name: Test arrow-array with all features - run: cargo test -p arrow-array --all-features + - name: Test arrow-array without SIMD + run: cargo test -p arrow-array - name: Test arrow-select with all features run: cargo test -p arrow-select --all-features - name: Test arrow-cast with all features @@ -72,6 +74,8 @@ jobs: run: cargo test -p arrow-json --all-features - name: Test arrow-string with all features run: cargo test -p arrow-string --all-features + - name: Test arrow-ord without SIMD + run: cargo test -p arrow-ord - name: Test arrow-integration-test with all features run: cargo test -p arrow-integration-test --all-features - name: Test arrow with default features @@ -129,10 +133,12 @@ jobs: uses: ./.github/actions/setup-builder with: rust-version: nightly - - name: Run tests --features "simd" - run: cargo test -p arrow --features "simd" - - name: Check compilation --features "simd" - run: cargo check -p arrow --features simd + - name: Test arrow-array with SIMD + run: cargo test -p arrow-array --features simd + - name: Test arrow-ord with SIMD + run: cargo test -p arrow-ord --features simd + - name: Test arrow with SIMD + run: cargo test -p arrow --features simd - name: Check compilation --features simd --all-targets run: cargo check -p arrow --features simd --all-targets @@ -174,8 +180,8 @@ jobs: run: cargo clippy -p arrow-data --all-targets --all-features -- -D warnings - name: Clippy arrow-schema with all features run: cargo clippy -p arrow-schema --all-targets --all-features -- -D warnings - - name: Clippy arrow-array with all features - run: cargo clippy -p arrow-array --all-targets --all-features -- -D warnings + - name: Clippy arrow-array without SIMD + run: cargo clippy -p arrow-array --all-targets -- -D warnings - name: Clippy arrow-select with all features run: cargo clippy -p arrow-select --all-targets --all-features -- -D warnings - name: Clippy arrow-cast with all features @@ -188,5 +194,7 @@ jobs: run: cargo clippy -p arrow-json --all-targets --all-features -- -D warnings - name: Clippy arrow-string with all features run: cargo clippy -p arrow-string --all-targets --all-features -- -D warnings + - name: Clippy arrow-ord without SIMD + run: cargo clippy -p arrow-ord --all-targets -- -D warnings - name: Clippy arrow run: cargo clippy -p arrow --features=prettyprint,csv,ipc,test_utils,ffi,ipc_compression,dyn_cmp_dict,dyn_arith_dict,chrono-tz --all-targets -- -D warnings diff --git a/Cargo.toml b/Cargo.toml index 556b86a008a..c123106c6f7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -28,6 +28,7 @@ members = [ "arrow-integration-testing", "arrow-ipc", "arrow-json", + "arrow-ord", "arrow-schema", "arrow-select", "arrow-string", diff --git a/arrow-array/Cargo.toml b/arrow-array/Cargo.toml index 37f73c6d1c4..67f59a6dcd6 100644 --- a/arrow-array/Cargo.toml +++ b/arrow-array/Cargo.toml @@ -53,6 +53,10 @@ chrono-tz = { version = "0.8", optional = true } num = { version = "0.4", default-features = false, features = ["std"] } half = { version = "2.1", default-features = false, features = ["num-traits"] } hashbrown = { version = "0.13", default-features = false } +packed_simd = { version = "0.3", default-features = false, optional = true, package = "packed_simd_2" } + +[features] +simd = ["packed_simd"] [dev-dependencies] rand = { version = "0.8", default-features = false, features = ["std", "std_rng"] } diff --git a/arrow-array/src/lib.rs b/arrow-array/src/lib.rs index 5fcd1f33d48..d6a9ab30b85 100644 --- a/arrow-array/src/lib.rs +++ b/arrow-array/src/lib.rs @@ -170,6 +170,9 @@ pub use record_batch::{RecordBatch, RecordBatchOptions, RecordBatchReader}; mod arithmetic; pub use arithmetic::ArrowNativeTypeOp; +mod numeric; +pub use numeric::*; + pub mod builder; pub mod cast; mod delta; diff --git a/arrow/src/datatypes/numeric.rs b/arrow-array/src/numeric.rs similarity index 98% rename from arrow/src/datatypes/numeric.rs rename to arrow-array/src/numeric.rs index 61fd05d52f9..ea19a0575f2 100644 --- a/arrow/src/datatypes/numeric.rs +++ b/arrow-array/src/numeric.rs @@ -15,7 +15,8 @@ // specific language governing permissions and limitations // under the License. -use super::*; +use crate::types::*; +use crate::ArrowPrimitiveType; #[cfg(feature = "simd")] use packed_simd::*; #[cfg(feature = "simd")] @@ -106,9 +107,11 @@ where /// Writes a SIMD result back to a slice fn write(simd_result: Self::Simd, slice: &mut [Self::Native]); + /// Performs a SIMD unary operation fn unary_op Self::Simd>(a: Self::Simd, op: F) -> Self::Simd; } +/// A subtype of primitive type that represents numeric values. #[cfg(not(feature = "simd"))] pub trait ArrowNumericType: ArrowPrimitiveType {} @@ -468,7 +471,7 @@ impl ArrowNumericType for Decimal256Type {} #[cfg(feature = "simd")] impl ArrowNumericType for Decimal256Type { - type Simd = i256; + type Simd = arrow_buffer::i256; type SimdMask = bool; fn lanes() -> usize { @@ -555,11 +558,14 @@ impl ArrowNumericType for Decimal256Type { } } +/// A subtype of primitive type that represents numeric float values #[cfg(feature = "simd")] pub trait ArrowFloatNumericType: ArrowNumericType { + /// SIMD version of pow fn pow(base: Self::Simd, raise: Self::Simd) -> Self::Simd; } +/// A subtype of primitive type that represents numeric float values #[cfg(not(feature = "simd"))] pub trait ArrowFloatNumericType: ArrowNumericType {} diff --git a/arrow-ord/Cargo.toml b/arrow-ord/Cargo.toml new file mode 100644 index 00000000000..7a9e7ba4a9c --- /dev/null +++ b/arrow-ord/Cargo.toml @@ -0,0 +1,53 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +[package] +name = "arrow-ord" +version = "28.0.0" +description = "Ordering kernels for arrow arrays" +homepage = "https://github.com/apache/arrow-rs" +repository = "https://github.com/apache/arrow-rs" +authors = ["Apache Arrow "] +license = "Apache-2.0" +keywords = ["arrow"] +include = [ + "benches/*.rs", + "src/**/*.rs", + "Cargo.toml", +] +edition = "2021" +rust-version = "1.62" + +[lib] +name = "arrow_ord" +path = "src/lib.rs" +bench = false + +[dependencies] +arrow-array = { version = "28.0.0", path = "../arrow-array" } +arrow-buffer = { version = "28.0.0", path = "../arrow-buffer" } +arrow-data = { version = "28.0.0", path = "../arrow-data" } +arrow-schema = { version = "28.0.0", path = "../arrow-schema" } +arrow-select = { version = "28.0.0", path = "../arrow-select" } +num = { version = "0.4", default-features = false, features = ["std"] } + +[dev-dependencies] +rand = { version = "0.8", default-features = false, features = ["std", "std_rng"] } + +[features] +dyn_cmp_dict = [] +simd = ["arrow-array/simd"] diff --git a/arrow/src/compute/kernels/comparison.rs b/arrow-ord/src/comparison.rs similarity index 94% rename from arrow/src/compute/kernels/comparison.rs rename to arrow-ord/src/comparison.rs index 6976a68d99a..19659000824 100644 --- a/arrow/src/compute/kernels/comparison.rs +++ b/arrow-ord/src/comparison.rs @@ -23,19 +23,15 @@ //! [here](https://doc.rust-lang.org/stable/core/arch/) for more information. //! -pub use arrow_string::like::*; -pub use arrow_string::regexp::{regexp_is_match_utf8, regexp_is_match_utf8_scalar}; - -use crate::array::*; -use crate::buffer::{buffer_unary_not, Buffer, MutableBuffer}; -use crate::datatypes::*; -#[allow(unused_imports)] -use crate::downcast_dictionary_array; -use crate::error::{ArrowError, Result}; -use crate::util::bit_util; +use arrow_array::cast::*; +use arrow_array::types::*; +use arrow_array::*; +use arrow_buffer::buffer::buffer_unary_not; +use arrow_buffer::{bit_util, Buffer, MutableBuffer}; use arrow_data::bit_mask::combine_option_bitmap; +use arrow_data::ArrayData; +use arrow_schema::{ArrowError, DataType, IntervalUnit, TimeUnit}; use arrow_select::take::take; -use num::ToPrimitive; /// Helper function to perform boolean lambda function on values from two array accessors, this /// version does not attempt to use SIMD. @@ -43,7 +39,7 @@ fn compare_op( left: T, right: S, op: F, -) -> Result +) -> Result where F: Fn(T::Item, S::Item) -> bool, { @@ -59,7 +55,10 @@ where /// Helper function to perform boolean lambda function on values from array accessor, this /// version does not attempt to use SIMD. -fn compare_op_scalar(left: T, op: F) -> Result +fn compare_op_scalar( + left: T, + op: F, +) -> Result where F: Fn(T::Item) -> bool, { @@ -72,9 +71,9 @@ pub fn no_simd_compare_op( left: &PrimitiveArray, right: &PrimitiveArray, op: F, -) -> Result +) -> Result where - T: ArrowNumericType, + T: ArrowPrimitiveType, F: Fn(T::Native, T::Native) -> bool, { compare_op(left, right, op) @@ -86,9 +85,9 @@ pub fn no_simd_compare_op_scalar( left: &PrimitiveArray, right: T::Native, op: F, -) -> Result +) -> Result where - T: ArrowNumericType, + T: ArrowPrimitiveType, F: Fn(T::Native, T::Native) -> bool, { compare_op_scalar(left, |l| op(l, right)) @@ -98,13 +97,13 @@ where pub fn eq_utf8( left: &GenericStringArray, right: &GenericStringArray, -) -> Result { +) -> Result { compare_op(left, right, |a, b| a == b) } fn utf8_empty( left: &GenericStringArray, -) -> Result { +) -> Result { let null_bit_buffer = left .data() .null_buffer() @@ -140,7 +139,7 @@ fn utf8_empty( pub fn eq_utf8_scalar( left: &GenericStringArray, right: &str, -) -> Result { +) -> Result { if right.is_empty() { return utf8_empty::<_, true>(left); } @@ -148,37 +147,58 @@ pub fn eq_utf8_scalar( } /// Perform `left == right` operation on [`BooleanArray`] -pub fn eq_bool(left: &BooleanArray, right: &BooleanArray) -> Result { +pub fn eq_bool( + left: &BooleanArray, + right: &BooleanArray, +) -> Result { compare_op(left, right, |a, b| !(a ^ b)) } /// Perform `left != right` operation on [`BooleanArray`] -pub fn neq_bool(left: &BooleanArray, right: &BooleanArray) -> Result { +pub fn neq_bool( + left: &BooleanArray, + right: &BooleanArray, +) -> Result { compare_op(left, right, |a, b| (a ^ b)) } /// Perform `left < right` operation on [`BooleanArray`] -pub fn lt_bool(left: &BooleanArray, right: &BooleanArray) -> Result { +pub fn lt_bool( + left: &BooleanArray, + right: &BooleanArray, +) -> Result { compare_op(left, right, |a, b| ((!a) & b)) } /// Perform `left <= right` operation on [`BooleanArray`] -pub fn lt_eq_bool(left: &BooleanArray, right: &BooleanArray) -> Result { +pub fn lt_eq_bool( + left: &BooleanArray, + right: &BooleanArray, +) -> Result { compare_op(left, right, |a, b| !(a & (!b))) } /// Perform `left > right` operation on [`BooleanArray`] -pub fn gt_bool(left: &BooleanArray, right: &BooleanArray) -> Result { +pub fn gt_bool( + left: &BooleanArray, + right: &BooleanArray, +) -> Result { compare_op(left, right, |a, b| (a & (!b))) } /// Perform `left >= right` operation on [`BooleanArray`] -pub fn gt_eq_bool(left: &BooleanArray, right: &BooleanArray) -> Result { +pub fn gt_eq_bool( + left: &BooleanArray, + right: &BooleanArray, +) -> Result { compare_op(left, right, |a, b| !((!a) & b)) } /// Perform `left == right` operation on [`BooleanArray`] and a scalar -pub fn eq_bool_scalar(left: &BooleanArray, right: bool) -> Result { +pub fn eq_bool_scalar( + left: &BooleanArray, + right: bool, +) -> Result { let len = left.len(); let left_offset = left.offset(); @@ -207,27 +227,42 @@ pub fn eq_bool_scalar(left: &BooleanArray, right: bool) -> Result } /// Perform `left < right` operation on [`BooleanArray`] and a scalar -pub fn lt_bool_scalar(left: &BooleanArray, right: bool) -> Result { +pub fn lt_bool_scalar( + left: &BooleanArray, + right: bool, +) -> Result { compare_op_scalar(left, |a: bool| !a & right) } /// Perform `left <= right` operation on [`BooleanArray`] and a scalar -pub fn lt_eq_bool_scalar(left: &BooleanArray, right: bool) -> Result { +pub fn lt_eq_bool_scalar( + left: &BooleanArray, + right: bool, +) -> Result { compare_op_scalar(left, |a| a <= right) } /// Perform `left > right` operation on [`BooleanArray`] and a scalar -pub fn gt_bool_scalar(left: &BooleanArray, right: bool) -> Result { +pub fn gt_bool_scalar( + left: &BooleanArray, + right: bool, +) -> Result { compare_op_scalar(left, |a: bool| a & !right) } /// Perform `left >= right` operation on [`BooleanArray`] and a scalar -pub fn gt_eq_bool_scalar(left: &BooleanArray, right: bool) -> Result { +pub fn gt_eq_bool_scalar( + left: &BooleanArray, + right: bool, +) -> Result { compare_op_scalar(left, |a| a >= right) } /// Perform `left != right` operation on [`BooleanArray`] and a scalar -pub fn neq_bool_scalar(left: &BooleanArray, right: bool) -> Result { +pub fn neq_bool_scalar( + left: &BooleanArray, + right: bool, +) -> Result { eq_bool_scalar(left, !right) } @@ -235,7 +270,7 @@ pub fn neq_bool_scalar(left: &BooleanArray, right: bool) -> Result pub fn eq_binary( left: &GenericBinaryArray, right: &GenericBinaryArray, -) -> Result { +) -> Result { compare_op(left, right, |a, b| a == b) } @@ -243,7 +278,7 @@ pub fn eq_binary( pub fn eq_binary_scalar( left: &GenericBinaryArray, right: &[u8], -) -> Result { +) -> Result { compare_op_scalar(left, |a| a == right) } @@ -251,7 +286,7 @@ pub fn eq_binary_scalar( pub fn neq_binary( left: &GenericBinaryArray, right: &GenericBinaryArray, -) -> Result { +) -> Result { compare_op(left, right, |a, b| a != b) } @@ -259,7 +294,7 @@ pub fn neq_binary( pub fn neq_binary_scalar( left: &GenericBinaryArray, right: &[u8], -) -> Result { +) -> Result { compare_op_scalar(left, |a| a != right) } @@ -267,7 +302,7 @@ pub fn neq_binary_scalar( pub fn lt_binary( left: &GenericBinaryArray, right: &GenericBinaryArray, -) -> Result { +) -> Result { compare_op(left, right, |a, b| a < b) } @@ -275,7 +310,7 @@ pub fn lt_binary( pub fn lt_binary_scalar( left: &GenericBinaryArray, right: &[u8], -) -> Result { +) -> Result { compare_op_scalar(left, |a| a < right) } @@ -283,7 +318,7 @@ pub fn lt_binary_scalar( pub fn lt_eq_binary( left: &GenericBinaryArray, right: &GenericBinaryArray, -) -> Result { +) -> Result { compare_op(left, right, |a, b| a <= b) } @@ -291,7 +326,7 @@ pub fn lt_eq_binary( pub fn lt_eq_binary_scalar( left: &GenericBinaryArray, right: &[u8], -) -> Result { +) -> Result { compare_op_scalar(left, |a| a <= right) } @@ -299,7 +334,7 @@ pub fn lt_eq_binary_scalar( pub fn gt_binary( left: &GenericBinaryArray, right: &GenericBinaryArray, -) -> Result { +) -> Result { compare_op(left, right, |a, b| a > b) } @@ -307,7 +342,7 @@ pub fn gt_binary( pub fn gt_binary_scalar( left: &GenericBinaryArray, right: &[u8], -) -> Result { +) -> Result { compare_op_scalar(left, |a| a > right) } @@ -315,7 +350,7 @@ pub fn gt_binary_scalar( pub fn gt_eq_binary( left: &GenericBinaryArray, right: &GenericBinaryArray, -) -> Result { +) -> Result { compare_op(left, right, |a, b| a >= b) } @@ -323,7 +358,7 @@ pub fn gt_eq_binary( pub fn gt_eq_binary_scalar( left: &GenericBinaryArray, right: &[u8], -) -> Result { +) -> Result { compare_op_scalar(left, |a| a >= right) } @@ -331,7 +366,7 @@ pub fn gt_eq_binary_scalar( pub fn neq_utf8( left: &GenericStringArray, right: &GenericStringArray, -) -> Result { +) -> Result { compare_op(left, right, |a, b| a != b) } @@ -339,7 +374,7 @@ pub fn neq_utf8( pub fn neq_utf8_scalar( left: &GenericStringArray, right: &str, -) -> Result { +) -> Result { if right.is_empty() { return utf8_empty::<_, false>(left); } @@ -350,7 +385,7 @@ pub fn neq_utf8_scalar( pub fn lt_utf8( left: &GenericStringArray, right: &GenericStringArray, -) -> Result { +) -> Result { compare_op(left, right, |a, b| a < b) } @@ -358,7 +393,7 @@ pub fn lt_utf8( pub fn lt_utf8_scalar( left: &GenericStringArray, right: &str, -) -> Result { +) -> Result { compare_op_scalar(left, |a| a < right) } @@ -366,7 +401,7 @@ pub fn lt_utf8_scalar( pub fn lt_eq_utf8( left: &GenericStringArray, right: &GenericStringArray, -) -> Result { +) -> Result { compare_op(left, right, |a, b| a <= b) } @@ -374,7 +409,7 @@ pub fn lt_eq_utf8( pub fn lt_eq_utf8_scalar( left: &GenericStringArray, right: &str, -) -> Result { +) -> Result { compare_op_scalar(left, |a| a <= right) } @@ -382,7 +417,7 @@ pub fn lt_eq_utf8_scalar( pub fn gt_utf8( left: &GenericStringArray, right: &GenericStringArray, -) -> Result { +) -> Result { compare_op(left, right, |a, b| a > b) } @@ -390,7 +425,7 @@ pub fn gt_utf8( pub fn gt_utf8_scalar( left: &GenericStringArray, right: &str, -) -> Result { +) -> Result { compare_op_scalar(left, |a| a > right) } @@ -398,7 +433,7 @@ pub fn gt_utf8_scalar( pub fn gt_eq_utf8( left: &GenericStringArray, right: &GenericStringArray, -) -> Result { +) -> Result { compare_op(left, right, |a, b| a >= b) } @@ -406,12 +441,16 @@ pub fn gt_eq_utf8( pub fn gt_eq_utf8_scalar( left: &GenericStringArray, right: &str, -) -> Result { +) -> Result { compare_op_scalar(left, |a| a >= right) } // Avoids creating a closure for each combination of `$RIGHT` and `$TY` -fn try_to_type_result(value: Option, right: &str, ty: &str) -> Result { +fn try_to_type_result( + value: Option, + right: &str, + ty: &str, +) -> Result { value.ok_or_else(|| { ArrowError::ComputeError(format!("Could not convert {} with {}", right, ty,)) }) @@ -590,7 +629,7 @@ macro_rules! dyn_compare_utf8_scalar { /// For floating values like f32 and f64, this comparison produces an ordering in accordance to /// the totalOrder predicate as defined in the IEEE 754 (2008 revision) floating point standard. /// Please refer to `f32::total_cmp` and `f64::total_cmp`. -pub fn eq_dyn_scalar(left: &dyn Array, right: T) -> Result +pub fn eq_dyn_scalar(left: &dyn Array, right: T) -> Result where T: num::ToPrimitive + std::fmt::Debug, { @@ -609,7 +648,7 @@ where /// For floating values like f32 and f64, this comparison produces an ordering in accordance to /// the totalOrder predicate as defined in the IEEE 754 (2008 revision) floating point standard. /// Please refer to `f32::total_cmp` and `f64::total_cmp`. -pub fn lt_dyn_scalar(left: &dyn Array, right: T) -> Result +pub fn lt_dyn_scalar(left: &dyn Array, right: T) -> Result where T: num::ToPrimitive + std::fmt::Debug, { @@ -628,7 +667,7 @@ where /// For floating values like f32 and f64, this comparison produces an ordering in accordance to /// the totalOrder predicate as defined in the IEEE 754 (2008 revision) floating point standard. /// Please refer to `f32::total_cmp` and `f64::total_cmp`. -pub fn lt_eq_dyn_scalar(left: &dyn Array, right: T) -> Result +pub fn lt_eq_dyn_scalar(left: &dyn Array, right: T) -> Result where T: num::ToPrimitive + std::fmt::Debug, { @@ -647,7 +686,7 @@ where /// For floating values like f32 and f64, this comparison produces an ordering in accordance to /// the totalOrder predicate as defined in the IEEE 754 (2008 revision) floating point standard. /// Please refer to `f32::total_cmp` and `f64::total_cmp`. -pub fn gt_dyn_scalar(left: &dyn Array, right: T) -> Result +pub fn gt_dyn_scalar(left: &dyn Array, right: T) -> Result where T: num::ToPrimitive + std::fmt::Debug, { @@ -666,7 +705,7 @@ where /// For floating values like f32 and f64, this comparison produces an ordering in accordance to /// the totalOrder predicate as defined in the IEEE 754 (2008 revision) floating point standard. /// Please refer to `f32::total_cmp` and `f64::total_cmp`. -pub fn gt_eq_dyn_scalar(left: &dyn Array, right: T) -> Result +pub fn gt_eq_dyn_scalar(left: &dyn Array, right: T) -> Result where T: num::ToPrimitive + std::fmt::Debug, { @@ -685,7 +724,7 @@ where /// For floating values like f32 and f64, this comparison produces an ordering in accordance to /// the totalOrder predicate as defined in the IEEE 754 (2008 revision) floating point standard. /// Please refer to `f32::total_cmp` and `f64::total_cmp`. -pub fn neq_dyn_scalar(left: &dyn Array, right: T) -> Result +pub fn neq_dyn_scalar(left: &dyn Array, right: T) -> Result where T: num::ToPrimitive + std::fmt::Debug, { @@ -699,7 +738,10 @@ where /// Perform `left == right` operation on an array and a numeric scalar /// value. Supports BinaryArray and LargeBinaryArray -pub fn eq_dyn_binary_scalar(left: &dyn Array, right: &[u8]) -> Result { +pub fn eq_dyn_binary_scalar( + left: &dyn Array, + right: &[u8], +) -> Result { match left.data_type() { DataType::Binary => { let left = as_generic_binary_array::(left); @@ -717,7 +759,10 @@ pub fn eq_dyn_binary_scalar(left: &dyn Array, right: &[u8]) -> Result Result { +pub fn neq_dyn_binary_scalar( + left: &dyn Array, + right: &[u8], +) -> Result { match left.data_type() { DataType::Binary => { let left = as_generic_binary_array::(left); @@ -736,7 +781,10 @@ pub fn neq_dyn_binary_scalar(left: &dyn Array, right: &[u8]) -> Result Result { +pub fn lt_dyn_binary_scalar( + left: &dyn Array, + right: &[u8], +) -> Result { match left.data_type() { DataType::Binary => { let left = as_generic_binary_array::(left); @@ -754,7 +802,10 @@ pub fn lt_dyn_binary_scalar(left: &dyn Array, right: &[u8]) -> Result Result { +pub fn lt_eq_dyn_binary_scalar( + left: &dyn Array, + right: &[u8], +) -> Result { match left.data_type() { DataType::Binary => { let left = as_generic_binary_array::(left); @@ -773,7 +824,10 @@ pub fn lt_eq_dyn_binary_scalar(left: &dyn Array, right: &[u8]) -> Result right` operation on an array and a numeric scalar /// value. Supports BinaryArray and LargeBinaryArray -pub fn gt_dyn_binary_scalar(left: &dyn Array, right: &[u8]) -> Result { +pub fn gt_dyn_binary_scalar( + left: &dyn Array, + right: &[u8], +) -> Result { match left.data_type() { DataType::Binary => { let left = as_generic_binary_array::(left); @@ -791,7 +845,10 @@ pub fn gt_dyn_binary_scalar(left: &dyn Array, right: &[u8]) -> Result= right` operation on an array and a numeric scalar /// value. Supports BinaryArray and LargeBinaryArray -pub fn gt_eq_dyn_binary_scalar(left: &dyn Array, right: &[u8]) -> Result { +pub fn gt_eq_dyn_binary_scalar( + left: &dyn Array, + right: &[u8], +) -> Result { match left.data_type() { DataType::Binary => { let left = as_generic_binary_array::(left); @@ -810,7 +867,10 @@ pub fn gt_eq_dyn_binary_scalar(left: &dyn Array, right: &[u8]) -> Result Result { +pub fn eq_dyn_utf8_scalar( + left: &dyn Array, + right: &str, +) -> Result { let result = match left.data_type() { DataType::Dictionary(key_type, value_type) => match value_type.as_ref() { DataType::Utf8 | DataType::LargeUtf8 => { @@ -837,7 +897,10 @@ pub fn eq_dyn_utf8_scalar(left: &dyn Array, right: &str) -> Result /// Perform `left < right` operation on an array and a numeric scalar /// value. Supports StringArrays, and DictionaryArrays that have string values -pub fn lt_dyn_utf8_scalar(left: &dyn Array, right: &str) -> Result { +pub fn lt_dyn_utf8_scalar( + left: &dyn Array, + right: &str, +) -> Result { let result = match left.data_type() { DataType::Dictionary(key_type, value_type) => match value_type.as_ref() { DataType::Utf8 | DataType::LargeUtf8 => { @@ -864,7 +927,10 @@ pub fn lt_dyn_utf8_scalar(left: &dyn Array, right: &str) -> Result /// Perform `left >= right` operation on an array and a numeric scalar /// value. Supports StringArrays, and DictionaryArrays that have string values -pub fn gt_eq_dyn_utf8_scalar(left: &dyn Array, right: &str) -> Result { +pub fn gt_eq_dyn_utf8_scalar( + left: &dyn Array, + right: &str, +) -> Result { let result = match left.data_type() { DataType::Dictionary(key_type, value_type) => match value_type.as_ref() { DataType::Utf8 | DataType::LargeUtf8 => { @@ -891,7 +957,10 @@ pub fn gt_eq_dyn_utf8_scalar(left: &dyn Array, right: &str) -> Result Result { +pub fn lt_eq_dyn_utf8_scalar( + left: &dyn Array, + right: &str, +) -> Result { let result = match left.data_type() { DataType::Dictionary(key_type, value_type) => match value_type.as_ref() { DataType::Utf8 | DataType::LargeUtf8 => { @@ -918,7 +987,10 @@ pub fn lt_eq_dyn_utf8_scalar(left: &dyn Array, right: &str) -> Result right` operation on an array and a numeric scalar /// value. Supports StringArrays, and DictionaryArrays that have string values -pub fn gt_dyn_utf8_scalar(left: &dyn Array, right: &str) -> Result { +pub fn gt_dyn_utf8_scalar( + left: &dyn Array, + right: &str, +) -> Result { let result = match left.data_type() { DataType::Dictionary(key_type, value_type) => match value_type.as_ref() { DataType::Utf8 | DataType::LargeUtf8 => { @@ -945,7 +1017,10 @@ pub fn gt_dyn_utf8_scalar(left: &dyn Array, right: &str) -> Result /// Perform `left != right` operation on an array and a numeric scalar /// value. Supports StringArrays, and DictionaryArrays that have string values -pub fn neq_dyn_utf8_scalar(left: &dyn Array, right: &str) -> Result { +pub fn neq_dyn_utf8_scalar( + left: &dyn Array, + right: &str, +) -> Result { let result = match left.data_type() { DataType::Dictionary(key_type, value_type) => match value_type.as_ref() { DataType::Utf8 | DataType::LargeUtf8 => { @@ -972,7 +1047,10 @@ pub fn neq_dyn_utf8_scalar(left: &dyn Array, right: &str) -> Result Result { +pub fn eq_dyn_bool_scalar( + left: &dyn Array, + right: bool, +) -> Result { let result = match left.data_type() { DataType::Boolean => { let left = as_boolean_array(left); @@ -987,7 +1065,10 @@ pub fn eq_dyn_bool_scalar(left: &dyn Array, right: bool) -> Result /// Perform `left < right` operation on an array and a numeric scalar /// value. Supports BooleanArrays. -pub fn lt_dyn_bool_scalar(left: &dyn Array, right: bool) -> Result { +pub fn lt_dyn_bool_scalar( + left: &dyn Array, + right: bool, +) -> Result { let result = match left.data_type() { DataType::Boolean => { let left = as_boolean_array(left); @@ -1002,7 +1083,10 @@ pub fn lt_dyn_bool_scalar(left: &dyn Array, right: bool) -> Result /// Perform `left > right` operation on an array and a numeric scalar /// value. Supports BooleanArrays. -pub fn gt_dyn_bool_scalar(left: &dyn Array, right: bool) -> Result { +pub fn gt_dyn_bool_scalar( + left: &dyn Array, + right: bool, +) -> Result { let result = match left.data_type() { DataType::Boolean => { let left = as_boolean_array(left); @@ -1017,7 +1101,10 @@ pub fn gt_dyn_bool_scalar(left: &dyn Array, right: bool) -> Result /// Perform `left <= right` operation on an array and a numeric scalar /// value. Supports BooleanArrays. -pub fn lt_eq_dyn_bool_scalar(left: &dyn Array, right: bool) -> Result { +pub fn lt_eq_dyn_bool_scalar( + left: &dyn Array, + right: bool, +) -> Result { let result = match left.data_type() { DataType::Boolean => { let left = as_boolean_array(left); @@ -1032,7 +1119,10 @@ pub fn lt_eq_dyn_bool_scalar(left: &dyn Array, right: bool) -> Result= right` operation on an array and a numeric scalar /// value. Supports BooleanArrays. -pub fn gt_eq_dyn_bool_scalar(left: &dyn Array, right: bool) -> Result { +pub fn gt_eq_dyn_bool_scalar( + left: &dyn Array, + right: bool, +) -> Result { let result = match left.data_type() { DataType::Boolean => { let left = as_boolean_array(left); @@ -1047,7 +1137,10 @@ pub fn gt_eq_dyn_bool_scalar(left: &dyn Array, right: bool) -> Result Result { +pub fn neq_dyn_bool_scalar( + left: &dyn Array, + right: bool, +) -> Result { let result = match left.data_type() { DataType::Boolean => { let left = as_boolean_array(left); @@ -1067,10 +1160,10 @@ pub fn neq_dyn_bool_scalar(left: &dyn Array, right: bool) -> Result( dict: &DictionaryArray, dict_comparison: BooleanArray, -) -> Result +) -> Result where - K: ArrowNumericType, - K::Native: ToPrimitive, + K: ArrowPrimitiveType, + K::Native: num::ToPrimitive, { // TODO: Use take_boolean (#2967) let array = take(&dict_comparison, dict.keys(), None)?; @@ -1085,7 +1178,7 @@ fn simd_compare_op( right: &PrimitiveArray, simd_op: SI, scalar_op: SC, -) -> Result +) -> Result where T: ArrowNumericType, SI: Fn(T::Simd, T::Simd) -> T::SimdMask, @@ -1185,7 +1278,7 @@ fn simd_compare_op_scalar( right: T::Native, simd_op: SI, scalar_op: SC, -) -> Result +) -> Result where T: ArrowNumericType, SI: Fn(T::Simd, T::Simd) -> T::SimdMask, @@ -1271,11 +1364,11 @@ where Ok(BooleanArray::from(data)) } -fn cmp_primitive_array( +fn cmp_primitive_array( left: &dyn Array, right: &dyn Array, op: F, -) -> Result +) -> Result where F: Fn(T::Native, T::Native) -> bool, { @@ -1836,10 +1929,10 @@ fn cmp_dict_primitive( left: &DictionaryArray, right: &dyn Array, op: F, -) -> Result +) -> Result where - K: ArrowNumericType, - T: ArrowNumericType + Sync + Send, + K: ArrowPrimitiveType, + T: ArrowPrimitiveType + Sync + Send, F: Fn(T::Native, T::Native) -> bool, { compare_op( @@ -1856,9 +1949,9 @@ fn cmp_dict_string_array( left: &DictionaryArray, right: &dyn Array, op: F, -) -> Result +) -> Result where - K: ArrowNumericType, + K: ArrowPrimitiveType, F: Fn(&str, &str) -> bool, { compare_op( @@ -1879,9 +1972,9 @@ fn cmp_dict_boolean_array( left: &DictionaryArray, right: &dyn Array, op: F, -) -> Result +) -> Result where - K: ArrowNumericType, + K: ArrowPrimitiveType, F: Fn(bool, bool) -> bool, { compare_op( @@ -1898,9 +1991,9 @@ fn cmp_dict_binary_array( left: &DictionaryArray, right: &dyn Array, op: F, -) -> Result +) -> Result where - K: ArrowNumericType, + K: ArrowPrimitiveType, F: Fn(&[u8], &[u8]) -> bool, { compare_op( @@ -1922,10 +2015,10 @@ pub fn cmp_dict( left: &DictionaryArray, right: &DictionaryArray, op: F, -) -> Result +) -> Result where - K: ArrowNumericType, - T: ArrowNumericType + Sync + Send, + K: ArrowPrimitiveType, + T: ArrowPrimitiveType + Sync + Send, F: Fn(T::Native, T::Native) -> bool, { compare_op( @@ -1942,9 +2035,9 @@ pub fn cmp_dict_bool( left: &DictionaryArray, right: &DictionaryArray, op: F, -) -> Result +) -> Result where - K: ArrowNumericType, + K: ArrowPrimitiveType, F: Fn(bool, bool) -> bool, { compare_op( @@ -1961,9 +2054,9 @@ pub fn cmp_dict_utf8( left: &DictionaryArray, right: &DictionaryArray, op: F, -) -> Result +) -> Result where - K: ArrowNumericType, + K: ArrowPrimitiveType, F: Fn(&str, &str) -> bool, { compare_op( @@ -1983,9 +2076,9 @@ pub fn cmp_dict_binary( left: &DictionaryArray, right: &DictionaryArray, op: F, -) -> Result +) -> Result where - K: ArrowNumericType, + K: ArrowPrimitiveType, F: Fn(&[u8], &[u8]) -> bool, { compare_op( @@ -2009,14 +2102,14 @@ where /// /// # Example /// ``` -/// use arrow::array::{StringArray, BooleanArray}; -/// use arrow::compute::eq_dyn; +/// use arrow_array::{StringArray, BooleanArray}; +/// use arrow_ord::comparison::eq_dyn; /// let array1 = StringArray::from(vec![Some("foo"), None, Some("bar")]); /// let array2 = StringArray::from(vec![Some("foo"), None, Some("baz")]); /// let result = eq_dyn(&array1, &array2).unwrap(); /// assert_eq!(BooleanArray::from(vec![Some(true), None, Some(false)]), result); /// ``` -pub fn eq_dyn(left: &dyn Array, right: &dyn Array) -> Result { +pub fn eq_dyn(left: &dyn Array, right: &dyn Array) -> Result { match left.data_type() { DataType::Dictionary(_, _) if matches!(right.data_type(), DataType::Dictionary(_, _)) => @@ -2052,8 +2145,8 @@ pub fn eq_dyn(left: &dyn Array, right: &dyn Array) -> Result { /// /// # Example /// ``` -/// use arrow::array::{BinaryArray, BooleanArray}; -/// use arrow::compute::neq_dyn; +/// use arrow_array::{BinaryArray, BooleanArray}; +/// use arrow_ord::comparison::neq_dyn; /// let values1: Vec> = vec![Some(&[0xfc, 0xa9]), None, Some(&[0x36])]; /// let values2: Vec> = vec![Some(&[0xfc, 0xa9]), None, Some(&[0x36, 0x00])]; /// let array1 = BinaryArray::from(values1); @@ -2061,7 +2154,7 @@ pub fn eq_dyn(left: &dyn Array, right: &dyn Array) -> Result { /// let result = neq_dyn(&array1, &array2).unwrap(); /// assert_eq!(BooleanArray::from(vec![Some(false), None, Some(true)]), result); /// ``` -pub fn neq_dyn(left: &dyn Array, right: &dyn Array) -> Result { +pub fn neq_dyn(left: &dyn Array, right: &dyn Array) -> Result { match left.data_type() { DataType::Dictionary(_, _) if matches!(right.data_type(), DataType::Dictionary(_, _)) => @@ -2097,16 +2190,16 @@ pub fn neq_dyn(left: &dyn Array, right: &dyn Array) -> Result { /// /// # Example /// ``` -/// use arrow::array::{PrimitiveArray, BooleanArray}; -/// use arrow::datatypes::Int32Type; -/// use arrow::compute::lt_dyn; +/// use arrow_array::{PrimitiveArray, BooleanArray}; +/// use arrow_array::types::Int32Type; +/// use arrow_ord::comparison::lt_dyn; /// let array1: PrimitiveArray = PrimitiveArray::from(vec![Some(0), Some(1), Some(2)]); /// let array2: PrimitiveArray = PrimitiveArray::from(vec![Some(1), Some(1), None]); /// let result = lt_dyn(&array1, &array2).unwrap(); /// assert_eq!(BooleanArray::from(vec![Some(true), Some(false), None]), result); /// ``` #[allow(clippy::bool_comparison)] -pub fn lt_dyn(left: &dyn Array, right: &dyn Array) -> Result { +pub fn lt_dyn(left: &dyn Array, right: &dyn Array) -> Result { match left.data_type() { DataType::Dictionary(_, _) if matches!(right.data_type(), DataType::Dictionary(_, _)) => @@ -2142,15 +2235,18 @@ pub fn lt_dyn(left: &dyn Array, right: &dyn Array) -> Result { /// /// # Example /// ``` -/// use arrow::array::{PrimitiveArray, BooleanArray}; -/// use arrow::datatypes::Date32Type; -/// use arrow::compute::lt_eq_dyn; +/// use arrow_array::{PrimitiveArray, BooleanArray}; +/// use arrow_array::types::Date32Type; +/// use arrow_ord::comparison::lt_eq_dyn; /// let array1: PrimitiveArray = vec![Some(12356), Some(13548), Some(-365), Some(365)].into(); /// let array2: PrimitiveArray = vec![Some(12355), Some(13548), Some(-364), None].into(); /// let result = lt_eq_dyn(&array1, &array2).unwrap(); /// assert_eq!(BooleanArray::from(vec![Some(false), Some(true), Some(true), None]), result); /// ``` -pub fn lt_eq_dyn(left: &dyn Array, right: &dyn Array) -> Result { +pub fn lt_eq_dyn( + left: &dyn Array, + right: &dyn Array, +) -> Result { match left.data_type() { DataType::Dictionary(_, _) if matches!(right.data_type(), DataType::Dictionary(_, _)) => @@ -2186,15 +2282,15 @@ pub fn lt_eq_dyn(left: &dyn Array, right: &dyn Array) -> Result { /// /// # Example /// ``` -/// use arrow::array::BooleanArray; -/// use arrow::compute::gt_dyn; +/// use arrow_array::BooleanArray; +/// use arrow_ord::comparison::gt_dyn; /// let array1 = BooleanArray::from(vec![Some(true), Some(false), None]); /// let array2 = BooleanArray::from(vec![Some(false), Some(true), None]); /// let result = gt_dyn(&array1, &array2).unwrap(); /// assert_eq!(BooleanArray::from(vec![Some(true), Some(false), None]), result); /// ``` #[allow(clippy::bool_comparison)] -pub fn gt_dyn(left: &dyn Array, right: &dyn Array) -> Result { +pub fn gt_dyn(left: &dyn Array, right: &dyn Array) -> Result { match left.data_type() { DataType::Dictionary(_, _) if matches!(right.data_type(), DataType::Dictionary(_, _)) => @@ -2230,14 +2326,17 @@ pub fn gt_dyn(left: &dyn Array, right: &dyn Array) -> Result { /// /// # Example /// ``` -/// use arrow::array::{BooleanArray, StringArray}; -/// use arrow::compute::gt_eq_dyn; +/// use arrow_array::{BooleanArray, StringArray}; +/// use arrow_ord::comparison::gt_eq_dyn; /// let array1 = StringArray::from(vec![Some(""), Some("aaa"), None]); /// let array2 = StringArray::from(vec![Some(" "), Some("aa"), None]); /// let result = gt_eq_dyn(&array1, &array2).unwrap(); /// assert_eq!(BooleanArray::from(vec![Some(false), Some(true), None]), result); /// ``` -pub fn gt_eq_dyn(left: &dyn Array, right: &dyn Array) -> Result { +pub fn gt_eq_dyn( + left: &dyn Array, + right: &dyn Array, +) -> Result { match left.data_type() { DataType::Dictionary(_, _) if matches!(right.data_type(), DataType::Dictionary(_, _)) => @@ -2268,7 +2367,10 @@ pub fn gt_eq_dyn(left: &dyn Array, right: &dyn Array) -> Result { /// For floating values like f32 and f64, this comparison produces an ordering in accordance to /// the totalOrder predicate as defined in the IEEE 754 (2008 revision) floating point standard. /// Please refer to `f32::total_cmp` and `f64::total_cmp`. -pub fn eq(left: &PrimitiveArray, right: &PrimitiveArray) -> Result +pub fn eq( + left: &PrimitiveArray, + right: &PrimitiveArray, +) -> Result where T: ArrowNumericType, T::Native: ArrowNativeTypeOp, @@ -2285,7 +2387,10 @@ where /// For floating values like f32 and f64, this comparison produces an ordering in accordance to /// the totalOrder predicate as defined in the IEEE 754 (2008 revision) floating point standard. /// Please refer to `f32::total_cmp` and `f64::total_cmp`. -pub fn eq_scalar(left: &PrimitiveArray, right: T::Native) -> Result +pub fn eq_scalar( + left: &PrimitiveArray, + right: T::Native, +) -> Result where T: ArrowNumericType, T::Native: ArrowNativeTypeOp, @@ -2297,7 +2402,10 @@ where } /// Applies an unary and infallible comparison function to a primitive array. -pub fn unary_cmp(left: &PrimitiveArray, op: F) -> Result +pub fn unary_cmp( + left: &PrimitiveArray, + op: F, +) -> Result where T: ArrowNumericType, F: Fn(T::Native) -> bool, @@ -2311,7 +2419,10 @@ where /// For floating values like f32 and f64, this comparison produces an ordering in accordance to /// the totalOrder predicate as defined in the IEEE 754 (2008 revision) floating point standard. /// Please refer to `f32::total_cmp` and `f64::total_cmp`. -pub fn neq(left: &PrimitiveArray, right: &PrimitiveArray) -> Result +pub fn neq( + left: &PrimitiveArray, + right: &PrimitiveArray, +) -> Result where T: ArrowNumericType, T::Native: ArrowNativeTypeOp, @@ -2328,7 +2439,10 @@ where /// For floating values like f32 and f64, this comparison produces an ordering in accordance to /// the totalOrder predicate as defined in the IEEE 754 (2008 revision) floating point standard. /// Please refer to `f32::total_cmp` and `f64::total_cmp`. -pub fn neq_scalar(left: &PrimitiveArray, right: T::Native) -> Result +pub fn neq_scalar( + left: &PrimitiveArray, + right: T::Native, +) -> Result where T: ArrowNumericType, T::Native: ArrowNativeTypeOp, @@ -2346,7 +2460,10 @@ where /// For floating values like f32 and f64, this comparison produces an ordering in accordance to /// the totalOrder predicate as defined in the IEEE 754 (2008 revision) floating point standard. /// Please refer to `f32::total_cmp` and `f64::total_cmp`. -pub fn lt(left: &PrimitiveArray, right: &PrimitiveArray) -> Result +pub fn lt( + left: &PrimitiveArray, + right: &PrimitiveArray, +) -> Result where T: ArrowNumericType, T::Native: ArrowNativeTypeOp, @@ -2364,7 +2481,10 @@ where /// For floating values like f32 and f64, this comparison produces an ordering in accordance to /// the totalOrder predicate as defined in the IEEE 754 (2008 revision) floating point standard. /// Please refer to `f32::total_cmp` and `f64::total_cmp`. -pub fn lt_scalar(left: &PrimitiveArray, right: T::Native) -> Result +pub fn lt_scalar( + left: &PrimitiveArray, + right: T::Native, +) -> Result where T: ArrowNumericType, T::Native: ArrowNativeTypeOp, @@ -2385,7 +2505,7 @@ where pub fn lt_eq( left: &PrimitiveArray, right: &PrimitiveArray, -) -> Result +) -> Result where T: ArrowNumericType, T::Native: ArrowNativeTypeOp, @@ -2403,7 +2523,10 @@ where /// For floating values like f32 and f64, this comparison produces an ordering in accordance to /// the totalOrder predicate as defined in the IEEE 754 (2008 revision) floating point standard. /// Please refer to `f32::total_cmp` and `f64::total_cmp`. -pub fn lt_eq_scalar(left: &PrimitiveArray, right: T::Native) -> Result +pub fn lt_eq_scalar( + left: &PrimitiveArray, + right: T::Native, +) -> Result where T: ArrowNumericType, T::Native: ArrowNativeTypeOp, @@ -2421,7 +2544,10 @@ where /// For floating values like f32 and f64, this comparison produces an ordering in accordance to /// the totalOrder predicate as defined in the IEEE 754 (2008 revision) floating point standard. /// Please refer to `f32::total_cmp` and `f64::total_cmp`. -pub fn gt(left: &PrimitiveArray, right: &PrimitiveArray) -> Result +pub fn gt( + left: &PrimitiveArray, + right: &PrimitiveArray, +) -> Result where T: ArrowNumericType, T::Native: ArrowNativeTypeOp, @@ -2439,7 +2565,10 @@ where /// For floating values like f32 and f64, this comparison produces an ordering in accordance to /// the totalOrder predicate as defined in the IEEE 754 (2008 revision) floating point standard. /// Please refer to `f32::total_cmp` and `f64::total_cmp`. -pub fn gt_scalar(left: &PrimitiveArray, right: T::Native) -> Result +pub fn gt_scalar( + left: &PrimitiveArray, + right: T::Native, +) -> Result where T: ArrowNumericType, T::Native: ArrowNativeTypeOp, @@ -2460,7 +2589,7 @@ where pub fn gt_eq( left: &PrimitiveArray, right: &PrimitiveArray, -) -> Result +) -> Result where T: ArrowNumericType, T::Native: ArrowNativeTypeOp, @@ -2478,7 +2607,10 @@ where /// For floating values like f32 and f64, this comparison produces an ordering in accordance to /// the totalOrder predicate as defined in the IEEE 754 (2008 revision) floating point standard. /// Please refer to `f32::total_cmp` and `f64::total_cmp`. -pub fn gt_eq_scalar(left: &PrimitiveArray, right: T::Native) -> Result +pub fn gt_eq_scalar( + left: &PrimitiveArray, + right: T::Native, +) -> Result where T: ArrowNumericType, T::Native: ArrowNativeTypeOp, @@ -2493,7 +2625,7 @@ where pub fn contains( left: &PrimitiveArray, right: &GenericListArray, -) -> Result +) -> Result where T: ArrowNumericType, OffsetSize: OffsetSizeTrait, @@ -2551,7 +2683,7 @@ where pub fn contains_utf8( left: &GenericStringArray, right: &ListArray, -) -> Result +) -> Result where OffsetSize: OffsetSizeTrait, { @@ -2620,12 +2752,12 @@ fn new_all_set_buffer(len: usize) -> Buffer { #[rustfmt::skip::macros(vec)] #[cfg(test)] mod tests { - use arrow_buffer::i256; - use std::sync::Arc; - use super::*; - use crate::datatypes::Int8Type; - use crate::{array::Int32Array, array::Int64Array, datatypes::Field}; + use arrow_array::builder::{ + ListBuilder, PrimitiveDictionaryBuilder, StringBuilder, StringDictionaryBuilder, + }; + use arrow_buffer::i256; + use arrow_schema::Field; /// Evaluate `KERNEL` with two vectors as inputs and assert against the expected output. /// `A_VEC` and `B_VEC` can be of type `Vec` or `Vec>` where `T` is the native @@ -3639,82 +3771,6 @@ mod tests { }; } - macro_rules! test_flag_utf8 { - ($test_name:ident, $left:expr, $right:expr, $op:expr, $expected:expr) => { - #[test] - fn $test_name() { - let left = StringArray::from($left); - let right = StringArray::from($right); - let res = $op(&left, &right, None).unwrap(); - let expected = $expected; - assert_eq!(expected.len(), res.len()); - for i in 0..res.len() { - let v = res.value(i); - assert_eq!(v, expected[i]); - } - } - }; - ($test_name:ident, $left:expr, $right:expr, $flag:expr, $op:expr, $expected:expr) => { - #[test] - fn $test_name() { - let left = StringArray::from($left); - let right = StringArray::from($right); - let flag = Some(StringArray::from($flag)); - let res = $op(&left, &right, flag.as_ref()).unwrap(); - let expected = $expected; - assert_eq!(expected.len(), res.len()); - for i in 0..res.len() { - let v = res.value(i); - assert_eq!(v, expected[i]); - } - } - }; - } - - macro_rules! test_flag_utf8_scalar { - ($test_name:ident, $left:expr, $right:expr, $op:expr, $expected:expr) => { - #[test] - fn $test_name() { - let left = StringArray::from($left); - let res = $op(&left, $right, None).unwrap(); - let expected = $expected; - assert_eq!(expected.len(), res.len()); - for i in 0..res.len() { - let v = res.value(i); - assert_eq!( - v, - expected[i], - "unexpected result when comparing {} at position {} to {} ", - left.value(i), - i, - $right - ); - } - } - }; - ($test_name:ident, $left:expr, $right:expr, $flag:expr, $op:expr, $expected:expr) => { - #[test] - fn $test_name() { - let left = StringArray::from($left); - let flag = Some($flag); - let res = $op(&left, $right, flag).unwrap(); - let expected = $expected; - assert_eq!(expected.len(), res.len()); - for i in 0..res.len() { - let v = res.value(i); - assert_eq!( - v, - expected[i], - "unexpected result when comparing {} at position {} to {} ", - left.value(i), - i, - $right - ); - } - } - }; - } - test_utf8!( test_utf8_array_eq, vec!["arrow", "arrow", "arrow", "arrow"], @@ -3804,44 +3860,6 @@ mod tests { gt_eq_utf8_scalar, vec![false, false, true, true] ); - test_flag_utf8!( - test_utf8_array_regexp_is_match, - vec!["arrow", "arrow", "arrow", "arrow", "arrow", "arrow"], - vec!["^ar", "^AR", "ow$", "OW$", "foo", ""], - regexp_is_match_utf8, - vec![true, false, true, false, false, true] - ); - test_flag_utf8!( - test_utf8_array_regexp_is_match_insensitive, - vec!["arrow", "arrow", "arrow", "arrow", "arrow", "arrow"], - vec!["^ar", "^AR", "ow$", "OW$", "foo", ""], - vec!["i"; 6], - regexp_is_match_utf8, - vec![true, true, true, true, false, true] - ); - - test_flag_utf8_scalar!( - test_utf8_array_regexp_is_match_scalar, - vec!["arrow", "ARROW", "parquet", "PARQUET"], - "^ar", - regexp_is_match_utf8_scalar, - vec![true, false, false, false] - ); - test_flag_utf8_scalar!( - test_utf8_array_regexp_is_match_empty_scalar, - vec!["arrow", "ARROW", "parquet", "PARQUET"], - "", - regexp_is_match_utf8_scalar, - vec![true, true, true, true] - ); - test_flag_utf8_scalar!( - test_utf8_array_regexp_is_match_insensitive_scalar, - vec!["arrow", "ARROW", "parquet", "PARQUET"], - "^ar", - "i", - regexp_is_match_utf8_scalar, - vec![true, true, false, false] - ); #[test] fn test_eq_dyn_scalar() { @@ -3881,8 +3899,7 @@ mod tests { ); assert_eq!(eq_dyn_scalar(&array, 8).unwrap(), expected); - let array: ArrayRef = Arc::new(array); - let array = crate::compute::cast(&array, &DataType::Float64).unwrap(); + let array = array.unary::<_, Float64Type>(|x| x as f64); assert_eq!(eq_dyn_scalar(&array, 8).unwrap(), expected); } @@ -3924,8 +3941,7 @@ mod tests { ); assert_eq!(lt_dyn_scalar(&array, 8).unwrap(), expected); - let array: ArrayRef = Arc::new(array); - let array = crate::compute::cast(&array, &DataType::Float64).unwrap(); + let array = array.unary::<_, Float64Type>(|x| x as f64); assert_eq!(lt_dyn_scalar(&array, 8).unwrap(), expected); } @@ -3967,8 +3983,7 @@ mod tests { ); assert_eq!(lt_eq_dyn_scalar(&array, 8).unwrap(), expected); - let array: ArrayRef = Arc::new(array); - let array = crate::compute::cast(&array, &DataType::Float64).unwrap(); + let array = array.unary::<_, Float64Type>(|x| x as f64); assert_eq!(lt_eq_dyn_scalar(&array, 8).unwrap(), expected); } @@ -4010,8 +4025,7 @@ mod tests { ); assert_eq!(gt_dyn_scalar(&array, 8).unwrap(), expected); - let array: ArrayRef = Arc::new(array); - let array = crate::compute::cast(&array, &DataType::Float64).unwrap(); + let array = array.unary::<_, Float64Type>(|x| x as f64); assert_eq!(gt_dyn_scalar(&array, 8).unwrap(), expected); } @@ -4053,8 +4067,7 @@ mod tests { ); assert_eq!(gt_eq_dyn_scalar(&array, 8).unwrap(), expected); - let array: ArrayRef = Arc::new(array); - let array = crate::compute::cast(&array, &DataType::Float64).unwrap(); + let array = array.unary::<_, Float64Type>(|x| x as f64); assert_eq!(gt_eq_dyn_scalar(&array, 8).unwrap(), expected); } @@ -4096,8 +4109,7 @@ mod tests { ); assert_eq!(neq_dyn_scalar(&array, 8).unwrap(), expected); - let array: ArrayRef = Arc::new(array); - let array = crate::compute::cast(&array, &DataType::Float64).unwrap(); + let array = array.unary::<_, Float64Type>(|x| x as f64); assert_eq!(neq_dyn_scalar(&array, 8).unwrap(), expected); } @@ -4433,8 +4445,6 @@ mod tests { #[test] fn test_eq_dyn_neq_dyn_fixed_size_binary() { - use crate::array::FixedSizeBinaryArray; - let values1: Vec> = vec![Some(&[0xfc, 0xa9]), None, Some(&[0x36, 0x01])]; let values2: Vec> = diff --git a/arrow-ord/src/lib.rs b/arrow-ord/src/lib.rs new file mode 100644 index 00000000000..c84db09fd32 --- /dev/null +++ b/arrow-ord/src/lib.rs @@ -0,0 +1,23 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Arrow ordering kernels + +pub mod comparison; +pub mod ord; +pub mod partition; +pub mod sort; diff --git a/arrow/src/array/ord.rs b/arrow-ord/src/ord.rs similarity index 90% rename from arrow/src/array/ord.rs rename to arrow-ord/src/ord.rs index 305d41cc016..44eb3b18380 100644 --- a/arrow/src/array/ord.rs +++ b/arrow-ord/src/ord.rs @@ -17,14 +17,12 @@ //! Contains functions and function factories to compare arrays. -use std::cmp::Ordering; - -use crate::array::*; -use crate::datatypes::TimeUnit; -use crate::datatypes::*; -use crate::error::{ArrowError, Result}; - +use arrow_array::types::*; +use arrow_array::*; +use arrow_buffer::ArrowNativeType; +use arrow_schema::{ArrowError, DataType}; use num::Float; +use std::cmp::Ordering; /// Compare the values at two arbitrary indices in two arrays. pub type DynComparator = Box Ordering + Send + Sync>; @@ -130,7 +128,7 @@ fn cmp_dict_primitive( key_type: &DataType, left: &dyn Array, right: &dyn Array, -) -> Result +) -> Result where VT: ArrowPrimitiveType, VT::Native: Ord, @@ -160,25 +158,24 @@ where /// The arrays' types must be equal. /// # Example /// ``` -/// use arrow::array::{build_compare, Int32Array}; +/// use arrow_array::Int32Array; +/// use arrow_ord::ord::build_compare; /// -/// # fn main() -> arrow::error::Result<()> { /// let array1 = Int32Array::from(vec![1, 2]); /// let array2 = Int32Array::from(vec![3, 4]); /// -/// let cmp = build_compare(&array1, &array2)?; +/// let cmp = build_compare(&array1, &array2).unwrap(); /// /// // 1 (index 0 of array1) is smaller than 4 (index 1 of array2) /// assert_eq!(std::cmp::Ordering::Less, (cmp)(0, 1)); -/// # Ok(()) -/// # } /// ``` // This is a factory of comparisons. // The lifetime 'a enforces that we cannot use the closure beyond any of the array's lifetime. -pub fn build_compare(left: &dyn Array, right: &dyn Array) -> Result { - use DataType::*; - use IntervalUnit::*; - use TimeUnit::*; +pub fn build_compare( + left: &dyn Array, + right: &dyn Array, +) -> Result { + use arrow_schema::{DataType::*, IntervalUnit::*, TimeUnit::*}; Ok(match (left.data_type(), right.data_type()) { (a, b) if a != b => { return Err(ArrowError::InvalidArgumentError( @@ -315,130 +312,119 @@ pub fn build_compare(left: &dyn Array, right: &dyn Array) -> Result Result<()> { + fn test_fixed_size_binary() { let items = vec![vec![1u8], vec![2u8]]; let array = FixedSizeBinaryArray::try_from_iter(items.into_iter()).unwrap(); - let cmp = build_compare(&array, &array)?; + let cmp = build_compare(&array, &array).unwrap(); assert_eq!(Ordering::Less, (cmp)(0, 1)); - Ok(()) } #[test] - fn test_fixed_size_binary_fixed_size_binary() -> Result<()> { + fn test_fixed_size_binary_fixed_size_binary() { let items = vec![vec![1u8]]; let array1 = FixedSizeBinaryArray::try_from_iter(items.into_iter()).unwrap(); let items = vec![vec![2u8]]; let array2 = FixedSizeBinaryArray::try_from_iter(items.into_iter()).unwrap(); - let cmp = build_compare(&array1, &array2)?; + let cmp = build_compare(&array1, &array2).unwrap(); assert_eq!(Ordering::Less, (cmp)(0, 0)); - Ok(()) } #[test] - fn test_i32() -> Result<()> { + fn test_i32() { let array = Int32Array::from(vec![1, 2]); - let cmp = build_compare(&array, &array)?; + let cmp = build_compare(&array, &array).unwrap(); assert_eq!(Ordering::Less, (cmp)(0, 1)); - Ok(()) } #[test] - fn test_i32_i32() -> Result<()> { + fn test_i32_i32() { let array1 = Int32Array::from(vec![1]); let array2 = Int32Array::from(vec![2]); - let cmp = build_compare(&array1, &array2)?; + let cmp = build_compare(&array1, &array2).unwrap(); assert_eq!(Ordering::Less, (cmp)(0, 0)); - Ok(()) } #[test] - fn test_f64() -> Result<()> { + fn test_f64() { let array = Float64Array::from(vec![1.0, 2.0]); - let cmp = build_compare(&array, &array)?; + let cmp = build_compare(&array, &array).unwrap(); assert_eq!(Ordering::Less, (cmp)(0, 1)); - Ok(()) } #[test] - fn test_f64_nan() -> Result<()> { + fn test_f64_nan() { let array = Float64Array::from(vec![1.0, f64::NAN]); - let cmp = build_compare(&array, &array)?; + let cmp = build_compare(&array, &array).unwrap(); assert_eq!(Ordering::Less, (cmp)(0, 1)); - Ok(()) } #[test] - fn test_f64_zeros() -> Result<()> { + fn test_f64_zeros() { let array = Float64Array::from(vec![-0.0, 0.0]); - let cmp = build_compare(&array, &array)?; + let cmp = build_compare(&array, &array).unwrap(); assert_eq!(Ordering::Equal, (cmp)(0, 1)); assert_eq!(Ordering::Equal, (cmp)(1, 0)); - Ok(()) } #[test] - fn test_decimal() -> Result<()> { + fn test_decimal() { let array = vec![Some(5_i128), Some(2_i128), Some(3_i128)] .into_iter() .collect::() .with_precision_and_scale(23, 6) .unwrap(); - let cmp = build_compare(&array, &array)?; + let cmp = build_compare(&array, &array).unwrap(); assert_eq!(Ordering::Less, (cmp)(1, 0)); assert_eq!(Ordering::Greater, (cmp)(0, 2)); - Ok(()) } #[test] - fn test_dict() -> Result<()> { + fn test_dict() { let data = vec!["a", "b", "c", "a", "a", "c", "c"]; let array = data.into_iter().collect::>(); - let cmp = build_compare(&array, &array)?; + let cmp = build_compare(&array, &array).unwrap(); assert_eq!(Ordering::Less, (cmp)(0, 1)); assert_eq!(Ordering::Equal, (cmp)(3, 4)); assert_eq!(Ordering::Greater, (cmp)(2, 3)); - Ok(()) } #[test] - fn test_multiple_dict() -> Result<()> { + fn test_multiple_dict() { let d1 = vec!["a", "b", "c", "d"]; let a1 = d1.into_iter().collect::>(); let d2 = vec!["e", "f", "g", "a"]; let a2 = d2.into_iter().collect::>(); - let cmp = build_compare(&a1, &a2)?; + let cmp = build_compare(&a1, &a2).unwrap(); assert_eq!(Ordering::Less, (cmp)(0, 0)); assert_eq!(Ordering::Equal, (cmp)(0, 3)); assert_eq!(Ordering::Greater, (cmp)(1, 3)); - Ok(()) } #[test] - fn test_primitive_dict() -> Result<()> { + fn test_primitive_dict() { let values = Int32Array::from(vec![1_i32, 0, 2, 5]); let keys = Int8Array::from_iter_values([0, 0, 1, 3]); let array1 = DictionaryArray::::try_new(&keys, &values).unwrap(); @@ -447,13 +433,12 @@ pub mod tests { let keys = Int8Array::from_iter_values([0, 1, 1, 3]); let array2 = DictionaryArray::::try_new(&keys, &values).unwrap(); - let cmp = build_compare(&array1, &array2)?; + let cmp = build_compare(&array1, &array2).unwrap(); assert_eq!(Ordering::Less, (cmp)(0, 0)); assert_eq!(Ordering::Less, (cmp)(0, 3)); assert_eq!(Ordering::Equal, (cmp)(3, 3)); assert_eq!(Ordering::Greater, (cmp)(3, 1)); assert_eq!(Ordering::Greater, (cmp)(3, 2)); - Ok(()) } } diff --git a/arrow/src/compute/kernels/partition.rs b/arrow-ord/src/partition.rs similarity index 83% rename from arrow/src/compute/kernels/partition.rs rename to arrow-ord/src/partition.rs index 0e48e627e65..e62e585165d 100644 --- a/arrow/src/compute/kernels/partition.rs +++ b/arrow-ord/src/partition.rs @@ -17,11 +17,9 @@ //! Defines partition kernel for `ArrayRef` -use crate::compute::kernels::sort::LexicographicalComparator; -use crate::compute::SortColumn; -use crate::error::{ArrowError, Result}; +use crate::sort::{LexicographicalComparator, SortColumn}; +use arrow_schema::ArrowError; use std::cmp::Ordering; -use std::iter::Iterator; use std::ops::Range; /// Given a list of already sorted columns, find partition ranges that would partition @@ -35,7 +33,7 @@ use std::ops::Range; /// range. pub fn lexicographical_partition_ranges( columns: &[SortColumn], -) -> Result> + '_> { +) -> Result> + '_, ArrowError> { LexicographicalPartitionIterator::try_new(columns) } @@ -47,7 +45,9 @@ struct LexicographicalPartitionIterator<'a> { } impl<'a> LexicographicalPartitionIterator<'a> { - fn try_new(columns: &'a [SortColumn]) -> Result { + fn try_new( + columns: &'a [SortColumn], + ) -> Result { if columns.is_empty() { return Err(ArrowError::InvalidArgumentError( "Sort requires at least one column".to_string(), @@ -162,9 +162,9 @@ impl<'a> Iterator for LexicographicalPartitionIterator<'a> { #[cfg(test)] mod tests { use super::*; - use crate::array::*; - use crate::compute::SortOptions; - use crate::datatypes::DataType; + use crate::sort::SortOptions; + use arrow_array::*; + use arrow_schema::DataType; use std::sync::Arc; #[test] @@ -233,7 +233,7 @@ mod tests { } #[test] - fn test_lexicographical_partition_single_column() -> Result<()> { + fn test_lexicographical_partition_single_column() { let input = vec![SortColumn { values: Arc::new(Int64Array::from(vec![1, 2, 2, 2, 2, 2, 2, 2, 9])) as ArrayRef, @@ -242,18 +242,15 @@ mod tests { nulls_first: true, }), }]; - { - let results = lexicographical_partition_ranges(&input)?; - assert_eq!( - vec![(0_usize..1_usize), (1_usize..8_usize), (8_usize..9_usize)], - results.collect::>() - ); - } - Ok(()) + let results = lexicographical_partition_ranges(&input).unwrap(); + assert_eq!( + vec![(0_usize..1_usize), (1_usize..8_usize), (8_usize..9_usize)], + results.collect::>() + ); } #[test] - fn test_lexicographical_partition_all_equal_values() -> Result<()> { + fn test_lexicographical_partition_all_equal_values() { let input = vec![SortColumn { values: Arc::new(Int64Array::from_value(1, 1000)) as ArrayRef, options: Some(SortOptions { @@ -262,15 +259,12 @@ mod tests { }), }]; - { - let results = lexicographical_partition_ranges(&input)?; - assert_eq!(vec![(0_usize..1000_usize)], results.collect::>()); - } - Ok(()) + let results = lexicographical_partition_ranges(&input).unwrap(); + assert_eq!(vec![(0_usize..1000_usize)], results.collect::>()); } #[test] - fn test_lexicographical_partition_all_null_values() -> Result<()> { + fn test_lexicographical_partition_all_null_values() { let input = vec![ SortColumn { values: new_null_array(&DataType::Int8, 1000), @@ -287,15 +281,12 @@ mod tests { }), }, ]; - { - let results = lexicographical_partition_ranges(&input)?; - assert_eq!(vec![(0_usize..1000_usize)], results.collect::>()); - } - Ok(()) + let results = lexicographical_partition_ranges(&input).unwrap(); + assert_eq!(vec![(0_usize..1000_usize)], results.collect::>()); } #[test] - fn test_lexicographical_partition_unique_column_1() -> Result<()> { + fn test_lexicographical_partition_unique_column_1() { let input = vec![ SortColumn { values: Arc::new(Int64Array::from(vec![None, Some(-1)])) as ArrayRef, @@ -313,18 +304,15 @@ mod tests { }), }, ]; - { - let results = lexicographical_partition_ranges(&input)?; - assert_eq!( - vec![(0_usize..1_usize), (1_usize..2_usize)], - results.collect::>() - ); - } - Ok(()) + let results = lexicographical_partition_ranges(&input).unwrap(); + assert_eq!( + vec![(0_usize..1_usize), (1_usize..2_usize)], + results.collect::>() + ); } #[test] - fn test_lexicographical_partition_unique_column_2() -> Result<()> { + fn test_lexicographical_partition_unique_column_2() { let input = vec![ SortColumn { values: Arc::new(Int64Array::from(vec![None, Some(-1), Some(-1)])) @@ -346,18 +334,15 @@ mod tests { }), }, ]; - { - let results = lexicographical_partition_ranges(&input)?; - assert_eq!( - vec![(0_usize..1_usize), (1_usize..2_usize), (2_usize..3_usize),], - results.collect::>() - ); - } - Ok(()) + let results = lexicographical_partition_ranges(&input).unwrap(); + assert_eq!( + vec![(0_usize..1_usize), (1_usize..2_usize), (2_usize..3_usize),], + results.collect::>() + ); } #[test] - fn test_lexicographical_partition_non_unique_column_1() -> Result<()> { + fn test_lexicographical_partition_non_unique_column_1() { let input = vec![ SortColumn { values: Arc::new(Int64Array::from(vec![ @@ -384,13 +369,10 @@ mod tests { }), }, ]; - { - let results = lexicographical_partition_ranges(&input)?; - assert_eq!( - vec![(0_usize..1_usize), (1_usize..3_usize), (3_usize..4_usize),], - results.collect::>() - ); - } - Ok(()) + let results = lexicographical_partition_ranges(&input).unwrap(); + assert_eq!( + vec![(0_usize..1_usize), (1_usize..3_usize), (3_usize..4_usize),], + results.collect::>() + ); } } diff --git a/arrow/src/compute/kernels/sort.rs b/arrow-ord/src/sort.rs similarity index 98% rename from arrow/src/compute/kernels/sort.rs rename to arrow-ord/src/sort.rs index 81895760e58..a3a8262bfeb 100644 --- a/arrow/src/compute/kernels/sort.rs +++ b/arrow-ord/src/sort.rs @@ -17,14 +17,15 @@ //! Defines sort kernel for `ArrayRef` -use crate::array::*; -use crate::buffer::MutableBuffer; -use crate::compute::take; -use crate::datatypes::*; -use crate::downcast_dictionary_array; -use crate::error::{ArrowError, Result}; +use crate::ord::{build_compare, DynComparator}; +use arrow_array::cast::*; +use arrow_array::types::*; +use arrow_array::*; +use arrow_buffer::{ArrowNativeType, MutableBuffer}; +use arrow_data::ArrayData; +use arrow_schema::{ArrowError, DataType, IntervalUnit, TimeUnit}; +use arrow_select::take::take; use std::cmp::Ordering; -use TimeUnit::*; /// Sort the `ArrayRef` using `SortOptions`. /// @@ -41,18 +42,17 @@ use TimeUnit::*; /// # Example /// ```rust /// # use std::sync::Arc; -/// # use arrow::array::{Int32Array, ArrayRef}; -/// # use arrow::error::Result; -/// # use arrow::compute::kernels::sort::sort; -/// # fn main() -> Result<()> { +/// # use arrow_array::{Int32Array, ArrayRef}; +/// # use arrow_ord::sort::sort; /// let array: ArrayRef = Arc::new(Int32Array::from(vec![5, 4, 3, 2, 1])); /// let sorted_array = sort(&array, None).unwrap(); /// let sorted_array = sorted_array.as_any().downcast_ref::().unwrap(); /// assert_eq!(sorted_array, &Int32Array::from(vec![1, 2, 3, 4, 5])); -/// # Ok(()) -/// # } /// ``` -pub fn sort(values: &ArrayRef, options: Option) -> Result { +pub fn sort( + values: &ArrayRef, + options: Option, +) -> Result { let indices = sort_to_indices(values, options, None)?; take(values.as_ref(), &indices, None) } @@ -69,10 +69,8 @@ pub fn sort(values: &ArrayRef, options: Option) -> Result /// # Example /// ```rust /// # use std::sync::Arc; -/// # use arrow::array::{Int32Array, ArrayRef}; -/// # use arrow::error::Result; -/// # use arrow::compute::kernels::sort::{sort_limit, SortOptions}; -/// # fn main() -> Result<()> { +/// # use arrow_array::{Int32Array, ArrayRef}; +/// # use arrow_ord::sort::{sort_limit, SortOptions}; /// let array: ArrayRef = Arc::new(Int32Array::from(vec![5, 4, 3, 2, 1])); /// /// // Find the the top 2 items @@ -88,14 +86,12 @@ pub fn sort(values: &ArrayRef, options: Option) -> Result /// let sorted_array = sort_limit(&array, options, Some(2)).unwrap(); /// let sorted_array = sorted_array.as_any().downcast_ref::().unwrap(); /// assert_eq!(sorted_array, &Int32Array::from(vec![5, 4])); -/// # Ok(()) -/// # } /// ``` pub fn sort_limit( values: &ArrayRef, options: Option, limit: Option, -) -> Result { +) -> Result { let indices = sort_to_indices(values, options, limit)?; take(values.as_ref(), &indices, None) } @@ -139,7 +135,7 @@ pub fn sort_to_indices( values: &ArrayRef, options: Option, limit: Option, -) -> Result { +) -> Result { let options = options.unwrap_or_default(); let (v, n) = partition_validity(values); @@ -198,32 +194,32 @@ pub fn sort_to_indices( DataType::Date64 => { sort_primitive::(values, v, n, cmp, &options, limit) } - DataType::Time32(Second) => { + DataType::Time32(TimeUnit::Second) => { sort_primitive::(values, v, n, cmp, &options, limit) } - DataType::Time32(Millisecond) => { + DataType::Time32(TimeUnit::Millisecond) => { sort_primitive::(values, v, n, cmp, &options, limit) } - DataType::Time64(Microsecond) => { + DataType::Time64(TimeUnit::Microsecond) => { sort_primitive::(values, v, n, cmp, &options, limit) } - DataType::Time64(Nanosecond) => { + DataType::Time64(TimeUnit::Nanosecond) => { sort_primitive::(values, v, n, cmp, &options, limit) } - DataType::Timestamp(Second, _) => { + DataType::Timestamp(TimeUnit::Second, _) => { sort_primitive::(values, v, n, cmp, &options, limit) } - DataType::Timestamp(Millisecond, _) => { + DataType::Timestamp(TimeUnit::Millisecond, _) => { sort_primitive::( values, v, n, cmp, &options, limit, ) } - DataType::Timestamp(Microsecond, _) => { + DataType::Timestamp(TimeUnit::Microsecond, _) => { sort_primitive::( values, v, n, cmp, &options, limit, ) } - DataType::Timestamp(Nanosecond, _) => { + DataType::Timestamp(TimeUnit::Nanosecond, _) => { sort_primitive::( values, v, n, cmp, &options, limit, ) @@ -857,11 +853,12 @@ pub struct SortColumn { /// Example: /// /// ``` -/// use std::convert::From; -/// use std::sync::Arc; -/// use arrow::array::{ArrayRef, StringArray, PrimitiveArray, as_primitive_array}; -/// use arrow::compute::kernels::sort::{SortColumn, SortOptions, lexsort}; -/// use arrow::datatypes::Int64Type; +/// # use std::convert::From; +/// # use std::sync::Arc; +/// # use arrow_array::{ArrayRef, StringArray, PrimitiveArray}; +/// # use arrow_array::types::Int64Type; +/// # use arrow_array::cast::as_primitive_array; +/// # use arrow_ord::sort::{SortColumn, SortOptions, lexsort}; /// /// let sorted_columns = lexsort(&vec![ /// SortColumn { @@ -893,10 +890,13 @@ pub struct SortColumn { /// assert!(sorted_columns[0].is_null(0)); /// ``` /// -/// Note: for multi-column sorts without a limit, using the [row format][crate::row] +/// Note: for multi-column sorts without a limit, using the [row format](https://docs.rs/arrow/latest/arrow/row/) /// may be significantly faster /// -pub fn lexsort(columns: &[SortColumn], limit: Option) -> Result> { +pub fn lexsort( + columns: &[SortColumn], + limit: Option, +) -> Result, ArrowError> { let indices = lexsort_to_indices(columns, limit)?; columns .iter() @@ -907,12 +907,12 @@ pub fn lexsort(columns: &[SortColumn], limit: Option) -> Result, -) -> Result { +) -> Result { if columns.is_empty() { return Err(ArrowError::InvalidArgumentError( "Sort requires at least one column".to_string(), @@ -1018,7 +1018,7 @@ impl LexicographicalComparator<'_> { /// results with two indices. pub(crate) fn try_new( columns: &[SortColumn], - ) -> Result> { + ) -> Result, ArrowError> { let compare_items = columns .iter() .map(|column| { @@ -1032,7 +1032,7 @@ impl LexicographicalComparator<'_> { column.options.unwrap_or_default(), )) }) - .collect::>>()?; + .collect::, ArrowError>>()?; Ok(LexicographicalComparator { compare_items }) } } diff --git a/arrow-string/src/regexp.rs b/arrow-string/src/regexp.rs index bb4b2b0a826..ddb47969cf2 100644 --- a/arrow-string/src/regexp.rs +++ b/arrow-string/src/regexp.rs @@ -286,4 +286,119 @@ mod tests { let result = actual.as_any().downcast_ref::().unwrap(); assert_eq!(&expected, result); } + + macro_rules! test_flag_utf8 { + ($test_name:ident, $left:expr, $right:expr, $op:expr, $expected:expr) => { + #[test] + fn $test_name() { + let left = StringArray::from($left); + let right = StringArray::from($right); + let res = $op(&left, &right, None).unwrap(); + let expected = $expected; + assert_eq!(expected.len(), res.len()); + for i in 0..res.len() { + let v = res.value(i); + assert_eq!(v, expected[i]); + } + } + }; + ($test_name:ident, $left:expr, $right:expr, $flag:expr, $op:expr, $expected:expr) => { + #[test] + fn $test_name() { + let left = StringArray::from($left); + let right = StringArray::from($right); + let flag = Some(StringArray::from($flag)); + let res = $op(&left, &right, flag.as_ref()).unwrap(); + let expected = $expected; + assert_eq!(expected.len(), res.len()); + for i in 0..res.len() { + let v = res.value(i); + assert_eq!(v, expected[i]); + } + } + }; + } + + macro_rules! test_flag_utf8_scalar { + ($test_name:ident, $left:expr, $right:expr, $op:expr, $expected:expr) => { + #[test] + fn $test_name() { + let left = StringArray::from($left); + let res = $op(&left, $right, None).unwrap(); + let expected = $expected; + assert_eq!(expected.len(), res.len()); + for i in 0..res.len() { + let v = res.value(i); + assert_eq!( + v, + expected[i], + "unexpected result when comparing {} at position {} to {} ", + left.value(i), + i, + $right + ); + } + } + }; + ($test_name:ident, $left:expr, $right:expr, $flag:expr, $op:expr, $expected:expr) => { + #[test] + fn $test_name() { + let left = StringArray::from($left); + let flag = Some($flag); + let res = $op(&left, $right, flag).unwrap(); + let expected = $expected; + assert_eq!(expected.len(), res.len()); + for i in 0..res.len() { + let v = res.value(i); + assert_eq!( + v, + expected[i], + "unexpected result when comparing {} at position {} to {} ", + left.value(i), + i, + $right + ); + } + } + }; + } + + test_flag_utf8!( + test_utf8_array_regexp_is_match, + vec!["arrow", "arrow", "arrow", "arrow", "arrow", "arrow"], + vec!["^ar", "^AR", "ow$", "OW$", "foo", ""], + regexp_is_match_utf8, + vec![true, false, true, false, false, true] + ); + test_flag_utf8!( + test_utf8_array_regexp_is_match_insensitive, + vec!["arrow", "arrow", "arrow", "arrow", "arrow", "arrow"], + vec!["^ar", "^AR", "ow$", "OW$", "foo", ""], + vec!["i"; 6], + regexp_is_match_utf8, + vec![true, true, true, true, false, true] + ); + + test_flag_utf8_scalar!( + test_utf8_array_regexp_is_match_scalar, + vec!["arrow", "ARROW", "parquet", "PARQUET"], + "^ar", + regexp_is_match_utf8_scalar, + vec![true, false, false, false] + ); + test_flag_utf8_scalar!( + test_utf8_array_regexp_is_match_empty_scalar, + vec!["arrow", "ARROW", "parquet", "PARQUET"], + "", + regexp_is_match_utf8_scalar, + vec![true, true, true, true] + ); + test_flag_utf8_scalar!( + test_utf8_array_regexp_is_match_insensitive_scalar, + vec!["arrow", "ARROW", "parquet", "PARQUET"], + "^ar", + "i", + regexp_is_match_utf8_scalar, + vec![true, true, false, false] + ); } diff --git a/arrow/Cargo.toml b/arrow/Cargo.toml index 17f88c084cb..246e63c8c7b 100644 --- a/arrow/Cargo.toml +++ b/arrow/Cargo.toml @@ -55,6 +55,7 @@ arrow-json = { version = "28.0.0", path = "../arrow-json", optional = true } arrow-schema = { version = "28.0.0", path = "../arrow-schema" } arrow-select = { version = "28.0.0", path = "../arrow-select" } arrow-string = { version = "28.0.0", path = "../arrow-string" } +arrow-ord = { version = "28.0.0", path = "../arrow-ord" } rand = { version = "0.8", default-features = false, features = ["std", "std_rng"], optional = true } num = { version = "0.4", default-features = false, features = ["std"] } half = { version = "2.1", default-features = false, features = ["num-traits"] } @@ -76,7 +77,7 @@ ipc_compression = ["ipc", "arrow-ipc/lz4", "arrow-ipc/zstd"] csv = ["arrow-csv"] ipc = ["arrow-ipc"] json = ["arrow-json"] -simd = ["packed_simd"] +simd = ["packed_simd", "arrow-array/simd", "arrow-ord/simd"] prettyprint = ["comfy-table"] # The test utils feature enables code used in benchmarks and tests but # not the core arrow code itself. Be aware that `rand` must be kept as @@ -92,7 +93,7 @@ force_validate = ["arrow-data/force_validate"] ffi = ["bitflags"] # Enable dyn-comparison of dictionary arrays with other arrays # Note: this does not impact comparison against scalars -dyn_cmp_dict = ["arrow-string/dyn_cmp_dict"] +dyn_cmp_dict = ["arrow-string/dyn_cmp_dict", "arrow-ord/dyn_cmp_dict"] # Enable dyn-arithmetic kernels for dictionary arrays # Note: this does not impact arithmetic with scalars dyn_arith_dict = [] diff --git a/arrow/src/array/mod.rs b/arrow/src/array/mod.rs index af774de0a26..1a10725df67 100644 --- a/arrow/src/array/mod.rs +++ b/arrow/src/array/mod.rs @@ -21,7 +21,6 @@ #[cfg(feature = "ffi")] mod ffi; -mod ord; // --------------------- Array & ArrayData --------------------- pub use arrow_array::array::*; @@ -39,4 +38,4 @@ pub use self::ffi::{export_array_into_raw, make_array_from_raw}; // --------------------- Array's values comparison --------------------- -pub use self::ord::{build_compare, DynComparator}; +pub use arrow_ord::ord::{build_compare, DynComparator}; diff --git a/arrow/src/compute/kernels/mod.rs b/arrow/src/compute/kernels/mod.rs index 29468861f82..837fb73d56d 100644 --- a/arrow/src/compute/kernels/mod.rs +++ b/arrow/src/compute/kernels/mod.rs @@ -22,13 +22,18 @@ pub mod arithmetic; pub mod arity; pub mod bitwise; pub mod boolean; -pub mod comparison; pub mod limit; -pub mod partition; -pub mod sort; pub mod temporal; pub use arrow_cast::cast; pub use arrow_cast::parse as cast_utils; +pub use arrow_ord::{partition, sort}; pub use arrow_select::{concat, filter, interleave, take, window, zip}; pub use arrow_string::{concat_elements, length, regexp, substring}; + +/// Comparison kernels for `Array`s. +pub mod comparison { + pub use arrow_ord::comparison::*; + pub use arrow_string::like::*; + pub use arrow_string::regexp::{regexp_is_match_utf8, regexp_is_match_utf8_scalar}; +} diff --git a/arrow/src/datatypes/mod.rs b/arrow/src/datatypes/mod.rs index 5d625a051fd..c2524009681 100644 --- a/arrow/src/datatypes/mod.rs +++ b/arrow/src/datatypes/mod.rs @@ -22,11 +22,10 @@ //! * [`Field`](crate::datatypes::Field) to describe one field within a schema. //! * [`DataType`](crate::datatypes::DataType) to describe the type of a field. -mod numeric; -pub use numeric::*; - pub use arrow_array::types::*; -pub use arrow_array::{ArrowNativeTypeOp, ArrowPrimitiveType}; +pub use arrow_array::{ + ArrowFloatNumericType, ArrowNativeTypeOp, ArrowNumericType, ArrowPrimitiveType, +}; pub use arrow_buffer::{i256, ArrowNativeType, ToByteSlice}; pub use arrow_data::decimal::*; pub use arrow_schema::{