Skip to content

Commit

Permalink
Add PyArrow integration test for C Stream Interface (#1848)
Browse files Browse the repository at this point in the history
* Add PyArrow integration test for ArrowArrayStream

* Trigger Build
  • Loading branch information
viirya committed Jun 13, 2022
1 parent 3073a26 commit cedaf8a
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 17 deletions.
9 changes: 9 additions & 0 deletions arrow-pyarrow-integration-testing/src/lib.rs
Expand Up @@ -27,6 +27,7 @@ use arrow::array::{ArrayData, ArrayRef, Int64Array};
use arrow::compute::kernels;
use arrow::datatypes::{DataType, Field, Schema};
use arrow::error::ArrowError;
use arrow::ffi_stream::ArrowArrayStreamReader;
use arrow::pyarrow::PyArrowConvert;
use arrow::record_batch::RecordBatch;

Expand Down Expand Up @@ -111,6 +112,13 @@ fn round_trip_record_batch(obj: RecordBatch) -> PyResult<RecordBatch> {
Ok(obj)
}

#[pyfunction]
fn round_trip_record_batch_reader(
obj: ArrowArrayStreamReader,
) -> PyResult<ArrowArrayStreamReader> {
Ok(obj)
}

#[pymodule]
fn arrow_pyarrow_integration_testing(_py: Python, m: &PyModule) -> PyResult<()> {
m.add_wrapped(wrap_pyfunction!(double))?;
Expand All @@ -122,5 +130,6 @@ fn arrow_pyarrow_integration_testing(_py: Python, m: &PyModule) -> PyResult<()>
m.add_wrapped(wrap_pyfunction!(round_trip_schema))?;
m.add_wrapped(wrap_pyfunction!(round_trip_array))?;
m.add_wrapped(wrap_pyfunction!(round_trip_record_batch))?;
m.add_wrapped(wrap_pyfunction!(round_trip_record_batch_reader))?;
Ok(())
}
16 changes: 16 additions & 0 deletions arrow-pyarrow-integration-testing/tests/test_sql.py
Expand Up @@ -303,3 +303,19 @@ def test_dictionary_python():
assert a == b
del a
del b

def test_record_batch_reader():
"""
Python -> Rust -> Python
"""
schema = pa.schema([('ints', pa.list_(pa.int32()))], metadata={b'key1': b'value1'})
batches = [
pa.record_batch([[[1], [2, 42]]], schema),
pa.record_batch([[None, [], [5, 6]]], schema),
]
a = pa.RecordBatchReader.from_batches(schema, batches)
b = rust.round_trip_record_batch_reader(a)

assert b.schema == schema
got_batches = list(b)
assert got_batches == batches
24 changes: 8 additions & 16 deletions arrow/src/ffi_stream.rs
Expand Up @@ -198,13 +198,6 @@ impl ExportedArrayStream {
}

pub fn get_schema(&mut self, out: *mut FFI_ArrowSchema) -> i32 {
unsafe {
match (*out).release {
None => (),
Some(release) => release(out),
};
};

let mut private_data = self.get_private_data();
let reader = &private_data.batch_reader;

Expand All @@ -224,18 +217,17 @@ impl ExportedArrayStream {
}

pub fn get_next(&mut self, out: *mut FFI_ArrowArray) -> i32 {
unsafe {
match (*out).release {
None => (),
Some(release) => release(out),
};
};

let mut private_data = self.get_private_data();
let reader = &mut private_data.batch_reader;

let ret_code = match reader.next() {
None => 0,
None => {
// Marks ArrowArray released to indicate reaching the end of stream.
unsafe {
(*out).release = None;
}
0
}
Some(next_batch) => {
if let Ok(batch) = next_batch {
let struct_array = StructArray::from(batch);
Expand Down Expand Up @@ -275,7 +267,7 @@ fn get_error_code(err: &ArrowError) -> i32 {
/// Struct used to fetch `RecordBatch` from the C Stream Interface.
/// Its main responsibility is to expose `RecordBatchReader` functionality
/// that requires [FFI_ArrowArrayStream].
#[derive(Debug)]
#[derive(Debug, Clone)]
pub struct ArrowArrayStreamReader {
stream: Arc<FFI_ArrowArrayStream>,
schema: SchemaRef,
Expand Down
42 changes: 41 additions & 1 deletion arrow/src/pyarrow.rs
Expand Up @@ -24,13 +24,16 @@ use std::sync::Arc;
use pyo3::ffi::Py_uintptr_t;
use pyo3::import_exception;
use pyo3::prelude::*;
use pyo3::types::PyList;
use pyo3::types::{PyList, PyTuple};

use crate::array::{Array, ArrayData, ArrayRef};
use crate::datatypes::{DataType, Field, Schema};
use crate::error::ArrowError;
use crate::ffi;
use crate::ffi::FFI_ArrowSchema;
use crate::ffi_stream::{
export_reader_into_raw, ArrowArrayStreamReader, FFI_ArrowArrayStream,
};
use crate::record_batch::RecordBatch;

import_exception!(pyarrow, ArrowException);
Expand Down Expand Up @@ -198,6 +201,42 @@ impl PyArrowConvert for RecordBatch {
}
}

impl PyArrowConvert for ArrowArrayStreamReader {
fn from_pyarrow(value: &PyAny) -> PyResult<Self> {
// prepare a pointer to receive the stream struct
let stream = Box::new(FFI_ArrowArrayStream::empty());
let stream_ptr = Box::into_raw(stream) as *mut FFI_ArrowArrayStream;

// make the conversion through PyArrow's private API
// this changes the pointer's memory and is thus unsafe.
// In particular, `_export_to_c` can go out of bounds
let args = PyTuple::new(value.py(), &[stream_ptr as Py_uintptr_t]);
value.call_method1("_export_to_c", args)?;

let stream_reader =
unsafe { ArrowArrayStreamReader::from_raw(stream_ptr).unwrap() };

unsafe {
Box::from_raw(stream_ptr);
}

Ok(stream_reader)
}

fn to_pyarrow(&self, py: Python) -> PyResult<PyObject> {
let stream = Box::new(FFI_ArrowArrayStream::empty());
let stream_ptr = Box::into_raw(stream) as *mut FFI_ArrowArrayStream;

unsafe { export_reader_into_raw(Box::new(self.clone()), stream_ptr) };

let module = py.import("pyarrow")?;
let class = module.getattr("RecordBatchReader")?;
let args = PyTuple::new(py, &[stream_ptr as Py_uintptr_t]);
let reader = class.call_method1("_import_from_c", args)?;
Ok(PyObject::from(reader))
}
}

macro_rules! add_conversion {
($typ:ty) => {
impl<'source> FromPyObject<'source> for $typ {
Expand All @@ -219,3 +258,4 @@ add_conversion!(Field);
add_conversion!(Schema);
add_conversion!(ArrayData);
add_conversion!(RecordBatch);
add_conversion!(ArrowArrayStreamReader);

0 comments on commit cedaf8a

Please sign in to comment.