Skip to content

Commit

Permalink
Separate ArrayReader::next_batch with read_records and consume_batch (#…
Browse files Browse the repository at this point in the history
…2237)

* replace ArrayReader::next_batch with ArrayReader::read_records and ArrayReader::consume_batch.

* fix ut

* fix comment

* avoid clone.

* fix new ut

* fix comment

Co-authored-by: Raphael Taylor-Davies <r.taylordavies@googlemail.com>
  • Loading branch information
Ted-Jiang and tustvold committed Aug 3, 2022
1 parent 6b2c757 commit 1f9973c
Show file tree
Hide file tree
Showing 13 changed files with 233 additions and 67 deletions.
7 changes: 5 additions & 2 deletions parquet/src/arrow/array_reader/byte_array.rs
Expand Up @@ -108,8 +108,11 @@ impl<I: OffsetSizeTrait + ScalarValue> ArrayReader for ByteArrayReader<I> {
&self.data_type
}

fn next_batch(&mut self, batch_size: usize) -> Result<ArrayRef> {
read_records(&mut self.record_reader, self.pages.as_mut(), batch_size)?;
fn read_records(&mut self, batch_size: usize) -> Result<usize> {
read_records(&mut self.record_reader, self.pages.as_mut(), batch_size)
}

fn consume_batch(&mut self) -> Result<ArrayRef> {
let buffer = self.record_reader.consume_record_data();
let null_buffer = self.record_reader.consume_bitmap_buffer();
self.def_levels_buffer = self.record_reader.consume_def_levels();
Expand Down
9 changes: 6 additions & 3 deletions parquet/src/arrow/array_reader/byte_array_dictionary.rs
Expand Up @@ -25,7 +25,7 @@ use arrow::buffer::Buffer;
use arrow::datatypes::{ArrowNativeType, DataType as ArrowType};

use crate::arrow::array_reader::byte_array::{ByteArrayDecoder, ByteArrayDecoderPlain};
use crate::arrow::array_reader::{read_records, ArrayReader, skip_records};
use crate::arrow::array_reader::{read_records, skip_records, ArrayReader};
use crate::arrow::buffer::{
dictionary_buffer::DictionaryBuffer, offset_buffer::OffsetBuffer,
};
Expand Down Expand Up @@ -167,8 +167,11 @@ where
&self.data_type
}

fn next_batch(&mut self, batch_size: usize) -> Result<ArrayRef> {
read_records(&mut self.record_reader, self.pages.as_mut(), batch_size)?;
fn read_records(&mut self, batch_size: usize) -> Result<usize> {
read_records(&mut self.record_reader, self.pages.as_mut(), batch_size)
}

fn consume_batch(&mut self) -> Result<ArrayRef> {
let buffer = self.record_reader.consume_record_data();
let null_buffer = self.record_reader.consume_bitmap_buffer();
let array = buffer.into_array(null_buffer, &self.data_type)?;
Expand Down
133 changes: 94 additions & 39 deletions parquet/src/arrow/array_reader/complex_object_array.rs
Expand Up @@ -39,9 +39,13 @@ where
pages: Box<dyn PageIterator>,
def_levels_buffer: Option<Vec<i16>>,
rep_levels_buffer: Option<Vec<i16>>,
data_buffer: Vec<T::T>,
column_desc: ColumnDescPtr,
column_reader: Option<ColumnReaderImpl<T>>,
converter: C,
in_progress_def_levels_buffer: Option<Vec<i16>>,
in_progress_rep_levels_buffer: Option<Vec<i16>>,
before_consume: bool,
_parquet_type_marker: PhantomData<T>,
_converter_marker: PhantomData<C>,
}
Expand All @@ -59,7 +63,10 @@ where
&self.data_type
}

fn next_batch(&mut self, batch_size: usize) -> Result<ArrayRef> {
fn read_records(&mut self, batch_size: usize) -> Result<usize> {
if !self.before_consume {
self.before_consume = true;
}
// Try to initialize column reader
if self.column_reader.is_none() {
self.next_column_reader()?;
Expand Down Expand Up @@ -126,7 +133,6 @@ where
break;
}
}

data_buffer.truncate(num_read);
def_levels_buffer
.iter_mut()
Expand All @@ -135,13 +141,35 @@ where
.iter_mut()
.for_each(|buf| buf.truncate(num_read));

self.def_levels_buffer = def_levels_buffer;
self.rep_levels_buffer = rep_levels_buffer;
if let Some(mut def_levels_buffer) = def_levels_buffer {
match &mut self.in_progress_def_levels_buffer {
None => {
self.in_progress_def_levels_buffer = Some(def_levels_buffer);
}
Some(buf) => buf.append(&mut def_levels_buffer),
}
}

if let Some(mut rep_levels_buffer) = rep_levels_buffer {
match &mut self.in_progress_rep_levels_buffer {
None => {
self.in_progress_rep_levels_buffer = Some(rep_levels_buffer);
}
Some(buf) => buf.append(&mut rep_levels_buffer),
}
}

self.data_buffer.append(&mut data_buffer);

Ok(num_read)
}

let data: Vec<Option<T::T>> = if self.def_levels_buffer.is_some() {
fn consume_batch(&mut self) -> Result<ArrayRef> {
let data: Vec<Option<T::T>> = if self.in_progress_def_levels_buffer.is_some() {
let data_buffer = std::mem::take(&mut self.data_buffer);
data_buffer
.into_iter()
.zip(self.def_levels_buffer.as_ref().unwrap().iter())
.zip(self.in_progress_def_levels_buffer.as_ref().unwrap().iter())
.map(|(t, def_level)| {
if *def_level == self.column_desc.max_def_level() {
Some(t)
Expand All @@ -151,7 +179,7 @@ where
})
.collect()
} else {
data_buffer.into_iter().map(Some).collect()
self.data_buffer.iter().map(|x| Some(x.clone())).collect()
};

let mut array = self.converter.convert(data)?;
Expand All @@ -160,6 +188,11 @@ where
array = arrow::compute::cast(&array, &self.data_type)?;
}

self.data_buffer = vec![];
self.def_levels_buffer = std::mem::take(&mut self.in_progress_def_levels_buffer);
self.rep_levels_buffer = std::mem::take(&mut self.in_progress_rep_levels_buffer);
self.before_consume = false;

Ok(array)
}

Expand All @@ -168,20 +201,31 @@ where
Some(reader) => reader.skip_records(num_records),
None => {
if self.next_column_reader()? {
self.column_reader.as_mut().unwrap().skip_records(num_records)
}else {
self.column_reader
.as_mut()
.unwrap()
.skip_records(num_records)
} else {
Ok(0)
}
}
}
}

