Skip to content

Commit

Permalink
Add API to Retrieve Finished Writer from Parquet Writer (#2498)
Browse files Browse the repository at this point in the history
* add into_inner to take inner writer out

* flush writer before into_inner

* Apply suggestions from code review

Co-authored-by: Raphael Taylor-Davies <1781103+tustvold@users.noreply.github.com>
  • Loading branch information
jiacai2050 and tustvold committed Aug 18, 2022
1 parent 15f42b2 commit f3afdd2
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 25 deletions.
72 changes: 47 additions & 25 deletions parquet/src/arrow/arrow_writer/mod.rs
Expand Up @@ -223,6 +223,12 @@ impl<W: Write> ArrowWriter<W> {
Ok(())
}

/// Flushes any outstanding data and returns the underlying writer.
pub fn into_inner(mut self) -> Result<W> {
self.flush()?;
self.writer.into_inner()
}

/// Close and finalize the underlying Parquet writer
pub fn close(mut self) -> Result<parquet_format::FileMetaData> {
self.flush()?;
Expand Down Expand Up @@ -644,6 +650,25 @@ mod tests {
roundtrip(batch, Some(SMALL_SIZE / 2));
}

fn get_bytes_after_close(schema: SchemaRef, expected_batch: &RecordBatch) -> Vec<u8> {
let mut buffer = vec![];

let mut writer = ArrowWriter::try_new(&mut buffer, schema, None).unwrap();
writer.write(expected_batch).unwrap();
writer.close().unwrap();

buffer
}

fn get_bytes_by_into_inner(
schema: SchemaRef,
expected_batch: &RecordBatch,
) -> Vec<u8> {
let mut writer = ArrowWriter::try_new(Vec::new(), schema, None).unwrap();
writer.write(expected_batch).unwrap();
writer.into_inner().unwrap()
}

#[test]
fn roundtrip_bytes() {
// define schema
Expand All @@ -660,31 +685,28 @@ mod tests {
let expected_batch =
RecordBatch::try_new(schema.clone(), vec![Arc::new(a), Arc::new(b)]).unwrap();

let mut buffer = vec![];

{
let mut writer = ArrowWriter::try_new(&mut buffer, schema, None).unwrap();
writer.write(&expected_batch).unwrap();
writer.close().unwrap();
}

let cursor = Bytes::from(buffer);
let mut record_batch_reader =
ParquetRecordBatchReader::try_new(cursor, 1024).unwrap();

let actual_batch = record_batch_reader
.next()
.expect("No batch found")
.expect("Unable to get batch");

assert_eq!(expected_batch.schema(), actual_batch.schema());
assert_eq!(expected_batch.num_columns(), actual_batch.num_columns());
assert_eq!(expected_batch.num_rows(), actual_batch.num_rows());
for i in 0..expected_batch.num_columns() {
let expected_data = expected_batch.column(i).data().clone();
let actual_data = actual_batch.column(i).data().clone();

assert_eq!(expected_data, actual_data);
for buffer in vec![
get_bytes_after_close(schema.clone(), &expected_batch),
get_bytes_by_into_inner(schema, &expected_batch),
] {
let cursor = Bytes::from(buffer);
let mut record_batch_reader =
ParquetRecordBatchReader::try_new(cursor, 1024).unwrap();

let actual_batch = record_batch_reader
.next()
.expect("No batch found")
.expect("Unable to get batch");

assert_eq!(expected_batch.schema(), actual_batch.schema());
assert_eq!(expected_batch.num_columns(), actual_batch.num_columns());
assert_eq!(expected_batch.num_rows(), actual_batch.num_rows());
for i in 0..expected_batch.num_columns() {
let expected_data = expected_batch.column(i).data().clone();
let actual_data = actual_batch.column(i).data().clone();

assert_eq!(expected_data, actual_data);
}
}
}

Expand Down
13 changes: 13 additions & 0 deletions parquet/src/file/writer.rs
Expand Up @@ -62,6 +62,11 @@ impl<W: Write> TrackedWrite<W> {
pub fn bytes_written(&self) -> usize {
self.bytes_written
}

/// Returns the underlying writer.
pub fn into_inner(self) -> W {
self.inner
}
}

impl<W: Write> Write for TrackedWrite<W> {
Expand Down Expand Up @@ -292,6 +297,14 @@ impl<W: Write> SerializedFileWriter<W> {
Ok(())
}
}

/// Writes the file footer and returns the underlying writer.
pub fn into_inner(mut self) -> Result<W> {
self.assert_previous_writer_closed()?;
let _ = self.write_metadata()?;

Ok(self.buf.into_inner())
}
}

/// Parquet row group writer API.
Expand Down

0 comments on commit f3afdd2

Please sign in to comment.