diff --git a/parquet/src/arrow/arrow_writer/byte_array.rs b/parquet/src/arrow/arrow_writer/byte_array.rs index a7b6ccc3fc8..a25bd8d5c50 100644 --- a/parquet/src/arrow/arrow_writer/byte_array.rs +++ b/parquet/src/arrow/arrow_writer/byte_array.rs @@ -127,17 +127,10 @@ impl<'a> ByteArrayWriter<'a> { } pub fn close(self) -> Result<()> { - let (bytes_written, rows_written, metadata, column_index, offset_index) = - self.writer.close()?; + let r = self.writer.close()?; if let Some(on_close) = self.on_close { - on_close( - bytes_written, - rows_written, - metadata, - column_index, - offset_index, - )?; + on_close(r)?; } Ok(()) } diff --git a/parquet/src/column/writer/mod.rs b/parquet/src/column/writer/mod.rs index c7518c89e30..05e32f7e48a 100644 --- a/parquet/src/column/writer/mod.rs +++ b/parquet/src/column/writer/mod.rs @@ -145,13 +145,20 @@ pub fn get_typed_column_writer_mut<'a, 'b: 'a, T: DataType>( }) } -type ColumnCloseResult = ( - u64, - u64, - ColumnChunkMetaData, - Option, - Option, -); +/// Metadata returned by [`GenericColumnWriter::close`] +#[derive(Debug, Clone)] +pub struct ColumnCloseResult { + /// The total number of bytes written + pub bytes_written: u64, + /// The total number of rows written + pub rows_written: u64, + /// Metadata for this column chunk + pub metadata: ColumnChunkMetaData, + /// Optional column index, for filtering + pub column_index: Option, + /// Optional offset index, identifying page locations + pub offset_index: Option, +} // Metrics per page #[derive(Default)] @@ -442,13 +449,13 @@ impl<'a, E: ColumnValueEncoder> GenericColumnWriter<'a, E> { (None, None) }; - Ok(( - self.column_metrics.total_bytes_written, - self.column_metrics.total_rows_written, + Ok(ColumnCloseResult { + bytes_written: self.column_metrics.total_bytes_written, + rows_written: self.column_metrics.total_rows_written, metadata, column_index, offset_index, - )) + }) } /// Writes mini batch of values, definition and repetition levels. @@ -1201,11 +1208,13 @@ mod tests { .write_batch(&[true, false, true, false], None, None) .unwrap(); - let (bytes_written, rows_written, metadata, _, _) = writer.close().unwrap(); + let r = writer.close().unwrap(); // PlainEncoder uses bit writer to write boolean values, which all fit into 1 // byte. - assert_eq!(bytes_written, 1); - assert_eq!(rows_written, 4); + assert_eq!(r.bytes_written, 1); + assert_eq!(r.rows_written, 4); + + let metadata = r.metadata; assert_eq!(metadata.encodings(), &vec![Encoding::PLAIN, Encoding::RLE]); assert_eq!(metadata.num_values(), 4); // just values assert_eq!(metadata.dictionary_page_offset(), None); @@ -1474,9 +1483,11 @@ mod tests { let mut writer = get_test_column_writer::(page_writer, 0, 0, props); writer.write_batch(&[1, 2, 3, 4], None, None).unwrap(); - let (bytes_written, rows_written, metadata, _, _) = writer.close().unwrap(); - assert_eq!(bytes_written, 20); - assert_eq!(rows_written, 4); + let r = writer.close().unwrap(); + assert_eq!(r.bytes_written, 20); + assert_eq!(r.rows_written, 4); + + let metadata = r.metadata; assert_eq!( metadata.encodings(), &vec![Encoding::PLAIN, Encoding::RLE, Encoding::RLE_DICTIONARY] @@ -1531,7 +1542,7 @@ mod tests { None, ) .unwrap(); - let (_bytes_written, _rows_written, metadata, _, _) = writer.close().unwrap(); + let metadata = writer.close().unwrap().metadata; if let Some(stats) = metadata.statistics() { assert!(stats.has_min_max_set()); if let Statistics::ByteArray(stats) = stats { @@ -1565,7 +1576,7 @@ mod tests { Int32Type, >(page_writer, 0, 0, props); writer.write_batch(&[0, 1, 2, 3, 4, 5], None, None).unwrap(); - let (_bytes_written, _rows_written, metadata, _, _) = writer.close().unwrap(); + let metadata = writer.close().unwrap().metadata; if let Some(stats) = metadata.statistics() { assert!(stats.has_min_max_set()); if let Statistics::Int32(stats) = stats { @@ -1599,9 +1610,11 @@ mod tests { ) .unwrap(); - let (bytes_written, rows_written, metadata, _, _) = writer.close().unwrap(); - assert_eq!(bytes_written, 20); - assert_eq!(rows_written, 4); + let r = writer.close().unwrap(); + assert_eq!(r.bytes_written, 20); + assert_eq!(r.rows_written, 4); + + let metadata = r.metadata; assert_eq!( metadata.encodings(), &vec![Encoding::PLAIN, Encoding::RLE, Encoding::RLE_DICTIONARY] @@ -1646,9 +1659,9 @@ mod tests { ) .unwrap(); - let (_, rows_written, metadata, _, _) = writer.close().unwrap(); + let r = writer.close().unwrap(); - let stats = metadata.statistics().unwrap(); + let stats = r.metadata.statistics().unwrap(); assert_eq!(stats.min_bytes(), 1_i32.to_le_bytes()); assert_eq!(stats.max_bytes(), 7_i32.to_le_bytes()); assert_eq!(stats.null_count(), 0); @@ -1656,8 +1669,8 @@ mod tests { let reader = SerializedPageReader::new( Arc::new(Bytes::from(buf)), - &metadata, - rows_written as usize, + &r.metadata, + r.rows_written as usize, None, ) .unwrap(); @@ -1691,13 +1704,13 @@ mod tests { .write_batch(&[1, 2, 3, 4], Some(&[1, 0, 0, 1, 1, 1]), None) .unwrap(); - let (_, rows_written, metadata, _, _) = writer.close().unwrap(); - assert!(metadata.statistics().is_none()); + let r = writer.close().unwrap(); + assert!(r.metadata.statistics().is_none()); let reader = SerializedPageReader::new( Arc::new(Bytes::from(buf)), - &metadata, - rows_written as usize, + &r.metadata, + r.rows_written as usize, None, ) .unwrap(); @@ -1819,14 +1832,14 @@ mod tests { let data = &[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]; let mut writer = get_test_column_writer::(page_writer, 0, 0, props); writer.write_batch(data, None, None).unwrap(); - let (_, rows_written, metadata, _, _) = writer.close().unwrap(); + let r = writer.close().unwrap(); // Read pages and check the sequence let mut page_reader = Box::new( SerializedPageReader::new( Arc::new(file), - &metadata, - rows_written as usize, + &r.metadata, + r.rows_written as usize, None, ) .unwrap(), @@ -2064,22 +2077,11 @@ mod tests { // second page writer.write_batch(&[4, 8, 2, -5], None, None).unwrap(); - let (_, rows_written, metadata, column_index, offset_index) = - writer.close().unwrap(); - let column_index = match column_index { - None => { - panic!("Can't fine the column index"); - } - Some(column_index) => column_index, - }; - let offset_index = match offset_index { - None => { - panic!("Can't find the offset index"); - } - Some(offset_index) => offset_index, - }; + let r = writer.close().unwrap(); + let column_index = r.column_index.unwrap(); + let offset_index = r.offset_index.unwrap(); - assert_eq!(8, rows_written); + assert_eq!(8, r.rows_written); // column index assert_eq!(2, column_index.null_pages.len()); @@ -2090,7 +2092,7 @@ mod tests { assert_eq!(0, column_index.null_counts.as_ref().unwrap()[idx]); } - if let Some(stats) = metadata.statistics() { + if let Some(stats) = r.metadata.statistics() { assert!(stats.has_min_max_set()); assert_eq!(stats.null_count(), 0); assert_eq!(stats.distinct_count(), None); @@ -2201,13 +2203,13 @@ mod tests { let values_written = writer.write_batch(values, def_levels, rep_levels).unwrap(); assert_eq!(values_written, values.len()); - let (_, rows_written, column_metadata, _, _) = writer.close().unwrap(); + let result = writer.close().unwrap(); let page_reader = Box::new( SerializedPageReader::new( Arc::new(file), - &column_metadata, - rows_written as usize, + &result.metadata, + result.rows_written as usize, None, ) .unwrap(), @@ -2248,11 +2250,11 @@ mod tests { actual_rows_written += 1; } } - assert_eq!(actual_rows_written, rows_written); + assert_eq!(actual_rows_written, result.rows_written); } else if actual_def_levels.is_some() { - assert_eq!(levels_read as u64, rows_written); + assert_eq!(levels_read as u64, result.rows_written); } else { - assert_eq!(values_read as u64, rows_written); + assert_eq!(values_read as u64, result.rows_written); } } @@ -2266,8 +2268,7 @@ mod tests { let props = Arc::new(props); let mut writer = get_test_column_writer::(page_writer, 0, 0, props); writer.write_batch(values, None, None).unwrap(); - let (_, _, metadata, _, _) = writer.close().unwrap(); - metadata + writer.close().unwrap().metadata } // Function to use in tests for EncodingWriteSupport. This checks that dictionary @@ -2378,7 +2379,7 @@ mod tests { let mut writer = get_test_column_writer::(page_writer, 0, 0, props); writer.write_batch(values, None, None).unwrap(); - let (_bytes_written, _rows_written, metadata, _, _) = writer.close().unwrap(); + let metadata = writer.close().unwrap().metadata; if let Some(stats) = metadata.statistics() { stats.clone() } else { diff --git a/parquet/src/file/writer.rs b/parquet/src/file/writer.rs index 87a9ae3e14e..b7bab189bb8 100644 --- a/parquet/src/file/writer.rs +++ b/parquet/src/file/writer.rs @@ -25,7 +25,9 @@ use parquet_format::{ColumnIndex, OffsetIndex, RowGroup}; use thrift::protocol::{TCompactOutputProtocol, TOutputProtocol}; use crate::basic::PageType; -use crate::column::writer::{get_typed_column_writer_mut, ColumnWriterImpl}; +use crate::column::writer::{ + get_typed_column_writer_mut, ColumnCloseResult, ColumnWriterImpl, +}; use crate::column::{ page::{CompressedPage, Page, PageWriteSpec, PageWriter}, writer::{get_column_writer, ColumnWriter}, @@ -74,24 +76,8 @@ impl Write for TrackedWrite { } } -/// Callback invoked on closing a column chunk, arguments are: -/// -/// - the number of bytes written -/// - the number of rows written -/// - the column chunk metadata -/// - the column index -/// - the offset index -/// -pub type OnCloseColumnChunk<'a> = Box< - dyn FnOnce( - u64, - u64, - ColumnChunkMetaData, - Option, - Option, - ) -> Result<()> - + 'a, ->; +/// Callback invoked on closing a column chunk +pub type OnCloseColumnChunk<'a> = Box Result<()> + 'a>; /// Callback invoked on closing a row group, arguments are: /// @@ -388,28 +374,27 @@ impl<'a, W: Write> SerializedRowGroupWriter<'a, W> { let column_indexes = &mut self.column_indexes; let offset_indexes = &mut self.offset_indexes; - let on_close = - |bytes_written, rows_written, metadata, column_index, offset_index| { - // Update row group writer metrics - *total_bytes_written += bytes_written; - column_chunks.push(metadata); - column_indexes.push(column_index); - offset_indexes.push(offset_index); - - if let Some(rows) = *total_rows_written { - if rows != rows_written { - return Err(general_err!( - "Incorrect number of rows, expected {} != {} rows", - rows, - rows_written - )); - } - } else { - *total_rows_written = Some(rows_written); + let on_close = |r: ColumnCloseResult| { + // Update row group writer metrics + *total_bytes_written += r.bytes_written; + column_chunks.push(r.metadata); + column_indexes.push(r.column_index); + offset_indexes.push(r.offset_index); + + if let Some(rows) = *total_rows_written { + if rows != r.rows_written { + return Err(general_err!( + "Incorrect number of rows, expected {} != {} rows", + rows, + r.rows_written + )); } + } else { + *total_rows_written = Some(r.rows_written); + } - Ok(()) - }; + Ok(()) + }; let column = self.descr.column(self.column_index); self.column_index += 1; @@ -502,26 +487,19 @@ impl<'a> SerializedColumnWriter<'a> { /// Close this [`SerializedColumnWriter] pub fn close(mut self) -> Result<()> { - let (bytes_written, rows_written, metadata, column_index, offset_index) = - match self.inner { - ColumnWriter::BoolColumnWriter(typed) => typed.close()?, - ColumnWriter::Int32ColumnWriter(typed) => typed.close()?, - ColumnWriter::Int64ColumnWriter(typed) => typed.close()?, - ColumnWriter::Int96ColumnWriter(typed) => typed.close()?, - ColumnWriter::FloatColumnWriter(typed) => typed.close()?, - ColumnWriter::DoubleColumnWriter(typed) => typed.close()?, - ColumnWriter::ByteArrayColumnWriter(typed) => typed.close()?, - ColumnWriter::FixedLenByteArrayColumnWriter(typed) => typed.close()?, - }; + let r = match self.inner { + ColumnWriter::BoolColumnWriter(typed) => typed.close()?, + ColumnWriter::Int32ColumnWriter(typed) => typed.close()?, + ColumnWriter::Int64ColumnWriter(typed) => typed.close()?, + ColumnWriter::Int96ColumnWriter(typed) => typed.close()?, + ColumnWriter::FloatColumnWriter(typed) => typed.close()?, + ColumnWriter::DoubleColumnWriter(typed) => typed.close()?, + ColumnWriter::ByteArrayColumnWriter(typed) => typed.close()?, + ColumnWriter::FixedLenByteArrayColumnWriter(typed) => typed.close()?, + }; if let Some(on_close) = self.on_close.take() { - on_close( - bytes_written, - rows_written, - metadata, - column_index, - offset_index, - )? + on_close(r)? } Ok(())