From cedaf8a6ab55826c34f3b1bc9a21dbaf3e0328bc Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 13 Jun 2022 15:29:46 -0700 Subject: [PATCH] Add PyArrow integration test for C Stream Interface (#1848) * Add PyArrow integration test for ArrowArrayStream * Trigger Build --- arrow-pyarrow-integration-testing/src/lib.rs | 9 ++++ .../tests/test_sql.py | 16 +++++++ arrow/src/ffi_stream.rs | 24 ++++------- arrow/src/pyarrow.rs | 42 ++++++++++++++++++- 4 files changed, 74 insertions(+), 17 deletions(-) diff --git a/arrow-pyarrow-integration-testing/src/lib.rs b/arrow-pyarrow-integration-testing/src/lib.rs index 26c09d64d5d..086b2183465 100644 --- a/arrow-pyarrow-integration-testing/src/lib.rs +++ b/arrow-pyarrow-integration-testing/src/lib.rs @@ -27,6 +27,7 @@ use arrow::array::{ArrayData, ArrayRef, Int64Array}; use arrow::compute::kernels; use arrow::datatypes::{DataType, Field, Schema}; use arrow::error::ArrowError; +use arrow::ffi_stream::ArrowArrayStreamReader; use arrow::pyarrow::PyArrowConvert; use arrow::record_batch::RecordBatch; @@ -111,6 +112,13 @@ fn round_trip_record_batch(obj: RecordBatch) -> PyResult { Ok(obj) } +#[pyfunction] +fn round_trip_record_batch_reader( + obj: ArrowArrayStreamReader, +) -> PyResult { + Ok(obj) +} + #[pymodule] fn arrow_pyarrow_integration_testing(_py: Python, m: &PyModule) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(double))?; @@ -122,5 +130,6 @@ fn arrow_pyarrow_integration_testing(_py: Python, m: &PyModule) -> PyResult<()> m.add_wrapped(wrap_pyfunction!(round_trip_schema))?; m.add_wrapped(wrap_pyfunction!(round_trip_array))?; m.add_wrapped(wrap_pyfunction!(round_trip_record_batch))?; + m.add_wrapped(wrap_pyfunction!(round_trip_record_batch_reader))?; Ok(()) } diff --git a/arrow-pyarrow-integration-testing/tests/test_sql.py b/arrow-pyarrow-integration-testing/tests/test_sql.py index 324956c9c6a..a17ba6d0613 100644 --- a/arrow-pyarrow-integration-testing/tests/test_sql.py +++ b/arrow-pyarrow-integration-testing/tests/test_sql.py @@ -303,3 +303,19 @@ def test_dictionary_python(): assert a == b del a del b + +def test_record_batch_reader(): + """ + Python -> Rust -> Python + """ + schema = pa.schema([('ints', pa.list_(pa.int32()))], metadata={b'key1': b'value1'}) + batches = [ + pa.record_batch([[[1], [2, 42]]], schema), + pa.record_batch([[None, [], [5, 6]]], schema), + ] + a = pa.RecordBatchReader.from_batches(schema, batches) + b = rust.round_trip_record_batch_reader(a) + + assert b.schema == schema + got_batches = list(b) + assert got_batches == batches diff --git a/arrow/src/ffi_stream.rs b/arrow/src/ffi_stream.rs index ab4caea36f8..bfc62b8888c 100644 --- a/arrow/src/ffi_stream.rs +++ b/arrow/src/ffi_stream.rs @@ -198,13 +198,6 @@ impl ExportedArrayStream { } pub fn get_schema(&mut self, out: *mut FFI_ArrowSchema) -> i32 { - unsafe { - match (*out).release { - None => (), - Some(release) => release(out), - }; - }; - let mut private_data = self.get_private_data(); let reader = &private_data.batch_reader; @@ -224,18 +217,17 @@ impl ExportedArrayStream { } pub fn get_next(&mut self, out: *mut FFI_ArrowArray) -> i32 { - unsafe { - match (*out).release { - None => (), - Some(release) => release(out), - }; - }; - let mut private_data = self.get_private_data(); let reader = &mut private_data.batch_reader; let ret_code = match reader.next() { - None => 0, + None => { + // Marks ArrowArray released to indicate reaching the end of stream. + unsafe { + (*out).release = None; + } + 0 + } Some(next_batch) => { if let Ok(batch) = next_batch { let struct_array = StructArray::from(batch); @@ -275,7 +267,7 @@ fn get_error_code(err: &ArrowError) -> i32 { /// Struct used to fetch `RecordBatch` from the C Stream Interface. /// Its main responsibility is to expose `RecordBatchReader` functionality /// that requires [FFI_ArrowArrayStream]. -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct ArrowArrayStreamReader { stream: Arc, schema: SchemaRef, diff --git a/arrow/src/pyarrow.rs b/arrow/src/pyarrow.rs index 62e6316b621..3ae5b3b9987 100644 --- a/arrow/src/pyarrow.rs +++ b/arrow/src/pyarrow.rs @@ -24,13 +24,16 @@ use std::sync::Arc; use pyo3::ffi::Py_uintptr_t; use pyo3::import_exception; use pyo3::prelude::*; -use pyo3::types::PyList; +use pyo3::types::{PyList, PyTuple}; use crate::array::{Array, ArrayData, ArrayRef}; use crate::datatypes::{DataType, Field, Schema}; use crate::error::ArrowError; use crate::ffi; use crate::ffi::FFI_ArrowSchema; +use crate::ffi_stream::{ + export_reader_into_raw, ArrowArrayStreamReader, FFI_ArrowArrayStream, +}; use crate::record_batch::RecordBatch; import_exception!(pyarrow, ArrowException); @@ -198,6 +201,42 @@ impl PyArrowConvert for RecordBatch { } } +impl PyArrowConvert for ArrowArrayStreamReader { + fn from_pyarrow(value: &PyAny) -> PyResult { + // prepare a pointer to receive the stream struct + let stream = Box::new(FFI_ArrowArrayStream::empty()); + let stream_ptr = Box::into_raw(stream) as *mut FFI_ArrowArrayStream; + + // make the conversion through PyArrow's private API + // this changes the pointer's memory and is thus unsafe. + // In particular, `_export_to_c` can go out of bounds + let args = PyTuple::new(value.py(), &[stream_ptr as Py_uintptr_t]); + value.call_method1("_export_to_c", args)?; + + let stream_reader = + unsafe { ArrowArrayStreamReader::from_raw(stream_ptr).unwrap() }; + + unsafe { + Box::from_raw(stream_ptr); + } + + Ok(stream_reader) + } + + fn to_pyarrow(&self, py: Python) -> PyResult { + let stream = Box::new(FFI_ArrowArrayStream::empty()); + let stream_ptr = Box::into_raw(stream) as *mut FFI_ArrowArrayStream; + + unsafe { export_reader_into_raw(Box::new(self.clone()), stream_ptr) }; + + let module = py.import("pyarrow")?; + let class = module.getattr("RecordBatchReader")?; + let args = PyTuple::new(py, &[stream_ptr as Py_uintptr_t]); + let reader = class.call_method1("_import_from_c", args)?; + Ok(PyObject::from(reader)) + } +} + macro_rules! add_conversion { ($typ:ty) => { impl<'source> FromPyObject<'source> for $typ { @@ -219,3 +258,4 @@ add_conversion!(Field); add_conversion!(Schema); add_conversion!(ArrayData); add_conversion!(RecordBatch); +add_conversion!(ArrowArrayStreamReader);