Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add API to Retrieve Finished Writer from Parquet Writer #2498

Merged
merged 3 commits into from Aug 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
72 changes: 47 additions & 25 deletions parquet/src/arrow/arrow_writer/mod.rs
Expand Up @@ -223,6 +223,12 @@ impl<W: Write> ArrowWriter<W> {
Ok(())
}

/// Flushes any outstanding data and returns the underlying writer.
pub fn into_inner(mut self) -> Result<W> {
self.flush()?;
self.writer.into_inner()
}

/// Close and finalize the underlying Parquet writer
pub fn close(mut self) -> Result<parquet_format::FileMetaData> {
self.flush()?;
Expand Down Expand Up @@ -644,6 +650,25 @@ mod tests {
roundtrip(batch, Some(SMALL_SIZE / 2));
}

fn get_bytes_after_close(schema: SchemaRef, expected_batch: &RecordBatch) -> Vec<u8> {
let mut buffer = vec![];

let mut writer = ArrowWriter::try_new(&mut buffer, schema, None).unwrap();
writer.write(expected_batch).unwrap();
writer.close().unwrap();

buffer
}

fn get_bytes_by_into_inner(
schema: SchemaRef,
expected_batch: &RecordBatch,
) -> Vec<u8> {
let mut writer = ArrowWriter::try_new(Vec::new(), schema, None).unwrap();
writer.write(expected_batch).unwrap();
writer.into_inner().unwrap()
}

#[test]
fn roundtrip_bytes() {
// define schema
Expand All @@ -660,31 +685,28 @@ mod tests {
let expected_batch =
RecordBatch::try_new(schema.clone(), vec![Arc::new(a), Arc::new(b)]).unwrap();

let mut buffer = vec![];

{
let mut writer = ArrowWriter::try_new(&mut buffer, schema, None).unwrap();
writer.write(&expected_batch).unwrap();
writer.close().unwrap();
}

let cursor = Bytes::from(buffer);
let mut record_batch_reader =
ParquetRecordBatchReader::try_new(cursor, 1024).unwrap();

let actual_batch = record_batch_reader
.next()
.expect("No batch found")
.expect("Unable to get batch");

assert_eq!(expected_batch.schema(), actual_batch.schema());
assert_eq!(expected_batch.num_columns(), actual_batch.num_columns());
assert_eq!(expected_batch.num_rows(), actual_batch.num_rows());
for i in 0..expected_batch.num_columns() {
let expected_data = expected_batch.column(i).data().clone();
let actual_data = actual_batch.column(i).data().clone();

assert_eq!(expected_data, actual_data);
for buffer in vec![
get_bytes_after_close(schema.clone(), &expected_batch),
get_bytes_by_into_inner(schema, &expected_batch),
] {
let cursor = Bytes::from(buffer);
let mut record_batch_reader =
ParquetRecordBatchReader::try_new(cursor, 1024).unwrap();

let actual_batch = record_batch_reader
.next()
.expect("No batch found")
.expect("Unable to get batch");

assert_eq!(expected_batch.schema(), actual_batch.schema());
assert_eq!(expected_batch.num_columns(), actual_batch.num_columns());
assert_eq!(expected_batch.num_rows(), actual_batch.num_rows());
for i in 0..expected_batch.num_columns() {
let expected_data = expected_batch.column(i).data().clone();
let actual_data = actual_batch.column(i).data().clone();

assert_eq!(expected_data, actual_data);
}
}
}

Expand Down
13 changes: 13 additions & 0 deletions parquet/src/file/writer.rs
Expand Up @@ -60,6 +60,11 @@ impl<W: Write> TrackedWrite<W> {
pub fn bytes_written(&self) -> usize {
self.bytes_written
}

/// Returns the underlying writer.
pub fn into_inner(self) -> W {
self.inner
}
}

impl<W: Write> Write for TrackedWrite<W> {
Expand Down Expand Up @@ -306,6 +311,14 @@ impl<W: Write> SerializedFileWriter<W> {
Ok(())
}
}

/// Writes the file footer and returns the underlying writer.
pub fn into_inner(mut self) -> Result<W> {
self.assert_previous_writer_closed()?;
let _ = self.write_metadata()?;

Ok(self.buf.into_inner())
}
}

/// Parquet row group writer API.
Expand Down