diff --git a/parquet/src/arrow/array_reader/list_array.rs b/parquet/src/arrow/array_reader/list_array.rs index e1cd71b9e1f..3de5ed6f4a8 100644 --- a/parquet/src/arrow/array_reader/list_array.rs +++ b/parquet/src/arrow/array_reader/list_array.rs @@ -124,7 +124,7 @@ impl ArrayReader for ListArrayReader { // The output offsets for the computed ListArray let mut list_offsets: Vec = - Vec::with_capacity(next_batch_array.len()); + Vec::with_capacity(next_batch_array.len() + 1); // The validity mask of the computed ListArray if nullable let mut validity = self diff --git a/parquet/src/arrow/arrow_reader.rs b/parquet/src/arrow/arrow_reader.rs index c5d1f66e5bf..2c47222bef6 100644 --- a/parquet/src/arrow/arrow_reader.rs +++ b/parquet/src/arrow/arrow_reader.rs @@ -1430,4 +1430,57 @@ mod tests { assert_eq!(total_rows, expected_rows); } + + #[test] + fn test_row_group_exact_multiple() { + let schema = Arc::new(Schema::new(vec![ + Field::new("int", ArrowDataType::Int32, false), + Field::new( + "list", + ArrowDataType::List(Box::new(Field::new( + "item", + ArrowDataType::Int32, + true, + ))), + true, + ), + ])); + + let mut buf = Vec::with_capacity(1024); + + let mut writer = ArrowWriter::try_new( + &mut buf, + schema.clone(), + Some( + WriterProperties::builder() + .set_max_row_group_size(8) + .build(), + ), + ) + .unwrap(); + for _ in 0..2 { + let mut int_builder = Int32Builder::new(10); + let mut list_builder = ListBuilder::new(Int32Builder::new(10)); + for i in 0..10 { + int_builder.append_value(i).unwrap(); + list_builder.append(true).unwrap(); + } + let batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(int_builder.finish()), + Arc::new(list_builder.finish()), + ], + ) + .unwrap(); + writer.write(&batch).unwrap(); + } + writer.close().unwrap(); + + let mut file_reader = ParquetFileArrowReader::try_new(Bytes::from(buf)).unwrap(); + let mut record_reader = file_reader.get_record_reader(8).unwrap(); + assert_eq!(8, record_reader.next().unwrap().unwrap().num_rows()); + assert_eq!(8, record_reader.next().unwrap().unwrap().num_rows()); + assert_eq!(4, record_reader.next().unwrap().unwrap().num_rows()); + } } diff --git a/parquet/src/arrow/record_reader/mod.rs b/parquet/src/arrow/record_reader/mod.rs index af75dbb4951..e8d72caf040 100644 --- a/parquet/src/arrow/record_reader/mod.rs +++ b/parquet/src/arrow/record_reader/mod.rs @@ -138,32 +138,19 @@ where let mut records_read = 0; - // Used to mark whether we have reached the end of current - // column chunk - let mut end_of_column = false; - loop { // Try to find some records from buffers that has been read into memory // but not counted as seen records. + let end_of_column = !self.column_reader.as_mut().unwrap().has_next()?; + let (record_count, value_count) = - self.count_records(num_records - records_read); + self.count_records(num_records - records_read, end_of_column); self.num_records += record_count; self.num_values += value_count; records_read += record_count; - if records_read == num_records { - break; - } - - if end_of_column { - // Since page reader contains complete records, if we reached end of a - // page reader, we should reach the end of a record - if self.rep_levels.is_some() { - self.num_records += 1; - self.num_values = self.values_written; - records_read += 1; - } + if records_read == num_records || end_of_column { break; } @@ -193,10 +180,7 @@ where }; // Try to more value from parquet pages - let values_read = self.read_one_batch(batch_size)?; - if values_read < batch_size { - end_of_column = true; - } + self.read_one_batch(batch_size)?; } Ok(records_read) @@ -302,8 +286,15 @@ where /// Inspects the buffered repetition levels in the range `self.num_values..self.values_written` /// and returns the number of "complete" records along with the corresponding number of values /// + /// If `end_of_column` is true it indicates that there are no further values for this + /// column chunk beyond what is currently in the buffers + /// /// A "complete" record is one where the buffer contains a subsequent repetition level of 0 - fn count_records(&self, records_to_read: usize) -> (usize, usize) { + fn count_records( + &self, + records_to_read: usize, + end_of_column: bool, + ) -> (usize, usize) { match self.rep_levels.as_ref() { Some(buf) => { let buf = buf.as_slice(); @@ -327,6 +318,15 @@ where } } + // If reached end of column chunk => end of a record + if records_read != records_to_read + && end_of_column + && self.values_written != 0 + { + records_read += 1; + end_of_last_record = self.values_written; + } + (records_read, end_of_last_record - self.num_values) } None => { @@ -699,4 +699,48 @@ mod tests { assert_eq!(5000, record_reader.num_values()); } } + + #[test] + fn test_row_group_boundary() { + // Construct column schema + let message_type = " + message test_schema { + REPEATED Group test_struct { + REPEATED INT32 leaf; + } + } + "; + + let desc = parse_message_type(message_type) + .map(|t| SchemaDescriptor::new(Arc::new(t))) + .map(|s| s.column(0)) + .unwrap(); + + let values = [1, 2, 3]; + let def_levels = [1i16, 0i16, 1i16, 2i16, 2i16, 1i16, 2i16]; + let rep_levels = [0i16, 0i16, 0i16, 1i16, 2i16, 0i16, 1i16]; + let mut pb = DataPageBuilderImpl::new(desc.clone(), 7, true); + pb.add_rep_levels(2, &rep_levels); + pb.add_def_levels(2, &def_levels); + pb.add_values::(Encoding::PLAIN, &values); + let page = pb.consume(); + + let mut record_reader = RecordReader::::new(desc.clone()); + let page_reader = Box::new(InMemoryPageReader::new(vec![page.clone()])); + record_reader.set_page_reader(page_reader).unwrap(); + assert_eq!(record_reader.read_records(4).unwrap(), 4); + assert_eq!(record_reader.num_records(), 4); + assert_eq!(record_reader.num_values(), 7); + + let mut record_reader = RecordReader::::new(desc.clone()); + let page_reader = Box::new(InMemoryPageReader::new(vec![page.clone()])); + record_reader.set_page_reader(page_reader).unwrap(); + assert_eq!(record_reader.read_records(3).unwrap(), 3); + assert_eq!(record_reader.num_records(), 3); + assert_eq!(record_reader.num_values(), 5); + + assert_eq!(record_reader.read_records(3).unwrap(), 1); + assert_eq!(record_reader.num_records(), 4); + assert_eq!(record_reader.num_values(), 7); + } } diff --git a/parquet/src/column/reader.rs b/parquet/src/column/reader.rs index a97787ccfa5..3b35f69cec4 100644 --- a/parquet/src/column/reader.rs +++ b/parquet/src/column/reader.rs @@ -407,7 +407,7 @@ where } #[inline] - fn has_next(&mut self) -> Result { + pub(crate) fn has_next(&mut self) -> Result { if self.num_buffered_values == 0 || self.num_buffered_values == self.num_decoded_values {