Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Row decode cleanups #3180

Merged
merged 2 commits into from Nov 25, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
65 changes: 24 additions & 41 deletions arrow/src/row/fixed.rs
Expand Up @@ -19,8 +19,9 @@ use crate::array::PrimitiveArray;
use crate::compute::SortOptions;
use crate::datatypes::ArrowPrimitiveType;
use crate::row::{null_sentinel, Rows};
use arrow_array::builder::BufferBuilder;
use arrow_array::BooleanArray;
use arrow_buffer::{bit_util, i256, MutableBuffer, ToByteSlice};
use arrow_buffer::{bit_util, i256, ArrowNativeType, Buffer, MutableBuffer};
use arrow_data::{ArrayData, ArrayDataBuilder};
use arrow_schema::DataType;
use half::f16;
Expand Down Expand Up @@ -266,61 +267,43 @@ pub fn decode_bool(rows: &mut [&[u8]], options: SortOptions) -> BooleanArray {
unsafe { BooleanArray::from(builder.build_unchecked()) }
}

fn decode_nulls(rows: &[&[u8]]) -> (usize, Buffer) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Handling the null mask separately has no discernible impact on performance, it is possible it might actually improve it marginally, it reduces code complexity significantly, and will help reuse this logic for the fixed length binary array (PR to follow)

let mut null_count = 0;
let buffer = MutableBuffer::collect_bool(rows.len(), |idx| {
let valid = rows[idx][0] == 1;
null_count += !valid as usize;
valid
})
.into();
(null_count, buffer)
}

/// Decodes a `ArrayData` from rows based on the provided `FixedLengthEncoding` `T`
///
/// # Safety
///
/// `data_type` must be appropriate native type for `T`
unsafe fn decode_fixed<T: FixedLengthEncoding + ToByteSlice>(
unsafe fn decode_fixed<T: FixedLengthEncoding + ArrowNativeType>(
rows: &mut [&[u8]],
data_type: DataType,
options: SortOptions,
) -> ArrayData {
let len = rows.len();

let mut null_count = 0;
let mut nulls = MutableBuffer::new(bit_util::ceil(len, 64) * 8);
let mut values = MutableBuffer::new(std::mem::size_of::<T>() * len);
let mut values = BufferBuilder::<T>::new(len);
let (null_count, nulls) = decode_nulls(rows);

let chunks = len / 64;
let remainder = len % 64;
for chunk in 0..chunks {
let mut null_packed = 0;

for bit_idx in 0..64 {
let i = split_off(&mut rows[bit_idx + chunk * 64], T::ENCODED_LEN);
let null = i[0] == 1;
null_count += !null as usize;
null_packed |= (null as u64) << bit_idx;

let value = T::Encoded::from_slice(&i[1..], options.descending);
values.push(T::decode(value));
}

nulls.push(null_packed);
}

if remainder != 0 {
let mut null_packed = 0;

for bit_idx in 0..remainder {
let i = split_off(&mut rows[bit_idx + chunks * 64], T::ENCODED_LEN);
let null = i[0] == 1;
null_count += !null as usize;
null_packed |= (null as u64) << bit_idx;

let value = T::Encoded::from_slice(&i[1..], options.descending);
values.push(T::decode(value));
}

nulls.push(null_packed);
for row in rows {
let i = split_off(row, T::ENCODED_LEN);
let value = T::Encoded::from_slice(&i[1..], options.descending);
values.append(T::decode(value));
}

let builder = ArrayDataBuilder::new(data_type)
.len(rows.len())
.len(len)
.null_count(null_count)
.add_buffer(values.into())
.null_bit_buffer(Some(nulls.into()));
.add_buffer(values.finish())
.null_bit_buffer(Some(nulls));

// SAFETY: Buffers correct length
builder.build_unchecked()
Expand All @@ -333,7 +316,7 @@ pub fn decode_primitive<T: ArrowPrimitiveType>(
options: SortOptions,
) -> PrimitiveArray<T>
where
T::Native: FixedLengthEncoding + ToByteSlice,
T::Native: FixedLengthEncoding,
{
assert_eq!(
std::mem::discriminant(&T::DATA_TYPE),
Expand Down
71 changes: 15 additions & 56 deletions arrow/src/row/mod.rs
Expand Up @@ -908,11 +908,22 @@ fn encode_column(
}

macro_rules! decode_primitive_helper {
($t:ty, $rows: ident, $data_type:ident, $options:ident) => {
($t:ty, $rows:ident, $data_type:ident, $options:ident) => {
Arc::new(decode_primitive::<$t>($rows, $data_type, $options))
};
}

macro_rules! decode_dictionary_helper {
($t:ty, $interner:ident, $v:ident, $options:ident, $rows:ident) => {
Arc::new(decode_dictionary::<$t>(
$interner.unwrap(),
$v.as_ref(),
$options,
$rows,
)?)
};
}

/// Decodes a the provided `field` from `rows`
///
/// # Safety
Expand All @@ -934,61 +945,9 @@ unsafe fn decode_column(
DataType::LargeBinary => Arc::new(decode_binary::<i64>(rows, options)),
DataType::Utf8 => Arc::new(decode_string::<i32>(rows, options, validate_utf8)),
DataType::LargeUtf8 => Arc::new(decode_string::<i64>(rows, options, validate_utf8)),
DataType::Dictionary(k, v) => match k.as_ref() {
DataType::Int8 => Arc::new(decode_dictionary::<Int8Type>(
interner.unwrap(),
v.as_ref(),
options,
rows,
)?),
DataType::Int16 => Arc::new(decode_dictionary::<Int16Type>(
interner.unwrap(),
v.as_ref(),
options,
rows,
)?),
DataType::Int32 => Arc::new(decode_dictionary::<Int32Type>(
interner.unwrap(),
v.as_ref(),
options,
rows,
)?),
DataType::Int64 => Arc::new(decode_dictionary::<Int64Type>(
interner.unwrap(),
v.as_ref(),
options,
rows,
)?),
DataType::UInt8 => Arc::new(decode_dictionary::<UInt8Type>(
interner.unwrap(),
v.as_ref(),
options,
rows,
)?),
DataType::UInt16 => Arc::new(decode_dictionary::<UInt16Type>(
interner.unwrap(),
v.as_ref(),
options,
rows,
)?),
DataType::UInt32 => Arc::new(decode_dictionary::<UInt32Type>(
interner.unwrap(),
v.as_ref(),
options,
rows,
)?),
DataType::UInt64 => Arc::new(decode_dictionary::<UInt64Type>(
interner.unwrap(),
v.as_ref(),
options,
rows,
)?),
_ => {
return Err(ArrowError::InvalidArgumentError(format!(
"{} is not a valid dictionary key type",
field.data_type
)));
}
DataType::Dictionary(k, v) => downcast_integer! {
k.as_ref() => (decode_dictionary_helper, interner, v, options, rows),
_ => unreachable!()
},
_ => {
return Err(ArrowError::NotYetImplemented(format!(
Expand Down