Skip to content

Commit

Permalink
[SessionContext] - Add read_csv/read_parquet/read_avro functions to S…
Browse files Browse the repository at this point in the history
…essionContext (#57)
  • Loading branch information
francis-du committed Oct 13, 2022
1 parent 0ac714a commit 55909a8
Show file tree
Hide file tree
Showing 8 changed files with 230 additions and 10 deletions.
1 change: 1 addition & 0 deletions .github/workflows/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ jobs:
- name: Run tests
run: |
git submodule update --init
source venv/bin/activate
maturin develop --locked
RUST_BACKTRACE=1 pytest -v .
6 changes: 6 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
[submodule "testing"]
path = testing
url = https://github.com/apache/arrow-testing.git
[submodule "parquet"]
path = parquet
url = https://github.com/apache/parquet-testing.git
117 changes: 109 additions & 8 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ default = ["mimalloc"]
tokio = { version = "1.0", features = ["macros", "rt", "rt-multi-thread", "sync"] }
rand = "0.7"
pyo3 = { version = "~0.17.1", features = ["extension-module", "abi3", "abi3-py37"] }
datafusion = { version = "^12.0.0", features = ["pyarrow"] }
datafusion = { version = "^12.0.0", features = ["pyarrow", "avro"] }
datafusion-expr = { version = "^12.0.0" }
datafusion-common = { version = "^12.0.0", features = ["pyarrow"] }
uuid = { version = "0.8", features = ["v4"] }
Expand Down
15 changes: 15 additions & 0 deletions datafusion/tests/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,3 +179,18 @@ def test_table_exist(ctx):
ctx.register_dataset("t", dataset)

assert ctx.table_exist("t") is True


def test_read_csv(ctx):
csv_df = ctx.read_csv(path="testing/data/csv/aggregate_test_100.csv")
csv_df.select(column("c1")).show()


def test_read_parquet(ctx):
csv_df = ctx.read_parquet(path="parquet/data/alltypes_plain.parquet")
csv_df.show()


def test_read_avro(ctx):
csv_df = ctx.read_avro(path="testing/data/avro/alltypes_plain.avro")
csv_df.show()
1 change: 1 addition & 0 deletions parquet
Submodule parquet added at e13af1
97 changes: 96 additions & 1 deletion src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ use datafusion::arrow::record_batch::RecordBatch;
use datafusion::datasource::datasource::TableProvider;
use datafusion::datasource::MemTable;
use datafusion::execution::context::{SessionConfig, SessionContext};
use datafusion::prelude::{CsvReadOptions, ParquetReadOptions};
use datafusion::prelude::{AvroReadOptions, CsvReadOptions, ParquetReadOptions};

use crate::catalog::{PyCatalog, PyTable};
use crate::dataframe::PyDataFrame;
Expand Down Expand Up @@ -264,4 +264,99 @@ impl PySessionContext {
fn session_id(&self) -> PyResult<String> {
Ok(self.ctx.session_id())
}

#[allow(clippy::too_many_arguments)]
#[args(
schema = "None",
has_header = "true",
delimiter = "\",\"",
schema_infer_max_records = "1000",
file_extension = "\".csv\"",
table_partition_cols = "vec![]"
)]
fn read_csv(
&self,
path: PathBuf,
schema: Option<Schema>,
has_header: bool,
delimiter: &str,
schema_infer_max_records: usize,
file_extension: &str,
table_partition_cols: Vec<String>,
py: Python,
) -> PyResult<PyDataFrame> {
let path = path
.to_str()
.ok_or_else(|| PyValueError::new_err("Unable to convert path to a string"))?;

let delimiter = delimiter.as_bytes();
if delimiter.len() != 1 {
return Err(PyValueError::new_err(
"Delimiter must be a single character",
));
};

let mut options = CsvReadOptions::new()
.has_header(has_header)
.delimiter(delimiter[0])
.schema_infer_max_records(schema_infer_max_records)
.file_extension(file_extension)
.table_partition_cols(table_partition_cols);
options.schema = schema.as_ref();

let result = self.ctx.read_csv(path, options);
let df = PyDataFrame::new(wait_for_future(py, result).map_err(DataFusionError::from)?);

Ok(df)
}

#[allow(clippy::too_many_arguments)]
#[args(
parquet_pruning = "true",
file_extension = "\".parquet\"",
table_partition_cols = "vec![]",
skip_metadata = "true"
)]
fn read_parquet(
&self,
path: &str,
table_partition_cols: Vec<String>,
parquet_pruning: bool,
file_extension: &str,
skip_metadata: bool,
py: Python,
) -> PyResult<PyDataFrame> {
let mut options = ParquetReadOptions::default()
.table_partition_cols(table_partition_cols)
.parquet_pruning(parquet_pruning)
.skip_metadata(skip_metadata);
options.file_extension = file_extension;

let result = self.ctx.read_parquet(path, options);
let df = PyDataFrame::new(wait_for_future(py, result).map_err(DataFusionError::from)?);
Ok(df)
}

#[allow(clippy::too_many_arguments)]
#[args(
schema = "None",
file_extension = "\".avro\"",
table_partition_cols = "vec![]"
)]
fn read_avro(
&self,
path: &str,
schema: Option<Schema>,
table_partition_cols: Vec<String>,
file_extension: &str,
py: Python,
) -> PyResult<PyDataFrame> {
let mut options = AvroReadOptions::default().table_partition_cols(table_partition_cols);
options.file_extension = file_extension;
options.schema = schema.map(Arc::new);

let result = self.ctx.read_avro(path, options);
let df = PyDataFrame::new(wait_for_future(py, result).map_err(DataFusionError::from)?);
Ok(df)
}
}
1 change: 1 addition & 0 deletions testing
Submodule testing added at 5bab2f

0 comments on commit 55909a8

Please sign in to comment.