Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Upgrade to DataFusion 13 (784f10bb) / Arrow 25.0.0 #176

Merged
merged 12 commits into from
Oct 27, 2022
594 changes: 361 additions & 233 deletions Cargo.lock

Large diffs are not rendered by default.

26 changes: 16 additions & 10 deletions Cargo.toml
Expand Up @@ -25,7 +25,11 @@ frontend-postgres = ["convergence", "convergence-arrow"]
object-store-s3 = ["object_store/aws"]

[dependencies]
arrow = { version = "22.0.0", features = ["prettyprint"] }
arrow = "25.0.0"
# For the JSON format support
# https://github.com/apache/arrow-rs/pull/2868
# https://github.com/apache/arrow-rs/pull/2724
arrow-integration-test = "25.0.0"
async-trait = "0.1.41"
base64 = "0.13.0"

Expand All @@ -35,13 +39,15 @@ clap = { version = "3.2.19", features = [ "derive" ] }
config = "0.13.1"

# PG wire protocol support
convergence = { git = "https://github.com/splitgraph/convergence", branch = "datafusion-12-upgrade", optional = true }
convergence-arrow = { git = "https://github.com/splitgraph/convergence", branch = "datafusion-12-upgrade", package = "convergence-arrow", optional = true }
datafusion = "12"
datafusion-expr = "12"
datafusion-proto = "12"
convergence = { git = "https://github.com/splitgraph/convergence", branch = "datafusion-13-upgrade", optional = true }
convergence-arrow = { git = "https://github.com/splitgraph/convergence", branch = "datafusion-13-upgrade", package = "convergence-arrow", optional = true }

# DataFusion post-13 update that picks up Arrow 25.0.0
datafusion = { git = "https://github.com/apache/arrow-datafusion", rev = "784f10bb57f86a4db2e01a6cb51da742af0dd9d9" }
datafusion-expr = { git = "https://github.com/apache/arrow-datafusion", rev = "784f10bb57f86a4db2e01a6cb51da742af0dd9d9" }
datafusion-proto = { git = "https://github.com/apache/arrow-datafusion", rev = "784f10bb57f86a4db2e01a6cb51da742af0dd9d9" }

futures = "0.3"
hashbrown = { version = "0.12", features = ["raw"] }
hex = ">=0.4.0"
itertools = ">=0.10.0"
log = "0.4"
Expand All @@ -59,7 +65,7 @@ reqwest = { version = "0.11.11", features = [ "stream" ] }
serde = "1.0.138"
serde_json = "1.0.81"
sha2 = ">=0.10.1"
sqlparser = "0.23"
sqlparser = "0.25"
sqlx = { version = "0.6.2", features = [ "runtime-tokio-rustls", "sqlite" ] }
strum = ">=0.24"
strum_macros = ">=0.24"
Expand All @@ -72,8 +78,8 @@ warp = "0.3"
wasmtime = "0.40.0"

[patch.crates-io]
# Pick up https://github.com/apache/arrow-rs/pull/2731
object_store = { git = "https://github.com/apache/arrow-rs", rev = "5f441eedff2b7621c46aded8b1caf3b665b8e8a9", package = "object_store" }
datafusion = { git = "https://github.com/apache/arrow-datafusion", rev = "784f10bb57f86a4db2e01a6cb51da742af0dd9d9" }


[dev-dependencies]
mockall = "0.11.1"
Expand Down
142 changes: 87 additions & 55 deletions src/context.rs
Expand Up @@ -4,20 +4,20 @@ use async_trait::async_trait;
use base64::decode;
use bytes::BytesMut;

use datafusion::datasource::TableProvider;
use datafusion::datasource::{provider_as_source, TableProvider};
use datafusion::sql::ResolvedTableReference;
use itertools::Itertools;
use object_store::local::LocalFileSystem;
use std::collections::HashSet;
use std::collections::{HashMap, HashSet};
use tokio::fs::File as AsyncFile;
use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader};

use std::fs::File;

