Skip to content

Commit

Permalink
Return structured ColumnCloseResult (apache#2465)
Browse files Browse the repository at this point in the history
  • Loading branch information
tustvold committed Aug 16, 2022
1 parent 3b59adc commit c7b5cc5
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 118 deletions.
11 changes: 2 additions & 9 deletions parquet/src/arrow/arrow_writer/byte_array.rs
Expand Up @@ -127,17 +127,10 @@ impl<'a> ByteArrayWriter<'a> {
}

pub fn close(self) -> Result<()> {
let (bytes_written, rows_written, metadata, column_index, offset_index) =
self.writer.close()?;
let r = self.writer.close()?;

if let Some(on_close) = self.on_close {
on_close(
bytes_written,
rows_written,
metadata,
column_index,
offset_index,
)?;
on_close(r)?;
}
Ok(())
}
Expand Down
104 changes: 52 additions & 52 deletions parquet/src/column/writer/mod.rs
Expand Up @@ -145,13 +145,20 @@ pub fn get_typed_column_writer_mut<'a, 'b: 'a, T: DataType>(
})
}

type ColumnCloseResult = (
u64,
u64,
ColumnChunkMetaData,
Option<ColumnIndex>,
Option<OffsetIndex>,
);
/// Metadata returned by [`GenericColumnWriter::close`]
#[derive(Debug, Clone)]
pub struct ColumnCloseResult {
/// The total number of bytes written
pub bytes_written: u64,
/// The total number of rows written
pub rows_written: u64,
/// Metadata for this column chunk
pub metadata: ColumnChunkMetaData,
/// Optional column index, for filtering
pub column_index: Option<ColumnIndex>,
/// Optional offset index, identifying page locations
pub offset_index: Option<OffsetIndex>,
}

// Metrics per page
#[derive(Default)]
Expand Down Expand Up @@ -442,13 +449,13 @@ impl<'a, E: ColumnValueEncoder> GenericColumnWriter<'a, E> {
(None, None)
};

Ok((
self.column_metrics.total_bytes_written,
self.column_metrics.total_rows_written,
Ok(ColumnCloseResult {
bytes_written: self.column_metrics.total_bytes_written,
rows_written: self.column_metrics.total_rows_written,
metadata,
column_index,
offset_index,
))
})
}

