Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Array writer indirection #2091

Merged
merged 1 commit into from Jul 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
65 changes: 45 additions & 20 deletions parquet/src/arrow/arrow_writer/mod.rs
Expand Up @@ -33,7 +33,7 @@ use super::schema::{
decimal_length_from_precision,
};

use crate::column::writer::ColumnWriter;
use crate::column::writer::{get_column_writer, ColumnWriter};
use crate::errors::{ParquetError, Result};
use crate::file::metadata::RowGroupMetaDataPtr;
use crate::file::properties::WriterProperties;
Expand All @@ -43,6 +43,44 @@ use levels::{calculate_array_levels, LevelInfo};

mod levels;

/// An object-safe API for writing an [`ArrayRef`]
trait ArrayWriter {
fn write(&mut self, array: &ArrayRef, levels: LevelInfo) -> Result<()>;

fn close(&mut self) -> Result<()>;
}

/// Fallback implementation for writing an [`ArrayRef`] that uses [`SerializedColumnWriter`]
struct ColumnArrayWriter<'a>(Option<SerializedColumnWriter<'a>>);

impl<'a> ArrayWriter for ColumnArrayWriter<'a> {
fn write(&mut self, array: &ArrayRef, levels: LevelInfo) -> Result<()> {
write_leaf(self.0.as_mut().unwrap().untyped(), array, levels)?;
Ok(())
}

fn close(&mut self) -> Result<()> {
self.0.take().unwrap().close()
}
}

fn get_writer<'a, W: Write>(
row_group_writer: &'a mut SerializedRowGroupWriter<'_, W>,
) -> Result<Box<dyn ArrayWriter + 'a>> {
let array_writer = row_group_writer
.next_column_with_factory(|descr, props, page_writer, on_close| {
// TODO: Special case array readers (#1764)

let column_writer = get_column_writer(descr, props.clone(), page_writer);
let serialized_writer =
SerializedColumnWriter::new(column_writer, Some(on_close));

Ok(Box::new(ColumnArrayWriter(Some(serialized_writer))))
})?
.expect("Unable to get column writer");
Ok(array_writer)
}

/// Arrow writer
///
/// Writes Arrow `RecordBatch`es to a Parquet writer, buffering up `RecordBatch` in order
Expand Down Expand Up @@ -229,17 +267,6 @@ impl<W: Write> ArrowWriter<W> {
}
}

/// Convenience method to get the next ColumnWriter from the RowGroupWriter
#[inline]
fn get_col_writer<'a, W: Write>(
row_group_writer: &'a mut SerializedRowGroupWriter<'_, W>,
) -> Result<SerializedColumnWriter<'a>> {
let col_writer = row_group_writer
.next_column()?
.expect("Unable to get column writer");
Ok(col_writer)
}

fn write_leaves<W: Write>(
row_group_writer: &mut SerializedRowGroupWriter<'_, W>,
arrays: &[ArrayRef],
Expand Down Expand Up @@ -277,15 +304,14 @@ fn write_leaves<W: Write>(
| ArrowDataType::LargeUtf8
| ArrowDataType::Decimal(_, _)
| ArrowDataType::FixedSizeBinary(_) => {
let mut col_writer = get_col_writer(row_group_writer)?;
let mut writer = get_writer(row_group_writer)?;
for (array, levels) in arrays.iter().zip(levels.iter_mut()) {
write_leaf(
col_writer.untyped(),
writer.write(
array,
levels.pop().expect("Levels exhausted"),
)?;
}
col_writer.close()?;
writer.close()?;
Ok(())
}
ArrowDataType::List(_) | ArrowDataType::LargeList(_) => {
Expand Down Expand Up @@ -338,17 +364,16 @@ fn write_leaves<W: Write>(
Ok(())
}
ArrowDataType::Dictionary(_, value_type) => {
let mut col_writer = get_col_writer(row_group_writer)?;
let mut writer = get_writer(row_group_writer)?;
for (array, levels) in arrays.iter().zip(levels.iter_mut()) {
// cast dictionary to a primitive
let array = arrow::compute::cast(array, value_type)?;
write_leaf(
col_writer.untyped(),
writer.write(
&array,
levels.pop().expect("Levels exhausted"),
)?;
}
col_writer.close()?;
writer.close()?;
Ok(())
}
ArrowDataType::Float16 => Err(ParquetError::ArrowError(
Expand Down
51 changes: 36 additions & 15 deletions parquet/src/file/writer.rs
Expand Up @@ -37,7 +37,9 @@ use crate::file::{
metadata::*, properties::WriterPropertiesPtr,
statistics::to_thrift as statistics_to_thrift, FOOTER_SIZE, PARQUET_MAGIC,
};
use crate::schema::types::{self, SchemaDescPtr, SchemaDescriptor, TypePtr};
use crate::schema::types::{
self, ColumnDescPtr, SchemaDescPtr, SchemaDescriptor, TypePtr,
};
use crate::util::io::TryClone;

/// A wrapper around a [`Write`] that keeps track of the number
Expand Down Expand Up @@ -367,22 +369,26 @@ impl<'a, W: Write> SerializedRowGroupWriter<'a, W> {
}
}

/// Returns the next column writer, if available; otherwise returns `None`.
/// In case of any IO error or Thrift error, or if row group writer has already been
/// closed returns `Err`.
pub fn next_column(&mut self) -> Result<Option<SerializedColumnWriter<'_>>> {
/// Returns the next column writer, if available, using the factory function;
/// otherwise returns `None`.
pub(crate) fn next_column_with_factory<'b, F, C>(
&'b mut self,
factory: F,
) -> Result<Option<C>>
where
F: FnOnce(
ColumnDescPtr,
&'b WriterPropertiesPtr,
Box<dyn PageWriter + 'b>,
OnCloseColumnChunk<'b>,
) -> Result<C>,
{
self.assert_previous_writer_closed()?;

if self.column_index >= self.descr.num_columns() {
return Ok(None);
}
let page_writer = Box::new(SerializedPageWriter::new(self.buf));
let column_writer = get_column_writer(
self.descr.column(self.column_index),
self.props.clone(),
page_writer,
);
self.column_index += 1;

let total_bytes_written = &mut self.total_bytes_written;
let total_rows_written = &mut self.total_rows_written;
Expand Down Expand Up @@ -413,10 +419,25 @@ impl<'a, W: Write> SerializedRowGroupWriter<'a, W> {
Ok(())
};

Ok(Some(SerializedColumnWriter::new(
column_writer,
Some(Box::new(on_close)),
)))
let column = self.descr.column(self.column_index);
self.column_index += 1;

Ok(Some(factory(
column,
&self.props,
page_writer,
Box::new(on_close),
)?))
}

/// Returns the next column writer, if available; otherwise returns `None`.
/// In case of any IO error or Thrift error, or if row group writer has already been
/// closed returns `Err`.
pub fn next_column(&mut self) -> Result<Option<SerializedColumnWriter<'_>>> {
self.next_column_with_factory(|descr, props, page_writer, on_close| {
let column_writer = get_column_writer(descr, props.clone(), page_writer);
Ok(SerializedColumnWriter::new(column_writer, Some(on_close)))
})
}

/// Closes this row group writer and returns row group metadata.
Expand Down