diff --git a/parquet/src/arrow/arrow_writer/mod.rs b/parquet/src/arrow/arrow_writer/mod.rs index 75bd6f6aa75..53b094a9e28 100644 --- a/parquet/src/arrow/arrow_writer/mod.rs +++ b/parquet/src/arrow/arrow_writer/mod.rs @@ -33,7 +33,7 @@ use super::schema::{ decimal_length_from_precision, }; -use crate::column::writer::ColumnWriter; +use crate::column::writer::{get_column_writer, ColumnWriter}; use crate::errors::{ParquetError, Result}; use crate::file::metadata::RowGroupMetaDataPtr; use crate::file::properties::WriterProperties; @@ -43,6 +43,44 @@ use levels::{calculate_array_levels, LevelInfo}; mod levels; +/// An object-safe API for writing an [`ArrayRef`] +trait ArrayWriter { + fn write(&mut self, array: &ArrayRef, levels: LevelInfo) -> Result<()>; + + fn close(&mut self) -> Result<()>; +} + +/// Fallback implementation for writing an [`ArrayRef`] that uses [`SerializedColumnWriter`] +struct ColumnArrayWriter<'a>(Option>); + +impl<'a> ArrayWriter for ColumnArrayWriter<'a> { + fn write(&mut self, array: &ArrayRef, levels: LevelInfo) -> Result<()> { + write_leaf(self.0.as_mut().unwrap().untyped(), array, levels)?; + Ok(()) + } + + fn close(&mut self) -> Result<()> { + self.0.take().unwrap().close() + } +} + +fn get_writer<'a, W: Write>( + row_group_writer: &'a mut SerializedRowGroupWriter<'_, W>, +) -> Result> { + let array_writer = row_group_writer + .next_column_with_factory(|descr, props, page_writer, on_close| { + // TODO: Special case array readers (#1764) + + let column_writer = get_column_writer(descr, props.clone(), page_writer); + let serialized_writer = + SerializedColumnWriter::new(column_writer, Some(on_close)); + + Ok(Box::new(ColumnArrayWriter(Some(serialized_writer)))) + })? + .expect("Unable to get column writer"); + Ok(array_writer) +} + /// Arrow writer /// /// Writes Arrow `RecordBatch`es to a Parquet writer, buffering up `RecordBatch` in order @@ -229,17 +267,6 @@ impl ArrowWriter { } } -/// Convenience method to get the next ColumnWriter from the RowGroupWriter -#[inline] -fn get_col_writer<'a, W: Write>( - row_group_writer: &'a mut SerializedRowGroupWriter<'_, W>, -) -> Result> { - let col_writer = row_group_writer - .next_column()? - .expect("Unable to get column writer"); - Ok(col_writer) -} - fn write_leaves( row_group_writer: &mut SerializedRowGroupWriter<'_, W>, arrays: &[ArrayRef], @@ -277,15 +304,14 @@ fn write_leaves( | ArrowDataType::LargeUtf8 | ArrowDataType::Decimal(_, _) | ArrowDataType::FixedSizeBinary(_) => { - let mut col_writer = get_col_writer(row_group_writer)?; + let mut writer = get_writer(row_group_writer)?; for (array, levels) in arrays.iter().zip(levels.iter_mut()) { - write_leaf( - col_writer.untyped(), + writer.write( array, levels.pop().expect("Levels exhausted"), )?; } - col_writer.close()?; + writer.close()?; Ok(()) } ArrowDataType::List(_) | ArrowDataType::LargeList(_) => { @@ -338,17 +364,16 @@ fn write_leaves( Ok(()) } ArrowDataType::Dictionary(_, value_type) => { - let mut col_writer = get_col_writer(row_group_writer)?; + let mut writer = get_writer(row_group_writer)?; for (array, levels) in arrays.iter().zip(levels.iter_mut()) { // cast dictionary to a primitive let array = arrow::compute::cast(array, value_type)?; - write_leaf( - col_writer.untyped(), + writer.write( &array, levels.pop().expect("Levels exhausted"), )?; } - col_writer.close()?; + writer.close()?; Ok(()) } ArrowDataType::Float16 => Err(ParquetError::ArrowError( diff --git a/parquet/src/file/writer.rs b/parquet/src/file/writer.rs index 10983c74135..467273aaab9 100644 --- a/parquet/src/file/writer.rs +++ b/parquet/src/file/writer.rs @@ -37,7 +37,9 @@ use crate::file::{ metadata::*, properties::WriterPropertiesPtr, statistics::to_thrift as statistics_to_thrift, FOOTER_SIZE, PARQUET_MAGIC, }; -use crate::schema::types::{self, SchemaDescPtr, SchemaDescriptor, TypePtr}; +use crate::schema::types::{ + self, ColumnDescPtr, SchemaDescPtr, SchemaDescriptor, TypePtr, +}; use crate::util::io::TryClone; /// A wrapper around a [`Write`] that keeps track of the number @@ -367,22 +369,26 @@ impl<'a, W: Write> SerializedRowGroupWriter<'a, W> { } } - /// Returns the next column writer, if available; otherwise returns `None`. - /// In case of any IO error or Thrift error, or if row group writer has already been - /// closed returns `Err`. - pub fn next_column(&mut self) -> Result>> { + /// Returns the next column writer, if available, using the factory function; + /// otherwise returns `None`. + pub(crate) fn next_column_with_factory<'b, F, C>( + &'b mut self, + factory: F, + ) -> Result> + where + F: FnOnce( + ColumnDescPtr, + &'b WriterPropertiesPtr, + Box, + OnCloseColumnChunk<'b>, + ) -> Result, + { self.assert_previous_writer_closed()?; if self.column_index >= self.descr.num_columns() { return Ok(None); } let page_writer = Box::new(SerializedPageWriter::new(self.buf)); - let column_writer = get_column_writer( - self.descr.column(self.column_index), - self.props.clone(), - page_writer, - ); - self.column_index += 1; let total_bytes_written = &mut self.total_bytes_written; let total_rows_written = &mut self.total_rows_written; @@ -413,10 +419,25 @@ impl<'a, W: Write> SerializedRowGroupWriter<'a, W> { Ok(()) }; - Ok(Some(SerializedColumnWriter::new( - column_writer, - Some(Box::new(on_close)), - ))) + let column = self.descr.column(self.column_index); + self.column_index += 1; + + Ok(Some(factory( + column, + &self.props, + page_writer, + Box::new(on_close), + )?)) + } + + /// Returns the next column writer, if available; otherwise returns `None`. + /// In case of any IO error or Thrift error, or if row group writer has already been + /// closed returns `Err`. + pub fn next_column(&mut self) -> Result>> { + self.next_column_with_factory(|descr, props, page_writer, on_close| { + let column_writer = get_column_writer(descr, props.clone(), page_writer); + Ok(SerializedColumnWriter::new(column_writer, Some(on_close))) + }) } /// Closes this row group writer and returns row group metadata.