/// Writes mini batch of values, definition and repetition levels.
Expand Down Expand Up @@ -1200,11 +1207,13 @@ mod tests {
.write_batch(&[true, false, true, false], None, None)
.unwrap();

let (bytes_written, rows_written, metadata, _, _) = writer.close().unwrap();
let r = writer.close().unwrap();
// PlainEncoder uses bit writer to write boolean values, which all fit into 1
// byte.
assert_eq!(bytes_written, 1);
assert_eq!(rows_written, 4);
assert_eq!(r.bytes_written, 1);
assert_eq!(r.rows_written, 4);

let metadata = r.metadata;
assert_eq!(metadata.encodings(), &vec![Encoding::PLAIN, Encoding::RLE]);
assert_eq!(metadata.num_values(), 4); // just values
assert_eq!(metadata.dictionary_page_offset(), None);
Expand Down Expand Up @@ -1473,9 +1482,11 @@ mod tests {
let mut writer = get_test_column_writer::<Int32Type>(page_writer, 0, 0, props);
writer.write_batch(&[1, 2, 3, 4], None, None).unwrap();

let (bytes_written, rows_written, metadata, _, _) = writer.close().unwrap();
assert_eq!(bytes_written, 20);
assert_eq!(rows_written, 4);
let r = writer.close().unwrap();
assert_eq!(r.bytes_written, 20);
assert_eq!(r.rows_written, 4);

let metadata = r.metadata;
assert_eq!(
metadata.encodings(),
&vec![Encoding::PLAIN, Encoding::RLE, Encoding::RLE_DICTIONARY]
Expand Down Expand Up @@ -1530,7 +1541,7 @@ mod tests {
None,
)
.unwrap();
let (_bytes_written, _rows_written, metadata, _, _) = writer.close().unwrap();
let metadata = writer.close().unwrap().metadata;
if let Some(stats) = metadata.statistics() {
assert!(stats.has_min_max_set());
if let Statistics::ByteArray(stats) = stats {
Expand Down Expand Up @@ -1564,7 +1575,7 @@ mod tests {
Int32Type,
>(page_writer, 0, 0, props);
writer.write_batch(&[0, 1, 2, 3, 4, 5], None, None).unwrap();
let (_bytes_written, _rows_written, metadata, _, _) = writer.close().unwrap();
let metadata = writer.close().unwrap().metadata;
if let Some(stats) = metadata.statistics() {
assert!(stats.has_min_max_set());
if let Statistics::Int32(stats) = stats {
Expand Down Expand Up @@ -1598,9 +1609,11 @@ mod tests {
)
.unwrap();

let (bytes_written, rows_written, metadata, _, _) = writer.close().unwrap();
assert_eq!(bytes_written, 20);
assert_eq!(rows_written, 4);
let r = writer.close().unwrap();
assert_eq!(r.bytes_written, 20);
assert_eq!(r.rows_written, 4);

let metadata = r.metadata;
assert_eq!(
metadata.encodings(),
&vec![Encoding::PLAIN, Encoding::RLE, Encoding::RLE_DICTIONARY]
Expand Down Expand Up @@ -1645,7 +1658,7 @@ mod tests {
)
.unwrap();

let (_, _, metadata, _, _) = writer.close().unwrap();
let metadata = writer.close().unwrap().metadata;

let stats = metadata.statistics().unwrap();
assert_eq!(stats.min_bytes(), 1_i32.to_le_bytes());
Expand Down Expand Up @@ -1690,7 +1703,7 @@ mod tests {
.write_batch(&[1, 2, 3, 4], Some(&[1, 0, 0, 1, 1, 1]), None)
.unwrap();

let (_, _, metadata, _, _) = writer.close().unwrap();
let metadata = writer.close().unwrap().metadata;
assert!(metadata.statistics().is_none());

let reader = SerializedPageReader::new(
Expand Down Expand Up @@ -1818,7 +1831,7 @@ mod tests {
let data = &[1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
let mut writer = get_test_column_writer::<Int32Type>(page_writer, 0, 0, props);
writer.write_batch(data, None, None).unwrap();
let (bytes_written, _, _, _, _) = writer.close().unwrap();
let bytes_written = writer.close().unwrap().bytes_written;

// Read pages and check the sequence
let source = FileSource::new(&file, 0, bytes_written as usize);
Expand Down Expand Up @@ -2064,22 +2077,11 @@ mod tests {
// second page
writer.write_batch(&[4, 8, 2, -5], None, None).unwrap();

let (_, rows_written, metadata, column_index, offset_index) =
writer.close().unwrap();
let column_index = match column_index {
None => {
panic!("Can't fine the column index");
}
Some(column_index) => column_index,
};
let offset_index = match offset_index {
None => {
panic!("Can't find the offset index");
}
Some(offset_index) => offset_index,
};
let r = writer.close().unwrap();
let column_index = r.column_index.unwrap();
let offset_index = r.offset_index.unwrap();

assert_eq!(8, rows_written);
assert_eq!(8, r.rows_written);

// column index
assert_eq!(2, column_index.null_pages.len());
Expand All @@ -2090,7 +2092,7 @@ mod tests {
assert_eq!(0, column_index.null_counts.as_ref().unwrap()[idx]);
}

if let Some(stats) = metadata.statistics() {
if let Some(stats) = r.metadata.statistics() {
assert!(stats.has_min_max_set());
assert_eq!(stats.null_count(), 0);
assert_eq!(stats.distinct_count(), None);
Expand Down Expand Up @@ -2201,15 +2203,14 @@ mod tests {

let values_written = writer.write_batch(values, def_levels, rep_levels).unwrap();
assert_eq!(values_written, values.len());
let (bytes_written, rows_written, column_metadata, _, _) =
writer.close().unwrap();
let result = writer.close().unwrap();

let source = FileSource::new(&file, 0, bytes_written as usize);
let source = FileSource::new(&file, 0, result.bytes_written as usize);
let page_reader = Box::new(
SerializedPageReader::new(
source,
column_metadata.num_values(),
column_metadata.compression(),
result.metadata.num_values(),
result.metadata.compression(),
T::get_physical_type(),
)
.unwrap(),
Expand Down Expand Up @@ -2250,11 +2251,11 @@ mod tests {
actual_rows_written += 1;
}
}
assert_eq!(actual_rows_written, rows_written);
assert_eq!(actual_rows_written, result.rows_written);
} else if actual_def_levels.is_some() {
assert_eq!(levels_read as u64, rows_written);
assert_eq!(levels_read as u64, result.rows_written);
} else {
assert_eq!(values_read as u64, rows_written);
assert_eq!(values_read as u64, result.rows_written);
}
}

Expand All @@ -2268,8 +2269,7 @@ mod tests {
let props = Arc::new(props);
let mut writer = get_test_column_writer::<T>(page_writer, 0, 0, props);
writer.write_batch(values, None, None).unwrap();
let (_, _, metadata, _, _) = writer.close().unwrap();
metadata
writer.close().unwrap().metadata
}

// Function to use in tests for EncodingWriteSupport. This checks that dictionary
Expand Down Expand Up @@ -2380,7 +2380,7 @@ mod tests {
let mut writer = get_test_column_writer::<T>(page_writer, 0, 0, props);
writer.write_batch(values, None, None).unwrap();

let (_bytes_written, _rows_written, metadata, _, _) = writer.close().unwrap();
let metadata = writer.close().unwrap().metadata;
if let Some(stats) = metadata.statistics() {
stats.clone()
} else {
Expand Down
92 changes: 35 additions & 57 deletions parquet/src/file/writer.rs
Expand Up @@ -26,7 +26,9 @@ use parquet_format::{ColumnIndex, OffsetIndex, RowGroup};
use thrift::protocol::{TCompactOutputProtocol, TOutputProtocol};

use crate::basic::PageType;
use crate::column::writer::{get_typed_column_writer_mut, ColumnWriterImpl};
use crate::column::writer::{
get_typed_column_writer_mut, ColumnCloseResult, ColumnWriterImpl,
};
use crate::column::{
page::{CompressedPage, Page, PageWriteSpec, PageWriter},
writer::{get_column_writer, ColumnWriter},
Expand Down Expand Up @@ -75,24 +77,8 @@ impl<W: Write> Write for TrackedWrite<W> {
}
}

/// Callback invoked on closing a column chunk, arguments are:
///
/// - the number of bytes written
/// - the number of rows written
/// - the column chunk metadata
/// - the column index
/// - the offset index
///
pub type OnCloseColumnChunk<'a> = Box<
dyn FnOnce(
u64,
u64,
ColumnChunkMetaData,
Option<ColumnIndex>,
Option<OffsetIndex>,
) -> Result<()>
+ 'a,
>;
/// Callback invoked on closing a column chunk
pub type OnCloseColumnChunk<'a> = Box<dyn FnOnce(ColumnCloseResult) -> Result<()> + 'a>;

/// Callback invoked on closing a row group, arguments are:
///
Expand Down Expand Up @@ -390,28 +376,27 @@ impl<'a, W: Write> SerializedRowGroupWriter<'a, W> {
let column_indexes = &mut self.column_indexes;
let offset_indexes = &mut self.offset_indexes;

let on_close =
|bytes_written, rows_written, metadata, column_index, offset_index| {
// Update row group writer metrics
*total_bytes_written += bytes_written;
column_chunks.push(metadata);
column_indexes.push(column_index);
offset_indexes.push(offset_index);

if let Some(rows) = *total_rows_written {
if rows != rows_written {
return Err(general_err!(
"Incorrect number of rows, expected {} != {} rows",
rows,
rows_written
));
}
} else {
*total_rows_written = Some(rows_written);
let on_close = |r: ColumnCloseResult| {
// Update row group writer metrics
*total_bytes_written += r.bytes_written;
column_chunks.push(r.metadata);
column_indexes.push(r.column_index);
offset_indexes.push(r.offset_index);

if let Some(rows) = *total_rows_written {
if rows != r.rows_written {
return Err(general_err!(
"Incorrect number of rows, expected {} != {} rows",
rows,
r.rows_written
));
}
} else {
*total_rows_written = Some(r.rows_written);
}

Ok(())
};
Ok(())
};

let column = self.descr.column(self.column_index);
self.column_index += 1;
Expand Down Expand Up @@ -504,26 +489,19 @@ impl<'a> SerializedColumnWriter<'a> {

/// Close this [`SerializedColumnWriter]
pub fn close(mut self) -> Result<()> {
let (bytes_written, rows_written, metadata, column_index, offset_index) =
match self.inner {
ColumnWriter::BoolColumnWriter(typed) => typed.close()?,
ColumnWriter::Int32ColumnWriter(typed) => typed.close()?,
ColumnWriter::Int64ColumnWriter(typed) => typed.close()?,
ColumnWriter::Int96ColumnWriter(typed) => typed.close()?,
ColumnWriter::FloatColumnWriter(typed) => typed.close()?,
ColumnWriter::DoubleColumnWriter(typed) => typed.close()?,
ColumnWriter::ByteArrayColumnWriter(typed) => typed.close()?,
ColumnWriter::FixedLenByteArrayColumnWriter(typed) => typed.close()?,
};
let r = match self.inner {
ColumnWriter::BoolColumnWriter(typed) => typed.close()?,
ColumnWriter::Int32ColumnWriter(typed) => typed.close()?,
ColumnWriter::Int64ColumnWriter(typed) => typed.close()?,
ColumnWriter::Int96ColumnWriter(typed) => typed.close()?,
ColumnWriter::FloatColumnWriter(typed) => typed.close()?,
ColumnWriter::DoubleColumnWriter(typed) => typed.close()?,
ColumnWriter::ByteArrayColumnWriter(typed) => typed.close()?,
ColumnWriter::FixedLenByteArrayColumnWriter(typed) => typed.close()?,
};

if let Some(on_close) = self.on_close.take() {
on_close(
bytes_written,
rows_written,
metadata,
column_index,
offset_index,
)?
on_close(r)?
}

Ok(())
Expand Down

0 comments on commit c7b5cc5

Please sign in to comment.