Skip to content

Commit

Permalink
Update for changes in apache/arrow-rs#2711
Browse files Browse the repository at this point in the history
  • Loading branch information
tustvold committed Sep 20, 2022
1 parent 259f2e4 commit 04cbd33
Show file tree
Hide file tree
Showing 9 changed files with 111 additions and 91 deletions.
102 changes: 41 additions & 61 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

11 changes: 11 additions & 0 deletions Cargo.toml
Expand Up @@ -52,3 +52,14 @@ name = "datafusion._internal"
[profile.release]
lto = true
codegen-units = 1


[patch.crates-io]
arrow = { git = "https://github.com/tustvold/arrow-rs.git", rev = "6f62bb62f630cbd910ae5b1b04f97688af7c1b42" }
parquet = { git = "https://github.com/tustvold/arrow-rs.git", rev = "6f62bb62f630cbd910ae5b1b04f97688af7c1b42" }
arrow-buffer = { git = "https://github.com/tustvold/arrow-rs.git", rev = "6f62bb62f630cbd910ae5b1b04f97688af7c1b42" }
arrow-schema = { git = "https://github.com/tustvold/arrow-rs.git", rev = "6f62bb62f630cbd910ae5b1b04f97688af7c1b42" }

datafusion = { git = "https://github.com/tustvold/arrow-datafusion.git", rev = "9de354bf45c0cc4121af04ea8138df7fddab76ed" }
datafusion-expr = { git = "https://github.com/tustvold/arrow-datafusion.git", rev = "9de354bf45c0cc4121af04ea8138df7fddab76ed" }
datafusion-common = { git = "https://github.com/tustvold/arrow-datafusion.git", rev = "9de354bf45c0cc4121af04ea8138df7fddab76ed"}
26 changes: 19 additions & 7 deletions src/context.rs
Expand Up @@ -24,6 +24,7 @@ use pyo3::exceptions::{PyKeyError, PyValueError};
use pyo3::prelude::*;

use datafusion::arrow::datatypes::Schema;
use datafusion::arrow::pyarrow::PyArrowType;
use datafusion::arrow::record_batch::RecordBatch;
use datafusion::datasource::datasource::TableProvider;
use datafusion::datasource::MemTable;
Expand Down Expand Up @@ -99,9 +100,16 @@ impl PySessionContext {
Ok(PyDataFrame::new(df))
}

fn create_dataframe(&mut self, partitions: Vec<Vec<RecordBatch>>) -> PyResult<PyDataFrame> {
let table = MemTable::try_new(partitions[0][0].schema(), partitions)
.map_err(DataFusionError::from)?;
fn create_dataframe(
&mut self,
partitions: Vec<Vec<PyArrowType<RecordBatch>>>,
) -> PyResult<PyDataFrame> {
let schema = partitions[0][0].0.schema();
let partitions = partitions
.into_iter()
.map(|x| x.into_iter().map(|x| x.0).collect())
.collect();
let table = MemTable::try_new(schema, partitions).map_err(DataFusionError::from)?;

// generate a random (unique) name for this table
// table name cannot start with numeric digit
Expand Down Expand Up @@ -136,9 +144,13 @@ impl PySessionContext {
fn register_record_batches(
&mut self,
name: &str,
partitions: Vec<Vec<RecordBatch>>,
partitions: Vec<Vec<PyArrowType<RecordBatch>>>,
) -> PyResult<()> {
let schema = partitions[0][0].schema();
let schema = partitions[0][0].0.schema();
let partitions = partitions
.into_iter()
.map(|x| x.into_iter().map(|x| x.0).collect())
.collect();
let table = MemTable::try_new(schema, partitions)?;
self.ctx
.register_table(name, Arc::new(table))
Expand Down Expand Up @@ -182,7 +194,7 @@ impl PySessionContext {
&mut self,
name: &str,
path: PathBuf,
schema: Option<Schema>,
schema: Option<PyArrowType<Schema>>,
has_header: bool,
delimiter: &str,
schema_infer_max_records: usize,
Expand All @@ -204,7 +216,7 @@ impl PySessionContext {
.delimiter(delimiter[0])
.schema_infer_max_records(schema_infer_max_records)
.file_extension(file_extension);
options.schema = schema.as_ref();
options.schema = schema.as_ref().map(|x| &x.0);

let result = self.ctx.register_csv(name, path, options);
wait_for_future(py, result).map_err(DataFusionError::from)?;
Expand Down
12 changes: 7 additions & 5 deletions src/dataframe.rs
Expand Up @@ -18,7 +18,7 @@
use crate::utils::wait_for_future;
use crate::{errors::DataFusionError, expression::PyExpr};
use datafusion::arrow::datatypes::Schema;
use datafusion::arrow::pyarrow::PyArrowConvert;
use datafusion::arrow::pyarrow::{PyArrowConvert, PyArrowException, PyArrowType};
use datafusion::arrow::util::pretty;
use datafusion::dataframe::DataFrame;
use datafusion::prelude::*;
Expand Down Expand Up @@ -65,8 +65,8 @@ impl PyDataFrame {
}

/// Returns the schema from the logical plan
fn schema(&self) -> Schema {
self.df.schema().into()
fn schema(&self) -> PyArrowType<Schema> {
PyArrowType(self.df.schema().into())
}

#[args(args = "*")]
Expand Down Expand Up @@ -144,7 +144,8 @@ impl PyDataFrame {
fn show(&self, py: Python, num: usize) -> PyResult<()> {
let df = self.df.limit(0, Some(num))?;
let batches = wait_for_future(py, df.collect())?;
Ok(pretty::print_batches(&batches)?)
Ok(pretty::print_batches(&batches)
.map_err(|err| PyArrowException::new_err(err.to_string()))?)
}

/// Filter out duplicate rows
Expand Down Expand Up @@ -186,7 +187,8 @@ impl PyDataFrame {
fn explain(&self, py: Python, verbose: bool, analyze: bool) -> PyResult<()> {
let df = self.df.explain(verbose, analyze)?;
let batches = wait_for_future(py, df.collect())?;
Ok(pretty::print_batches(&batches)?)
Ok(pretty::print_batches(&batches)
.map_err(|err| PyArrowException::new_err(err.to_string()))?)
}

/// Repartition a `DataFrame` based on a logical partitioning scheme.
Expand Down
10 changes: 9 additions & 1 deletion src/dataset.rs
Expand Up @@ -27,6 +27,7 @@ use std::sync::Arc;
use async_trait::async_trait;

use datafusion::arrow::datatypes::SchemaRef;
use datafusion::arrow::pyarrow::PyArrowType;
use datafusion::datasource::datasource::TableProviderFilterPushDown;
use datafusion::datasource::{TableProvider, TableType};
use datafusion::error::{DataFusionError, Result as DFResult};
Expand Down Expand Up @@ -74,7 +75,14 @@ impl TableProvider for Dataset {
Python::with_gil(|py| {
let dataset = self.dataset.as_ref(py);
// This can panic but since we checked that self.dataset is a pyarrow.dataset.Dataset it should never
Arc::new(dataset.getattr("schema").unwrap().extract().unwrap())
Arc::new(
dataset
.getattr("schema")
.unwrap()
.extract::<PyArrowType<_>>()
.unwrap()
.0,
)
})
}

Expand Down

0 comments on commit 04cbd33

Please sign in to comment.