diff --git a/parquet/src/arrow/array_reader/list_array.rs b/parquet/src/arrow/array_reader/list_array.rs index 3d612facdd6..33bd9772a16 100644 --- a/parquet/src/arrow/array_reader/list_array.rs +++ b/parquet/src/arrow/array_reader/list_array.rs @@ -124,7 +124,7 @@ impl ArrayReader for ListArrayReader { // The output offsets for the computed ListArray let mut list_offsets: Vec = - Vec::with_capacity(next_batch_array.len()); + Vec::with_capacity(next_batch_array.len() + 1); // The validity mask of the computed ListArray if nullable let mut validity = self diff --git a/parquet/src/arrow/arrow_reader.rs b/parquet/src/arrow/arrow_reader.rs index 6a3270762f8..f93488f75f7 100644 --- a/parquet/src/arrow/arrow_reader.rs +++ b/parquet/src/arrow/arrow_reader.rs @@ -1528,4 +1528,45 @@ mod tests { assert_eq!(total_rows, expected_rows); } + + #[test] + fn test_row_group_exact_multiple() { + let schema = Arc::new(Schema::new(vec![Field::new( + "list", + ArrowDataType::List(Box::new(Field::new("item", ArrowDataType::Int32, true))), + true, + )])); + + let mut buf = Vec::with_capacity(1024); + + let mut writer = ArrowWriter::try_new( + &mut buf, + schema.clone(), + Some( + WriterProperties::builder() + .set_max_row_group_size(8) + .build(), + ), + ) + .unwrap(); + for _ in 0..2 { + let mut list_builder = ListBuilder::new(Int32Builder::new(10)); + for _ in 0..10 { + list_builder.append(true).unwrap(); + } + let batch = RecordBatch::try_new( + schema.clone(), + vec![Arc::new(list_builder.finish())], + ) + .unwrap(); + writer.write(&batch).unwrap(); + } + writer.close().unwrap(); + + let mut file_reader = ParquetFileArrowReader::try_new(Bytes::from(buf)).unwrap(); + let mut record_reader = file_reader.get_record_reader(8).unwrap(); + assert_eq!(8, record_reader.next().unwrap().unwrap().num_rows()); + assert_eq!(8, record_reader.next().unwrap().unwrap().num_rows()); + assert_eq!(4, record_reader.next().unwrap().unwrap().num_rows()); + } } diff --git a/parquet/src/arrow/record_reader/mod.rs b/parquet/src/arrow/record_reader/mod.rs index 046e01d4672..30324fbe3e3 100644 --- a/parquet/src/arrow/record_reader/mod.rs +++ b/parquet/src/arrow/record_reader/mod.rs @@ -139,32 +139,19 @@ where let mut records_read = 0; - // Used to mark whether we have reached the end of current - // column chunk - let mut end_of_column = false; - loop { // Try to find some records from buffers that has been read into memory // but not counted as seen records. + let end_of_column = !self.column_reader.as_mut().unwrap().has_next()?; + let (record_count, value_count) = - self.count_records(num_records - records_read); + self.count_records(num_records - records_read, end_of_column); self.num_records += record_count; self.num_values += value_count; records_read += record_count; - if records_read == num_records { - break; - } - - if end_of_column { - // Since page reader contains complete records, if we reached end of a - // page reader, we should reach the end of a record - if self.rep_levels.is_some() { - self.num_records += 1; - self.num_values = self.values_written; - records_read += 1; - } + if records_read == num_records || end_of_column { break; } @@ -194,10 +181,7 @@ where }; // Try to more value from parquet pages - let values_read = self.read_one_batch(batch_size)?; - if values_read < batch_size { - end_of_column = true; - } + self.read_one_batch(batch_size)?; } Ok(records_read) @@ -210,7 +194,14 @@ where /// Number of records skipped pub fn skip_records(&mut self, num_records: usize) -> Result { // First need to clear the buffer - let (buffered_records, buffered_values) = self.count_records(num_records); + let end_of_column = match self.column_reader.as_mut() { + Some(reader) => !reader.has_next()?, + None => return Ok(0), + }; + + let (buffered_records, buffered_values) = + self.count_records(num_records, end_of_column); + self.num_records += buffered_records; self.num_values += buffered_values; @@ -226,10 +217,11 @@ where return Ok(buffered_records); } - let skipped = match self.column_reader.as_mut() { - Some(column_reader) => column_reader.skip_records(remaining)?, - None => 0, - }; + let skipped = self + .column_reader + .as_mut() + .unwrap() + .skip_records(remaining)?; Ok(skipped + buffered_records) } @@ -334,8 +326,15 @@ where /// Inspects the buffered repetition levels in the range `self.num_values..self.values_written` /// and returns the number of "complete" records along with the corresponding number of values /// + /// If `end_of_column` is true it indicates that there are no further values for this + /// column chunk beyond what is currently in the buffers + /// /// A "complete" record is one where the buffer contains a subsequent repetition level of 0 - fn count_records(&self, records_to_read: usize) -> (usize, usize) { + fn count_records( + &self, + records_to_read: usize, + end_of_column: bool, + ) -> (usize, usize) { match self.rep_levels.as_ref() { Some(buf) => { let buf = buf.as_slice(); @@ -359,6 +358,15 @@ where } } + // If reached end of column chunk => end of a record + if records_read != records_to_read + && end_of_column + && self.values_written != self.num_values + { + records_read += 1; + end_of_last_record = self.values_written; + } + (records_read, end_of_last_record - self.num_values) } None => { @@ -731,4 +739,54 @@ mod tests { assert_eq!(5000, record_reader.num_values()); } } + + #[test] + fn test_row_group_boundary() { + // Construct column schema + let message_type = " + message test_schema { + REPEATED Group test_struct { + REPEATED INT32 leaf; + } + } + "; + + let desc = parse_message_type(message_type) + .map(|t| SchemaDescriptor::new(Arc::new(t))) + .map(|s| s.column(0)) + .unwrap(); + + let values = [1, 2, 3]; + let def_levels = [1i16, 0i16, 1i16, 2i16, 2i16, 1i16, 2i16]; + let rep_levels = [0i16, 0i16, 0i16, 1i16, 2i16, 0i16, 1i16]; + let mut pb = DataPageBuilderImpl::new(desc.clone(), 7, true); + pb.add_rep_levels(2, &rep_levels); + pb.add_def_levels(2, &def_levels); + pb.add_values::(Encoding::PLAIN, &values); + let page = pb.consume(); + + let mut record_reader = RecordReader::::new(desc); + let page_reader = Box::new(InMemoryPageReader::new(vec![page.clone()])); + record_reader.set_page_reader(page_reader).unwrap(); + assert_eq!(record_reader.read_records(4).unwrap(), 4); + assert_eq!(record_reader.num_records(), 4); + assert_eq!(record_reader.num_values(), 7); + + assert_eq!(record_reader.read_records(4).unwrap(), 0); + assert_eq!(record_reader.num_records(), 4); + assert_eq!(record_reader.num_values(), 7); + + record_reader.read_records(4).unwrap(); + + let page_reader = Box::new(InMemoryPageReader::new(vec![page])); + record_reader.set_page_reader(page_reader).unwrap(); + + assert_eq!(record_reader.read_records(4).unwrap(), 4); + assert_eq!(record_reader.num_records(), 8); + assert_eq!(record_reader.num_values(), 14); + + assert_eq!(record_reader.read_records(4).unwrap(), 0); + assert_eq!(record_reader.num_records(), 8); + assert_eq!(record_reader.num_values(), 14); + } } diff --git a/parquet/src/column/reader.rs b/parquet/src/column/reader.rs index 35e725b1959..80174d75679 100644 --- a/parquet/src/column/reader.rs +++ b/parquet/src/column/reader.rs @@ -474,7 +474,7 @@ where } #[inline] - fn has_next(&mut self) -> Result { + pub(crate) fn has_next(&mut self) -> Result { if self.num_buffered_values == 0 || self.num_buffered_values == self.num_decoded_values {