diff --git a/parquet/src/arrow/array_reader/byte_array.rs b/parquet/src/arrow/array_reader/byte_array.rs index a29888f70e4..ec4188890ef 100644 --- a/parquet/src/arrow/array_reader/byte_array.rs +++ b/parquet/src/arrow/array_reader/byte_array.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use crate::arrow::array_reader::{read_records, ArrayReader, set_column_reader}; +use crate::arrow::array_reader::{read_records, skip_records, ArrayReader}; use crate::arrow::buffer::offset_buffer::OffsetBuffer; use crate::arrow::record_reader::buffer::ScalarValue; use crate::arrow::record_reader::GenericRecordReader; @@ -120,8 +120,7 @@ impl ArrayReader for ByteArrayReader { } fn skip_records(&mut self, num_records: usize) -> Result { - set_column_reader(&mut self.record_reader, self.pages.as_mut())?; - self.record_reader.skip_records(num_records) + skip_records(&mut self.record_reader, self.pages.as_mut(), num_records) } fn get_def_levels(&self) -> Option<&[i16]> { diff --git a/parquet/src/arrow/array_reader/byte_array_dictionary.rs b/parquet/src/arrow/array_reader/byte_array_dictionary.rs index eba9e578f55..51ef38d0d07 100644 --- a/parquet/src/arrow/array_reader/byte_array_dictionary.rs +++ b/parquet/src/arrow/array_reader/byte_array_dictionary.rs @@ -25,7 +25,7 @@ use arrow::buffer::Buffer; use arrow::datatypes::{ArrowNativeType, DataType as ArrowType}; use crate::arrow::array_reader::byte_array::{ByteArrayDecoder, ByteArrayDecoderPlain}; -use crate::arrow::array_reader::{read_records, ArrayReader, set_column_reader}; +use crate::arrow::array_reader::{read_records, ArrayReader, skip_records}; use crate::arrow::buffer::{ dictionary_buffer::DictionaryBuffer, offset_buffer::OffsetBuffer, }; @@ -181,8 +181,7 @@ where } fn skip_records(&mut self, num_records: usize) -> Result { - set_column_reader(&mut self.record_reader, self.pages.as_mut())?; - self.record_reader.skip_records(num_records) + skip_records(&mut self.record_reader, self.pages.as_mut(), num_records) } fn get_def_levels(&self) -> Option<&[i16]> { diff --git a/parquet/src/arrow/array_reader/mod.rs b/parquet/src/arrow/array_reader/mod.rs index a9d8cc0faa6..8bdd6c071c3 100644 --- a/parquet/src/arrow/array_reader/mod.rs +++ b/parquet/src/arrow/array_reader/mod.rs @@ -113,7 +113,7 @@ impl RowGroupCollection for Arc { /// Uses `record_reader` to read up to `batch_size` records from `pages` /// -/// Returns the number of records read, which can be less than batch_size if +/// Returns the number of records read, which can be less than `batch_size` if /// pages is exhausted. fn read_records( record_reader: &mut GenericRecordReader, @@ -145,29 +145,36 @@ where Ok(records_read) } -/// Uses `pages` to set up to `record_reader` 's `column_reader` +/// Uses `record_reader` to skip up to `batch_size` records from`pages` /// -/// If we skip records before all read operation, -/// need set `column_reader` by `set_page_reader` -/// for constructing `def_level_decoder` and `rep_level_decoder`. -fn set_column_reader( +/// Returns the number of records skipped, which can be less than `batch_size` if +/// pages is exhausted +fn skip_records( record_reader: &mut GenericRecordReader, pages: &mut dyn PageIterator, -) -> Result -where - V: ValuesBuffer + Default, - CV: ColumnValueDecoder, + batch_size: usize, +) -> Result + where + V: ValuesBuffer + Default, + CV: ColumnValueDecoder, { - return if record_reader.column_reader().is_none() { - // If we skip records before all read operation - // we need set `column_reader` by `set_page_reader` - if let Some(page_reader) = pages.next() { - record_reader.set_page_reader(page_reader?)?; - Ok(true) - } else { - Ok(false) + let mut records_skipped = 0usize; + while records_skipped < batch_size { + let records_to_read = batch_size - records_skipped; + + let records_skipped_once = record_reader.skip_records(records_to_read)?; + records_skipped += records_skipped_once; + + // Record reader exhausted + if records_skipped_once < records_to_read { + if let Some(page_reader) = pages.next() { + // Read from new page reader (i.e. column chunk) + record_reader.set_page_reader(page_reader?)?; + } else { + // Page reader also exhausted + break; + } } - } else { - Ok(true) - }; + } + Ok(records_skipped) } diff --git a/parquet/src/arrow/array_reader/null_array.rs b/parquet/src/arrow/array_reader/null_array.rs index a8c50b87f7e..63f73d41e4f 100644 --- a/parquet/src/arrow/array_reader/null_array.rs +++ b/parquet/src/arrow/array_reader/null_array.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use crate::arrow::array_reader::{read_records, ArrayReader, set_column_reader}; +use crate::arrow::array_reader::{read_records, ArrayReader, skip_records}; use crate::arrow::record_reader::buffer::ScalarValue; use crate::arrow::record_reader::RecordReader; use crate::column::page::PageIterator; @@ -97,8 +97,7 @@ where } fn skip_records(&mut self, num_records: usize) -> Result { - set_column_reader(&mut self.record_reader, self.pages.as_mut())?; - self.record_reader.skip_records(num_records) + skip_records(&mut self.record_reader, self.pages.as_mut(), num_records) } fn get_def_levels(&self) -> Option<&[i16]> { diff --git a/parquet/src/arrow/array_reader/primitive_array.rs b/parquet/src/arrow/array_reader/primitive_array.rs index 700b12b0a0b..2a59f0326d3 100644 --- a/parquet/src/arrow/array_reader/primitive_array.rs +++ b/parquet/src/arrow/array_reader/primitive_array.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use crate::arrow::array_reader::{read_records, set_column_reader, ArrayReader}; +use crate::arrow::array_reader::{read_records, skip_records, ArrayReader}; use crate::arrow::record_reader::buffer::ScalarValue; use crate::arrow::record_reader::RecordReader; use crate::arrow::schema::parquet_to_arrow_field; @@ -222,8 +222,7 @@ where } fn skip_records(&mut self, num_records: usize) -> Result { - set_column_reader(&mut self.record_reader, self.pages.as_mut())?; - self.record_reader.skip_records(num_records) + skip_records(&mut self.record_reader, self.pages.as_mut(), num_records) } fn get_def_levels(&self) -> Option<&[i16]> { diff --git a/parquet/src/arrow/arrow_reader.rs b/parquet/src/arrow/arrow_reader.rs index 8b4641d6814..5a864fd73c0 100644 --- a/parquet/src/arrow/arrow_reader.rs +++ b/parquet/src/arrow/arrow_reader.rs @@ -384,6 +384,7 @@ impl ParquetRecordBatchReader { mod tests { use bytes::Bytes; use std::cmp::min; + use std::collections::VecDeque; use std::convert::TryFrom; use std::fs::File; use std::io::Seek; @@ -1605,154 +1606,105 @@ mod tests { test_row_group_batch(MIN_BATCH_SIZE - 1, MIN_BATCH_SIZE); } - #[test] - fn test_scan_row_with_selection() { - let testdata = arrow::util::test_util::parquet_test_data(); - let path = format!("{}/alltypes_tiny_pages_plain.parquet", testdata); - let test_file = File::open(&path).unwrap(); + /// Given a RecordBatch containing all the column data, return the expected batches given + /// a `batch_size` and `selection` + fn get_expected_batches( + column: &RecordBatch, + selection: &[RowSelection], + batch_size: usize, + ) -> Vec { + let mut expected_batches = vec![]; + + let mut selection: VecDeque<_> = selection.iter().cloned().collect(); + let mut row_offset = 0; + let mut last_start = None; + while row_offset < column.num_rows() && !selection.is_empty() { + let mut batch_remaining = batch_size.min(column.num_rows() - row_offset); + while batch_remaining > 0 && !selection.is_empty() { + let (to_read, skip) = match selection.front_mut() { + Some(selection) if selection.row_count > batch_remaining => { + selection.row_count -= batch_remaining; + (batch_remaining, selection.skip) + } + Some(_) => { + let select = selection.pop_front().unwrap(); + (select.row_count, select.skip) + } + None => break, + }; - // total row count 7300 - // 1. test selection len more than one page row count - let batch_size = 1000; - let expected_data = create_expect_batch(&test_file, batch_size); - - let selections = create_test_selection(batch_size, 7300, false); - let skip_reader = create_skip_reader(&test_file, batch_size, selections); - let mut total_row_count = 0; - let mut index = 0; - for batch in skip_reader { - let batch = batch.unwrap(); - assert_eq!(batch, expected_data.get(index).unwrap().clone()); - index += 2; - let num = batch.num_rows(); - assert!(num == batch_size || num == 300); - total_row_count += num; - } - assert_eq!(total_row_count, 4000); + batch_remaining -= to_read; - let selections = create_test_selection(batch_size, 7300, true); - let skip_reader = create_skip_reader(&test_file, batch_size, selections); - let mut total_row_count = 0; - let mut index = 1; - for batch in skip_reader { - let batch = batch.unwrap(); - assert_eq!(batch, expected_data.get(index).unwrap().clone()); - index += 2; - let num = batch.num_rows(); - //the lase batch will be 300 - assert!(num == batch_size || num == 300); - total_row_count += num; + match skip { + true => { + if let Some(last_start) = last_start.take() { + expected_batches + .push(column.slice(last_start, row_offset - last_start)) + } + row_offset += to_read + } + false => { + last_start.get_or_insert(row_offset); + row_offset += to_read + } + } + } } - assert_eq!(total_row_count, 3300); - // 2. test selection len less than one page row count - let batch_size = 20; - let expected_data = create_expect_batch(&test_file, batch_size); - let selections = create_test_selection(batch_size, 7300, false); - - let skip_reader = create_skip_reader(&test_file, batch_size, selections); - let mut total_row_count = 0; - let mut index = 0; - for batch in skip_reader { - let batch = batch.unwrap(); - assert_eq!(batch, expected_data.get(index).unwrap().clone()); - index += 2; - let num = batch.num_rows(); - assert_eq!(num, batch_size); - total_row_count += num; + if let Some(last_start) = last_start.take() { + expected_batches.push(column.slice(last_start, row_offset - last_start)) } - assert_eq!(total_row_count, 3660); - let selections = create_test_selection(batch_size, 7300, true); - let skip_reader = create_skip_reader(&test_file, batch_size, selections); - let mut total_row_count = 0; - let mut index = 1; - for batch in skip_reader { - let batch = batch.unwrap(); - assert_eq!(batch, expected_data.get(index).unwrap().clone()); - index += 2; - let num = batch.num_rows(); - assert_eq!(num, batch_size); - total_row_count += num; + // Sanity check, all batches except the final should be the batch size + for batch in &expected_batches[..expected_batches.len() - 1] { + assert_eq!(batch.num_rows(), batch_size); } - assert_eq!(total_row_count, 3640); - // 3. test selection_len less than batch_size - let batch_size = 20; - let selection_len = 5; - let expected_data_batch = create_expect_batch(&test_file, batch_size); - let expected_data_selection = create_expect_batch(&test_file, selection_len); - let selections = create_test_selection(selection_len, 7300, false); - let skip_reader = create_skip_reader(&test_file, batch_size, selections); + expected_batches + } + + #[test] + fn test_scan_row_with_selection() { + let testdata = arrow::util::test_util::parquet_test_data(); + let path = format!("{}/alltypes_tiny_pages_plain.parquet", testdata); + let test_file = File::open(&path).unwrap(); - let mut total_row_count = 0; + let mut serial_arrow_reader = + ParquetFileArrowReader::try_new(File::open(path).unwrap()).unwrap(); + let mut serial_reader = serial_arrow_reader.get_record_reader(7300).unwrap(); + let data = serial_reader.next().unwrap().unwrap(); - for batch in skip_reader { - let batch = batch.unwrap(); - let num = batch.num_rows(); - assert!(num == batch_size || num == selection_len); - if num == batch_size { - assert_eq!( - batch, - expected_data_batch - .get(total_row_count / batch_size) - .unwrap() - .clone() - ); - total_row_count += batch_size; - } else if num == selection_len { + let do_test = |batch_size: usize, selection_len: usize| { + for skip_first in [false, true] { + let selections = + create_test_selection(batch_size, data.num_rows(), skip_first); + + let expected = get_expected_batches(&data, &selections, batch_size); + let skip_reader = create_skip_reader(&test_file, batch_size, selections); assert_eq!( - batch, - expected_data_selection - .get(total_row_count / selection_len) - .unwrap() - .clone() + skip_reader.collect::>>().unwrap(), + expected, + "batch_size: {}, selection_len: {}, skip_first: {}", + batch_size, + selection_len, + skip_first ); - total_row_count += selection_len; } - // add skip offset - total_row_count += selection_len; - } + }; + + // total row count 7300 + // 1. test selection len more than one page row count + do_test(1000, 1000); + + // 2. test selection len less than one page row count + do_test(20, 20); + + // 3. test selection_len less than batch_size + do_test(20, 5); // 4. test selection_len more than batch_size - // If batch_size < selection_len will divide selection(50, read) -> - // selection(20, read), selection(20, read), selection(10, read) - let batch_size = 20; - let selection_len = 50; - let another_batch_size = 10; - let expected_data_batch = create_expect_batch(&test_file, batch_size); - let expected_data_batch2 = create_expect_batch(&test_file, another_batch_size); - let selections = create_test_selection(selection_len, 7300, false); - let skip_reader = create_skip_reader(&test_file, batch_size, selections); - - let mut total_row_count = 0; - - for batch in skip_reader { - let batch = batch.unwrap(); - let num = batch.num_rows(); - assert!(num == batch_size || num == another_batch_size); - if num == batch_size { - assert_eq!( - batch, - expected_data_batch - .get(total_row_count / batch_size) - .unwrap() - .clone() - ); - total_row_count += batch_size; - } else if num == another_batch_size { - assert_eq!( - batch, - expected_data_batch2 - .get(total_row_count / another_batch_size) - .unwrap() - .clone() - ); - total_row_count += 10; - // add skip offset - total_row_count += selection_len; - } - } + // If batch_size < selection_len + do_test(20, 5); fn create_skip_reader( test_file: &File, @@ -1793,17 +1745,5 @@ mod tests { } vec } - - fn create_expect_batch(test_file: &File, batch_size: usize) -> Vec { - let mut serial_arrow_reader = - ParquetFileArrowReader::try_new(test_file.try_clone().unwrap()).unwrap(); - let serial_reader = - serial_arrow_reader.get_record_reader(batch_size).unwrap(); - let mut expected_data = vec![]; - for batch in serial_reader { - expected_data.push(batch.unwrap()); - } - expected_data - } } }