diff --git a/arrow/src/row/fixed.rs b/arrow/src/row/fixed.rs index 76bf358e7e0..0bad033d9bd 100644 --- a/arrow/src/row/fixed.rs +++ b/arrow/src/row/fixed.rs @@ -19,8 +19,9 @@ use crate::array::PrimitiveArray; use crate::compute::SortOptions; use crate::datatypes::ArrowPrimitiveType; use crate::row::{null_sentinel, Rows}; +use arrow_array::builder::BufferBuilder; use arrow_array::BooleanArray; -use arrow_buffer::{bit_util, i256, MutableBuffer, ToByteSlice}; +use arrow_buffer::{bit_util, i256, ArrowNativeType, Buffer, MutableBuffer}; use arrow_data::{ArrayData, ArrayDataBuilder}; use arrow_schema::DataType; use half::f16; @@ -266,61 +267,43 @@ pub fn decode_bool(rows: &mut [&[u8]], options: SortOptions) -> BooleanArray { unsafe { BooleanArray::from(builder.build_unchecked()) } } +fn decode_nulls(rows: &[&[u8]]) -> (usize, Buffer) { + let mut null_count = 0; + let buffer = MutableBuffer::collect_bool(rows.len(), |idx| { + let valid = rows[idx][0] == 1; + null_count += !valid as usize; + valid + }) + .into(); + (null_count, buffer) +} + /// Decodes a `ArrayData` from rows based on the provided `FixedLengthEncoding` `T` /// /// # Safety /// /// `data_type` must be appropriate native type for `T` -unsafe fn decode_fixed( +unsafe fn decode_fixed( rows: &mut [&[u8]], data_type: DataType, options: SortOptions, ) -> ArrayData { let len = rows.len(); - let mut null_count = 0; - let mut nulls = MutableBuffer::new(bit_util::ceil(len, 64) * 8); - let mut values = MutableBuffer::new(std::mem::size_of::() * len); + let mut values = BufferBuilder::::new(len); + let (null_count, nulls) = decode_nulls(rows); - let chunks = len / 64; - let remainder = len % 64; - for chunk in 0..chunks { - let mut null_packed = 0; - - for bit_idx in 0..64 { - let i = split_off(&mut rows[bit_idx + chunk * 64], T::ENCODED_LEN); - let null = i[0] == 1; - null_count += !null as usize; - null_packed |= (null as u64) << bit_idx; - - let value = T::Encoded::from_slice(&i[1..], options.descending); - values.push(T::decode(value)); - } - - nulls.push(null_packed); - } - - if remainder != 0 { - let mut null_packed = 0; - - for bit_idx in 0..remainder { - let i = split_off(&mut rows[bit_idx + chunks * 64], T::ENCODED_LEN); - let null = i[0] == 1; - null_count += !null as usize; - null_packed |= (null as u64) << bit_idx; - - let value = T::Encoded::from_slice(&i[1..], options.descending); - values.push(T::decode(value)); - } - - nulls.push(null_packed); + for row in rows { + let i = split_off(row, T::ENCODED_LEN); + let value = T::Encoded::from_slice(&i[1..], options.descending); + values.append(T::decode(value)); } let builder = ArrayDataBuilder::new(data_type) - .len(rows.len()) + .len(len) .null_count(null_count) - .add_buffer(values.into()) - .null_bit_buffer(Some(nulls.into())); + .add_buffer(values.finish()) + .null_bit_buffer(Some(nulls)); // SAFETY: Buffers correct length builder.build_unchecked() @@ -333,7 +316,7 @@ pub fn decode_primitive( options: SortOptions, ) -> PrimitiveArray where - T::Native: FixedLengthEncoding + ToByteSlice, + T::Native: FixedLengthEncoding, { assert_eq!( std::mem::discriminant(&T::DATA_TYPE), diff --git a/arrow/src/row/mod.rs b/arrow/src/row/mod.rs index 6ce9f2b12c2..4f48b46cb2a 100644 --- a/arrow/src/row/mod.rs +++ b/arrow/src/row/mod.rs @@ -908,11 +908,22 @@ fn encode_column( } macro_rules! decode_primitive_helper { - ($t:ty, $rows: ident, $data_type:ident, $options:ident) => { + ($t:ty, $rows:ident, $data_type:ident, $options:ident) => { Arc::new(decode_primitive::<$t>($rows, $data_type, $options)) }; } +macro_rules! decode_dictionary_helper { + ($t:ty, $interner:ident, $v:ident, $options:ident, $rows:ident) => { + Arc::new(decode_dictionary::<$t>( + $interner.unwrap(), + $v.as_ref(), + $options, + $rows, + )?) + }; +} + /// Decodes a the provided `field` from `rows` /// /// # Safety @@ -934,61 +945,9 @@ unsafe fn decode_column( DataType::LargeBinary => Arc::new(decode_binary::(rows, options)), DataType::Utf8 => Arc::new(decode_string::(rows, options, validate_utf8)), DataType::LargeUtf8 => Arc::new(decode_string::(rows, options, validate_utf8)), - DataType::Dictionary(k, v) => match k.as_ref() { - DataType::Int8 => Arc::new(decode_dictionary::( - interner.unwrap(), - v.as_ref(), - options, - rows, - )?), - DataType::Int16 => Arc::new(decode_dictionary::( - interner.unwrap(), - v.as_ref(), - options, - rows, - )?), - DataType::Int32 => Arc::new(decode_dictionary::( - interner.unwrap(), - v.as_ref(), - options, - rows, - )?), - DataType::Int64 => Arc::new(decode_dictionary::( - interner.unwrap(), - v.as_ref(), - options, - rows, - )?), - DataType::UInt8 => Arc::new(decode_dictionary::( - interner.unwrap(), - v.as_ref(), - options, - rows, - )?), - DataType::UInt16 => Arc::new(decode_dictionary::( - interner.unwrap(), - v.as_ref(), - options, - rows, - )?), - DataType::UInt32 => Arc::new(decode_dictionary::( - interner.unwrap(), - v.as_ref(), - options, - rows, - )?), - DataType::UInt64 => Arc::new(decode_dictionary::( - interner.unwrap(), - v.as_ref(), - options, - rows, - )?), - _ => { - return Err(ArrowError::InvalidArgumentError(format!( - "{} is not a valid dictionary key type", - field.data_type - ))); - } + DataType::Dictionary(k, v) => downcast_integer! { + k.as_ref() => (decode_dictionary_helper, interner, v, options, rows), + _ => unreachable!() }, _ => { return Err(ArrowError::NotYetImplemented(format!(