diff --git a/Cargo.lock b/Cargo.lock index 41cdd3a2..576e767c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -70,11 +70,12 @@ checksum = "8da52d66c7071e2e3fa2a1e5c6d088fec47b593032b254f5e980de8ea54454d6" [[package]] name = "arrow" -version = "22.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c5936b4185aa57cb9790d8742aab22859045ce5cc6a3023796240cd101c19335" +version = "23.0.0" +source = "git+https://github.com/tustvold/arrow-rs.git?rev=6f62bb62f630cbd910ae5b1b04f97688af7c1b42#6f62bb62f630cbd910ae5b1b04f97688af7c1b42" dependencies = [ "ahash 0.8.0", + "arrow-buffer", + "arrow-schema", "bitflags", "chrono", "comfy-table", @@ -90,10 +91,23 @@ dependencies = [ "pyo3", "regex", "regex-syntax", - "serde", "serde_json", ] +[[package]] +name = "arrow-buffer" +version = "23.0.0" +source = "git+https://github.com/tustvold/arrow-rs.git?rev=6f62bb62f630cbd910ae5b1b04f97688af7c1b42#6f62bb62f630cbd910ae5b1b04f97688af7c1b42" +dependencies = [ + "half", + "num", +] + +[[package]] +name = "arrow-schema" +version = "23.0.0" +source = "git+https://github.com/tustvold/arrow-rs.git?rev=6f62bb62f630cbd910ae5b1b04f97688af7c1b42#6f62bb62f630cbd910ae5b1b04f97688af7c1b42" + [[package]] name = "async-trait" version = "0.1.56" @@ -325,8 +339,7 @@ dependencies = [ [[package]] name = "datafusion" version = "12.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2aca80caa2b0f7fdf267799b8895ac8b6341ea879db6b1e2d361ec49b47bc676" +source = "git+https://github.com/tustvold/arrow-datafusion.git?rev=9de354bf45c0cc4121af04ea8138df7fddab76ed#9de354bf45c0cc4121af04ea8138df7fddab76ed" dependencies = [ "ahash 0.8.0", "arrow", @@ -366,23 +379,20 @@ dependencies = [ [[package]] name = "datafusion-common" version = "12.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7721fd550f6a28ad7235b62462aa51e9a43b08f8346d5cbe4d61f1e83f5df511" +source = "git+https://github.com/tustvold/arrow-datafusion.git?rev=9de354bf45c0cc4121af04ea8138df7fddab76ed#9de354bf45c0cc4121af04ea8138df7fddab76ed" dependencies = [ "arrow", "object_store", "ordered-float 3.0.0", "parquet", "pyo3", - "serde_json", "sqlparser", ] [[package]] name = "datafusion-expr" version = "12.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2d81255d043dc594c0ded6240e8a9be6ce8d7c22777a5093357cdb97af3d29ce" +source = "git+https://github.com/tustvold/arrow-datafusion.git?rev=9de354bf45c0cc4121af04ea8138df7fddab76ed#9de354bf45c0cc4121af04ea8138df7fddab76ed" dependencies = [ "ahash 0.8.0", "arrow", @@ -393,8 +403,7 @@ dependencies = [ [[package]] name = "datafusion-optimizer" version = "12.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "71b39f8c75163691fff72b4a71816ad5a912e7c6963ee55f29ed1910b5a6993f" +source = "git+https://github.com/tustvold/arrow-datafusion.git?rev=9de354bf45c0cc4121af04ea8138df7fddab76ed#9de354bf45c0cc4121af04ea8138df7fddab76ed" dependencies = [ "arrow", "async-trait", @@ -409,8 +418,7 @@ dependencies = [ [[package]] name = "datafusion-physical-expr" version = "12.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "109c4138220a109feafb63bf05418b86b17a42ece4bf047c38e4fd417572a9f7" +source = "git+https://github.com/tustvold/arrow-datafusion.git?rev=9de354bf45c0cc4121af04ea8138df7fddab76ed#9de354bf45c0cc4121af04ea8138df7fddab76ed" dependencies = [ "ahash 0.8.0", "arrow", @@ -450,8 +458,7 @@ dependencies = [ [[package]] name = "datafusion-row" version = "12.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "87a178fc0fd7693d9c9f608f7b605823eb982c6731ede0cccd99e2319cacabbc" +source = "git+https://github.com/tustvold/arrow-datafusion.git?rev=9de354bf45c0cc4121af04ea8138df7fddab76ed#9de354bf45c0cc4121af04ea8138df7fddab76ed" dependencies = [ "arrow", "datafusion-common", @@ -462,8 +469,7 @@ dependencies = [ [[package]] name = "datafusion-sql" version = "12.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "148cb56e7635faff3b16019393c49b988188c3fdadd1ca90eadb322a80aa1128" +source = "git+https://github.com/tustvold/arrow-datafusion.git?rev=9de354bf45c0cc4121af04ea8138df7fddab76ed#9de354bf45c0cc4121af04ea8138df7fddab76ed" dependencies = [ "ahash 0.8.0", "arrow", @@ -669,6 +675,9 @@ name = "half" version = "2.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c207b0ee023c7fce79daf01828163aaf53a1ddd0be8b1ef9541da7d41f6fa63a" +dependencies = [ + "num-traits", +] [[package]] name = "hashbrown" @@ -732,9 +741,9 @@ dependencies = [ [[package]] name = "integer-encoding" -version = "1.1.7" +version = "3.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "48dc51180a9b377fd75814d0cc02199c20f8e99433d6762f650d39cdbbd3b56f" +checksum = "8bb03732005da905c88227371639bf1ad885cc712789c011c31c5fb3ab3ccf02" [[package]] name = "itertools" @@ -842,6 +851,12 @@ version = "0.2.126" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "349d5a591cd28b49e1d1037471617a32ddcda5731b99419008085f72d5a53836" +[[package]] +name = "libm" +version = "0.2.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "292a948cd991e376cf75541fe5b97a1081d713c618b4f1b9500f8844e49eb565" + [[package]] name = "libmimalloc-sys" version = "0.1.25" @@ -1032,6 +1047,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "578ede34cf02f8924ab9447f50c28075b4d3e5b269972345e7e0372b38c6cdcd" dependencies = [ "autocfg", + "libm", ] [[package]] @@ -1113,9 +1129,8 @@ dependencies = [ [[package]] name = "parquet" -version = "22.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "474c423be6f10921adab3b94b42ec7fe87c1b87e1360dee150976caee444224f" +version = "23.0.0" +source = "git+https://github.com/tustvold/arrow-rs.git?rev=6f62bb62f630cbd910ae5b1b04f97688af7c1b42#6f62bb62f630cbd910ae5b1b04f97688af7c1b42" dependencies = [ "ahash 0.8.0", "arrow", @@ -1129,7 +1144,6 @@ dependencies = [ "lz4", "num", "num-bigint", - "parquet-format", "rand 0.8.5", "seq-macro", "snap", @@ -1138,15 +1152,6 @@ dependencies = [ "zstd", ] -[[package]] -name = "parquet-format" -version = "4.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1f0c06cdcd5460967c485f9c40a821746f5955ad81990533c7fae95dbd9bc0b5" -dependencies = [ - "thrift", -] - [[package]] name = "paste" version = "1.0.7" @@ -1411,20 +1416,6 @@ name = "serde" version = "1.0.137" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "61ea8d54c77f8315140a05f4c7237403bf38b72704d031543aa1d16abbf517d1" -dependencies = [ - "serde_derive", -] - -[[package]] -name = "serde_derive" -version = "1.0.137" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1f26faba0c3959972377d3b2d306ee9f71faee9714294e41bb777f83f88578be" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] [[package]] name = "serde_json" @@ -1579,26 +1570,15 @@ dependencies = [ "syn", ] -[[package]] -name = "threadpool" -version = "1.8.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d050e60b33d41c19108b32cea32164033a9013fe3b46cbd4457559bfbf77afaa" -dependencies = [ - "num_cpus", -] - [[package]] name = "thrift" -version = "0.13.0" +version = "0.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0c6d965454947cc7266d22716ebfd07b18d84ebaf35eec558586bbb2a8cb6b5b" +checksum = "09678c4cdbb4eed72e18b7c2af1329c69825ed16fcbac62d083fc3e2b0590ff0" dependencies = [ "byteorder", "integer-encoding", - "log", "ordered-float 1.1.1", - "threadpool", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 217ac1c7..cb105c6a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -52,3 +52,14 @@ name = "datafusion._internal" [profile.release] lto = true codegen-units = 1 + + +[patch.crates-io] +arrow = { git = "https://github.com/tustvold/arrow-rs.git", rev = "6f62bb62f630cbd910ae5b1b04f97688af7c1b42" } +parquet = { git = "https://github.com/tustvold/arrow-rs.git", rev = "6f62bb62f630cbd910ae5b1b04f97688af7c1b42" } +arrow-buffer = { git = "https://github.com/tustvold/arrow-rs.git", rev = "6f62bb62f630cbd910ae5b1b04f97688af7c1b42" } +arrow-schema = { git = "https://github.com/tustvold/arrow-rs.git", rev = "6f62bb62f630cbd910ae5b1b04f97688af7c1b42" } + +datafusion = { git = "https://github.com/tustvold/arrow-datafusion.git", rev = "9de354bf45c0cc4121af04ea8138df7fddab76ed" } +datafusion-expr = { git = "https://github.com/tustvold/arrow-datafusion.git", rev = "9de354bf45c0cc4121af04ea8138df7fddab76ed" } +datafusion-common = { git = "https://github.com/tustvold/arrow-datafusion.git", rev = "9de354bf45c0cc4121af04ea8138df7fddab76ed"} diff --git a/src/context.rs b/src/context.rs index 25d08ef8..64627321 100644 --- a/src/context.rs +++ b/src/context.rs @@ -24,6 +24,7 @@ use pyo3::exceptions::{PyKeyError, PyValueError}; use pyo3::prelude::*; use datafusion::arrow::datatypes::Schema; +use datafusion::arrow::pyarrow::PyArrowType; use datafusion::arrow::record_batch::RecordBatch; use datafusion::datasource::datasource::TableProvider; use datafusion::datasource::MemTable; @@ -99,9 +100,16 @@ impl PySessionContext { Ok(PyDataFrame::new(df)) } - fn create_dataframe(&mut self, partitions: Vec>) -> PyResult { - let table = MemTable::try_new(partitions[0][0].schema(), partitions) - .map_err(DataFusionError::from)?; + fn create_dataframe( + &mut self, + partitions: Vec>>, + ) -> PyResult { + let schema = partitions[0][0].0.schema(); + let partitions = partitions + .into_iter() + .map(|x| x.into_iter().map(|x| x.0).collect()) + .collect(); + let table = MemTable::try_new(schema, partitions).map_err(DataFusionError::from)?; // generate a random (unique) name for this table // table name cannot start with numeric digit @@ -136,9 +144,13 @@ impl PySessionContext { fn register_record_batches( &mut self, name: &str, - partitions: Vec>, + partitions: Vec>>, ) -> PyResult<()> { - let schema = partitions[0][0].schema(); + let schema = partitions[0][0].0.schema(); + let partitions = partitions + .into_iter() + .map(|x| x.into_iter().map(|x| x.0).collect()) + .collect(); let table = MemTable::try_new(schema, partitions)?; self.ctx .register_table(name, Arc::new(table)) @@ -182,7 +194,7 @@ impl PySessionContext { &mut self, name: &str, path: PathBuf, - schema: Option, + schema: Option>, has_header: bool, delimiter: &str, schema_infer_max_records: usize, @@ -204,7 +216,7 @@ impl PySessionContext { .delimiter(delimiter[0]) .schema_infer_max_records(schema_infer_max_records) .file_extension(file_extension); - options.schema = schema.as_ref(); + options.schema = schema.as_ref().map(|x| &x.0); let result = self.ctx.register_csv(name, path, options); wait_for_future(py, result).map_err(DataFusionError::from)?; diff --git a/src/dataframe.rs b/src/dataframe.rs index c992d105..2538a857 100644 --- a/src/dataframe.rs +++ b/src/dataframe.rs @@ -18,7 +18,7 @@ use crate::utils::wait_for_future; use crate::{errors::DataFusionError, expression::PyExpr}; use datafusion::arrow::datatypes::Schema; -use datafusion::arrow::pyarrow::PyArrowConvert; +use datafusion::arrow::pyarrow::{PyArrowConvert, PyArrowException, PyArrowType}; use datafusion::arrow::util::pretty; use datafusion::dataframe::DataFrame; use datafusion::prelude::*; @@ -65,8 +65,8 @@ impl PyDataFrame { } /// Returns the schema from the logical plan - fn schema(&self) -> Schema { - self.df.schema().into() + fn schema(&self) -> PyArrowType { + PyArrowType(self.df.schema().into()) } #[args(args = "*")] @@ -144,7 +144,8 @@ impl PyDataFrame { fn show(&self, py: Python, num: usize) -> PyResult<()> { let df = self.df.limit(0, Some(num))?; let batches = wait_for_future(py, df.collect())?; - Ok(pretty::print_batches(&batches)?) + Ok(pretty::print_batches(&batches) + .map_err(|err| PyArrowException::new_err(err.to_string()))?) } /// Filter out duplicate rows @@ -186,7 +187,8 @@ impl PyDataFrame { fn explain(&self, py: Python, verbose: bool, analyze: bool) -> PyResult<()> { let df = self.df.explain(verbose, analyze)?; let batches = wait_for_future(py, df.collect())?; - Ok(pretty::print_batches(&batches)?) + Ok(pretty::print_batches(&batches) + .map_err(|err| PyArrowException::new_err(err.to_string()))?) } /// Repartition a `DataFrame` based on a logical partitioning scheme. diff --git a/src/dataset.rs b/src/dataset.rs index 952f2258..8208acc9 100644 --- a/src/dataset.rs +++ b/src/dataset.rs @@ -27,6 +27,7 @@ use std::sync::Arc; use async_trait::async_trait; use datafusion::arrow::datatypes::SchemaRef; +use datafusion::arrow::pyarrow::PyArrowType; use datafusion::datasource::datasource::TableProviderFilterPushDown; use datafusion::datasource::{TableProvider, TableType}; use datafusion::error::{DataFusionError, Result as DFResult}; @@ -74,7 +75,14 @@ impl TableProvider for Dataset { Python::with_gil(|py| { let dataset = self.dataset.as_ref(py); // This can panic but since we checked that self.dataset is a pyarrow.dataset.Dataset it should never - Arc::new(dataset.getattr("schema").unwrap().extract().unwrap()) + Arc::new( + dataset + .getattr("schema") + .unwrap() + .extract::>() + .unwrap() + .0, + ) }) } diff --git a/src/dataset_exec.rs b/src/dataset_exec.rs index 54997fb6..91b9942a 100644 --- a/src/dataset_exec.rs +++ b/src/dataset_exec.rs @@ -28,6 +28,7 @@ use futures::stream; use datafusion::arrow::datatypes::SchemaRef; use datafusion::arrow::error::ArrowError; use datafusion::arrow::error::Result as ArrowResult; +use datafusion::arrow::pyarrow::PyArrowType; use datafusion::arrow::record_batch::RecordBatch; use datafusion::error::{DataFusionError as InnerDataFusionError, Result as DFResult}; use datafusion::execution::context::TaskContext; @@ -54,7 +55,7 @@ impl Iterator for PyArrowBatchesAdapter { Some( batches .next()? - .and_then(|batch| batch.extract()) + .and_then(|batch| Ok(batch.extract::>()?.0)) .map_err(|err| ArrowError::ExternalError(Box::new(err))), ) }) @@ -109,7 +110,12 @@ impl DatasetExec { let scanner = dataset.call_method("scanner", (), Some(kwargs))?; - let schema = Arc::new(scanner.getattr("projected_schema")?.extract()?); + let schema = Arc::new( + scanner + .getattr("projected_schema")? + .extract::>()? + .0, + ); let builtins = Python::import(py, "builtins")?; let pylist = builtins.getattr("list")?; @@ -211,7 +217,7 @@ impl ExecutionPlan for DatasetExec { let schema: SchemaRef = Arc::new( scanner .getattr("projected_schema") - .and_then(|schema| schema.extract()) + .and_then(|schema| Ok(schema.extract::>()?.0)) .map_err(|err| InnerDataFusionError::External(Box::new(err)))?, ); let record_batches: &PyIterator = scanner diff --git a/src/expression.rs b/src/expression.rs index b4019920..5d8e7d0b 100644 --- a/src/expression.rs +++ b/src/expression.rs @@ -19,6 +19,7 @@ use pyo3::{basic::CompareOp, prelude::*}; use std::convert::{From, Into}; use datafusion::arrow::datatypes::DataType; +use datafusion::arrow::pyarrow::PyArrowType; use datafusion::logical_plan::{col, lit, Expr}; use datafusion::scalar::ScalarValue; @@ -125,12 +126,12 @@ impl PyExpr { self.expr.clone().is_null().into() } - pub fn cast(&self, to: DataType) -> PyExpr { + pub fn cast(&self, to: PyArrowType) -> PyExpr { // self.expr.cast_to() requires DFSchema to validate that the cast // is supported, omit that for now let expr = Expr::Cast { expr: Box::new(self.expr.clone()), - data_type: to, + data_type: to.0, }; expr.into() } diff --git a/src/udaf.rs b/src/udaf.rs index 8bc2b594..d5c51f17 100644 --- a/src/udaf.rs +++ b/src/udaf.rs @@ -21,7 +21,7 @@ use pyo3::{prelude::*, types::PyTuple}; use datafusion::arrow::array::ArrayRef; use datafusion::arrow::datatypes::DataType; -use datafusion::arrow::pyarrow::PyArrowConvert; +use datafusion::arrow::pyarrow::{PyArrowConvert, PyArrowType}; use datafusion::common::ScalarValue; use datafusion::error::{DataFusionError, Result}; use datafusion_expr::{ @@ -120,18 +120,18 @@ impl PyAggregateUDF { fn new( name: &str, accumulator: PyObject, - input_type: DataType, - return_type: DataType, - state_type: Vec, + input_type: PyArrowType, + return_type: PyArrowType, + state_type: Vec>, volatility: &str, ) -> PyResult { let function = create_udaf( name, - input_type, - Arc::new(return_type), + input_type.0, + Arc::new(return_type.0), parse_volatility(volatility)?, to_rust_accumulator(accumulator), - Arc::new(state_type), + Arc::new(state_type.into_iter().map(|x| x.0).collect()), ); Ok(Self { function }) } diff --git a/src/udf.rs b/src/udf.rs index b20eed59..5aac196c 100644 --- a/src/udf.rs +++ b/src/udf.rs @@ -21,7 +21,7 @@ use pyo3::{prelude::*, types::PyTuple}; use datafusion::arrow::array::ArrayRef; use datafusion::arrow::datatypes::DataType; -use datafusion::arrow::pyarrow::PyArrowConvert; +use datafusion::arrow::pyarrow::{PyArrowConvert, PyArrowType}; use datafusion::error::DataFusionError; use datafusion::physical_plan::functions::make_scalar_function; use datafusion::physical_plan::udf::ScalarUDF; @@ -73,14 +73,14 @@ impl PyScalarUDF { fn new( name: &str, func: PyObject, - input_types: Vec, - return_type: DataType, + input_types: Vec>, + return_type: PyArrowType, volatility: &str, ) -> PyResult { let function = create_udf( name, - input_types, - Arc::new(return_type), + input_types.into_iter().map(|x| x.0).collect(), + Arc::new(return_type.0), parse_volatility(volatility)?, to_rust_function(func), );