use datafusion::datasource::file_format::avro::{AvroFormat, DEFAULT_AVRO_EXTENSION};
use datafusion::datasource::file_format::csv::{CsvFormat, DEFAULT_CSV_EXTENSION};
use datafusion::datasource::file_format::json::{JsonFormat, DEFAULT_JSON_EXTENSION};
use datafusion::datasource::file_format::parquet::DEFAULT_PARQUET_EXTENSION;
use datafusion::datasource::file_format::avro::AvroFormat;
use datafusion::datasource::file_format::csv::CsvFormat;
use datafusion::datasource::file_format::json::JsonFormat;

use datafusion::datasource::listing::ListingOptions;
use datafusion::datasource::object_store::ObjectStoreUrl;
use datafusion::execution::context::SessionState;
Expand All @@ -32,7 +32,6 @@ use crate::object_store::wrapped::InternalObjectStore;
use crate::utils::{gc_partitions, group_partitions, hash_file};
use crate::wasm_udf::wasm::create_udf_from_wasm;
use futures::{StreamExt, TryStreamExt};
use hashbrown::HashMap;

#[cfg(test)]
use mockall::automock;
Expand All @@ -42,10 +41,13 @@ use sqlparser::ast::{
AlterTableOperation, ObjectType, Statement, TableFactor, TableWithJoins,
};

use arrow_integration_test::field_to_json;
use std::iter::zip;
use std::str::FromStr;
use std::sync::Arc;

use datafusion::common::{Column, DFField, DFSchema, ToDFSchema};
use datafusion::datasource::file_format::file_type::{FileCompressionType, FileType};
pub use datafusion::error::{DataFusionError as Error, Result};
use datafusion::physical_expr::create_physical_expr;
use datafusion::physical_expr::execution_props::ExecutionProps;
Expand All @@ -67,11 +69,12 @@ use datafusion::{
prelude::SessionContext,
sql::{planner::SqlToRel, TableReference},
};

