Skip to content

Commit

Permalink
Combine take_utf8 and take_binary (#2969) (#2970)
Browse files Browse the repository at this point in the history
  • Loading branch information
tustvold committed Oct 28, 2022
1 parent 7798043 commit dbe518c
Showing 1 changed file with 25 additions and 67 deletions.
92 changes: 25 additions & 67 deletions arrow-select/src/take.rs
Expand Up @@ -17,14 +17,15 @@

//! Defines take kernel for [Array]

use std::{ops::AddAssign, sync::Arc};
use std::sync::Arc;

use arrow_array::types::*;
use arrow_array::*;
use arrow_buffer::{bit_util, ArrowNativeType, Buffer, MutableBuffer};
use arrow_data::{ArrayData, ArrayDataBuilder};
use arrow_schema::{ArrowError, DataType, Field};

use arrow_array::cast::{as_generic_binary_array, as_largestring_array, as_string_array};
use num::{ToPrimitive, Zero};

/// Take elements by index from [Array], creating a new [Array] from those indexes.
Expand Down Expand Up @@ -140,18 +141,10 @@ where
Ok(Arc::new(array))
}
DataType::Utf8 => {
let values = values
.as_any()
.downcast_ref::<GenericStringArray<i32>>()
.unwrap();
Ok(Arc::new(take_string::<i32, _>(values, indices)?))
Ok(Arc::new(take_bytes(as_string_array(values), indices)?))
}
DataType::LargeUtf8 => {
let values = values
.as_any()
.downcast_ref::<GenericStringArray<i64>>()
.unwrap();
Ok(Arc::new(take_string::<i64, _>(values, indices)?))
Ok(Arc::new(take_bytes(as_largestring_array(values), indices)?))
}
DataType::List(_) => {
let values = values
Expand Down Expand Up @@ -209,18 +202,10 @@ where
t => unimplemented!("Take not supported for dictionary type {:?}", t)
}
DataType::Binary => {
let values = values
.as_any()
.downcast_ref::<GenericBinaryArray<i32>>()
.unwrap();
Ok(Arc::new(take_binary(values, indices)?))
Ok(Arc::new(take_bytes(as_generic_binary_array::<i32>(values), indices)?))
}
DataType::LargeBinary => {
let values = values
.as_any()
.downcast_ref::<GenericBinaryArray<i64>>()
.unwrap();
Ok(Arc::new(take_binary(values, indices)?))
Ok(Arc::new(take_bytes(as_generic_binary_array::<i64>(values), indices)?))
}
DataType::FixedSizeBinary(_) => {
let values = values
Expand Down Expand Up @@ -579,23 +564,23 @@ where
}

/// `take` implementation for string arrays
fn take_string<OffsetSize, IndexType>(
array: &GenericStringArray<OffsetSize>,
fn take_bytes<T, IndexType>(
array: &GenericByteArray<T>,
indices: &PrimitiveArray<IndexType>,
) -> Result<GenericStringArray<OffsetSize>, ArrowError>
) -> Result<GenericByteArray<T>, ArrowError>
where
OffsetSize: Zero + AddAssign + OffsetSizeTrait,
T: ByteArrayType,
IndexType: ArrowPrimitiveType,
IndexType::Native: ToPrimitive,
{
let data_len = indices.len();

let bytes_offset = (data_len + 1) * std::mem::size_of::<OffsetSize>();
let bytes_offset = (data_len + 1) * std::mem::size_of::<T::Offset>();
let mut offsets_buffer = MutableBuffer::from_len_zeroed(bytes_offset);

let offsets = offsets_buffer.typed_data_mut();
let mut values = MutableBuffer::new(0);
let mut length_so_far = OffsetSize::zero();
let mut length_so_far = T::Offset::from_usize(0).unwrap();
offsets[0] = length_so_far;

let nulls;
Expand All @@ -607,8 +592,8 @@ where

let s = array.value(index);

length_so_far += OffsetSize::from_usize(s.len()).unwrap();
values.extend_from_slice(s.as_bytes());
length_so_far += T::Offset::from_usize(s.as_ref().len()).unwrap();
values.extend_from_slice(s.as_ref());
*offset = length_so_far;
}
nulls = None
Expand All @@ -624,10 +609,10 @@ where
})?;

if array.is_valid(index) {
let s = array.value(index);
let s = array.value(index).as_ref();

length_so_far += OffsetSize::from_usize(s.len()).unwrap();
values.extend_from_slice(s.as_bytes());
length_so_far += T::Offset::from_usize(s.len()).unwrap();
values.extend_from_slice(s.as_ref());
} else {
bit_util::unset_bit(null_slice, i);
}
Expand All @@ -642,10 +627,10 @@ where
ArrowError::ComputeError("Cast to usize failed".to_string())
})?;

let s = array.value(index);
let s = array.value(index).as_ref();

length_so_far += OffsetSize::from_usize(s.len()).unwrap();
values.extend_from_slice(s.as_bytes());
length_so_far += T::Offset::from_usize(s.len()).unwrap();
values.extend_from_slice(s);
}
*offset = length_so_far;
}
Expand All @@ -662,10 +647,10 @@ where
})?;

if array.is_valid(index) && indices.is_valid(i) {
let s = array.value(index);
let s = array.value(index).as_ref();

length_so_far += OffsetSize::from_usize(s.len()).unwrap();
values.extend_from_slice(s.as_bytes());
length_so_far += T::Offset::from_usize(s.len()).unwrap();
values.extend_from_slice(s);
} else {
// set null bit
bit_util::unset_bit(null_slice, i);
Expand All @@ -676,15 +661,15 @@ where
nulls = Some(null_buf.into())
}

let array_data = ArrayData::builder(GenericStringArray::<OffsetSize>::DATA_TYPE)
let array_data = ArrayData::builder(T::DATA_TYPE)
.len(data_len)
.add_buffer(offsets_buffer.into())
.add_buffer(values.into())
.null_bit_buffer(nulls);

let array_data = unsafe { array_data.build_unchecked() };

Ok(GenericStringArray::<OffsetSize>::from(array_data))
Ok(GenericByteArray::from(array_data))
}

/// `take` implementation for list arrays
Expand Down Expand Up @@ -781,33 +766,6 @@ where
Ok(FixedSizeListArray::from(list_data))
}

fn take_binary<IndexType, OffsetType>(
values: &GenericBinaryArray<OffsetType>,
indices: &PrimitiveArray<IndexType>,
) -> Result<GenericBinaryArray<OffsetType>, ArrowError>
where
OffsetType: OffsetSizeTrait,
IndexType: ArrowPrimitiveType,
IndexType::Native: ToPrimitive,
{
let data_ref = values.data_ref();
let array_iter = indices
.values()
.iter()
.map(|idx| {
let idx = maybe_usize::<IndexType::Native>(*idx)?;
if data_ref.is_valid(idx) {
Ok(Some(values.value(idx)))
} else {
Ok(None)
}
})
.collect::<Result<Vec<_>, ArrowError>>()?
.into_iter();

Ok(array_iter.collect::<GenericBinaryArray<OffsetType>>())
}

fn take_fixed_size_binary<IndexType>(
values: &FixedSizeBinaryArray,
indices: &PrimitiveArray<IndexType>,
Expand Down

0 comments on commit dbe518c

Please sign in to comment.