diff --git a/arrow/benches/row_format.rs b/arrow/benches/row_format.rs index ec872c12706..ff505781a0a 100644 --- a/arrow/benches/row_format.rs +++ b/arrow/benches/row_format.rs @@ -20,161 +20,95 @@ extern crate criterion; extern crate core; use arrow::array::ArrayRef; -use arrow::datatypes::{DataType, Int64Type, UInt64Type}; +use arrow::datatypes::{Int64Type, UInt64Type}; use arrow::row::{RowConverter, SortField}; use arrow::util::bench_util::{ create_primitive_array, create_string_array_with_len, create_string_dict_array, }; use arrow_array::types::Int32Type; +use arrow_array::Array; use criterion::{black_box, Criterion}; use std::sync::Arc; -fn row_bench(c: &mut Criterion) { - let cols = vec![Arc::new(create_primitive_array::(4096, 0.)) as ArrayRef]; +fn do_bench(c: &mut Criterion, name: &str, cols: Vec) { + let fields: Vec<_> = cols + .iter() + .map(|x| SortField::new(x.data_type().clone())) + .collect(); - c.bench_function("row_batch 4096 u64(0)", |b| { + c.bench_function(&format!("convert_columns {}", name), |b| { b.iter(|| { - let mut converter = RowConverter::new(vec![SortField::new(DataType::UInt64)]); - black_box(converter.convert_columns(&cols)) + let mut converter = RowConverter::new(fields.clone()); + black_box(converter.convert_columns(&cols).unwrap()) }); }); - let cols = vec![Arc::new(create_primitive_array::(4096, 0.)) as ArrayRef]; + let mut converter = RowConverter::new(fields); + let rows = converter.convert_columns(&cols).unwrap(); - c.bench_function("row_batch 4096 i64(0)", |b| { - b.iter(|| { - let mut converter = RowConverter::new(vec![SortField::new(DataType::Int64)]); - black_box(converter.convert_columns(&cols)) - }); + c.bench_function(&format!("convert_rows {}", name), |b| { + b.iter(|| black_box(converter.convert_rows(&rows).unwrap())); }); +} + +fn row_bench(c: &mut Criterion) { + let cols = vec![Arc::new(create_primitive_array::(4096, 0.)) as ArrayRef]; + do_bench(c, "4096 u64(0)", cols); + + let cols = vec![Arc::new(create_primitive_array::(4096, 0.)) as ArrayRef]; + do_bench(c, "4096 i64(0)", cols); let cols = vec![Arc::new(create_string_array_with_len::(4096, 0., 10)) as ArrayRef]; - - c.bench_function("row_batch 4096 string(10, 0)", |b| { - b.iter(|| { - let mut converter = RowConverter::new(vec![SortField::new(DataType::Utf8)]); - black_box(converter.convert_columns(&cols)) - }); - }); + do_bench(c, "4096 string(10, 0)", cols); let cols = vec![Arc::new(create_string_array_with_len::(4096, 0., 30)) as ArrayRef]; - - c.bench_function("row_batch 4096 string(30, 0)", |b| { - b.iter(|| { - let mut converter = RowConverter::new(vec![SortField::new(DataType::Utf8)]); - black_box(converter.convert_columns(&cols)) - }); - }); + do_bench(c, "4096 string(30, 0)", cols); let cols = vec![Arc::new(create_string_array_with_len::(4096, 0., 100)) as ArrayRef]; - - c.bench_function("row_batch 4096 string(100, 0)", |b| { - b.iter(|| { - let mut converter = RowConverter::new(vec![SortField::new(DataType::Utf8)]); - black_box(converter.convert_columns(&cols)) - }); - }); + do_bench(c, "4096 string(100, 0)", cols); let cols = vec![Arc::new(create_string_array_with_len::(4096, 0.5, 100)) as ArrayRef]; - - c.bench_function("row_batch 4096 string(100, 0.5)", |b| { - b.iter(|| { - let mut converter = RowConverter::new(vec![SortField::new(DataType::Utf8)]); - black_box(converter.convert_columns(&cols)) - }); - }); + do_bench(c, "4096 string(100, 0.5)", cols); let cols = vec![Arc::new(create_string_dict_array::(4096, 0., 10)) as ArrayRef]; - - c.bench_function("row_batch 4096 string_dictionary(10, 0)", |b| { - b.iter(|| { - let mut converter = RowConverter::new(vec![SortField::new(DataType::Utf8)]); - black_box(converter.convert_columns(&cols)) - }); - }); + do_bench(c, "4096 string_dictionary(10, 0)", cols); let cols = vec![Arc::new(create_string_dict_array::(4096, 0., 30)) as ArrayRef]; - - c.bench_function("row_batch 4096 string_dictionary(30, 0)", |b| { - b.iter(|| { - let mut converter = RowConverter::new(vec![SortField::new(DataType::Utf8)]); - black_box(converter.convert_columns(&cols)) - }); - }); + do_bench(c, "4096 string_dictionary(30, 0)", cols); let cols = vec![Arc::new(create_string_dict_array::(4096, 0., 100)) as ArrayRef]; - - c.bench_function("row_batch 4096 string_dictionary(100, 0)", |b| { - b.iter(|| { - let mut converter = RowConverter::new(vec![SortField::new(DataType::Utf8)]); - black_box(converter.convert_columns(&cols)) - }); - }); + do_bench(c, "4096 string_dictionary(100, 0)", cols); let cols = vec![Arc::new(create_string_dict_array::(4096, 0.5, 100)) as ArrayRef]; + do_bench(c, "4096 string_dictionary(100, 0.5)", cols); - c.bench_function("row_batch 4096 string_dictionary(100, 0.5)", |b| { - b.iter(|| { - let mut converter = RowConverter::new(vec![SortField::new(DataType::Utf8)]); - black_box(converter.convert_columns(&cols)) - }); - }); - - let cols = [ + let cols = vec![ Arc::new(create_string_array_with_len::(4096, 0.5, 20)) as ArrayRef, Arc::new(create_string_array_with_len::(4096, 0., 30)) as ArrayRef, Arc::new(create_string_array_with_len::(4096, 0., 100)) as ArrayRef, Arc::new(create_primitive_array::(4096, 0.)) as ArrayRef, ]; - - let fields = [ - SortField::new(DataType::Utf8), - SortField::new(DataType::Utf8), - SortField::new(DataType::Utf8), - SortField::new(DataType::Int64), - ]; - - c.bench_function( - "row_batch 4096 string(20, 0.5), string(30, 0), string(100, 0), i64(0)", - |b| { - b.iter(|| { - let mut converter = RowConverter::new(fields.to_vec()); - black_box(converter.convert_columns(&cols)) - }); - }, + do_bench( + c, + "4096 string(20, 0.5), string(30, 0), string(100, 0), i64(0)", + cols, ); - let cols = [ + let cols = vec![ Arc::new(create_string_dict_array::(4096, 0.5, 20)) as ArrayRef, Arc::new(create_string_dict_array::(4096, 0., 30)) as ArrayRef, Arc::new(create_string_dict_array::(4096, 0., 100)) as ArrayRef, Arc::new(create_primitive_array::(4096, 0.)) as ArrayRef, ]; - - let fields = [ - SortField::new(DataType::Utf8), - SortField::new(DataType::Utf8), - SortField::new(DataType::Utf8), - SortField::new(DataType::Int64), - ]; - - c.bench_function( - "row_batch 4096 string_dictionary(20, 0.5), string_dictionary(30, 0), string_dictionary(100, 0), i64(0)", - |b| { - b.iter(|| { - let mut converter = RowConverter::new(fields.to_vec()); - black_box(converter.convert_columns(&cols)) - }); - }, - ); + do_bench(c, "4096 4096 string_dictionary(20, 0.5), string_dictionary(30, 0), string_dictionary(100, 0), i64(0)", cols); } criterion_group!(benches, row_bench); diff --git a/arrow/src/row/dictionary.rs b/arrow/src/row/dictionary.rs new file mode 100644 index 00000000000..4a048fbce86 --- /dev/null +++ b/arrow/src/row/dictionary.rs @@ -0,0 +1,337 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::compute::SortOptions; +use crate::row::fixed::{FixedLengthEncoding, FromSlice, RawDecimal}; +use crate::row::interner::{Interned, OrderPreservingInterner}; +use crate::row::{null_sentinel, Rows}; +use arrow_array::builder::*; +use arrow_array::cast::*; +use arrow_array::types::*; +use arrow_array::*; +use arrow_buffer::{ArrowNativeType, MutableBuffer, ToByteSlice}; +use arrow_data::{ArrayData, ArrayDataBuilder}; +use arrow_schema::{ArrowError, DataType, IntervalUnit, TimeUnit}; +use std::collections::hash_map::Entry; +use std::collections::HashMap; + +/// Computes the dictionary mapping for the given dictionary values +pub fn compute_dictionary_mapping( + interner: &mut OrderPreservingInterner, + values: &ArrayRef, +) -> Result>, ArrowError> { + Ok(downcast_primitive_array! { + values => interner + .intern(values.iter().map(|x| x.map(|x| x.encode()))), + DataType::Binary => { + let iter = as_generic_binary_array::(values).iter(); + interner.intern(iter) + } + DataType::LargeBinary => { + let iter = as_generic_binary_array::(values).iter(); + interner.intern(iter) + } + DataType::Utf8 => { + let iter = as_string_array(values).iter().map(|x| x.map(|x| x.as_bytes())); + interner.intern(iter) + } + DataType::LargeUtf8 => { + let iter = as_largestring_array(values).iter().map(|x| x.map(|x| x.as_bytes())); + interner.intern(iter) + } + t => return Err(ArrowError::NotYetImplemented(format!("dictionary value {} is not supported", t))), + }) +} + +/// Dictionary types are encoded as +/// +/// - single `0_u8` if null +/// - the bytes of the corresponding normalized key including the null terminator +pub fn encode_dictionary( + out: &mut Rows, + column: &DictionaryArray, + normalized_keys: &[Option<&[u8]>], + opts: SortOptions, +) { + for (offset, k) in out.offsets.iter_mut().skip(1).zip(column.keys()) { + match k.and_then(|k| normalized_keys[k.as_usize()]) { + Some(normalized_key) => { + let end_offset = *offset + 1 + normalized_key.len(); + out.buffer[*offset] = 1; + out.buffer[*offset + 1..end_offset].copy_from_slice(normalized_key); + // Negate if descending + if opts.descending { + out.buffer[*offset..end_offset] + .iter_mut() + .for_each(|v| *v = !*v) + } + *offset = end_offset; + } + None => { + out.buffer[*offset] = null_sentinel(opts); + *offset += 1; + } + } + } +} + +/// Decodes a string array from `rows` with the provided `options` +/// +/// # Safety +/// +/// `interner` must contain valid data for the provided `value_type` +pub unsafe fn decode_dictionary( + interner: &OrderPreservingInterner, + value_type: &DataType, + options: SortOptions, + rows: &mut [&[u8]], +) -> Result, ArrowError> { + let len = rows.len(); + let mut dictionary: HashMap = HashMap::with_capacity(len); + + let null_sentinel = null_sentinel(options); + + // If descending, the null terminator will have been negated + let null_terminator = match options.descending { + true => 0xFF, + false => 0_u8, + }; + + let mut null_builder = BooleanBufferBuilder::new(len); + let mut keys = BufferBuilder::::new(len); + let mut values = Vec::with_capacity(len); + let mut null_count = 0; + let mut key_scratch = Vec::new(); + + for row in rows { + if row[0] == null_sentinel { + null_builder.append(false); + null_count += 1; + *row = &row[1..]; + keys.append(K::Native::default()); + continue; + } + + let key_offset = row + .iter() + .skip(1) + .position(|x| *x == null_terminator) + .unwrap(); + + // Extract the normalized key including the null terminator + let key = &row[1..key_offset + 2]; + *row = &row[key_offset + 2..]; + + let interned = match options.descending { + true => { + // If options.descending the normalized key will have been + // negated we must first reverse this + key_scratch.clear(); + key_scratch.extend_from_slice(key); + key_scratch.iter_mut().for_each(|o| *o = !*o); + interner.lookup(&key_scratch).unwrap() + } + false => interner.lookup(key).unwrap(), + }; + + let k = match dictionary.entry(interned) { + Entry::Vacant(v) => { + let k = values.len(); + values.push(interner.value(interned)); + let key = K::Native::from_usize(k) + .ok_or(ArrowError::DictionaryKeyOverflowError)?; + *v.insert(key) + } + Entry::Occupied(o) => *o.get(), + }; + + keys.append(k); + null_builder.append(true); + } + + let child = match &value_type { + DataType::Null => NullArray::new(values.len()).into_data(), + DataType::Boolean => decode_bool(&values), + DataType::Int8 => decode_primitive::(&values), + DataType::Int16 => decode_primitive::(&values), + DataType::Int32 => decode_primitive::(&values), + DataType::Int64 => decode_primitive::(&values), + DataType::UInt8 => decode_primitive::(&values), + DataType::UInt16 => decode_primitive::(&values), + DataType::UInt32 => decode_primitive::(&values), + DataType::UInt64 => decode_primitive::(&values), + DataType::Float16 => decode_primitive::(&values), + DataType::Float32 => decode_primitive::(&values), + DataType::Float64 => decode_primitive::(&values), + DataType::Timestamp(TimeUnit::Second, _) => { + decode_primitive::(&values) + } + DataType::Timestamp(TimeUnit::Millisecond, _) => { + decode_primitive::(&values) + } + DataType::Timestamp(TimeUnit::Microsecond, _) => { + decode_primitive::(&values) + } + DataType::Timestamp(TimeUnit::Nanosecond, _) => { + decode_primitive::(&values) + } + DataType::Date32 => decode_primitive::(&values), + DataType::Date64 => decode_primitive::(&values), + DataType::Time32(t) => match t { + TimeUnit::Second => decode_primitive::(&values), + TimeUnit::Millisecond => decode_primitive::(&values), + _ => unreachable!(), + }, + DataType::Time64(t) => match t { + TimeUnit::Microsecond => decode_primitive::(&values), + TimeUnit::Nanosecond => decode_primitive::(&values), + _ => unreachable!(), + }, + DataType::Duration(TimeUnit::Second) => { + decode_primitive::(&values) + } + DataType::Duration(TimeUnit::Millisecond) => { + decode_primitive::(&values) + } + DataType::Duration(TimeUnit::Microsecond) => { + decode_primitive::(&values) + } + DataType::Duration(TimeUnit::Nanosecond) => { + decode_primitive::(&values) + } + DataType::Interval(IntervalUnit::DayTime) => { + decode_primitive::(&values) + } + DataType::Interval(IntervalUnit::MonthDayNano) => { + decode_primitive::(&values) + } + DataType::Interval(IntervalUnit::YearMonth) => { + decode_primitive::(&values) + } + DataType::Decimal128(p, s) => { + decode_decimal::<16, Decimal128Type>(&values, *p, *s) + } + DataType::Decimal256(p, s) => { + decode_decimal::<32, Decimal256Type>(&values, *p, *s) + } + DataType::Utf8 => decode_string::(&values), + DataType::LargeUtf8 => decode_string::(&values), + DataType::Binary => decode_binary::(&values), + DataType::LargeBinary => decode_binary::(&values), + _ => { + return Err(ArrowError::NotYetImplemented(format!( + "decoding dictionary values of {}", + value_type + ))) + } + }; + + let data_type = + DataType::Dictionary(Box::new(K::DATA_TYPE), Box::new(value_type.clone())); + + let builder = ArrayDataBuilder::new(data_type) + .len(len) + .null_bit_buffer(Some(null_builder.finish())) + .null_count(null_count) + .add_buffer(keys.finish()) + .add_child_data(child); + + Ok(DictionaryArray::from(builder.build_unchecked())) +} + +/// Decodes a binary array from dictionary values +/// +/// # Safety +/// +/// Values must be valid UTF-8 +fn decode_binary(values: &[&[u8]]) -> ArrayData { + let capacity = values.iter().map(|x| x.len()).sum(); + let mut builder = GenericBinaryBuilder::::with_capacity(values.len(), capacity); + for v in values { + builder.append_value(v) + } + builder.finish().into_data() +} + +/// Decodes a string array from dictionary values +/// +/// # Safety +/// +/// Values must be valid UTF-8 +unsafe fn decode_string(values: &[&[u8]]) -> ArrayData { + let d = match O::IS_LARGE { + true => DataType::LargeUtf8, + false => DataType::Utf8, + }; + + decode_binary::(values) + .into_builder() + .data_type(d) + .build_unchecked() +} + +/// Decodes a boolean array from dictionary values +fn decode_bool(values: &[&[u8]]) -> ArrayData { + let mut builder = BooleanBufferBuilder::new(values.len()); + for value in values { + builder.append(bool::decode([value[0]])) + } + + let builder = ArrayDataBuilder::new(DataType::Boolean) + .len(values.len()) + .add_buffer(builder.finish()); + + // SAFETY: Buffers correct length + unsafe { builder.build_unchecked() } +} + +/// Decodes a fixed length type array from dictionary values +fn decode_fixed( + values: &[&[u8]], + data_type: DataType, +) -> ArrayData { + let mut buffer = MutableBuffer::new(std::mem::size_of::() * values.len()); + + for value in values { + let value = T::Encoded::from_slice(value, false); + buffer.push(T::decode(value)) + } + + let builder = ArrayDataBuilder::new(data_type) + .len(values.len()) + .add_buffer(buffer.into()); + + // SAFETY: Buffers correct length + unsafe { builder.build_unchecked() } +} + +/// Decodes a `PrimitiveArray` from dictionary values +fn decode_primitive(values: &[&[u8]]) -> ArrayData +where + T::Native: FixedLengthEncoding, +{ + decode_fixed::(values, T::DATA_TYPE) +} + +/// Decodes a `DecimalArray` from dictionary values +fn decode_decimal( + values: &[&[u8]], + precision: u8, + scale: u8, +) -> ArrayData { + decode_fixed::>(values, T::TYPE_CONSTRUCTOR(precision, scale)) +} diff --git a/arrow/src/row/fixed.rs b/arrow/src/row/fixed.rs index 04b9a30ecad..9952ee094bf 100644 --- a/arrow/src/row/fixed.rs +++ b/arrow/src/row/fixed.rs @@ -18,18 +18,39 @@ use crate::array::PrimitiveArray; use crate::compute::SortOptions; use crate::datatypes::ArrowPrimitiveType; -use crate::row::Rows; -use crate::util::decimal::{Decimal128, Decimal256}; +use crate::row::{null_sentinel, Rows}; +use arrow_array::types::DecimalType; +use arrow_array::{BooleanArray, DecimalArray}; +use arrow_buffer::{bit_util, MutableBuffer, ToByteSlice}; +use arrow_data::{ArrayData, ArrayDataBuilder}; +use arrow_schema::DataType; use half::f16; +pub trait FromSlice { + fn from_slice(slice: &[u8], invert: bool) -> Self; +} + +impl FromSlice for [u8; N] { + #[inline] + fn from_slice(slice: &[u8], invert: bool) -> Self { + let mut t: Self = slice.try_into().unwrap(); + if invert { + t.iter_mut().for_each(|o| *o = !*o); + } + t + } +} + /// Encodes a value of a particular fixed width type into bytes according to the rules /// described on [`super::RowConverter`] pub trait FixedLengthEncoding: Copy { const ENCODED_LEN: usize = 1 + std::mem::size_of::(); - type Encoded: Sized + Copy + AsRef<[u8]> + AsMut<[u8]>; + type Encoded: Sized + Copy + FromSlice + AsRef<[u8]> + AsMut<[u8]>; fn encode(self) -> Self::Encoded; + + fn decode(encoded: Self::Encoded) -> Self; } impl FixedLengthEncoding for bool { @@ -38,6 +59,10 @@ impl FixedLengthEncoding for bool { fn encode(self) -> [u8; 1] { [self as u8] } + + fn decode(encoded: Self::Encoded) -> Self { + encoded[0] != 0 + } } macro_rules! encode_signed { @@ -51,6 +76,12 @@ macro_rules! encode_signed { b[0] ^= 0x80; b } + + fn decode(mut encoded: Self::Encoded) -> Self { + // Toggle top "sign" bit + encoded[0] ^= 0x80; + Self::from_be_bytes(encoded) + } } }; } @@ -69,6 +100,10 @@ macro_rules! encode_unsigned { fn encode(self) -> [u8; $n] { self.to_be_bytes() } + + fn decode(encoded: Self::Encoded) -> Self { + Self::from_be_bytes(encoded) + } } }; } @@ -87,6 +122,12 @@ impl FixedLengthEncoding for f16 { let val = s ^ (((s >> 15) as u16) >> 1) as i16; val.encode() } + + fn decode(encoded: Self::Encoded) -> Self { + let bits = i16::decode(encoded); + let val = bits ^ (((bits >> 15) as u16) >> 1) as i16; + Self::from_bits(val as u16) + } } impl FixedLengthEncoding for f32 { @@ -98,6 +139,12 @@ impl FixedLengthEncoding for f32 { let val = s ^ (((s >> 31) as u32) >> 1) as i32; val.encode() } + + fn decode(encoded: Self::Encoded) -> Self { + let bits = i32::decode(encoded); + let val = bits ^ (((bits >> 31) as u32) >> 1) as i32; + Self::from_bits(val as u32) + } } impl FixedLengthEncoding for f64 { @@ -109,32 +156,44 @@ impl FixedLengthEncoding for f64 { let val = s ^ (((s >> 63) as u64) >> 1) as i64; val.encode() } + + fn decode(encoded: Self::Encoded) -> Self { + let bits = i64::decode(encoded); + let val = bits ^ (((bits >> 63) as u64) >> 1) as i64; + Self::from_bits(val as u64) + } } -impl FixedLengthEncoding for Decimal128 { - type Encoded = [u8; 16]; +pub type RawDecimal128 = RawDecimal<16>; +pub type RawDecimal256 = RawDecimal<32>; - fn encode(self) -> [u8; 16] { - let mut val = *self.raw_value(); - // Convert to big endian representation - val.reverse(); - // Toggle top "sign" bit to ensure consistent sort order - val[0] ^= 0x80; - val +/// The raw bytes of a decimal +#[derive(Copy, Clone)] +pub struct RawDecimal(pub [u8; N]); + +impl ToByteSlice for RawDecimal { + fn to_byte_slice(&self) -> &[u8] { + &self.0 } } -impl FixedLengthEncoding for Decimal256 { - type Encoded = [u8; 32]; +impl FixedLengthEncoding for RawDecimal { + type Encoded = [u8; N]; - fn encode(self) -> [u8; 32] { - let mut val = *self.raw_value(); + fn encode(self) -> [u8; N] { + let mut val = self.0; // Convert to big endian representation val.reverse(); // Toggle top "sign" bit to ensure consistent sort order val[0] ^= 0x80; val } + + fn decode(mut encoded: Self::Encoded) -> Self { + encoded[0] ^= 0x80; + encoded.reverse(); + Self(encoded) + } } /// Returns the total encoded length (including null byte) for a value of type `T::Native` @@ -166,9 +225,153 @@ pub fn encode>>( encoded.as_mut().iter_mut().for_each(|v| *v = !*v) } to_write[1..].copy_from_slice(encoded.as_ref()) - } else if !opts.nulls_first { - out.buffer[*offset] = 0xFF; + } else { + out.buffer[*offset] = null_sentinel(opts); } *offset = end_offset; } } + +/// Splits `len` bytes from `src` +#[inline] +fn split_off<'a>(src: &mut &'a [u8], len: usize) -> &'a [u8] { + let v = &src[..len]; + *src = &src[len..]; + v +} + +/// Decodes a `BooleanArray` from rows +pub fn decode_bool(rows: &mut [&[u8]], options: SortOptions) -> BooleanArray { + let true_val = match options.descending { + true => !1, + false => 1, + }; + + let len = rows.len(); + + let mut null_count = 0; + let mut nulls = MutableBuffer::new(bit_util::ceil(len, 64) * 8); + let mut values = MutableBuffer::new(bit_util::ceil(len, 64) * 8); + + let chunks = len / 64; + let remainder = len % 64; + for chunk in 0..chunks { + let mut null_packed = 0; + let mut values_packed = 0; + + for bit_idx in 0..64 { + let i = split_off(&mut rows[bit_idx + chunk * 64], 2); + let (null, value) = (i[0] == 1, i[1] == true_val); + null_count += !null as usize; + null_packed |= (null as u64) << bit_idx; + values_packed |= (value as u64) << bit_idx; + } + + nulls.push(null_packed); + values.push(values_packed); + } + + if remainder != 0 { + let mut null_packed = 0; + let mut values_packed = 0; + + for bit_idx in 0..remainder { + let i = split_off(&mut rows[bit_idx + chunks * 64], 2); + let (null, value) = (i[0] == 1, i[1] == true_val); + null_count += !null as usize; + null_packed |= (null as u64) << bit_idx; + values_packed |= (value as u64) << bit_idx; + } + + nulls.push(null_packed); + values.push(values_packed); + } + + let builder = ArrayDataBuilder::new(DataType::Boolean) + .len(rows.len()) + .null_count(null_count) + .add_buffer(values.into()) + .null_bit_buffer(Some(nulls.into())); + + // SAFETY: + // Buffers are the correct length + unsafe { BooleanArray::from(builder.build_unchecked()) } +} + +/// Decodes a `ArrayData` from rows based on the provided `FixedLengthEncoding` `T` +fn decode_fixed( + rows: &mut [&[u8]], + data_type: DataType, + options: SortOptions, +) -> ArrayData { + let len = rows.len(); + + let mut null_count = 0; + let mut nulls = MutableBuffer::new(bit_util::ceil(len, 64) * 8); + let mut values = MutableBuffer::new(std::mem::size_of::() * len); + + let chunks = len / 64; + let remainder = len % 64; + for chunk in 0..chunks { + let mut null_packed = 0; + + for bit_idx in 0..64 { + let i = split_off(&mut rows[bit_idx + chunk * 64], T::ENCODED_LEN); + let null = i[0] == 1; + null_count += !null as usize; + null_packed |= (null as u64) << bit_idx; + + let value = T::Encoded::from_slice(&i[1..], options.descending); + values.push(T::decode(value)); + } + + nulls.push(null_packed); + } + + if remainder != 0 { + let mut null_packed = 0; + + for bit_idx in 0..remainder { + let i = split_off(&mut rows[bit_idx + chunks * 64], T::ENCODED_LEN); + let null = i[0] == 1; + null_count += !null as usize; + null_packed |= (null as u64) << bit_idx; + + let value = T::Encoded::from_slice(&i[1..], options.descending); + values.push(T::decode(value)); + } + + nulls.push(null_packed); + } + + let builder = ArrayDataBuilder::new(data_type) + .len(rows.len()) + .null_count(null_count) + .add_buffer(values.into()) + .null_bit_buffer(Some(nulls.into())); + + // SAFETY: Buffers correct length + unsafe { builder.build_unchecked() } +} + +/// Decodes a `DecimalArray` from rows +pub fn decode_decimal( + rows: &mut [&[u8]], + options: SortOptions, + precision: u8, + scale: u8, +) -> DecimalArray { + decode_fixed::>(rows, T::TYPE_CONSTRUCTOR(precision, scale), options) + .into() +} + +/// Decodes a `PrimitiveArray` from rows +pub fn decode_primitive( + rows: &mut [&[u8]], + options: SortOptions, +) -> PrimitiveArray +where + T::Native: FixedLengthEncoding + ToByteSlice, +{ + decode_fixed::(rows, T::DATA_TYPE, options).into() +} diff --git a/arrow/src/row/interner.rs b/arrow/src/row/interner.rs index e48670984c9..156d23465bf 100644 --- a/arrow/src/row/interner.rs +++ b/arrow/src/row/interner.rs @@ -22,7 +22,7 @@ use std::num::NonZeroU32; use std::ops::Index; /// An interned value -#[derive(Debug, Copy, Clone, PartialEq, Eq)] +#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] pub struct Interned(NonZeroU32); // We use NonZeroU32 so that `Option` is 32 bits /// A byte array interner that generates normalized keys that are sorted with respect @@ -132,7 +132,6 @@ impl OrderPreservingInterner { /// Converts a normalized key returned by [`Self::normalized_key`] to [`Interned`] /// returning `None` if it cannot be found - #[allow(dead_code)] pub fn lookup(&self, normalized_key: &[u8]) -> Option { let len = normalized_key.len(); @@ -159,7 +158,6 @@ impl OrderPreservingInterner { } /// Returns the interned value for a given [`Interned`] - #[allow(dead_code)] pub fn value(&self, key: Interned) -> &[u8] { self.values.index(key) } diff --git a/arrow/src/row/mod.rs b/arrow/src/row/mod.rs index 88c8a916663..f5ac570320b 100644 --- a/arrow/src/row/mod.rs +++ b/arrow/src/row/mod.rs @@ -17,17 +17,28 @@ //! A comparable row-oriented representation of a collection of [`Array`] -use crate::array::{ - as_boolean_array, as_generic_binary_array, as_largestring_array, as_string_array, - Array, ArrayRef, Decimal128Array, Decimal256Array, -}; +use std::cmp::Ordering; +use std::hash::{Hash, Hasher}; +use std::sync::Arc; + +use arrow_array::cast::*; +use arrow_array::*; + use crate::compute::SortOptions; use crate::datatypes::*; use crate::error::{ArrowError, Result}; -use crate::row::interner::{Interned, OrderPreservingInterner}; -use crate::util::decimal::{Decimal128, Decimal256}; +use crate::row::dictionary::{ + compute_dictionary_mapping, decode_dictionary, encode_dictionary, +}; +use crate::row::fixed::{ + decode_bool, decode_decimal, decode_primitive, RawDecimal, RawDecimal128, + RawDecimal256, +}; +use crate::row::interner::OrderPreservingInterner; +use crate::row::variable::{decode_binary, decode_string}; use crate::{downcast_dictionary_array, downcast_primitive_array}; +mod dictionary; mod fixed; mod interner; mod variable; @@ -134,13 +145,13 @@ mod variable; /// [byte stuffing]:[https://en.wikipedia.org/wiki/High-Level_Data_Link_Control#Asynchronous_framing] #[derive(Debug)] pub struct RowConverter { - fields: Vec, + fields: Arc<[SortField]>, /// interning state for column `i`, if column`i` is a dictionary interners: Vec>>, } /// Configure the data type and sort order for a given column -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq, Eq)] pub struct SortField { /// Sort options options: SortOptions, @@ -164,7 +175,10 @@ impl RowConverter { /// Create a new [`RowConverter`] with the provided schema pub fn new(fields: Vec) -> Self { let interners = (0..fields.len()).map(|_| None).collect(); - Self { fields, interners } + Self { + fields: fields.into(), + interners, + } } /// Convert [`ArrayRef`] columns into [`Rows`] @@ -186,7 +200,7 @@ impl RowConverter { let dictionaries = columns .iter() .zip(&mut self.interners) - .zip(&self.fields) + .zip(self.fields.iter()) .map(|((column, interner), field)| { if !column.data_type().equals_datatype(&field.data_type) { return Err(ArrowError::InvalidArgumentError(format!( @@ -214,10 +228,10 @@ impl RowConverter { }) .collect::>>()?; - let mut rows = new_empty_rows(columns, &dictionaries)?; + let mut rows = new_empty_rows(columns, &dictionaries, Arc::clone(&self.fields))?; for ((column, field), dictionary) in - columns.iter().zip(&self.fields).zip(dictionaries) + columns.iter().zip(self.fields.iter()).zip(dictionaries) { // We encode a column at a time to minimise dispatch overheads encode_column(&mut rows, column, field.options, dictionary.as_deref()) @@ -227,11 +241,44 @@ impl RowConverter { assert_eq!(*rows.offsets.last().unwrap(), rows.buffer.len()); rows.offsets .windows(2) - .for_each(|w| assert!(w[0] < w[1], "offsets should be monotonic")); + .for_each(|w| assert!(w[0] <= w[1], "offsets should be monotonic")); } Ok(rows) } + + /// Convert [`Rows`] columns into [`ArrayRef`] + /// + /// # Panics + /// + /// Panics if the rows were not produced by this [`RowConverter`] + pub fn convert_rows<'a, I>(&self, rows: I) -> Result> + where + I: IntoIterator>, + { + let mut rows: Vec<_> = rows + .into_iter() + .map(|row| { + assert!( + Arc::ptr_eq(row.fields, &self.fields), + "rows were not produced by this RowConverter" + ); + + row.data + }) + .collect(); + + self.fields + .iter() + .zip(&self.interners) + .map(|(field, interner)| { + // SAFETY + // We have validated that the rows came from this [`RowConverter`] + // and therefore must be valid + unsafe { decode_column(field, &mut rows, interner.as_deref()) } + }) + .collect() + } } /// A row-oriented representation of arrow data, that is normalized for comparison @@ -243,13 +290,18 @@ pub struct Rows { buffer: Box<[u8]>, /// Row `i` has data `&buffer[offsets[i]..offsets[i+1]]` offsets: Box<[usize]>, + /// The schema for these rows + fields: Arc<[SortField]>, } impl Rows { pub fn row(&self, row: usize) -> Row<'_> { let end = self.offsets[row + 1]; let start = self.offsets[row]; - Row(&self.buffer[start..end]) + Row { + data: &self.buffer[start..end], + fields: &self.fields, + } } pub fn num_rows(&self) -> usize { @@ -257,54 +309,127 @@ impl Rows { } } +impl<'a> IntoIterator for &'a Rows { + type Item = Row<'a>; + type IntoIter = RowsIter<'a>; + + fn into_iter(self) -> Self::IntoIter { + RowsIter { + rows: self, + start: 0, + end: self.num_rows(), + } + } +} + +/// An iterator over [`Rows`] +#[derive(Debug)] +pub struct RowsIter<'a> { + rows: &'a Rows, + start: usize, + end: usize, +} + +impl<'a> Iterator for RowsIter<'a> { + type Item = Row<'a>; + + fn next(&mut self) -> Option { + if self.end == self.start { + return None; + } + let row = self.rows.row(self.start); + self.start += 1; + Some(row) + } + + fn size_hint(&self) -> (usize, Option) { + let len = self.len(); + (len, Some(len)) + } +} + +impl<'a> ExactSizeIterator for RowsIter<'a> { + fn len(&self) -> usize { + self.end - self.start + } +} + +impl<'a> DoubleEndedIterator for RowsIter<'a> { + fn next_back(&mut self) -> Option { + if self.end == self.start { + return None; + } + let row = self.rows.row(self.end); + self.end -= 1; + Some(row) + } +} + /// A comparable representation of a row /// /// Two [`Row`] can be compared if they both belong to [`Rows`] returned by calls to /// [`RowConverter::convert_columns`] on the same [`RowConverter`] /// /// Otherwise any ordering established by comparing the [`Row`] is arbitrary -#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] -pub struct Row<'a>(&'a [u8]); +#[derive(Debug, Copy, Clone)] +pub struct Row<'a> { + data: &'a [u8], + fields: &'a Arc<[SortField]>, +} + +// Manually derive these as don't wish to include `fields` + +impl<'a> PartialEq for Row<'a> { + #[inline] + fn eq(&self, other: &Self) -> bool { + self.data.eq(other.data) + } +} + +impl<'a> Eq for Row<'a> {} + +impl<'a> PartialOrd for Row<'a> { + #[inline] + fn partial_cmp(&self, other: &Self) -> Option { + self.data.partial_cmp(other.data) + } +} + +impl<'a> Ord for Row<'a> { + #[inline] + fn cmp(&self, other: &Self) -> Ordering { + self.data.cmp(other.data) + } +} + +impl<'a> Hash for Row<'a> { + #[inline] + fn hash(&self, state: &mut H) { + self.data.hash(state) + } +} impl<'a> AsRef<[u8]> for Row<'a> { + #[inline] fn as_ref(&self) -> &[u8] { - self.0 + self.data } } -/// Computes the dictionary mapping for the given dictionary values -fn compute_dictionary_mapping( - interner: &mut OrderPreservingInterner, - values: &ArrayRef, -) -> Result>> { - use fixed::FixedLengthEncoding; - Ok(downcast_primitive_array! { - values => interner - .intern(values.iter().map(|x| x.map(|x| x.encode()))), - DataType::Binary => { - let iter = as_generic_binary_array::(values).iter(); - interner.intern(iter) - } - DataType::LargeBinary => { - let iter = as_generic_binary_array::(values).iter(); - interner.intern(iter) - } - DataType::Utf8 => { - let iter = as_string_array(values).iter().map(|x| x.map(|x| x.as_bytes())); - interner.intern(iter) - } - DataType::LargeUtf8 => { - let iter = as_largestring_array(values).iter().map(|x| x.map(|x| x.as_bytes())); - interner.intern(iter) - } - t => return Err(ArrowError::NotYetImplemented(format!("dictionary value {} is not supported", t))), - }) +/// Returns the null sentinel, negated if `invert` is true +#[inline] +fn null_sentinel(options: SortOptions) -> u8 { + match options.nulls_first { + true => 0, + false => 0xFF, + } } /// Computes the length of each encoded [`Rows`] and returns an empty [`Rows`] fn new_empty_rows( cols: &[ArrayRef], dictionaries: &[Option>>], + fields: Arc<[SortField]>, ) -> Result { use fixed::FixedLengthEncoding; @@ -314,10 +439,10 @@ fn new_empty_rows( for (array, dict) in cols.iter().zip(dictionaries) { downcast_primitive_array! { array => lengths.iter_mut().for_each(|x| *x += fixed::encoded_len(array)), - DataType::Null => lengths.iter_mut().for_each(|x| *x += 1), + DataType::Null => {}, DataType::Boolean => lengths.iter_mut().for_each(|x| *x += bool::ENCODED_LEN), - DataType::Decimal128(_, _) => lengths.iter_mut().for_each(|x| *x += Decimal128::ENCODED_LEN), - DataType::Decimal256(_, _) => lengths.iter_mut().for_each(|x| *x += Decimal256::ENCODED_LEN), + DataType::Decimal128(_, _) => lengths.iter_mut().for_each(|x| *x += RawDecimal128::ENCODED_LEN), + DataType::Decimal256(_, _) => lengths.iter_mut().for_each(|x| *x += RawDecimal256::ENCODED_LEN), DataType::Binary => as_generic_binary_array::(array) .iter() .zip(lengths.iter_mut()) @@ -383,6 +508,7 @@ fn new_empty_rows( Ok(Rows { buffer: buffer.into(), offsets: offsets.into(), + fields, }) } @@ -395,20 +521,28 @@ fn encode_column( ) { downcast_primitive_array! { column => fixed::encode(out, column, opts), - DataType::Null => { - fixed::encode(out, std::iter::repeat(None::).take(column.len()), opts) - } + DataType::Null => {} DataType::Boolean => fixed::encode(out, as_boolean_array(column), opts), - DataType::Decimal128(_, _) => fixed::encode( - out, - column.as_any().downcast_ref::().unwrap(), - opts, - ), - DataType::Decimal256(_, _) => fixed::encode( - out, - column.as_any().downcast_ref::().unwrap(), - opts, - ), + DataType::Decimal128(_, _) => { + let iter = column + .as_any() + .downcast_ref::() + .unwrap() + .into_iter() + .map(|x| x.map(|x| RawDecimal(*x.raw_value()))); + + fixed::encode(out, iter, opts) + }, + DataType::Decimal256(_, _) => { + let iter = column + .as_any() + .downcast_ref::() + .unwrap() + .into_iter() + .map(|x| x.map(|x| RawDecimal(*x.raw_value()))); + + fixed::encode(out, iter, opts) + }, DataType::Binary => { variable::encode(out, as_generic_binary_array::(column).iter(), opts) } @@ -428,37 +562,183 @@ fn encode_column( opts, ), DataType::Dictionary(_, _) => downcast_dictionary_array! { - column => { - let dict = dictionary.unwrap(); - for (offset, k) in out.offsets.iter_mut().skip(1).zip(column.keys()) { - match k.and_then(|k| dict[k as usize]) { - Some(v) => { - let end_offset = *offset + 1 + v.len(); - out.buffer[*offset] = 1; - out.buffer[*offset+1..end_offset].copy_from_slice(v); - if opts.descending { - out.buffer[*offset..end_offset].iter_mut().for_each(|v| *v = !*v) - } - *offset = end_offset; - } - None => { - if !opts.nulls_first { - out.buffer[*offset] = 0xFF; - } - *offset += 1; - } - } - } - }, + column => encode_dictionary(out, column, dictionary.unwrap(), opts), _ => unreachable!() } t => unimplemented!("not yet implemented: {}", t) } } +/// Decodes a the provided `field` from `rows` +/// +/// # Safety +/// +/// Rows must contain valid data for the provided field +unsafe fn decode_column( + field: &SortField, + rows: &mut [&[u8]], + interner: Option<&OrderPreservingInterner>, +) -> Result { + let options = field.options; + let array: ArrayRef = match &field.data_type { + DataType::Null => Arc::new(NullArray::new(rows.len())), + DataType::Boolean => Arc::new(decode_bool(rows, options)), + DataType::Int8 => Arc::new(decode_primitive::(rows, options)), + DataType::Int16 => Arc::new(decode_primitive::(rows, options)), + DataType::Int32 => Arc::new(decode_primitive::(rows, options)), + DataType::Int64 => Arc::new(decode_primitive::(rows, options)), + DataType::UInt8 => Arc::new(decode_primitive::(rows, options)), + DataType::UInt16 => Arc::new(decode_primitive::(rows, options)), + DataType::UInt32 => Arc::new(decode_primitive::(rows, options)), + DataType::UInt64 => Arc::new(decode_primitive::(rows, options)), + DataType::Float16 => Arc::new(decode_primitive::(rows, options)), + DataType::Float32 => Arc::new(decode_primitive::(rows, options)), + DataType::Float64 => Arc::new(decode_primitive::(rows, options)), + DataType::Timestamp(TimeUnit::Second, _) => { + Arc::new(decode_primitive::(rows, options)) + } + DataType::Timestamp(TimeUnit::Millisecond, _) => { + Arc::new(decode_primitive::(rows, options)) + } + DataType::Timestamp(TimeUnit::Microsecond, _) => { + Arc::new(decode_primitive::(rows, options)) + } + DataType::Timestamp(TimeUnit::Nanosecond, _) => { + Arc::new(decode_primitive::(rows, options)) + } + DataType::Date32 => Arc::new(decode_primitive::(rows, options)), + DataType::Date64 => Arc::new(decode_primitive::(rows, options)), + DataType::Time32(t) => match t { + TimeUnit::Second => { + Arc::new(decode_primitive::(rows, options)) + } + TimeUnit::Millisecond => { + Arc::new(decode_primitive::(rows, options)) + } + _ => unreachable!(), + }, + DataType::Time64(t) => match t { + TimeUnit::Microsecond => { + Arc::new(decode_primitive::(rows, options)) + } + TimeUnit::Nanosecond => { + Arc::new(decode_primitive::(rows, options)) + } + _ => unreachable!(), + }, + DataType::Duration(TimeUnit::Second) => { + Arc::new(decode_primitive::(rows, options)) + } + DataType::Duration(TimeUnit::Millisecond) => { + Arc::new(decode_primitive::(rows, options)) + } + DataType::Duration(TimeUnit::Microsecond) => { + Arc::new(decode_primitive::(rows, options)) + } + DataType::Duration(TimeUnit::Nanosecond) => { + Arc::new(decode_primitive::(rows, options)) + } + DataType::Interval(IntervalUnit::DayTime) => { + Arc::new(decode_primitive::(rows, options)) + } + DataType::Interval(IntervalUnit::MonthDayNano) => { + Arc::new(decode_primitive::(rows, options)) + } + DataType::Interval(IntervalUnit::YearMonth) => { + Arc::new(decode_primitive::(rows, options)) + } + DataType::Binary => Arc::new(decode_binary::(rows, options)), + DataType::LargeBinary => Arc::new(decode_binary::(rows, options)), + DataType::Utf8 => Arc::new(decode_string::(rows, options)), + DataType::LargeUtf8 => Arc::new(decode_string::(rows, options)), + DataType::Decimal128(p, s) => { + Arc::new(decode_decimal::<16, Decimal128Type>(rows, options, *p, *s)) + } + DataType::Decimal256(p, s) => { + Arc::new(decode_decimal::<32, Decimal256Type>(rows, options, *p, *s)) + } + DataType::Dictionary(k, v) => match k.as_ref() { + DataType::Int8 => Arc::new(decode_dictionary::( + interner.unwrap(), + v.as_ref(), + options, + rows, + )?), + DataType::Int16 => Arc::new(decode_dictionary::( + interner.unwrap(), + v.as_ref(), + options, + rows, + )?), + DataType::Int32 => Arc::new(decode_dictionary::( + interner.unwrap(), + v.as_ref(), + options, + rows, + )?), + DataType::Int64 => Arc::new(decode_dictionary::( + interner.unwrap(), + v.as_ref(), + options, + rows, + )?), + DataType::UInt8 => Arc::new(decode_dictionary::( + interner.unwrap(), + v.as_ref(), + options, + rows, + )?), + DataType::UInt16 => Arc::new(decode_dictionary::( + interner.unwrap(), + v.as_ref(), + options, + rows, + )?), + DataType::UInt32 => Arc::new(decode_dictionary::( + interner.unwrap(), + v.as_ref(), + options, + rows, + )?), + DataType::UInt64 => Arc::new(decode_dictionary::( + interner.unwrap(), + v.as_ref(), + options, + rows, + )?), + _ => { + return Err(ArrowError::InvalidArgumentError(format!( + "{} is not a valid dictionary key type", + field.data_type + ))); + } + }, + DataType::FixedSizeBinary(_) + | DataType::List(_) + | DataType::FixedSizeList(_, _) + | DataType::LargeList(_) + | DataType::Struct(_) + | DataType::Union(_, _, _) + | DataType::Map(_, _) => { + return Err(ArrowError::NotYetImplemented(format!( + "converting {} row is not supported", + field.data_type + ))) + } + }; + Ok(array) +} + #[cfg(test)] mod tests { - use super::*; + use std::sync::Arc; + + use rand::distributions::uniform::SampleUniform; + use rand::distributions::{Distribution, Standard}; + use rand::{thread_rng, Rng}; + + use arrow_array::NullArray; + use crate::array::{ BinaryArray, BooleanArray, DictionaryArray, Float32Array, GenericStringArray, Int16Array, Int32Array, OffsetSizeTrait, PrimitiveArray, @@ -466,10 +746,8 @@ mod tests { }; use crate::compute::{LexicographicalComparator, SortColumn}; use crate::util::display::array_value_to_string; - use rand::distributions::uniform::SampleUniform; - use rand::distributions::{Distribution, Standard}; - use rand::{thread_rng, Rng}; - use std::sync::Arc; + + use super::*; #[test] fn test_fixed_width() { @@ -525,19 +803,29 @@ mod tests { assert!(rows.row(0) < rows.row(1)); assert!(rows.row(3) < rows.row(0)); assert!(rows.row(4) < rows.row(1)); - assert!(rows.row(5) < rows.row(4)) + assert!(rows.row(5) < rows.row(4)); + + let back = converter.convert_rows(&rows).unwrap(); + for (expected, actual) in cols.iter().zip(&back) { + assert_eq!(expected, actual); + } } #[test] fn test_bool() { let mut converter = RowConverter::new(vec![SortField::new(DataType::Boolean)]); - let col = Arc::new(BooleanArray::from_iter([None, Some(false), Some(true)])); - let rows = converter.convert_columns(&[col]).unwrap(); + let col = Arc::new(BooleanArray::from_iter([None, Some(false), Some(true)])) + as ArrayRef; + + let rows = converter.convert_columns(&[Arc::clone(&col)]).unwrap(); assert!(rows.row(2) > rows.row(1)); assert!(rows.row(2) > rows.row(0)); assert!(rows.row(1) > rows.row(0)); + let cols = converter.convert_rows(&rows).unwrap(); + assert_eq!(&cols[0], &col); + let mut converter = RowConverter::new(vec![SortField::new_with_options( DataType::Boolean, SortOptions { @@ -546,11 +834,21 @@ mod tests { }, )]); - let col = Arc::new(BooleanArray::from_iter([None, Some(false), Some(true)])); - let rows = converter.convert_columns(&[col]).unwrap(); + let rows = converter.convert_columns(&[Arc::clone(&col)]).unwrap(); assert!(rows.row(2) < rows.row(1)); assert!(rows.row(2) < rows.row(0)); assert!(rows.row(1) < rows.row(0)); + let cols = converter.convert_rows(&rows).unwrap(); + assert_eq!(&cols[0], &col); + } + + #[test] + fn test_null_encoding() { + let col = Arc::new(NullArray::new(10)); + let mut converter = RowConverter::new(vec![SortField::new(DataType::Null)]); + let rows = converter.convert_columns(&[col]).unwrap(); + assert_eq!(rows.num_rows(), 10); + assert_eq!(rows.row(1).data.len(), 0); } #[test] @@ -561,16 +859,19 @@ mod tests { None, Some("foo"), Some(""), - ])); + ])) as ArrayRef; let mut converter = RowConverter::new(vec![SortField::new(DataType::Utf8)]); - let rows = converter.convert_columns(&[col]).unwrap(); + let rows = converter.convert_columns(&[Arc::clone(&col)]).unwrap(); assert!(rows.row(1) < rows.row(0)); assert!(rows.row(2) < rows.row(4)); assert!(rows.row(3) < rows.row(0)); assert!(rows.row(3) < rows.row(1)); + let cols = converter.convert_rows(&rows).unwrap(); + assert_eq!(&cols[0], &col); + let col = Arc::new(BinaryArray::from_iter([ None, Some(vec![0_u8; 0]), @@ -601,6 +902,9 @@ mod tests { } } + let cols = converter.convert_rows(&rows).unwrap(); + assert_eq!(&cols[0], &col); + let mut converter = RowConverter::new(vec![SortField::new_with_options( DataType::Binary, SortOptions { @@ -608,7 +912,7 @@ mod tests { nulls_first: false, }, )]); - let rows = converter.convert_columns(&[col]).unwrap(); + let rows = converter.convert_columns(&[Arc::clone(&col)]).unwrap(); for i in 0..rows.num_rows() { for j in i + 1..rows.num_rows() { @@ -622,6 +926,9 @@ mod tests { ); } } + + let cols = converter.convert_rows(&rows).unwrap(); + assert_eq!(&cols[0], &col); } #[test] @@ -650,17 +957,23 @@ mod tests { assert_eq!(rows_a.row(1), rows_a.row(6)); assert_eq!(rows_a.row(1), rows_a.row(7)); + let cols = converter.convert_rows(&rows_a).unwrap(); + assert_eq!(&cols[0], &a); + let b = Arc::new(DictionaryArray::::from_iter([ Some("hello"), None, Some("cupcakes"), - ])); + ])) as ArrayRef; - let rows_b = converter.convert_columns(&[b]).unwrap(); + let rows_b = converter.convert_columns(&[Arc::clone(&b)]).unwrap(); assert_eq!(rows_a.row(1), rows_b.row(0)); assert_eq!(rows_a.row(3), rows_b.row(1)); assert!(rows_b.row(2) < rows_a.row(0)); + let cols = converter.convert_rows(&rows_b).unwrap(); + assert_eq!(&cols[0], &b); + let mut converter = RowConverter::new(vec![SortField::new_with_options( a.data_type().clone(), SortOptions { @@ -669,11 +982,14 @@ mod tests { }, )]); - let rows_c = converter.convert_columns(&[a]).unwrap(); + let rows_c = converter.convert_columns(&[Arc::clone(&a)]).unwrap(); assert!(rows_c.row(3) > rows_c.row(5)); assert!(rows_c.row(2) > rows_c.row(1)); assert!(rows_c.row(0) > rows_c.row(1)); assert!(rows_c.row(3) > rows_c.row(0)); + + let cols = converter.convert_rows(&rows_c).unwrap(); + assert_eq!(&cols[0], &a); } #[test] @@ -727,6 +1043,17 @@ mod tests { assert!(rows.row(3) < rows.row(0)); } + #[test] + #[should_panic(expected = "rows were not produced by this RowConverter")] + fn test_different_converter() { + let values = Arc::new(Int32Array::from_iter([Some(1), Some(-1)])); + let mut converter = RowConverter::new(vec![SortField::new(DataType::Int32)]); + let rows = converter.convert_columns(&[values]).unwrap(); + + let converter = RowConverter::new(vec![SortField::new(DataType::Int32)]); + let _ = converter.convert_rows(&rows); + } + fn generate_primitive_array(len: usize, valid_percent: f64) -> PrimitiveArray where K: ArrowPrimitiveType, @@ -888,6 +1215,11 @@ mod tests { ); } } + + let back = converter.convert_rows(&rows).unwrap(); + for (actual, expected) in back.iter().zip(&arrays) { + assert_eq!(actual, expected) + } } } } diff --git a/arrow/src/row/variable.rs b/arrow/src/row/variable.rs index 2213dad9e78..36f337e658b 100644 --- a/arrow/src/row/variable.rs +++ b/arrow/src/row/variable.rs @@ -16,12 +16,26 @@ // under the License. use crate::compute::SortOptions; -use crate::row::Rows; +use crate::row::{null_sentinel, Rows}; use crate::util::bit_util::ceil; +use arrow_array::builder::BufferBuilder; +use arrow_array::{Array, GenericBinaryArray, GenericStringArray, OffsetSizeTrait}; +use arrow_buffer::MutableBuffer; +use arrow_data::ArrayDataBuilder; +use arrow_schema::DataType; /// The block size of the variable length encoding pub const BLOCK_SIZE: usize = 32; +/// The continuation token +pub const BLOCK_CONTINUATION: u8 = 0xFF; + +/// Indicates an empty string +pub const EMPTY_SENTINEL: u8 = 1; + +/// Indicates a non-empty string +pub const NON_EMPTY_SENTINEL: u8 = 2; + /// Returns the length of the encoded representation of a byte array, including the null byte pub fn encoded_len(a: Option<&[u8]>) -> usize { match a { @@ -50,8 +64,8 @@ pub fn encode<'a, I: Iterator>>( match maybe_val { Some(val) if val.is_empty() => { out.buffer[*offset] = match opts.descending { - true => !1, - false => 1, + true => !EMPTY_SENTINEL, + false => EMPTY_SENTINEL, }; *offset += 1; } @@ -61,7 +75,7 @@ pub fn encode<'a, I: Iterator>>( let to_write = &mut out.buffer[*offset..end_offset]; // Write `2_u8` to demarcate as non-empty, non-null string - to_write[0] = 2; + to_write[0] = NON_EMPTY_SENTINEL; let chunks = val.chunks_exact(BLOCK_SIZE); let remainder = chunks.remainder(); @@ -76,7 +90,7 @@ pub fn encode<'a, I: Iterator>>( *out_block = *input; // Indicate that there are further blocks to follow - output[BLOCK_SIZE] = u8::MAX; + output[BLOCK_SIZE] = BLOCK_CONTINUATION; } if !remainder.is_empty() { @@ -97,11 +111,121 @@ pub fn encode<'a, I: Iterator>>( } } None => { - if !opts.nulls_first { - out.buffer[*offset] = 0xFF; - } + out.buffer[*offset] = null_sentinel(opts); *offset += 1; } } } } + +/// Returns the number of bytes of encoded data +fn decoded_len(row: &[u8], options: SortOptions) -> usize { + let (non_empty_sentinel, continuation) = match options.descending { + true => (!NON_EMPTY_SENTINEL, !BLOCK_CONTINUATION), + false => (NON_EMPTY_SENTINEL, BLOCK_CONTINUATION), + }; + + if row[0] != non_empty_sentinel { + // Empty or null string + return 0; + } + + let mut str_len = 0; + let mut idx = 1; + loop { + let sentinel = row[idx + BLOCK_SIZE]; + if sentinel == continuation { + idx += BLOCK_SIZE + 1; + str_len += BLOCK_SIZE; + continue; + } + let block_len = match options.descending { + true => !sentinel, + false => sentinel, + }; + return str_len + block_len as usize; + } +} + +/// Decodes a binary array from `rows` with the provided `options` +pub fn decode_binary( + rows: &mut [&[u8]], + options: SortOptions, +) -> GenericBinaryArray { + let len = rows.len(); + let mut null_count = 0; + let nulls = MutableBuffer::collect_bool(len, |x| { + let valid = rows[x][0] != null_sentinel(options); + null_count += !valid as usize; + valid + }); + + let values_capacity = rows.iter().map(|row| decoded_len(row, options)).sum(); + let mut offsets = BufferBuilder::::new(len + 1); + offsets.append(I::zero()); + let mut values = MutableBuffer::new(values_capacity); + + for row in rows { + let str_length = decoded_len(row, options); + let mut to_read = str_length; + let mut offset = 1; + while to_read >= BLOCK_SIZE { + to_read -= BLOCK_SIZE; + + values.extend_from_slice(&row[offset..offset + BLOCK_SIZE]); + offset += BLOCK_SIZE + 1; + } + + if to_read != 0 { + values.extend_from_slice(&row[offset..offset + to_read]); + offset += BLOCK_SIZE + 1; + } + *row = &row[offset..]; + + offsets.append(I::from_usize(values.len()).expect("offset overflow")) + } + + if options.descending { + values.as_slice_mut().iter_mut().for_each(|o| *o = !*o) + } + + let d = match I::IS_LARGE { + true => DataType::LargeBinary, + false => DataType::Binary, + }; + + let builder = ArrayDataBuilder::new(d) + .len(len) + .null_count(null_count) + .null_bit_buffer(Some(nulls.into())) + .add_buffer(offsets.finish()) + .add_buffer(values.into()); + + // SAFETY: + // Valid by construction above + unsafe { GenericBinaryArray::from(builder.build_unchecked()) } +} + +/// Decodes a string array from `rows` with the provided `options` +/// +/// # Safety +/// +/// The row must contain valid UTF-8 data +pub unsafe fn decode_string( + rows: &mut [&[u8]], + options: SortOptions, +) -> GenericStringArray { + let d = match I::IS_LARGE { + true => DataType::LargeUtf8, + false => DataType::Utf8, + }; + + let builder = decode_binary::(rows, options) + .into_data() + .into_builder() + .data_type(d); + + // SAFETY: + // Row data must have come from a valid UTF-8 array + GenericStringArray::from(builder.build_unchecked()) +}