Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Upgrade to DataFusion 13 #54

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
102 changes: 41 additions & 61 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

11 changes: 11 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,14 @@ name = "datafusion._internal"
[profile.release]
lto = true
codegen-units = 1


[patch.crates-io]
arrow = { git = "https://github.com/tustvold/arrow-rs.git", rev = "6f62bb62f630cbd910ae5b1b04f97688af7c1b42" }
parquet = { git = "https://github.com/tustvold/arrow-rs.git", rev = "6f62bb62f630cbd910ae5b1b04f97688af7c1b42" }
arrow-buffer = { git = "https://github.com/tustvold/arrow-rs.git", rev = "6f62bb62f630cbd910ae5b1b04f97688af7c1b42" }
arrow-schema = { git = "https://github.com/tustvold/arrow-rs.git", rev = "6f62bb62f630cbd910ae5b1b04f97688af7c1b42" }

datafusion = { git = "https://github.com/tustvold/arrow-datafusion.git", rev = "9de354bf45c0cc4121af04ea8138df7fddab76ed" }
datafusion-expr = { git = "https://github.com/tustvold/arrow-datafusion.git", rev = "9de354bf45c0cc4121af04ea8138df7fddab76ed" }
datafusion-common = { git = "https://github.com/tustvold/arrow-datafusion.git", rev = "9de354bf45c0cc4121af04ea8138df7fddab76ed"}
26 changes: 19 additions & 7 deletions src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ use pyo3::exceptions::{PyKeyError, PyValueError};
use pyo3::prelude::*;

use datafusion::arrow::datatypes::Schema;
use datafusion::arrow::pyarrow::PyArrowType;
use datafusion::arrow::record_batch::RecordBatch;
use datafusion::datasource::datasource::TableProvider;
use datafusion::datasource::MemTable;
Expand Down Expand Up @@ -99,9 +100,16 @@ impl PySessionContext {
Ok(PyDataFrame::new(df))
}

fn create_dataframe(&mut self, partitions: Vec<Vec<RecordBatch>>) -> PyResult<PyDataFrame> {
let table = MemTable::try_new(partitions[0][0].schema(), partitions)
.map_err(DataFusionError::from)?;
fn create_dataframe(
&mut self,
partitions: Vec<Vec<PyArrowType<RecordBatch>>>,
) -> PyResult<PyDataFrame> {
let schema = partitions[0][0].0.schema();
let partitions = partitions
.into_iter()
.map(|x| x.into_iter().map(|x| x.0).collect())
.collect();
let table = MemTable::try_new(schema, partitions).map_err(DataFusionError::from)?;

// generate a random (unique) name for this table
// table name cannot start with numeric digit
Expand Down Expand Up @@ -136,9 +144,13 @@ impl PySessionContext {
fn register_record_batches(
&mut self,
name: &str,
partitions: Vec<Vec<RecordBatch>>,
partitions: Vec<Vec<PyArrowType<RecordBatch>>>,
) -> PyResult<()> {
let schema = partitions[0][0].schema();
let schema = partitions[0][0].0.schema();
let partitions = partitions
tustvold marked this conversation as resolved.
Show resolved Hide resolved
.into_iter()
.map(|x| x.into_iter().map(|x| x.0).collect())
.collect();
let table = MemTable::try_new(schema, partitions)?;
self.ctx
.register_table(name, Arc::new(table))
Expand Down Expand Up @@ -182,7 +194,7 @@ impl PySessionContext {
&mut self,
name: &str,
path: PathBuf,
schema: Option<Schema>,
schema: Option<PyArrowType<Schema>>,
has_header: bool,
delimiter: &str,
schema_infer_max_records: usize,
Expand All @@ -204,7 +216,7 @@ impl PySessionContext {
.delimiter(delimiter[0])
.schema_infer_max_records(schema_infer_max_records)
.file_extension(file_extension);
options.schema = schema.as_ref();
options.schema = schema.as_ref().map(|x| &x.0);

let result = self.ctx.register_csv(name, path, options);
wait_for_future(py, result).map_err(DataFusionError::from)?;
Expand Down
12 changes: 7 additions & 5 deletions src/dataframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
use crate::utils::wait_for_future;
use crate::{errors::DataFusionError, expression::PyExpr};
use datafusion::arrow::datatypes::Schema;
use datafusion::arrow::pyarrow::PyArrowConvert;
use datafusion::arrow::pyarrow::{PyArrowConvert, PyArrowException, PyArrowType};
use datafusion::arrow::util::pretty;
use datafusion::dataframe::DataFrame;
use datafusion::prelude::*;
Expand Down Expand Up @@ -65,8 +65,8 @@ impl PyDataFrame {
}

/// Returns the schema from the logical plan
fn schema(&self) -> Schema {
self.df.schema().into()
fn schema(&self) -> PyArrowType<Schema> {
PyArrowType(self.df.schema().into())
}

#[args(args = "*")]
Expand Down Expand Up @@ -144,7 +144,8 @@ impl PyDataFrame {
fn show(&self, py: Python, num: usize) -> PyResult<()> {
let df = self.df.limit(0, Some(num))?;
let batches = wait_for_future(py, df.collect())?;
Ok(pretty::print_batches(&batches)?)
Ok(pretty::print_batches(&batches)
.map_err(|err| PyArrowException::new_err(err.to_string()))?)
}

/// Filter out duplicate rows
Expand Down Expand Up @@ -186,7 +187,8 @@ impl PyDataFrame {
fn explain(&self, py: Python, verbose: bool, analyze: bool) -> PyResult<()> {
let df = self.df.explain(verbose, analyze)?;
let batches = wait_for_future(py, df.collect())?;
Ok(pretty::print_batches(&batches)?)
Ok(pretty::print_batches(&batches)
.map_err(|err| PyArrowException::new_err(err.to_string()))?)
}

/// Repartition a `DataFrame` based on a logical partitioning scheme.
Expand Down
10 changes: 9 additions & 1 deletion src/dataset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ use std::sync::Arc;
use async_trait::async_trait;

use datafusion::arrow::datatypes::SchemaRef;
use datafusion::arrow::pyarrow::PyArrowType;
use datafusion::datasource::datasource::TableProviderFilterPushDown;
use datafusion::datasource::{TableProvider, TableType};
use datafusion::error::{DataFusionError, Result as DFResult};
Expand Down Expand Up @@ -74,7 +75,14 @@ impl TableProvider for Dataset {
Python::with_gil(|py| {
let dataset = self.dataset.as_ref(py);
// This can panic but since we checked that self.dataset is a pyarrow.dataset.Dataset it should never
Arc::new(dataset.getattr("schema").unwrap().extract().unwrap())
Arc::new(
dataset
.getattr("schema")
.unwrap()
.extract::<PyArrowType<_>>()
.unwrap()
.0,
)
})
}

Expand Down