fn get_def_levels(&self) -> Option<&[i16]> {
self.def_levels_buffer.as_deref()
if self.before_consume {
self.in_progress_def_levels_buffer.as_deref()
} else {
self.def_levels_buffer.as_deref()
}
}

fn get_rep_levels(&self) -> Option<&[i16]> {
self.rep_levels_buffer.as_deref()
if self.before_consume {
self.in_progress_rep_levels_buffer.as_deref()
} else {
self.rep_levels_buffer.as_deref()
}
}
}

Expand All @@ -208,9 +252,13 @@ where
pages,
def_levels_buffer: None,
rep_levels_buffer: None,
data_buffer: vec![],
column_desc,
column_reader: None,
converter,
in_progress_def_levels_buffer: None,
in_progress_rep_levels_buffer: None,
before_consume: true,
_parquet_type_marker: PhantomData,
_converter_marker: PhantomData,
})
Expand Down Expand Up @@ -349,30 +397,32 @@ mod tests {

let mut accu_len: usize = 0;

let array = array_reader.next_batch(values_per_page / 2).unwrap();
assert_eq!(array.len(), values_per_page / 2);
let len = array_reader.read_records(values_per_page / 2).unwrap();
assert_eq!(len, values_per_page / 2);
assert_eq!(
Some(&def_levels[accu_len..(accu_len + array.len())]),
Some(&def_levels[accu_len..(accu_len + len)]),
array_reader.get_def_levels()
);
assert_eq!(
Some(&rep_levels[accu_len..(accu_len + array.len())]),
Some(&rep_levels[accu_len..(accu_len + len)]),
array_reader.get_rep_levels()
);
accu_len += array.len();
accu_len += len;
array_reader.consume_batch().unwrap();

// Read next values_per_page values, the first values_per_page/2 ones are from the first column chunk,
// and the last values_per_page/2 ones are from the second column chunk
let array = array_reader.next_batch(values_per_page).unwrap();
assert_eq!(array.len(), values_per_page);
let len = array_reader.read_records(values_per_page).unwrap();
assert_eq!(len, values_per_page);
assert_eq!(
Some(&def_levels[accu_len..(accu_len + array.len())]),
Some(&def_levels[accu_len..(accu_len + len)]),
array_reader.get_def_levels()
);
assert_eq!(
Some(&rep_levels[accu_len..(accu_len + array.len())]),
Some(&rep_levels[accu_len..(accu_len + len)]),
array_reader.get_rep_levels()
);
let array = array_reader.consume_batch().unwrap();
let strings = array.as_any().downcast_ref::<StringArray>().unwrap();
for i in 0..array.len() {
if array.is_valid(i) {
Expand All @@ -384,19 +434,20 @@ mod tests {
assert_eq!(all_values[i + accu_len], None)
}
}
accu_len += array.len();
accu_len += len;

// Try to read values_per_page values, however there are only values_per_page/2 values
let array = array_reader.next_batch(values_per_page).unwrap();
assert_eq!(array.len(), values_per_page / 2);
let len = array_reader.read_records(values_per_page).unwrap();
assert_eq!(len, values_per_page / 2);
assert_eq!(
Some(&def_levels[accu_len..(accu_len + array.len())]),
Some(&def_levels[accu_len..(accu_len + len)]),
array_reader.get_def_levels()
);
assert_eq!(
Some(&rep_levels[accu_len..(accu_len + array.len())]),
Some(&rep_levels[accu_len..(accu_len + len)]),
array_reader.get_rep_levels()
);
array_reader.consume_batch().unwrap();
}

#[test]
Expand Down Expand Up @@ -491,31 +542,34 @@ mod tests {
let mut accu_len: usize = 0;

// println!("---------- reading a batch of {} values ----------", values_per_page / 2);
let array = array_reader.next_batch(values_per_page / 2).unwrap();
assert_eq!(array.len(), values_per_page / 2);
let len = array_reader.read_records(values_per_page / 2).unwrap();
assert_eq!(len, values_per_page / 2);
assert_eq!(
Some(&def_levels[accu_len..(accu_len + array.len())]),
Some(&def_levels[accu_len..(accu_len + len)]),
array_reader.get_def_levels()
);
assert_eq!(
Some(&rep_levels[accu_len..(accu_len + array.len())]),
Some(&rep_levels[accu_len..(accu_len + len)]),
array_reader.get_rep_levels()
);
accu_len += array.len();
accu_len += len;
array_reader.consume_batch().unwrap();

// Read next values_per_page values, the first values_per_page/2 ones are from the first column chunk,
// and the last values_per_page/2 ones are from the second column chunk
// println!("---------- reading a batch of {} values ----------", values_per_page);
let array = array_reader.next_batch(values_per_page).unwrap();
assert_eq!(array.len(), values_per_page);
//let array = array_reader.next_batch(values_per_page).unwrap();
let len = array_reader.read_records(values_per_page).unwrap();
assert_eq!(len, values_per_page);
assert_eq!(
Some(&def_levels[accu_len..(accu_len + array.len())]),
Some(&def_levels[accu_len..(accu_len + len)]),
array_reader.get_def_levels()
);
assert_eq!(
Some(&rep_levels[accu_len..(accu_len + array.len())]),
Some(&rep_levels[accu_len..(accu_len + len)]),
array_reader.get_rep_levels()
);
let array = array_reader.consume_batch().unwrap();
let strings = array.as_any().downcast_ref::<StringArray>().unwrap();
for i in 0..array.len() {
if array.is_valid(i) {
Expand All @@ -527,19 +581,20 @@ mod tests {
assert_eq!(all_values[i + accu_len], None)
}
}
accu_len += array.len();
accu_len += len;

// Try to read values_per_page values, however there are only values_per_page/2 values
// println!("---------- reading a batch of {} values ----------", values_per_page);
let array = array_reader.next_batch(values_per_page).unwrap();
assert_eq!(array.len(), values_per_page / 2);
let len = array_reader.read_records(values_per_page).unwrap();
assert_eq!(len, values_per_page / 2);
assert_eq!(
Some(&def_levels[accu_len..(accu_len + array.len())]),
Some(&def_levels[accu_len..(accu_len + len)]),
array_reader.get_def_levels()
);
assert_eq!(
Some(&rep_levels[accu_len..(accu_len + array.len())]),
Some(&rep_levels[accu_len..(accu_len + len)]),
array_reader.get_rep_levels()
);
array_reader.consume_batch().unwrap();
}
}
12 changes: 9 additions & 3 deletions parquet/src/arrow/array_reader/empty_array.rs
Expand Up @@ -33,13 +33,15 @@ pub fn make_empty_array_reader(row_count: usize) -> Box<dyn ArrayReader> {
struct EmptyArrayReader {
data_type: ArrowType,
remaining_rows: usize,
need_consume_records: usize,
}

impl EmptyArrayReader {
pub fn new(row_count: usize) -> Self {
Self {
data_type: ArrowType::Struct(vec![]),
remaining_rows: row_count,
need_consume_records: 0,
}
}
}
Expand All @@ -53,15 +55,19 @@ impl ArrayReader for EmptyArrayReader {
&self.data_type
}

fn next_batch(&mut self, batch_size: usize) -> Result<ArrayRef> {
fn read_records(&mut self, batch_size: usize) -> Result<usize> {
let len = self.remaining_rows.min(batch_size);
self.remaining_rows -= len;
self.need_consume_records += len;
Ok(len)
}

fn consume_batch(&mut self) -> Result<ArrayRef> {
let data = ArrayDataBuilder::new(self.data_type.clone())
.len(len)
.len(self.need_consume_records)
.build()
.unwrap();

self.need_consume_records = 0;
Ok(Arc::new(StructArray::from(data)))
}

Expand Down
8 changes: 6 additions & 2 deletions parquet/src/arrow/array_reader/list_array.rs
Expand Up @@ -78,9 +78,13 @@ impl<OffsetSize: OffsetSizeTrait> ArrayReader for ListArrayReader<OffsetSize> {
&self.data_type
}

fn next_batch(&mut self, batch_size: usize) -> Result<ArrayRef> {
let next_batch_array = self.item_reader.next_batch(batch_size)?;
fn read_records(&mut self, batch_size: usize) -> Result<usize> {
let size = self.item_reader.read_records(batch_size)?;
Ok(size)
}

fn consume_batch(&mut self) -> Result<ArrayRef> {
let next_batch_array = self.item_reader.consume_batch()?;
if next_batch_array.len() == 0 {
return Ok(new_empty_array(&self.data_type));
}
Expand Down
18 changes: 15 additions & 3 deletions parquet/src/arrow/array_reader/map_array.rs
Expand Up @@ -62,9 +62,21 @@ impl ArrayReader for MapArrayReader {
&self.data_type
}

fn next_batch(&mut self, batch_size: usize) -> Result<ArrayRef> {
let key_array = self.key_reader.next_batch(batch_size)?;
let value_array = self.value_reader.next_batch(batch_size)?;
fn read_records(&mut self, batch_size: usize) -> Result<usize> {
let key_len = self.key_reader.read_records(batch_size)?;
let value_len = self.value_reader.read_records(batch_size)?;
// Check that key and value have the same lengths
if key_len != value_len {
return Err(general_err!(
"Map key and value should have the same lengths."
));
}
Ok(key_len)
}

fn consume_batch(&mut self) -> Result<ArrayRef> {
let key_array = self.key_reader.consume_batch()?;
let value_array = self.value_reader.consume_batch()?;

// Check that key and value have the same lengths
let key_length = key_array.len();
Expand Down

0 comments on commit 1f9973c

Please sign in to comment.