From 1417f8897c069a674aa52c046da3e661a20ca1f4 Mon Sep 17 00:00:00 2001 From: Ritchie Vink Date: Sat, 10 Dec 2022 11:30:59 +0100 Subject: [PATCH] Improved performance of check_indexes (#1313) --- src/array/dictionary/mod.rs | 53 ++++++++++++++++++++++++++++++------- src/array/specification.rs | 25 +++++++++++++++++ 2 files changed, 68 insertions(+), 10 deletions(-) diff --git a/src/array/dictionary/mod.rs b/src/array/dictionary/mod.rs index 63cdfc7fa7d..8dc1bbd7c9a 100644 --- a/src/array/dictionary/mod.rs +++ b/src/array/dictionary/mod.rs @@ -16,6 +16,7 @@ mod ffi; pub(super) mod fmt; mod iterator; mod mutable; +use crate::array::specification::check_indexes_unchecked; pub use iterator::*; pub use mutable::*; @@ -23,7 +24,11 @@ use super::{new_empty_array, primitive::PrimitiveArray, Array}; use super::{new_null_array, specification::check_indexes}; /// Trait denoting [`NativeType`]s that can be used as keys of a dictionary. -pub trait DictionaryKey: NativeType + TryInto + TryFrom { +/// # Safety +/// +/// Any implementation of this trait must ensure that `always_fits_usize` only +/// returns `true` if all values succeeds on `value::try_into::().unwrap()`. +pub unsafe trait DictionaryKey: NativeType + TryInto + TryFrom { /// The corresponding [`IntegerType`] of this key const KEY_TYPE: IntegerType; @@ -37,31 +42,53 @@ pub trait DictionaryKey: NativeType + TryInto + TryFrom { Err(_) => unreachable_unchecked(), } } + + /// If the key type always can be converted to `usize`. + fn always_fits_usize() -> bool { + false + } } -impl DictionaryKey for i8 { +unsafe impl DictionaryKey for i8 { const KEY_TYPE: IntegerType = IntegerType::Int8; } -impl DictionaryKey for i16 { +unsafe impl DictionaryKey for i16 { const KEY_TYPE: IntegerType = IntegerType::Int16; } -impl DictionaryKey for i32 { +unsafe impl DictionaryKey for i32 { const KEY_TYPE: IntegerType = IntegerType::Int32; } -impl DictionaryKey for i64 { +unsafe impl DictionaryKey for i64 { const KEY_TYPE: IntegerType = IntegerType::Int64; } -impl DictionaryKey for u8 { +unsafe impl DictionaryKey for u8 { const KEY_TYPE: IntegerType = IntegerType::UInt8; + + fn always_fits_usize() -> bool { + true + } } -impl DictionaryKey for u16 { +unsafe impl DictionaryKey for u16 { const KEY_TYPE: IntegerType = IntegerType::UInt16; + + fn always_fits_usize() -> bool { + true + } } -impl DictionaryKey for u32 { +unsafe impl DictionaryKey for u32 { const KEY_TYPE: IntegerType = IntegerType::UInt32; + + fn always_fits_usize() -> bool { + true + } } -impl DictionaryKey for u64 { +unsafe impl DictionaryKey for u64 { const KEY_TYPE: IntegerType = IntegerType::UInt64; + + #[cfg(target_pointer_width = "64")] + fn always_fits_usize() -> bool { + true + } } /// An [`Array`] whose values are stored as indices. This [`Array`] is useful when the cardinality of @@ -120,7 +147,13 @@ impl DictionaryArray { check_data_type(K::KEY_TYPE, &data_type, values.data_type())?; if keys.null_count() != keys.len() { - check_indexes(keys.values(), values.len())?; + if K::always_fits_usize() { + // safety: we just checked that conversion to `usize` always + // succeeds + unsafe { check_indexes_unchecked(keys.values(), values.len()) }?; + } else { + check_indexes(keys.values(), values.len())?; + } } Ok(Self { diff --git a/src/array/specification.rs b/src/array/specification.rs index 021cbd5c80c..274321be93f 100644 --- a/src/array/specification.rs +++ b/src/array/specification.rs @@ -1,3 +1,4 @@ +use crate::array::DictionaryKey; use crate::error::{Error, Result}; use crate::offset::{Offset, Offsets, OffsetsBuffer}; @@ -107,6 +108,30 @@ pub(crate) fn try_check_utf8>( } } +/// Check dictionary indexes without checking usize conversion. +/// # Safety +/// The caller must ensure that `K::as_usize` always succeeds. +pub(crate) unsafe fn check_indexes_unchecked( + keys: &[K], + len: usize, +) -> Result<()> { + let mut invalid = false; + + // this loop is auto-vectorized + keys.iter().for_each(|k| { + if k.as_usize() > len { + invalid = true; + } + }); + + if invalid { + let key = keys.iter().map(|k| k.as_usize()).max().unwrap(); + Err(Error::oos(format!("One of the dictionary keys is {} but it must be < than the length of the dictionary values, which is {}", key, len))) + } else { + Ok(()) + } +} + pub fn check_indexes(keys: &[K], len: usize) -> Result<()> where K: std::fmt::Debug + Copy + TryInto,