Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add PyArrow integration test for C Stream Interface #1848

Merged
merged 3 commits into from Jun 13, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
9 changes: 9 additions & 0 deletions arrow-pyarrow-integration-testing/src/lib.rs
Expand Up @@ -27,6 +27,7 @@ use arrow::array::{ArrayData, ArrayRef, Int64Array};
use arrow::compute::kernels;
use arrow::datatypes::{DataType, Field, Schema};
use arrow::error::ArrowError;
use arrow::ffi_stream::ArrowArrayStreamReader;
use arrow::pyarrow::PyArrowConvert;
use arrow::record_batch::RecordBatch;

Expand Down Expand Up @@ -111,6 +112,13 @@ fn round_trip_record_batch(obj: RecordBatch) -> PyResult<RecordBatch> {
Ok(obj)
}

#[pyfunction]
fn round_trip_record_batch_reader(
obj: ArrowArrayStreamReader,
) -> PyResult<ArrowArrayStreamReader> {
Ok(obj)
}

#[pymodule]
fn arrow_pyarrow_integration_testing(_py: Python, m: &PyModule) -> PyResult<()> {
m.add_wrapped(wrap_pyfunction!(double))?;
Expand All @@ -122,5 +130,6 @@ fn arrow_pyarrow_integration_testing(_py: Python, m: &PyModule) -> PyResult<()>
m.add_wrapped(wrap_pyfunction!(round_trip_schema))?;
m.add_wrapped(wrap_pyfunction!(round_trip_array))?;
m.add_wrapped(wrap_pyfunction!(round_trip_record_batch))?;
m.add_wrapped(wrap_pyfunction!(round_trip_record_batch_reader))?;
Ok(())
}
16 changes: 16 additions & 0 deletions arrow-pyarrow-integration-testing/tests/test_sql.py
Expand Up @@ -303,3 +303,19 @@ def test_dictionary_python():
assert a == b
del a
del b

def test_record_batch_reader():
"""
Python -> Rust -> Python
"""
schema = pa.schema([('ints', pa.list_(pa.int32()))], metadata={b'key1': b'value1'})
batches = [
pa.record_batch([[[1], [2, 42]]], schema),
pa.record_batch([[None, [], [5, 6]]], schema),
]
a = pa.RecordBatchReader.from_batches(schema, batches)
b = rust.round_trip_record_batch_reader(a)

assert b.schema == schema
got_batches = list(b)
assert got_batches == batches
24 changes: 8 additions & 16 deletions arrow/src/ffi_stream.rs
Expand Up @@ -198,13 +198,6 @@ impl ExportedArrayStream {
}

pub fn get_schema(&mut self, out: *mut FFI_ArrowSchema) -> i32 {
unsafe {
match (*out).release {
None => (),
Some(release) => release(out),
};
};

Comment on lines -201 to -207
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For get_schema and get_next, we did it wrongly to releasing the output struct. It should be handled by the caller.

let mut private_data = self.get_private_data();
let reader = &private_data.batch_reader;

Expand All @@ -224,18 +217,17 @@ impl ExportedArrayStream {
}

pub fn get_next(&mut self, out: *mut FFI_ArrowArray) -> i32 {
unsafe {
match (*out).release {
None => (),
Some(release) => release(out),
};
};

let mut private_data = self.get_private_data();
let reader = &mut private_data.batch_reader;

let ret_code = match reader.next() {
None => 0,
None => {
// Marks ArrowArray released to indicate reaching the end of stream.
unsafe {
(*out).release = None;
}
0
}
Some(next_batch) => {
if let Ok(batch) = next_batch {
let struct_array = StructArray::from(batch);
Expand Down Expand Up @@ -275,7 +267,7 @@ fn get_error_code(err: &ArrowError) -> i32 {
/// Struct used to fetch `RecordBatch` from the C Stream Interface.
/// Its main responsibility is to expose `RecordBatchReader` functionality
/// that requires [FFI_ArrowArrayStream].
#[derive(Debug)]
#[derive(Debug, Clone)]
pub struct ArrowArrayStreamReader {
stream: Arc<FFI_ArrowArrayStream>,
schema: SchemaRef,
Expand Down
42 changes: 41 additions & 1 deletion arrow/src/pyarrow.rs
Expand Up @@ -24,13 +24,16 @@ use std::sync::Arc;
use pyo3::ffi::Py_uintptr_t;
use pyo3::import_exception;
use pyo3::prelude::*;
use pyo3::types::PyList;
use pyo3::types::{PyList, PyTuple};

use crate::array::{Array, ArrayData, ArrayRef};
use crate::datatypes::{DataType, Field, Schema};
use crate::error::ArrowError;
use crate::ffi;
use crate::ffi::FFI_ArrowSchema;
use crate::ffi_stream::{
export_reader_into_raw, ArrowArrayStreamReader, FFI_ArrowArrayStream,
};
use crate::record_batch::RecordBatch;

import_exception!(pyarrow, ArrowException);
Expand Down Expand Up @@ -198,6 +201,42 @@ impl PyArrowConvert for RecordBatch {
}
}

impl PyArrowConvert for ArrowArrayStreamReader {
fn from_pyarrow(value: &PyAny) -> PyResult<Self> {
// prepare a pointer to receive the stream struct
let stream = Box::new(FFI_ArrowArrayStream::empty());
let stream_ptr = Box::into_raw(stream) as *mut FFI_ArrowArrayStream;

// make the conversion through PyArrow's private API
// this changes the pointer's memory and is thus unsafe.
// In particular, `_export_to_c` can go out of bounds
let args = PyTuple::new(value.py(), &[stream_ptr as Py_uintptr_t]);
value.call_method1("_export_to_c", args)?;

let stream_reader =
unsafe { ArrowArrayStreamReader::from_raw(stream_ptr).unwrap() };

unsafe {
Box::from_raw(stream_ptr);
}

Ok(stream_reader)
}

fn to_pyarrow(&self, py: Python) -> PyResult<PyObject> {
let stream = Box::new(FFI_ArrowArrayStream::empty());
let stream_ptr = Box::into_raw(stream) as *mut FFI_ArrowArrayStream;

unsafe { export_reader_into_raw(Box::new(self.clone()), stream_ptr) };

let module = py.import("pyarrow")?;
let class = module.getattr("RecordBatchReader")?;
let args = PyTuple::new(py, &[stream_ptr as Py_uintptr_t]);
let reader = class.call_method1("_import_from_c", args)?;
Ok(PyObject::from(reader))
}
}

macro_rules! add_conversion {
($typ:ty) => {
impl<'source> FromPyObject<'source> for $typ {
Expand All @@ -219,3 +258,4 @@ add_conversion!(Field);
add_conversion!(Schema);
add_conversion!(ArrayData);
add_conversion!(RecordBatch);
add_conversion!(ArrowArrayStreamReader);