use datafusion_expr::logical_plan::{
CreateCatalog, CreateCatalogSchema, CreateExternalTable, CreateMemoryTable,
DropTable, Extension, LogicalPlan, Projection,
};
use datafusion_expr::Expr;
use datafusion_expr::{cast, Expr, LogicalPlanBuilder};
use log::{debug, info, warn};
use prost::Message;
use tempfile::TempPath;
Expand Down Expand Up @@ -204,7 +207,7 @@ fn build_partition_columns(

PartitionColumn {
name: Arc::from(column.name().to_string()),
r#type: Arc::from(column.to_json().to_string()),
r#type: Arc::from(field_to_json(column).to_string()),
min_value: Arc::new(min_value),
max_value: Arc::new(max_value),
null_count: stats.null_count.map(|nc| nc as i32),
Expand All @@ -216,7 +219,7 @@ fn build_partition_columns(
.iter()
.map(|column| PartitionColumn {
name: Arc::from(column.name().to_string()),
r#type: Arc::from(column.to_json().to_string()),
r#type: Arc::from(field_to_json(column).to_string()),
min_value: Arc::new(None),
max_value: Arc::new(None),
null_count: None,
Expand Down Expand Up @@ -763,36 +766,43 @@ impl DefaultSeafowlContext {
})
}

// Copied from DataFUsion's source code (private functions)
// Copied from DataFusion's source code (private functions)
async fn create_listing_table(
&self,
cmd: &CreateExternalTable,
) -> Result<Arc<dyn ExecutionPlan>> {
let (file_format, file_extension) = match cmd.file_type.as_str() {
"CSV" => (
Arc::new(
CsvFormat::default()
.with_has_header(cmd.has_header)
.with_delimiter(cmd.delimiter as u8),
) as Arc<dyn FileFormat>,
DEFAULT_CSV_EXTENSION,
),
"PARQUET" => (
Arc::new(ParquetFormat::default()) as Arc<dyn FileFormat>,
DEFAULT_PARQUET_EXTENSION,
),
"AVRO" => (
Arc::new(AvroFormat::default()) as Arc<dyn FileFormat>,
DEFAULT_AVRO_EXTENSION,
),
"JSON" => (
Arc::new(JsonFormat::default()) as Arc<dyn FileFormat>,
DEFAULT_JSON_EXTENSION,
),
_ => Err(DataFusionError::Execution(
let file_compression_type =
match FileCompressionType::from_str(cmd.file_compression_type.as_str()) {
Ok(t) => t,
Err(_) => Err(DataFusionError::Execution(
"Only known FileCompressionTypes can be ListingTables!".to_string(),
))?,
};

let file_type = match FileType::from_str(cmd.file_type.as_str()) {
Ok(t) => t,
Err(_) => Err(DataFusionError::Execution(
"Only known FileTypes can be ListingTables!".to_string(),
))?,
};

let file_extension =
file_type.get_ext_with_compression(file_compression_type.to_owned())?;

let file_format: Arc<dyn FileFormat> = match file_type {
FileType::CSV => Arc::new(
CsvFormat::default()
.with_has_header(cmd.has_header)
.with_delimiter(cmd.delimiter as u8)
.with_file_compression_type(file_compression_type),
),
FileType::PARQUET => Arc::new(ParquetFormat::default()),
FileType::AVRO => Arc::new(AvroFormat::default()),
FileType::JSON => Arc::new(
JsonFormat::default().with_file_compression_type(file_compression_type),
),
};

let table = self.inner.table(cmd.name.as_str());
match (cmd.if_not_exists, table) {
(true, Ok(_)) => Ok(make_dummy_exec()),
Expand All @@ -805,7 +815,7 @@ impl DefaultSeafowlContext {
};
let options = ListingOptions {
format: file_format,
collect_stat: false,
collect_stat: self.inner.copied_config().collect_statistics,
file_extension: file_extension.to_owned(),
target_partitions: self.inner.copied_config().target_partitions,
table_partition_cols: cmd.table_partition_cols.clone(),
Expand Down Expand Up @@ -1014,10 +1024,10 @@ impl SeafowlContext for DefaultSeafowlContext {
expr: target_schema.fields().iter().zip(plan.schema().field_names()).map(|(table_field, query_field_name)| {
// Generate CAST (source_col AS table_col_type) AS table_col
// If the type is the same, this will be optimized out.
Expr::Cast{
expr: Box::new(Expr::Column(Column::from_name(query_field_name))),
data_type: table_field.data_type().clone()
}.alias(table_field.name())
cast(
Expr::Column(Column::from_name(query_field_name)),
table_field.data_type().clone()).alias(table_field.name()
)
}).collect(),
input: Arc::new(plan),
schema: Arc::new(target_schema),
Expand Down Expand Up @@ -1051,7 +1061,7 @@ impl SeafowlContext for DefaultSeafowlContext {
// Get the actual table schema, since DF needs to validate unqualified columns
// (i.e. ones referenced only by column name, lacking the relation name)
let table_name = name.to_string();
let seafowl_table = self.try_get_seafowl_table(&table_name)?;
let seafowl_table = Arc::new(self.try_get_seafowl_table(&table_name)?);
let table_schema = seafowl_table.schema.arrow_schema.clone().to_dfschema()?;

let selection_expr = match selection {
Expand All @@ -1066,14 +1076,22 @@ impl SeafowlContext for DefaultSeafowlContext {
))
}).collect::<Result<Vec<(String, Expr)>>>()?;

Ok(LogicalPlan::Extension(Extension {
let logical_plan = LogicalPlan::Extension(Extension {
node: Arc::new(SeafowlExtensionNode::Update(Update {
table: Arc::new(seafowl_table),
table: seafowl_table.clone(),
table_plan: Arc::new(LogicalPlanBuilder::scan(table_name,
provider_as_source(seafowl_table),
None,
)?.build()?),
selection: selection_expr,
assignments: HashMap::from_iter(assignment_exprs),
assignments: assignment_exprs,
output_schema: Arc::new(DFSchema::empty())
})),
}))
});

// Run the optimizer in order to apply required transformations to the query plan
// (e.g. type coercions for the WHERE clause)
self.inner.optimize(&logical_plan)
}
Statement::Delete {
table_name,
Expand All @@ -1083,21 +1101,30 @@ impl SeafowlContext for DefaultSeafowlContext {
// Get the actual table schema, since DF needs to validate unqualified columns
// (i.e. ones referenced only by column name, lacking the relation name)
let table_name = table_name.to_string();
let seafowl_table = self.try_get_seafowl_table(&table_name)?;
let seafowl_table = Arc::new(self.try_get_seafowl_table(&table_name)?);
let table_schema = seafowl_table.schema.arrow_schema.clone().to_dfschema()?;

let selection_expr = match selection {
None => None,
Some(expr) => Some(query_planner.sql_to_rex(expr, &table_schema, &mut HashMap::new())?),
};

Ok(LogicalPlan::Extension(Extension {
let logical_plan = LogicalPlan::Extension(Extension {
node: Arc::new(SeafowlExtensionNode::Delete(Delete {
table: Arc::new(seafowl_table),
table: seafowl_table.clone(),
table_plan: Arc::new(LogicalPlanBuilder::scan(table_name,
provider_as_source(seafowl_table),
None,
)?
.build()?),
selection: selection_expr,
output_schema: Arc::new(DFSchema::empty())
})),
}))
});

// Run the optimizer in order to apply required transformations to the query plan
// (e.g. type coercions for the WHERE clause)
self.inner.optimize(&logical_plan)
},
Statement::CreateFunction {
temporary: false,
Expand Down Expand Up @@ -1375,10 +1402,15 @@ impl SeafowlContext for DefaultSeafowlContext {

let mut final_partition_ids =
Vec::with_capacity(partitions.len());

// Deduplicate assignments (we have to keep them as a vector in order
// to keep the order of column name -> expression mapping)
let assignment_map = HashMap::from_iter(assignments.clone());

let mut update_plan: Arc<dyn ExecutionPlan>;
let project_expressions = project_expressions(
&schema,
assignments,
&assignment_map,
selection_expr,
)?;

Expand Down Expand Up @@ -1936,12 +1968,12 @@ mod tests {
use super::test_utils::mock_context;

const PARTITION_1_FILE_NAME: &str =
"f48028e0f51f9447a90c407e9b0caa0f2af13f421db4939dc9b60825e0a26079.parquet";
"4c643a98a232ba10452165d3673af89c09999b8f747efb2f4fec163fbcd325df.parquet";
const PARTITION_2_FILE_NAME: &str =
"6a4d8c5721ab70411ad52b807cdde40ade15641b6168513ebd3e1e8e4eb4505b.parquet";
"7b1aaeaed9cf57509b2ecb31e9c298880e26cd269c93cc2fdb4973f2a6649f90.parquet";

const EXPECTED_INSERT_FILE_NAME: &str =
"3d552f85de97027297b42f6ffa644f5f1555b6b18131a4537b208b51ee4ef39f.parquet";
"bacf07bd78884b01c3d6d80c6799e6b9bd9281fa0224a2c20b6474745376b208.parquet";

fn to_min_max_value(value: ScalarValue) -> Arc<Option<Vec<u8>>> {
Arc::from(scalar_value_to_bytes(&value))
Expand Down Expand Up @@ -2220,7 +2252,7 @@ mod tests {
assert_eq!(
format!("{:?}", plan),
"Insert: some_table\
\n Projection: CAST(#column1 AS Date64) AS date, CAST(#column2 AS Float64) AS value\
\n Projection: CAST(column1 AS Date64) AS date, CAST(column2 AS Float64) AS value\
\n Values: (Utf8(\"2022-01-01T12:00:00\"), Int64(42))"
);
}
Expand All @@ -2239,8 +2271,8 @@ mod tests {
.unwrap();

assert_eq!(format!("{:?}", plan), "Insert: some_table\
\n Projection: CAST(#my_date AS Date64) AS date, CAST(#my_value AS Float64) AS value\
\n Projection: #testdb.testcol.some_table.date AS my_date, #testdb.testcol.some_table.value AS my_value\
\n Projection: CAST(my_date AS Date64) AS date, CAST(my_value AS Float64) AS value\
\n Projection: testdb.testcol.some_table.date AS my_date, testdb.testcol.some_table.value AS my_value\
\n TableScan: testdb.testcol.some_table");
}

Expand All @@ -2258,7 +2290,7 @@ mod tests {
assert_eq!(
format!("{:?}", plan),
"Insert: some_table\
\n Projection: CAST(#column1 AS Date64) AS date, CAST(#column2 AS Float64) AS value\
\n Projection: CAST(column1 AS Date64) AS date, CAST(column2 AS Float64) AS value\
\n Values: (Utf8(\"2022-01-01T12:00:00\"), Int64(42))"
);
}
Expand Down