Skip to content

Commit

Permalink
Specialize filter kernel for binary arrays (#2969) (#2971)
Browse files Browse the repository at this point in the history
* Generalize filter byte array (#2969)

* Fix doc

* Update comment
  • Loading branch information
tustvold committed Nov 1, 2022
1 parent c7f97c2 commit 62e878e
Showing 1 changed file with 28 additions and 28 deletions.
56 changes: 28 additions & 28 deletions arrow-select/src/filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,11 @@

//! Defines filter kernels

use std::ops::AddAssign;
use std::sync::Arc;

use num::Zero;

use arrow_array::builder::BooleanBufferBuilder;
use arrow_array::cast::{as_generic_binary_array, as_largestring_array, as_string_array};
use arrow_array::types::ByteArrayType;
use arrow_array::*;
use arrow_buffer::bit_util;
use arrow_buffer::{buffer::buffer_bin_and, Buffer, MutableBuffer};
Expand Down Expand Up @@ -355,18 +354,16 @@ fn filter_array(
Ok(Arc::new(filter_boolean(values, predicate)))
}
DataType::Utf8 => {
let values = values
.as_any()
.downcast_ref::<GenericStringArray<i32>>()
.unwrap();
Ok(Arc::new(filter_string::<i32>(values, predicate)))
Ok(Arc::new(filter_bytes(as_string_array(values), predicate)))
}
DataType::LargeUtf8 => {
let values = values
.as_any()
.downcast_ref::<GenericStringArray<i64>>()
.unwrap();
Ok(Arc::new(filter_string::<i64>(values, predicate)))
Ok(Arc::new(filter_bytes(as_largestring_array(values), predicate)))
}
DataType::Binary => {
Ok(Arc::new(filter_bytes(as_generic_binary_array::<i32>(values), predicate)))
}
DataType::LargeBinary => {
Ok(Arc::new(filter_bytes(as_generic_binary_array::<i64>(values), predicate)))
}
DataType::Dictionary(_, _) => downcast_dictionary_array! {
values => Ok(Arc::new(filter_dict(values, predicate))),
Expand Down Expand Up @@ -545,27 +542,30 @@ where
PrimitiveArray::from(data)
}

/// [`FilterString`] is created from a source [`GenericStringArray`] and can be
/// used to build a new [`GenericStringArray`] by copying values from the source
/// [`FilterBytes`] is created from a source [`GenericByteArray`] and can be
/// used to build a new [`GenericByteArray`] by copying values from the source
///
/// TODO(raphael): Could this be used for the take kernel as well?
struct FilterString<'a, OffsetSize> {
struct FilterBytes<'a, OffsetSize> {
src_offsets: &'a [OffsetSize],
src_values: &'a [u8],
dst_offsets: MutableBuffer,
dst_values: MutableBuffer,
cur_offset: OffsetSize,
}

impl<'a, OffsetSize> FilterString<'a, OffsetSize>
impl<'a, OffsetSize> FilterBytes<'a, OffsetSize>
where
OffsetSize: Zero + AddAssign + OffsetSizeTrait,
OffsetSize: OffsetSizeTrait,
{
fn new(capacity: usize, array: &'a GenericStringArray<OffsetSize>) -> Self {
fn new<T>(capacity: usize, array: &'a GenericByteArray<T>) -> Self
where
T: ByteArrayType<Offset = OffsetSize>,
{
let num_offsets_bytes = (capacity + 1) * std::mem::size_of::<OffsetSize>();
let mut dst_offsets = MutableBuffer::new(num_offsets_bytes);
let dst_values = MutableBuffer::new(0);
let cur_offset = OffsetSize::zero();
let cur_offset = OffsetSize::from_usize(0).unwrap();
dst_offsets.push(cur_offset);

Self {
Expand Down Expand Up @@ -622,21 +622,21 @@ where
}
}

/// `filter` implementation for string arrays
/// `filter` implementation for byte arrays
///
/// Note: NULLs with a non-zero slot length in `array` will have the corresponding
/// data copied across. This allows handling the null mask separately from the data
fn filter_string<OffsetSize>(
array: &GenericStringArray<OffsetSize>,
fn filter_bytes<T>(
array: &GenericByteArray<T>,
predicate: &FilterPredicate,
) -> GenericStringArray<OffsetSize>
) -> GenericByteArray<T>
where
OffsetSize: Zero + AddAssign + OffsetSizeTrait,
T: ByteArrayType,
{
let data = array.data();
assert_eq!(data.buffers().len(), 2);
assert_eq!(data.child_data().len(), 0);
let mut filter = FilterString::new(predicate.count, array);
let mut filter = FilterBytes::new(predicate.count, array);

match &predicate.strategy {
IterationStrategy::SlicesIterator => {
Expand All @@ -650,7 +650,7 @@ where
IterationStrategy::All | IterationStrategy::None => unreachable!(),
}

let mut builder = ArrayDataBuilder::new(data.data_type().clone())
let mut builder = ArrayDataBuilder::new(T::DATA_TYPE)
.len(predicate.count)
.add_buffer(filter.dst_offsets.into())
.add_buffer(filter.dst_values.into());
Expand All @@ -660,7 +660,7 @@ where
}

let data = unsafe { builder.build_unchecked() };
GenericStringArray::from(data)
GenericByteArray::from(data)
}

/// `filter` implementation for dictionaries
Expand Down

0 comments on commit 62e878e

Please sign in to comment.