diff --git a/parquet/src/arrow/array_reader/byte_array.rs b/parquet/src/arrow/array_reader/byte_array.rs index ec4188890ef..172aeb96d6d 100644 --- a/parquet/src/arrow/array_reader/byte_array.rs +++ b/parquet/src/arrow/array_reader/byte_array.rs @@ -108,8 +108,11 @@ impl ArrayReader for ByteArrayReader { &self.data_type } - fn next_batch(&mut self, batch_size: usize) -> Result { - read_records(&mut self.record_reader, self.pages.as_mut(), batch_size)?; + fn read_records(&mut self, batch_size: usize) -> Result { + read_records(&mut self.record_reader, self.pages.as_mut(), batch_size) + } + + fn consume_batch(&mut self) -> Result { let buffer = self.record_reader.consume_record_data(); let null_buffer = self.record_reader.consume_bitmap_buffer(); self.def_levels_buffer = self.record_reader.consume_def_levels(); diff --git a/parquet/src/arrow/array_reader/byte_array_dictionary.rs b/parquet/src/arrow/array_reader/byte_array_dictionary.rs index 51ef38d0d07..0a5d94fa6ae 100644 --- a/parquet/src/arrow/array_reader/byte_array_dictionary.rs +++ b/parquet/src/arrow/array_reader/byte_array_dictionary.rs @@ -25,7 +25,7 @@ use arrow::buffer::Buffer; use arrow::datatypes::{ArrowNativeType, DataType as ArrowType}; use crate::arrow::array_reader::byte_array::{ByteArrayDecoder, ByteArrayDecoderPlain}; -use crate::arrow::array_reader::{read_records, ArrayReader, skip_records}; +use crate::arrow::array_reader::{read_records, skip_records, ArrayReader}; use crate::arrow::buffer::{ dictionary_buffer::DictionaryBuffer, offset_buffer::OffsetBuffer, }; @@ -167,8 +167,11 @@ where &self.data_type } - fn next_batch(&mut self, batch_size: usize) -> Result { - read_records(&mut self.record_reader, self.pages.as_mut(), batch_size)?; + fn read_records(&mut self, batch_size: usize) -> Result { + read_records(&mut self.record_reader, self.pages.as_mut(), batch_size) + } + + fn consume_batch(&mut self) -> Result { let buffer = self.record_reader.consume_record_data(); let null_buffer = self.record_reader.consume_bitmap_buffer(); let array = buffer.into_array(null_buffer, &self.data_type)?; diff --git a/parquet/src/arrow/array_reader/complex_object_array.rs b/parquet/src/arrow/array_reader/complex_object_array.rs index 1390866cf6a..79b53733176 100644 --- a/parquet/src/arrow/array_reader/complex_object_array.rs +++ b/parquet/src/arrow/array_reader/complex_object_array.rs @@ -39,9 +39,13 @@ where pages: Box, def_levels_buffer: Option>, rep_levels_buffer: Option>, + data_buffer: Vec, column_desc: ColumnDescPtr, column_reader: Option>, converter: C, + in_progress_def_levels_buffer: Option>, + in_progress_rep_levels_buffer: Option>, + before_consume: bool, _parquet_type_marker: PhantomData, _converter_marker: PhantomData, } @@ -59,7 +63,10 @@ where &self.data_type } - fn next_batch(&mut self, batch_size: usize) -> Result { + fn read_records(&mut self, batch_size: usize) -> Result { + if !self.before_consume { + self.before_consume = true; + } // Try to initialize column reader if self.column_reader.is_none() { self.next_column_reader()?; @@ -126,7 +133,6 @@ where break; } } - data_buffer.truncate(num_read); def_levels_buffer .iter_mut() @@ -135,13 +141,35 @@ where .iter_mut() .for_each(|buf| buf.truncate(num_read)); - self.def_levels_buffer = def_levels_buffer; - self.rep_levels_buffer = rep_levels_buffer; + if let Some(mut def_levels_buffer) = def_levels_buffer { + match &mut self.in_progress_def_levels_buffer { + None => { + self.in_progress_def_levels_buffer = Some(def_levels_buffer); + } + Some(buf) => buf.append(&mut def_levels_buffer), + } + } + + if let Some(mut rep_levels_buffer) = rep_levels_buffer { + match &mut self.in_progress_rep_levels_buffer { + None => { + self.in_progress_rep_levels_buffer = Some(rep_levels_buffer); + } + Some(buf) => buf.append(&mut rep_levels_buffer), + } + } + + self.data_buffer.append(&mut data_buffer); + + Ok(num_read) + } - let data: Vec> = if self.def_levels_buffer.is_some() { + fn consume_batch(&mut self) -> Result { + let data: Vec> = if self.in_progress_def_levels_buffer.is_some() { + let data_buffer = std::mem::take(&mut self.data_buffer); data_buffer .into_iter() - .zip(self.def_levels_buffer.as_ref().unwrap().iter()) + .zip(self.in_progress_def_levels_buffer.as_ref().unwrap().iter()) .map(|(t, def_level)| { if *def_level == self.column_desc.max_def_level() { Some(t) @@ -151,7 +179,7 @@ where }) .collect() } else { - data_buffer.into_iter().map(Some).collect() + self.data_buffer.iter().map(|x| Some(x.clone())).collect() }; let mut array = self.converter.convert(data)?; @@ -160,6 +188,11 @@ where array = arrow::compute::cast(&array, &self.data_type)?; } + self.data_buffer = vec![]; + self.def_levels_buffer = std::mem::take(&mut self.in_progress_def_levels_buffer); + self.rep_levels_buffer = std::mem::take(&mut self.in_progress_rep_levels_buffer); + self.before_consume = false; + Ok(array) } @@ -168,8 +201,11 @@ where Some(reader) => reader.skip_records(num_records), None => { if self.next_column_reader()? { - self.column_reader.as_mut().unwrap().skip_records(num_records) - }else { + self.column_reader + .as_mut() + .unwrap() + .skip_records(num_records) + } else { Ok(0) } } @@ -177,11 +213,19 @@ where } fn get_def_levels(&self) -> Option<&[i16]> { - self.def_levels_buffer.as_deref() + if self.before_consume { + self.in_progress_def_levels_buffer.as_deref() + } else { + self.def_levels_buffer.as_deref() + } } fn get_rep_levels(&self) -> Option<&[i16]> { - self.rep_levels_buffer.as_deref() + if self.before_consume { + self.in_progress_rep_levels_buffer.as_deref() + } else { + self.rep_levels_buffer.as_deref() + } } } @@ -208,9 +252,13 @@ where pages, def_levels_buffer: None, rep_levels_buffer: None, + data_buffer: vec![], column_desc, column_reader: None, converter, + in_progress_def_levels_buffer: None, + in_progress_rep_levels_buffer: None, + before_consume: true, _parquet_type_marker: PhantomData, _converter_marker: PhantomData, }) @@ -349,30 +397,32 @@ mod tests { let mut accu_len: usize = 0; - let array = array_reader.next_batch(values_per_page / 2).unwrap(); - assert_eq!(array.len(), values_per_page / 2); + let len = array_reader.read_records(values_per_page / 2).unwrap(); + assert_eq!(len, values_per_page / 2); assert_eq!( - Some(&def_levels[accu_len..(accu_len + array.len())]), + Some(&def_levels[accu_len..(accu_len + len)]), array_reader.get_def_levels() ); assert_eq!( - Some(&rep_levels[accu_len..(accu_len + array.len())]), + Some(&rep_levels[accu_len..(accu_len + len)]), array_reader.get_rep_levels() ); - accu_len += array.len(); + accu_len += len; + array_reader.consume_batch().unwrap(); // Read next values_per_page values, the first values_per_page/2 ones are from the first column chunk, // and the last values_per_page/2 ones are from the second column chunk - let array = array_reader.next_batch(values_per_page).unwrap(); - assert_eq!(array.len(), values_per_page); + let len = array_reader.read_records(values_per_page).unwrap(); + assert_eq!(len, values_per_page); assert_eq!( - Some(&def_levels[accu_len..(accu_len + array.len())]), + Some(&def_levels[accu_len..(accu_len + len)]), array_reader.get_def_levels() ); assert_eq!( - Some(&rep_levels[accu_len..(accu_len + array.len())]), + Some(&rep_levels[accu_len..(accu_len + len)]), array_reader.get_rep_levels() ); + let array = array_reader.consume_batch().unwrap(); let strings = array.as_any().downcast_ref::().unwrap(); for i in 0..array.len() { if array.is_valid(i) { @@ -384,19 +434,20 @@ mod tests { assert_eq!(all_values[i + accu_len], None) } } - accu_len += array.len(); + accu_len += len; // Try to read values_per_page values, however there are only values_per_page/2 values - let array = array_reader.next_batch(values_per_page).unwrap(); - assert_eq!(array.len(), values_per_page / 2); + let len = array_reader.read_records(values_per_page).unwrap(); + assert_eq!(len, values_per_page / 2); assert_eq!( - Some(&def_levels[accu_len..(accu_len + array.len())]), + Some(&def_levels[accu_len..(accu_len + len)]), array_reader.get_def_levels() ); assert_eq!( - Some(&rep_levels[accu_len..(accu_len + array.len())]), + Some(&rep_levels[accu_len..(accu_len + len)]), array_reader.get_rep_levels() ); + array_reader.consume_batch().unwrap(); } #[test] @@ -491,31 +542,34 @@ mod tests { let mut accu_len: usize = 0; // println!("---------- reading a batch of {} values ----------", values_per_page / 2); - let array = array_reader.next_batch(values_per_page / 2).unwrap(); - assert_eq!(array.len(), values_per_page / 2); + let len = array_reader.read_records(values_per_page / 2).unwrap(); + assert_eq!(len, values_per_page / 2); assert_eq!( - Some(&def_levels[accu_len..(accu_len + array.len())]), + Some(&def_levels[accu_len..(accu_len + len)]), array_reader.get_def_levels() ); assert_eq!( - Some(&rep_levels[accu_len..(accu_len + array.len())]), + Some(&rep_levels[accu_len..(accu_len + len)]), array_reader.get_rep_levels() ); - accu_len += array.len(); + accu_len += len; + array_reader.consume_batch().unwrap(); // Read next values_per_page values, the first values_per_page/2 ones are from the first column chunk, // and the last values_per_page/2 ones are from the second column chunk // println!("---------- reading a batch of {} values ----------", values_per_page); - let array = array_reader.next_batch(values_per_page).unwrap(); - assert_eq!(array.len(), values_per_page); + //let array = array_reader.next_batch(values_per_page).unwrap(); + let len = array_reader.read_records(values_per_page).unwrap(); + assert_eq!(len, values_per_page); assert_eq!( - Some(&def_levels[accu_len..(accu_len + array.len())]), + Some(&def_levels[accu_len..(accu_len + len)]), array_reader.get_def_levels() ); assert_eq!( - Some(&rep_levels[accu_len..(accu_len + array.len())]), + Some(&rep_levels[accu_len..(accu_len + len)]), array_reader.get_rep_levels() ); + let array = array_reader.consume_batch().unwrap(); let strings = array.as_any().downcast_ref::().unwrap(); for i in 0..array.len() { if array.is_valid(i) { @@ -527,19 +581,20 @@ mod tests { assert_eq!(all_values[i + accu_len], None) } } - accu_len += array.len(); + accu_len += len; // Try to read values_per_page values, however there are only values_per_page/2 values // println!("---------- reading a batch of {} values ----------", values_per_page); - let array = array_reader.next_batch(values_per_page).unwrap(); - assert_eq!(array.len(), values_per_page / 2); + let len = array_reader.read_records(values_per_page).unwrap(); + assert_eq!(len, values_per_page / 2); assert_eq!( - Some(&def_levels[accu_len..(accu_len + array.len())]), + Some(&def_levels[accu_len..(accu_len + len)]), array_reader.get_def_levels() ); assert_eq!( - Some(&rep_levels[accu_len..(accu_len + array.len())]), + Some(&rep_levels[accu_len..(accu_len + len)]), array_reader.get_rep_levels() ); + array_reader.consume_batch().unwrap(); } } diff --git a/parquet/src/arrow/array_reader/empty_array.rs b/parquet/src/arrow/array_reader/empty_array.rs index b06646cc1c6..abe839b9dc2 100644 --- a/parquet/src/arrow/array_reader/empty_array.rs +++ b/parquet/src/arrow/array_reader/empty_array.rs @@ -33,6 +33,7 @@ pub fn make_empty_array_reader(row_count: usize) -> Box { struct EmptyArrayReader { data_type: ArrowType, remaining_rows: usize, + need_consume_records: usize, } impl EmptyArrayReader { @@ -40,6 +41,7 @@ impl EmptyArrayReader { Self { data_type: ArrowType::Struct(vec![]), remaining_rows: row_count, + need_consume_records: 0, } } } @@ -53,15 +55,19 @@ impl ArrayReader for EmptyArrayReader { &self.data_type } - fn next_batch(&mut self, batch_size: usize) -> Result { + fn read_records(&mut self, batch_size: usize) -> Result { let len = self.remaining_rows.min(batch_size); self.remaining_rows -= len; + self.need_consume_records += len; + Ok(len) + } + fn consume_batch(&mut self) -> Result { let data = ArrayDataBuilder::new(self.data_type.clone()) - .len(len) + .len(self.need_consume_records) .build() .unwrap(); - + self.need_consume_records = 0; Ok(Arc::new(StructArray::from(data))) } diff --git a/parquet/src/arrow/array_reader/list_array.rs b/parquet/src/arrow/array_reader/list_array.rs index 33bd9772a16..c245c61312f 100644 --- a/parquet/src/arrow/array_reader/list_array.rs +++ b/parquet/src/arrow/array_reader/list_array.rs @@ -78,9 +78,13 @@ impl ArrayReader for ListArrayReader { &self.data_type } - fn next_batch(&mut self, batch_size: usize) -> Result { - let next_batch_array = self.item_reader.next_batch(batch_size)?; + fn read_records(&mut self, batch_size: usize) -> Result { + let size = self.item_reader.read_records(batch_size)?; + Ok(size) + } + fn consume_batch(&mut self) -> Result { + let next_batch_array = self.item_reader.consume_batch()?; if next_batch_array.len() == 0 { return Ok(new_empty_array(&self.data_type)); } diff --git a/parquet/src/arrow/array_reader/map_array.rs b/parquet/src/arrow/array_reader/map_array.rs index 00c3db41a37..83ba63ca170 100644 --- a/parquet/src/arrow/array_reader/map_array.rs +++ b/parquet/src/arrow/array_reader/map_array.rs @@ -62,9 +62,21 @@ impl ArrayReader for MapArrayReader { &self.data_type } - fn next_batch(&mut self, batch_size: usize) -> Result { - let key_array = self.key_reader.next_batch(batch_size)?; - let value_array = self.value_reader.next_batch(batch_size)?; + fn read_records(&mut self, batch_size: usize) -> Result { + let key_len = self.key_reader.read_records(batch_size)?; + let value_len = self.value_reader.read_records(batch_size)?; + // Check that key and value have the same lengths + if key_len != value_len { + return Err(general_err!( + "Map key and value should have the same lengths." + )); + } + Ok(key_len) + } + + fn consume_batch(&mut self) -> Result { + let key_array = self.key_reader.consume_batch()?; + let value_array = self.value_reader.consume_batch()?; // Check that key and value have the same lengths let key_length = key_array.len(); diff --git a/parquet/src/arrow/array_reader/mod.rs b/parquet/src/arrow/array_reader/mod.rs index 8bdd6c071c3..d7665ef0f6b 100644 --- a/parquet/src/arrow/array_reader/mod.rs +++ b/parquet/src/arrow/array_reader/mod.rs @@ -62,7 +62,20 @@ pub trait ArrayReader: Send { fn get_data_type(&self) -> &ArrowType; /// Reads at most `batch_size` records into an arrow array and return it. - fn next_batch(&mut self, batch_size: usize) -> Result; + fn next_batch(&mut self, batch_size: usize) -> Result { + self.read_records(batch_size)?; + self.consume_batch() + } + + /// Reads at most `batch_size` records' bytes into buffer + /// + /// Returns the number of records read, which can be less than `batch_size` if + /// pages is exhausted. + fn read_records(&mut self, batch_size: usize) -> Result; + + /// Consume all currently stored buffer data + /// into an arrow array and return it. + fn consume_batch(&mut self) -> Result; /// Skips over `num_records` records, returning the number of rows skipped fn skip_records(&mut self, num_records: usize) -> Result; diff --git a/parquet/src/arrow/array_reader/null_array.rs b/parquet/src/arrow/array_reader/null_array.rs index 63f73d41e4f..682d15f8a17 100644 --- a/parquet/src/arrow/array_reader/null_array.rs +++ b/parquet/src/arrow/array_reader/null_array.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use crate::arrow::array_reader::{read_records, ArrayReader, skip_records}; +use crate::arrow::array_reader::{read_records, skip_records, ArrayReader}; use crate::arrow::record_reader::buffer::ScalarValue; use crate::arrow::record_reader::RecordReader; use crate::column::page::PageIterator; @@ -78,10 +78,11 @@ where &self.data_type } - /// Reads at most `batch_size` records into array. - fn next_batch(&mut self, batch_size: usize) -> Result { - read_records(&mut self.record_reader, self.pages.as_mut(), batch_size)?; + fn read_records(&mut self, batch_size: usize) -> Result { + read_records(&mut self.record_reader, self.pages.as_mut(), batch_size) + } + fn consume_batch(&mut self) -> Result { // convert to arrays let array = arrow::array::NullArray::new(self.record_reader.num_values()); diff --git a/parquet/src/arrow/array_reader/primitive_array.rs b/parquet/src/arrow/array_reader/primitive_array.rs index 89f2ce51bef..59526f093af 100644 --- a/parquet/src/arrow/array_reader/primitive_array.rs +++ b/parquet/src/arrow/array_reader/primitive_array.rs @@ -95,10 +95,11 @@ where &self.data_type } - /// Reads at most `batch_size` records into array. - fn next_batch(&mut self, batch_size: usize) -> Result { - read_records(&mut self.record_reader, self.pages.as_mut(), batch_size)?; + fn read_records(&mut self, batch_size: usize) -> Result { + read_records(&mut self.record_reader, self.pages.as_mut(), batch_size) + } + fn consume_batch(&mut self) -> Result { let target_type = self.get_data_type().clone(); let arrow_data_type = match T::get_physical_type() { PhysicalType::BOOLEAN => ArrowType::Boolean, diff --git a/parquet/src/arrow/array_reader/struct_array.rs b/parquet/src/arrow/array_reader/struct_array.rs index 602c598f826..b333c66cb21 100644 --- a/parquet/src/arrow/array_reader/struct_array.rs +++ b/parquet/src/arrow/array_reader/struct_array.rs @@ -63,7 +63,27 @@ impl ArrayReader for StructArrayReader { &self.data_type } - /// Read `batch_size` struct records. + fn read_records(&mut self, batch_size: usize) -> Result { + let mut read = None; + for child in self.children.iter_mut() { + let child_read = child.read_records(batch_size)?; + match read { + Some(expected) => { + if expected != child_read { + return Err(general_err!( + "StructArrayReader out of sync in read_records, expected {} skipped, got {}", + expected, + child_read + )); + } + } + None => read = Some(child_read), + } + } + Ok(read.unwrap_or(0)) + } + + /// Consume struct records. /// /// Definition levels of struct array is calculated as following: /// ```ignore @@ -80,7 +100,8 @@ impl ArrayReader for StructArrayReader { /// ```ignore /// null_bitmap[i] = (def_levels[i] >= self.def_level); /// ``` - fn next_batch(&mut self, batch_size: usize) -> Result { + /// + fn consume_batch(&mut self) -> Result { if self.children.is_empty() { return Ok(Arc::new(StructArray::from(Vec::new()))); } @@ -88,7 +109,7 @@ impl ArrayReader for StructArrayReader { let children_array = self .children .iter_mut() - .map(|reader| reader.next_batch(batch_size)) + .map(|reader| reader.consume_batch()) .collect::>>()?; // check that array child data has same size diff --git a/parquet/src/arrow/array_reader/test_util.rs b/parquet/src/arrow/array_reader/test_util.rs index 04c0f6c68f3..da9b8d3bf9b 100644 --- a/parquet/src/arrow/array_reader/test_util.rs +++ b/parquet/src/arrow/array_reader/test_util.rs @@ -101,6 +101,7 @@ pub struct InMemoryArrayReader { rep_levels: Option>, last_idx: usize, cur_idx: usize, + need_consume_records: usize, } impl InMemoryArrayReader { @@ -127,6 +128,7 @@ impl InMemoryArrayReader { rep_levels, cur_idx: 0, last_idx: 0, + need_consume_records: 0, } } } @@ -140,7 +142,7 @@ impl ArrayReader for InMemoryArrayReader { &self.data_type } - fn next_batch(&mut self, batch_size: usize) -> Result { + fn read_records(&mut self, batch_size: usize) -> Result { assert_ne!(batch_size, 0); // This replicates the logical normally performed by // RecordReader to delimit semantic records @@ -164,10 +166,17 @@ impl ArrayReader for InMemoryArrayReader { } None => batch_size.min(self.array.len() - self.cur_idx), }; + self.need_consume_records += read; + Ok(read) + } + fn consume_batch(&mut self) -> Result { + let batch_size = self.need_consume_records; + assert_ne!(batch_size, 0); self.last_idx = self.cur_idx; - self.cur_idx += read; - Ok(self.array.slice(self.last_idx, read)) + self.cur_idx += batch_size; + self.need_consume_records = 0; + Ok(self.array.slice(self.last_idx, batch_size)) } fn skip_records(&mut self, num_records: usize) -> Result { diff --git a/parquet/src/arrow/arrow_reader.rs b/parquet/src/arrow/arrow_reader.rs index 26305cd41ba..3cd5cb9d4ed 100644 --- a/parquet/src/arrow/arrow_reader.rs +++ b/parquet/src/arrow/arrow_reader.rs @@ -769,6 +769,44 @@ mod tests { assert_eq!(&written.slice(6, 2), &read[2]); } + #[test] + fn test_int32_nullable_struct() { + let int32 = Int32Array::from_iter_values([1, 2, 3, 4, 5, 6, 7, 8]); + let data = ArrayDataBuilder::new(ArrowDataType::Struct(vec![Field::new( + "int32", + int32.data_type().clone(), + false, + )])) + .len(8) + .null_bit_buffer(Some(Buffer::from(&[0b11101111]))) + .child_data(vec![int32.into_data()]) + .build() + .unwrap(); + + let written = RecordBatch::try_from_iter([( + "struct", + Arc::new(StructArray::from(data)) as ArrayRef, + )]) + .unwrap(); + + let mut buffer = Vec::with_capacity(1024); + let mut writer = + ArrowWriter::try_new(&mut buffer, written.schema(), None).unwrap(); + writer.write(&written).unwrap(); + writer.close().unwrap(); + + let read = ParquetFileArrowReader::try_new(Bytes::from(buffer)) + .unwrap() + .get_record_reader(3) + .unwrap() + .collect::>>() + .unwrap(); + + assert_eq!(&written.slice(0, 3), &read[0]); + assert_eq!(&written.slice(3, 3), &read[1]); + assert_eq!(&written.slice(6, 2), &read[2]); + } + #[test] #[ignore] // https://github.com/apache/arrow-rs/issues/2253 fn test_decimal_list() { diff --git a/parquet/src/arrow/arrow_writer/mod.rs b/parquet/src/arrow/arrow_writer/mod.rs index 1c95fcc27c1..49531d9724a 100644 --- a/parquet/src/arrow/arrow_writer/mod.rs +++ b/parquet/src/arrow/arrow_writer/mod.rs @@ -1161,7 +1161,7 @@ mod tests { Some(props), ) .expect("Unable to write file"); - writer.write(&expected_batch).unwrap(); + writer.write(expected_batch).unwrap(); writer.close().unwrap(); let mut arrow_reader =