From 117df4d4fa2881435a428cff3ab5880c3b8632e1 Mon Sep 17 00:00:00 2001 From: "xudong.w" Date: Fri, 22 Jul 2022 22:38:58 +0800 Subject: [PATCH 1/2] chore: update jit-related dependencies (#2956) --- datafusion/common/Cargo.toml | 2 +- datafusion/jit/Cargo.toml | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/datafusion/common/Cargo.toml b/datafusion/common/Cargo.toml index 28a079ed0c54..0e1a79445146 100644 --- a/datafusion/common/Cargo.toml +++ b/datafusion/common/Cargo.toml @@ -40,7 +40,7 @@ pyarrow = ["pyo3"] [dependencies] arrow = { version = "18.0.0", features = ["prettyprint"] } avro-rs = { version = "0.13", features = ["snappy"], optional = true } -cranelift-module = { version = "0.85.0", optional = true } +cranelift-module = { version = "0.86.1", optional = true } object_store = { version = "0.3", optional = true } ordered-float = "3.0" parquet = { version = "18.0.0", features = ["arrow"], optional = true } diff --git a/datafusion/jit/Cargo.toml b/datafusion/jit/Cargo.toml index 69327e1d1413..eb539cbfcb16 100644 --- a/datafusion/jit/Cargo.toml +++ b/datafusion/jit/Cargo.toml @@ -37,10 +37,10 @@ jit = [] [dependencies] arrow = { version = "18.0.0" } -cranelift = "0.85.0" -cranelift-jit = "0.85.0" -cranelift-module = "0.85.0" -cranelift-native = "0.85.0" +cranelift = "0.86.1" +cranelift-jit = "0.86.1" +cranelift-module = "0.86.1" +cranelift-native = "0.86.1" datafusion-common = { path = "../common", version = "10.0.0", features = ["jit"] } datafusion-expr = { path = "../expr", version = "10.0.0" } From 7b0f2f846a7c8c2ffee2a4f29772cf3527a8d92c Mon Sep 17 00:00:00 2001 From: Brent Gardner Date: Fri, 22 Jul 2022 10:06:14 -0600 Subject: [PATCH 2/2] Add support for correlated subqueries & fix all related TPC-H benchmark issues (#2885) * Failing test case for TPC-H query 20 * Fix name * Broken test for adding intervals to dates * Tests pass * Fix rebase * Fix query * Additional tests * Reduce to minimum failing (and passing) cases * Adjust so data _should_ be returned, but see none * Fixed data, decorrelated test passes * Check in plans * Put real assertion in place * Add test for already working subquery optimizer * Add decorellator * Check in broken test * Add some passing and failing tests to see scope of problem * Have almost all inputs needed for optimization, but need to catch 1 level earlier in tree * Collected all inputs, now we just need to optimize * Successfully decorrelated query 4 * refactor * Pass test 4 * Ready for PR? * Only operate on equality expressions * Lint error * Tests still pass because we are losing remaining predicate * Don't lose remaining expressions * Update test to expect remaining filter clause * Debugging * Can run query 4 * Remove debugging code * Clippy * Refactor where exists, add scalar subquery * Login qty < () and 0.2 times, predicate pushdown is killing our plan * Query plan looks good * Fudge data to make test output nicer * Fix syntax error * [WIP] where in * Working recursively, q20 plan looks good, but execution failing * Fix CSV for execution error, remove silly variables in favor of --nocapture * Silence verbose logs * Query 21 test * [WIP] refactoring, query 4 looking good * [WIP] 4 & 17 look good * 22 good? * Check in "Test" for query 11 * query 11 works * Don't throw away plans when multiple subqueries in one filter * Manually decorellate query 21 * [WIP] add data for query 21, anti join failing for some reason * Does appear to be problem with anti-join * Minimum failing test * Verify anti join fix * Repeatable tests * cargo fmt * Restore some optimizers and update test expectations * Restore some optimizers and update test expectations * Restore some optimizers and update test expectations * Restore some optimizers and update test expectations * Cleanup * Cleanup scalar subquery, de-duplicate some code * Cleanup * Refactor * Refactor * Refactor * Refactor * Handle recursive where in * Update assertions * Support recursion in where exists queries * Unit tests on where in * Add correlated where in test * Nasty code to make where in work for both correlated and uncorrelated queries * Cleanup * Refactoring * Refactoring * Add correlated unit test * Add correlated where exists unit test * [WIP] Failing scalar subquery unit test * Refactor * tuple mixup * Scalar subquery unit test * ASF header * PR feedback * PR feedback * PR feedback * PR feedback * Fix build again * Formatting * Testing * multiple where in * Unit tests for where in * where exists tests * scalar subquery tests * add aggregates to scalar subqueries * Remove tests that only existed to get logical plans as input to unit tests * Check in assertions for valid tests * 1/33 passing unit tests :/ * Down to one failing test * All the unit tests pass * into methods * Where exists unit tests passing * Try from methods * Fix tests * Fix tests * Refactor * Fix test * Refactor * Fix test * Fix error message * Fix tests * Fix tests * Refactor * Refactor and fix tests * Improved recursive subquery test * Recursive subquery fix * Update tests * Update tests * Update tests * Doc * Clippy * Linter & clippy * Add doc, move test methods into test modules * PR cleanup * Inline test data * Remove shared test data * Remove shared test data * Update tests * Fix toml * Update expectation * PR feedback * PR feedback Co-authored-by: Andrew Lamb * Fix test to reveal logic error * Simplify test * Fix stuff, break other stuff * I've writen scala in rust because I'm in a hurry :( * Clean the API up a little * PR feedback * PR feedback * PR feedback * PR feedback Co-authored-by: Andrew Lamb --- benchmarks/queries/q20.sql | 2 +- datafusion/common/src/error.rs | 27 + datafusion/core/Cargo.toml | 2 + datafusion/core/src/execution/context.rs | 6 + .../src/physical_plan/coalesce_batches.rs | 4 +- .../core/src/physical_plan/file_format/mod.rs | 4 +- datafusion/core/tests/sql/mod.rs | 111 ++- datafusion/core/tests/sql/subqueries.rs | 474 ++++++++++++ datafusion/core/tests/tpch-csv/lineitem.csv | 4 +- datafusion/core/tests/tpch-csv/nation.csv | 2 +- datafusion/core/tests/tpch-csv/part.csv | 2 + datafusion/core/tests/tpch-csv/partsupp.csv | 2 + datafusion/core/tests/tpch-csv/region.csv | 2 + datafusion/core/tests/tpch-csv/supplier.csv | 3 + datafusion/expr/src/expr.rs | 9 +- datafusion/expr/src/logical_plan/plan.rs | 36 +- datafusion/optimizer/Cargo.toml | 5 + .../src/decorrelate_scalar_subquery.rs | 705 ++++++++++++++++++ .../optimizer/src/decorrelate_where_exists.rs | 557 ++++++++++++++ .../optimizer/src/decorrelate_where_in.rs | 693 +++++++++++++++++ datafusion/optimizer/src/lib.rs | 3 + datafusion/optimizer/src/test/mod.rs | 77 +- datafusion/optimizer/src/utils.rs | 235 +++++- 23 files changed, 2952 insertions(+), 13 deletions(-) create mode 100644 datafusion/core/tests/sql/subqueries.rs create mode 100644 datafusion/core/tests/tpch-csv/part.csv create mode 100644 datafusion/core/tests/tpch-csv/partsupp.csv create mode 100644 datafusion/core/tests/tpch-csv/region.csv create mode 100644 datafusion/core/tests/tpch-csv/supplier.csv create mode 100644 datafusion/optimizer/src/decorrelate_scalar_subquery.rs create mode 100644 datafusion/optimizer/src/decorrelate_where_exists.rs create mode 100644 datafusion/optimizer/src/decorrelate_where_in.rs diff --git a/benchmarks/queries/q20.sql b/benchmarks/queries/q20.sql index f0339a6013c2..dd61a7d8e6ea 100644 --- a/benchmarks/queries/q20.sql +++ b/benchmarks/queries/q20.sql @@ -28,7 +28,7 @@ where l_partkey = ps_partkey and l_suppkey = ps_suppkey and l_shipdate >= date '1994-01-01' - and l_shipdate < 'date 1994-01-01' + interval '1' year + and l_shipdate < date '1994-01-01' + interval '1' year ) ) and s_nationkey = n_nationkey diff --git a/datafusion/common/src/error.rs b/datafusion/common/src/error.rs index c1d0f29b11eb..de5bbe8e004c 100644 --- a/datafusion/common/src/error.rs +++ b/datafusion/common/src/error.rs @@ -83,6 +83,30 @@ pub enum DataFusionError { #[cfg(feature = "jit")] /// Error occurs during code generation JITError(ModuleError), + /// Error with additional context + Context(String, Box), +} + +#[macro_export] +macro_rules! context { + ($desc:expr, $err:expr) => { + datafusion_common::DataFusionError::Context( + format!("{} at {}:{}", $desc, file!(), line!()), + Box::new($err), + ) + }; +} + +#[macro_export] +macro_rules! plan_err { + ($desc:expr) => { + Err(datafusion_common::DataFusionError::Plan(format!( + "{} at {}:{}", + $desc, + file!(), + line!() + ))) + }; } /// Schema-related errors @@ -285,6 +309,9 @@ impl Display for DataFusionError { DataFusionError::ObjectStore(ref desc) => { write!(f, "Object Store error: {}", desc) } + DataFusionError::Context(ref desc, ref err) => { + write!(f, "{}\ncaused by\n{}", desc, *err) + } } } } diff --git a/datafusion/core/Cargo.toml b/datafusion/core/Cargo.toml index ac21d7f90097..351e72fc2a5d 100644 --- a/datafusion/core/Cargo.toml +++ b/datafusion/core/Cargo.toml @@ -94,6 +94,8 @@ uuid = { version = "1.0", features = ["v4"] } [dev-dependencies] criterion = "0.3" +csv = "1.1.6" +ctor = "0.1.22" doc-comment = "0.3" env_logger = "0.9" fuzz-utils = { path = "fuzz-utils" } diff --git a/datafusion/core/src/execution/context.rs b/datafusion/core/src/execution/context.rs index e37d3b0bafc1..41964e33ac96 100644 --- a/datafusion/core/src/execution/context.rs +++ b/datafusion/core/src/execution/context.rs @@ -102,6 +102,9 @@ use async_trait::async_trait; use chrono::{DateTime, Utc}; use datafusion_common::ScalarValue; use datafusion_expr::TableSource; +use datafusion_optimizer::decorrelate_scalar_subquery::DecorrelateScalarSubquery; +use datafusion_optimizer::decorrelate_where_exists::DecorrelateWhereExists; +use datafusion_optimizer::decorrelate_where_in::DecorrelateWhereIn; use datafusion_optimizer::filter_null_join_keys::FilterNullJoinKeys; use datafusion_sql::{ parser::DFParser, @@ -1356,6 +1359,9 @@ impl SessionState { // Simplify expressions first to maximize the chance // of applying other optimizations Arc::new(SimplifyExpressions::new()), + Arc::new(DecorrelateWhereExists::new()), + Arc::new(DecorrelateWhereIn::new()), + Arc::new(DecorrelateScalarSubquery::new()), Arc::new(SubqueryFilterToJoin::new()), Arc::new(EliminateFilter::new()), Arc::new(CommonSubexprEliminate::new()), diff --git a/datafusion/core/src/physical_plan/coalesce_batches.rs b/datafusion/core/src/physical_plan/coalesce_batches.rs index 3f39caaefba8..a257ccf09994 100644 --- a/datafusion/core/src/physical_plan/coalesce_batches.rs +++ b/datafusion/core/src/physical_plan/coalesce_batches.rs @@ -35,7 +35,7 @@ use arrow::datatypes::SchemaRef; use arrow::error::Result as ArrowResult; use arrow::record_batch::RecordBatch; use futures::stream::{Stream, StreamExt}; -use log::debug; +use log::trace; use super::expressions::PhysicalSortExpr; use super::metrics::{BaselineMetrics, MetricsSet}; @@ -286,7 +286,7 @@ pub fn concat_batches( )?; arrays.push(array); } - debug!( + trace!( "Combined {} batches containing {} rows", batches.len(), row_count diff --git a/datafusion/core/src/physical_plan/file_format/mod.rs b/datafusion/core/src/physical_plan/file_format/mod.rs index 3ea520b2cc94..c26b2d760a77 100644 --- a/datafusion/core/src/physical_plan/file_format/mod.rs +++ b/datafusion/core/src/physical_plan/file_format/mod.rs @@ -26,6 +26,8 @@ mod file_stream; mod json; mod parquet; +pub(crate) use self::csv::plan_to_csv; +pub use self::csv::CsvExec; pub(crate) use self::parquet::plan_to_parquet; pub use self::parquet::ParquetExec; use arrow::{ @@ -36,8 +38,6 @@ use arrow::{ record_batch::RecordBatch, }; pub use avro::AvroExec; -pub(crate) use csv::plan_to_csv; -pub use csv::CsvExec; pub(crate) use json::plan_to_json; pub use json::NdJsonExec; diff --git a/datafusion/core/tests/sql/mod.rs b/datafusion/core/tests/sql/mod.rs index a7f4cabe9d7b..186584aebd1e 100644 --- a/datafusion/core/tests/sql/mod.rs +++ b/datafusion/core/tests/sql/mod.rs @@ -49,6 +49,7 @@ use datafusion_expr::Volatility; use object_store::path::Path; use std::fs::File; use std::io::Write; +use std::ops::Sub; use std::path::PathBuf; use tempfile::TempDir; @@ -108,6 +109,7 @@ mod explain; mod idenfifers; pub mod information_schema; mod partitioned_csv; +mod subqueries; #[cfg(feature = "unicode_expressions")] pub mod unicode; @@ -483,7 +485,43 @@ fn get_tpch_table_schema(table: &str) -> Schema { Field::new("n_comment", DataType::Utf8, false), ]), - _ => unimplemented!(), + "supplier" => Schema::new(vec![ + Field::new("s_suppkey", DataType::Int64, false), + Field::new("s_name", DataType::Utf8, false), + Field::new("s_address", DataType::Utf8, false), + Field::new("s_nationkey", DataType::Int64, false), + Field::new("s_phone", DataType::Utf8, false), + Field::new("s_acctbal", DataType::Float64, false), + Field::new("s_comment", DataType::Utf8, false), + ]), + + "partsupp" => Schema::new(vec![ + Field::new("ps_partkey", DataType::Int64, false), + Field::new("ps_suppkey", DataType::Int64, false), + Field::new("ps_availqty", DataType::Int32, false), + Field::new("ps_supplycost", DataType::Float64, false), + Field::new("ps_comment", DataType::Utf8, false), + ]), + + "part" => Schema::new(vec![ + Field::new("p_partkey", DataType::Int64, false), + Field::new("p_name", DataType::Utf8, false), + Field::new("p_mfgr", DataType::Utf8, false), + Field::new("p_brand", DataType::Utf8, false), + Field::new("p_type", DataType::Utf8, false), + Field::new("p_size", DataType::Int32, false), + Field::new("p_container", DataType::Utf8, false), + Field::new("p_retailprice", DataType::Float64, false), + Field::new("p_comment", DataType::Utf8, false), + ]), + + "region" => Schema::new(vec![ + Field::new("r_regionkey", DataType::Int64, false), + Field::new("r_name", DataType::Utf8, false), + Field::new("r_comment", DataType::Utf8, false), + ]), + + _ => unimplemented!("Table: {}", table), } } @@ -499,6 +537,77 @@ async fn register_tpch_csv(ctx: &SessionContext, table: &str) -> Result<()> { Ok(()) } +async fn register_tpch_csv_data( + ctx: &SessionContext, + table_name: &str, + data: &str, +) -> Result<()> { + let schema = Arc::new(get_tpch_table_schema(table_name)); + + let mut reader = ::csv::ReaderBuilder::new() + .has_headers(false) + .from_reader(data.as_bytes()); + let records: Vec<_> = reader.records().map(|it| it.unwrap()).collect(); + + let mut cols: Vec> = vec![]; + for field in schema.fields().iter() { + match field.data_type() { + DataType::Utf8 => cols.push(Box::new(StringBuilder::new(records.len()))), + DataType::Date32 => cols.push(Box::new(Date32Builder::new(records.len()))), + DataType::Int32 => cols.push(Box::new(Int32Builder::new(records.len()))), + DataType::Int64 => cols.push(Box::new(Int64Builder::new(records.len()))), + DataType::Float64 => cols.push(Box::new(Float64Builder::new(records.len()))), + _ => { + let msg = format!("Not implemented: {}", field.data_type()); + Err(DataFusionError::Plan(msg))? + } + } + } + + for record in records.iter() { + for (idx, val) in record.iter().enumerate() { + let col = cols.get_mut(idx).unwrap(); + let field = schema.field(idx); + match field.data_type() { + DataType::Utf8 => { + let sb = col.as_any_mut().downcast_mut::().unwrap(); + sb.append_value(val)?; + } + DataType::Date32 => { + let sb = col.as_any_mut().downcast_mut::().unwrap(); + let dt = NaiveDate::parse_from_str(val.trim(), "%Y-%m-%d").unwrap(); + let dt = dt.sub(NaiveDate::from_ymd(1970, 1, 1)).num_days() as i32; + sb.append_value(dt)?; + } + DataType::Int32 => { + let sb = col.as_any_mut().downcast_mut::().unwrap(); + sb.append_value(val.trim().parse().unwrap())?; + } + DataType::Int64 => { + let sb = col.as_any_mut().downcast_mut::().unwrap(); + sb.append_value(val.trim().parse().unwrap())?; + } + DataType::Float64 => { + let sb = col.as_any_mut().downcast_mut::().unwrap(); + sb.append_value(val.trim().parse().unwrap())?; + } + _ => Err(DataFusionError::Plan(format!( + "Not implemented: {}", + field.data_type() + )))?, + } + } + } + let cols: Vec = cols.iter_mut().map(|it| it.finish()).collect(); + + let batch = RecordBatch::try_new(Arc::clone(&schema), cols)?; + + let table = Arc::new(MemTable::try_new(Arc::clone(&schema), vec![vec![batch]])?); + let _ = ctx.register_table(table_name, table).unwrap(); + + Ok(()) +} + async fn register_aggregate_csv_by_sql(ctx: &SessionContext) { let testdata = datafusion::test_util::arrow_test_data(); diff --git a/datafusion/core/tests/sql/subqueries.rs b/datafusion/core/tests/sql/subqueries.rs new file mode 100644 index 000000000000..4eaf921f6937 --- /dev/null +++ b/datafusion/core/tests/sql/subqueries.rs @@ -0,0 +1,474 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use super::*; +use crate::sql::execute_to_batches; +use datafusion::assert_batches_eq; +use datafusion::prelude::SessionContext; +use log::debug; + +#[cfg(test)] +#[ctor::ctor] +fn init() { + let _ = env_logger::try_init(); +} + +#[tokio::test] +async fn correlated_recursive_scalar_subquery() -> Result<()> { + let ctx = SessionContext::new(); + register_tpch_csv(&ctx, "customer").await?; + register_tpch_csv(&ctx, "orders").await?; + register_tpch_csv(&ctx, "lineitem").await?; + + let sql = r#" +select c_custkey from customer +where c_acctbal < ( + select sum(o_totalprice) from orders + where o_custkey = c_custkey + and o_totalprice < ( + select sum(l_extendedprice) as price from lineitem where l_orderkey = o_orderkey + ) +) order by c_custkey;"#; + + // assert plan + let plan = ctx.create_logical_plan(sql).unwrap(); + debug!("input:\n{}", plan.display_indent()); + + let plan = ctx.optimize(&plan).unwrap(); + let actual = format!("{}", plan.display_indent()); + let expected = r#"Sort: #customer.c_custkey ASC NULLS LAST + Projection: #customer.c_custkey + Filter: #customer.c_acctbal < #__sq_2.__value + Inner Join: #customer.c_custkey = #__sq_2.o_custkey + TableScan: customer projection=[c_custkey, c_acctbal] + Projection: #orders.o_custkey, #SUM(orders.o_totalprice) AS __value, alias=__sq_2 + Aggregate: groupBy=[[#orders.o_custkey]], aggr=[[SUM(#orders.o_totalprice)]] + Filter: #orders.o_totalprice < #__sq_1.__value + Inner Join: #orders.o_orderkey = #__sq_1.l_orderkey + TableScan: orders projection=[o_orderkey, o_custkey, o_totalprice] + Projection: #lineitem.l_orderkey, #SUM(lineitem.l_extendedprice) AS price AS __value, alias=__sq_1 + Aggregate: groupBy=[[#lineitem.l_orderkey]], aggr=[[SUM(#lineitem.l_extendedprice)]] + TableScan: lineitem projection=[l_orderkey, l_extendedprice]"# + .to_string(); + assert_eq!(actual, expected); + + Ok(()) +} + +#[tokio::test] +async fn correlated_where_in() -> Result<()> { + let orders = r#"1,3691,O,194029.55,1996-01-02,5-LOW,Clerk#000000951,0, +65,1627,P,99763.79,1995-03-18,1-URGENT,Clerk#000000632,0, +"#; + let lineitems = r#"1,15519,785,1,17,24386.67,0.04,0.02,N,O,1996-03-13,1996-02-12,1996-03-22,DELIVER IN PERSON,TRUCK, +1,6731,732,2,36,58958.28,0.09,0.06,N,O,1996-04-12,1996-02-28,1996-04-20,TAKE BACK RETURN,MAIL, +65,5970,481,1,26,48775.22,0.03,0.03,A,F,1995-04-20,1995-04-25,1995-05-13,NONE,TRUCK, +65,7382,897,2,22,28366.36,0,0.05,N,O,1995-07-17,1995-06-04,1995-07-19,COLLECT COD,FOB, +"#; + + let ctx = SessionContext::new(); + register_tpch_csv_data(&ctx, "orders", orders).await?; + register_tpch_csv_data(&ctx, "lineitem", lineitems).await?; + + let sql = r#"select o_orderkey from orders +where o_orderstatus in ( + select l_linestatus from lineitem where l_orderkey = orders.o_orderkey +);"#; + + // assert plan + let plan = ctx.create_logical_plan(sql).unwrap(); + let plan = ctx.optimize(&plan).unwrap(); + let actual = format!("{}", plan.display_indent()); + let expected = r#"Projection: #orders.o_orderkey + Semi Join: #orders.o_orderstatus = #__sq_1.l_linestatus, #orders.o_orderkey = #__sq_1.l_orderkey + TableScan: orders projection=[o_orderkey, o_orderstatus] + Projection: #lineitem.l_linestatus AS l_linestatus, #lineitem.l_orderkey AS l_orderkey, alias=__sq_1 + TableScan: lineitem projection=[l_orderkey, l_linestatus]"# + .to_string(); + assert_eq!(actual, expected); + + // assert data + let results = execute_to_batches(&ctx, sql).await; + let expected = vec![ + "+------------+", + "| o_orderkey |", + "+------------+", + "| 1 |", + "+------------+", + ]; + assert_batches_eq!(expected, &results); + + Ok(()) +} + +#[tokio::test] +async fn tpch_q2_correlated() -> Result<()> { + let ctx = SessionContext::new(); + register_tpch_csv(&ctx, "part").await?; + register_tpch_csv(&ctx, "supplier").await?; + register_tpch_csv(&ctx, "partsupp").await?; + register_tpch_csv(&ctx, "nation").await?; + register_tpch_csv(&ctx, "region").await?; + + let sql = r#"select s_acctbal, s_name, n_name, p_partkey, p_mfgr, s_address, s_phone, s_comment +from part, supplier, partsupp, nation, region +where p_partkey = ps_partkey and s_suppkey = ps_suppkey and p_size = 15 and p_type like '%BRASS' + and s_nationkey = n_nationkey and n_regionkey = r_regionkey and r_name = 'EUROPE' + and ps_supplycost = ( + select min(ps_supplycost) from partsupp, supplier, nation, region + where p_partkey = ps_partkey and s_suppkey = ps_suppkey and s_nationkey = n_nationkey + and n_regionkey = r_regionkey and r_name = 'EUROPE' + ) +order by s_acctbal desc, n_name, s_name, p_partkey;"#; + + // assert plan + let plan = ctx.create_logical_plan(sql).unwrap(); + let plan = ctx.optimize(&plan).unwrap(); + let actual = format!("{}", plan.display_indent()); + let expected = r#"Sort: #supplier.s_acctbal DESC NULLS FIRST, #nation.n_name ASC NULLS LAST, #supplier.s_name ASC NULLS LAST, #part.p_partkey ASC NULLS LAST + Projection: #supplier.s_acctbal, #supplier.s_name, #nation.n_name, #part.p_partkey, #part.p_mfgr, #supplier.s_address, #supplier.s_phone, #supplier.s_comment + Filter: #partsupp.ps_supplycost = #__sq_1.__value + Inner Join: #part.p_partkey = #__sq_1.ps_partkey + Inner Join: #nation.n_regionkey = #region.r_regionkey + Inner Join: #supplier.s_nationkey = #nation.n_nationkey + Inner Join: #partsupp.ps_suppkey = #supplier.s_suppkey + Inner Join: #part.p_partkey = #partsupp.ps_partkey + Filter: #part.p_size = Int64(15) AND #part.p_type LIKE Utf8("%BRASS") + TableScan: part projection=[p_partkey, p_mfgr, p_type, p_size], partial_filters=[#part.p_size = Int64(15), #part.p_type LIKE Utf8("%BRASS")] + TableScan: partsupp projection=[ps_partkey, ps_suppkey, ps_supplycost] + TableScan: supplier projection=[s_suppkey, s_name, s_address, s_nationkey, s_phone, s_acctbal, s_comment] + TableScan: nation projection=[n_nationkey, n_name, n_regionkey] + Filter: #region.r_name = Utf8("EUROPE") + TableScan: region projection=[r_regionkey, r_name], partial_filters=[#region.r_name = Utf8("EUROPE")] + Projection: #partsupp.ps_partkey, #MIN(partsupp.ps_supplycost) AS __value, alias=__sq_1 + Aggregate: groupBy=[[#partsupp.ps_partkey]], aggr=[[MIN(#partsupp.ps_supplycost)]] + Inner Join: #nation.n_regionkey = #region.r_regionkey + Inner Join: #supplier.s_nationkey = #nation.n_nationkey + Inner Join: #partsupp.ps_suppkey = #supplier.s_suppkey + TableScan: partsupp projection=[ps_partkey, ps_suppkey, ps_supplycost] + TableScan: supplier projection=[s_suppkey, s_name, s_address, s_nationkey, s_phone, s_acctbal, s_comment] + TableScan: nation projection=[n_nationkey, n_name, n_regionkey] + Filter: #region.r_name = Utf8("EUROPE") + TableScan: region projection=[r_regionkey, r_name], partial_filters=[#region.r_name = Utf8("EUROPE")]"# + .to_string(); + assert_eq!(actual, expected); + + // assert data + let results = execute_to_batches(&ctx, sql).await; + let expected = vec!["++", "++"]; + assert_batches_eq!(expected, &results); + + Ok(()) +} + +#[tokio::test] +async fn tpch_q4_correlated() -> Result<()> { + let orders = r#"4,13678,O,53829.87,1995-10-11,5-LOW,Clerk#000000124,0, +35,12760,O,192885.43,1995-10-23,4-NOT SPECIFIED,Clerk#000000259,0, +65,1627,P,99763.79,1995-03-18,1-URGENT,Clerk#000000632,0, +"#; + let lineitems = r#"4,8804,579,1,30,51384,0.03,0.08,N,O,1996-01-10,1995-12-14,1996-01-18,DELIVER IN PERSON,REG AIR, +35,45,296,1,24,22680.96,0.02,0,N,O,1996-02-21,1996-01-03,1996-03-18,TAKE BACK RETURN,FOB, +65,5970,481,1,26,48775.22,0.03,0.03,A,F,1995-04-20,1995-04-25,1995-05-13,NONE,TRUCK, +"#; + + let ctx = SessionContext::new(); + register_tpch_csv_data(&ctx, "orders", orders).await?; + register_tpch_csv_data(&ctx, "lineitem", lineitems).await?; + + let sql = r#" + select o_orderpriority, count(*) as order_count + from orders + where exists ( + select * from lineitem where l_orderkey = o_orderkey and l_commitdate < l_receiptdate) + group by o_orderpriority + order by o_orderpriority; + "#; + + // assert plan + let plan = ctx.create_logical_plan(sql).unwrap(); + let plan = ctx.optimize(&plan).unwrap(); + let actual = format!("{}", plan.display_indent()); + let expected = r#"Sort: #orders.o_orderpriority ASC NULLS LAST + Projection: #orders.o_orderpriority, #COUNT(UInt8(1)) AS order_count + Aggregate: groupBy=[[#orders.o_orderpriority]], aggr=[[COUNT(UInt8(1))]] + Semi Join: #orders.o_orderkey = #lineitem.l_orderkey + TableScan: orders projection=[o_orderkey, o_orderpriority] + Filter: #lineitem.l_commitdate < #lineitem.l_receiptdate + TableScan: lineitem projection=[l_orderkey, l_commitdate, l_receiptdate]"# + .to_string(); + assert_eq!(actual, expected); + + // assert data + let results = execute_to_batches(&ctx, sql).await; + let expected = vec![ + "+-----------------+-------------+", + "| o_orderpriority | order_count |", + "+-----------------+-------------+", + "| 1-URGENT | 1 |", + "| 4-NOT SPECIFIED | 1 |", + "| 5-LOW | 1 |", + "+-----------------+-------------+", + ]; + assert_batches_eq!(expected, &results); + + Ok(()) +} + +#[tokio::test] +async fn tpch_q17_correlated() -> Result<()> { + let parts = r#"63700,goldenrod lavender spring chocolate lace,Manufacturer#1,Brand#23,PROMO BURNISHED COPPER,7,MED BOX,901.00,ly. slyly ironi +"#; + let lineitems = r#"1,63700,7311,2,36.0,45983.16,0.09,0.06,N,O,1996-04-12,1996-02-28,1996-04-20,TAKE BACK RETURN,MAIL,ly final dependencies: slyly bold +1,63700,3701,3,1.0,13309.6,0.1,0.02,N,O,1996-01-29,1996-03-05,1996-01-31,TAKE BACK RETURN,REG AIR,"riously. regular, express dep" +"#; + + let ctx = SessionContext::new(); + register_tpch_csv_data(&ctx, "part", parts).await?; + register_tpch_csv_data(&ctx, "lineitem", lineitems).await?; + + let sql = r#"select sum(l_extendedprice) / 7.0 as avg_yearly + from lineitem, part + where p_partkey = l_partkey and p_brand = 'Brand#23' and p_container = 'MED BOX' + and l_quantity < ( + select 0.2 * avg(l_quantity) + from lineitem where l_partkey = p_partkey + );"#; + + // assert plan + let plan = ctx + .create_logical_plan(sql) + .map_err(|e| format!("{:?} at {}", e, "error")) + .unwrap(); + println!("before:\n{}", plan.display_indent()); + let plan = ctx + .optimize(&plan) + .map_err(|e| format!("{:?} at {}", e, "error")) + .unwrap(); + let actual = format!("{}", plan.display_indent()); + let expected = r#"Projection: #SUM(lineitem.l_extendedprice) / Float64(7) AS avg_yearly + Aggregate: groupBy=[[]], aggr=[[SUM(#lineitem.l_extendedprice)]] + Filter: #lineitem.l_quantity < #__sq_1.__value + Inner Join: #part.p_partkey = #__sq_1.l_partkey + Inner Join: #lineitem.l_partkey = #part.p_partkey + TableScan: lineitem projection=[l_partkey, l_quantity, l_extendedprice] + Filter: #part.p_brand = Utf8("Brand#23") AND #part.p_container = Utf8("MED BOX") + TableScan: part projection=[p_partkey, p_brand, p_container] + Projection: #lineitem.l_partkey, Float64(0.2) * #AVG(lineitem.l_quantity) AS __value, alias=__sq_1 + Aggregate: groupBy=[[#lineitem.l_partkey]], aggr=[[AVG(#lineitem.l_quantity)]] + TableScan: lineitem projection=[l_partkey, l_quantity, l_extendedprice]"# + .to_string(); + assert_eq!(actual, expected); + + // assert data + let results = execute_to_batches(&ctx, sql).await; + let expected = vec![ + "+--------------------+", + "| avg_yearly |", + "+--------------------+", + "| 1901.3714285714286 |", + "+--------------------+", + ]; + assert_batches_eq!(expected, &results); + + Ok(()) +} + +#[tokio::test] +async fn tpch_q20_correlated() -> Result<()> { + let ctx = SessionContext::new(); + register_tpch_csv(&ctx, "supplier").await?; + register_tpch_csv(&ctx, "nation").await?; + register_tpch_csv(&ctx, "partsupp").await?; + register_tpch_csv(&ctx, "part").await?; + register_tpch_csv(&ctx, "lineitem").await?; + + let sql = r#"select s_name, s_address +from supplier, nation +where s_suppkey in ( + select ps_suppkey from partsupp + where ps_partkey in ( select p_partkey from part where p_name like 'forest%' ) + and ps_availqty > ( select 0.5 * sum(l_quantity) from lineitem + where l_partkey = ps_partkey and l_suppkey = ps_suppkey and l_shipdate >= date '1994-01-01' + ) +) +and s_nationkey = n_nationkey and n_name = 'CANADA' +order by s_name; +"#; + + // assert plan + let plan = ctx + .create_logical_plan(sql) + .map_err(|e| format!("{:?} at {}", e, "error")) + .unwrap(); + let plan = ctx + .optimize(&plan) + .map_err(|e| format!("{:?} at {}", e, "error")) + .unwrap(); + let actual = format!("{}", plan.display_indent()); + let expected = r#"Sort: #supplier.s_name ASC NULLS LAST + Projection: #supplier.s_name, #supplier.s_address + Semi Join: #supplier.s_suppkey = #__sq_2.ps_suppkey + Inner Join: #supplier.s_nationkey = #nation.n_nationkey + TableScan: supplier projection=[s_suppkey, s_name, s_address, s_nationkey] + Filter: #nation.n_name = Utf8("CANADA") + TableScan: nation projection=[n_nationkey, n_name], partial_filters=[#nation.n_name = Utf8("CANADA")] + Projection: #partsupp.ps_suppkey AS ps_suppkey, alias=__sq_2 + Filter: #partsupp.ps_availqty > #__sq_3.__value + Inner Join: #partsupp.ps_partkey = #__sq_3.l_partkey, #partsupp.ps_suppkey = #__sq_3.l_suppkey + Semi Join: #partsupp.ps_partkey = #__sq_1.p_partkey + TableScan: partsupp projection=[ps_partkey, ps_suppkey, ps_availqty] + Projection: #part.p_partkey AS p_partkey, alias=__sq_1 + Filter: #part.p_name LIKE Utf8("forest%") + TableScan: part projection=[p_partkey, p_name], partial_filters=[#part.p_name LIKE Utf8("forest%")] + Projection: #lineitem.l_partkey, #lineitem.l_suppkey, Float64(0.5) * #SUM(lineitem.l_quantity) AS __value, alias=__sq_3 + Aggregate: groupBy=[[#lineitem.l_partkey, #lineitem.l_suppkey]], aggr=[[SUM(#lineitem.l_quantity)]] + Filter: #lineitem.l_shipdate >= CAST(Utf8("1994-01-01") AS Date32) + TableScan: lineitem projection=[l_partkey, l_suppkey, l_quantity, l_shipdate], partial_filters=[#lineitem.l_shipdate >= CAST(Utf8("1994-01-01") AS Date32)]"# + .to_string(); + assert_eq!(actual, expected); + + // assert data + let results = execute_to_batches(&ctx, sql).await; + let expected = vec!["++", "++"]; + assert_batches_eq!(expected, &results); + + Ok(()) +} + +#[tokio::test] +async fn tpch_q22_correlated() -> Result<()> { + let ctx = SessionContext::new(); + register_tpch_csv(&ctx, "customer").await?; + register_tpch_csv(&ctx, "orders").await?; + + let sql = r#"select cntrycode, count(*) as numcust, sum(c_acctbal) as totacctbal +from ( + select substring(c_phone from 1 for 2) as cntrycode, c_acctbal from customer + where substring(c_phone from 1 for 2) in ('13', '31', '23', '29', '30', '18', '17') + and c_acctbal > ( + select avg(c_acctbal) from customer where c_acctbal > 0.00 + and substring(c_phone from 1 for 2) in ('13', '31', '23', '29', '30', '18', '17') + ) + and not exists ( select * from orders where o_custkey = c_custkey ) + ) as custsale +group by cntrycode +order by cntrycode;"#; + + // assert plan + let plan = ctx + .create_logical_plan(sql) + .map_err(|e| format!("{:?} at {}", e, "error")) + .unwrap(); + let plan = ctx + .optimize(&plan) + .map_err(|e| format!("{:?} at {}", e, "error")) + .unwrap(); + let actual = format!("{}", plan.display_indent()); + let expected = r#"Sort: #custsale.cntrycode ASC NULLS LAST + Projection: #custsale.cntrycode, #COUNT(UInt8(1)) AS numcust, #SUM(custsale.c_acctbal) AS totacctbal + Aggregate: groupBy=[[#custsale.cntrycode]], aggr=[[COUNT(UInt8(1)), SUM(#custsale.c_acctbal)]] + Projection: #custsale.cntrycode, #custsale.c_acctbal, alias=custsale + Projection: substr(#customer.c_phone, Int64(1), Int64(2)) AS cntrycode, #customer.c_acctbal, alias=custsale + Filter: #customer.c_acctbal > #__sq_1.__value + CrossJoin: + Anti Join: #customer.c_custkey = #orders.o_custkey + Filter: substr(#customer.c_phone, Int64(1), Int64(2)) IN ([Utf8("13"), Utf8("31"), Utf8("23"), Utf8("29"), Utf8("30"), Utf8("18"), Utf8("17")]) + TableScan: customer projection=[c_custkey, c_phone, c_acctbal], partial_filters=[substr(#customer.c_phone, Int64(1), Int64(2)) IN ([Utf8("13"), Utf8("31"), Utf8("23"), Utf8("29"), Utf8("30"), Utf8("18"), Utf8("17")])] + TableScan: orders projection=[o_custkey] + Projection: #AVG(customer.c_acctbal) AS __value, alias=__sq_1 + Aggregate: groupBy=[[]], aggr=[[AVG(#customer.c_acctbal)]] + Filter: #customer.c_acctbal > Float64(0) AND substr(#customer.c_phone, Int64(1), Int64(2)) IN ([Utf8("13"), Utf8("31"), Utf8("23"), Utf8("29"), Utf8("30"), Utf8("18"), Utf8("17")]) + TableScan: customer projection=[c_phone, c_acctbal], partial_filters=[#customer.c_acctbal > Float64(0), substr(#customer.c_phone, Int64(1), Int64(2)) IN ([Utf8("13"), Utf8("31"), Utf8("23"), Utf8("29"), Utf8("30"), Utf8("18"), Utf8("17")])]"# + .to_string(); + assert_eq!(actual, expected); + + // assert data + let results = execute_to_batches(&ctx, sql).await; + let expected = vec![ + "+-----------+---------+------------+", + "| cntrycode | numcust | totacctbal |", + "+-----------+---------+------------+", + "| 18 | 1 | 8324.07 |", + "| 30 | 1 | 7638.57 |", + "+-----------+---------+------------+", + ]; + assert_batches_eq!(expected, &results); + + Ok(()) +} + +#[tokio::test] +async fn tpch_q11_correlated() -> Result<()> { + let ctx = SessionContext::new(); + register_tpch_csv(&ctx, "partsupp").await?; + register_tpch_csv(&ctx, "supplier").await?; + register_tpch_csv(&ctx, "nation").await?; + + let sql = r#"select ps_partkey, sum(ps_supplycost * ps_availqty) as value +from partsupp, supplier, nation +where ps_suppkey = s_suppkey and s_nationkey = n_nationkey and n_name = 'GERMANY' +group by ps_partkey having + sum(ps_supplycost * ps_availqty) > ( + select sum(ps_supplycost * ps_availqty) * 0.0001 + from partsupp, supplier, nation + where ps_suppkey = s_suppkey and s_nationkey = n_nationkey and n_name = 'GERMANY' + ) +order by value desc; +"#; + + // assert plan + let plan = ctx + .create_logical_plan(sql) + .map_err(|e| format!("{:?} at {}", e, "error")) + .unwrap(); + println!("before:\n{}", plan.display_indent()); + let plan = ctx + .optimize(&plan) + .map_err(|e| format!("{:?} at {}", e, "error")) + .unwrap(); + let actual = format!("{}", plan.display_indent()); + println!("after:\n{}", actual); + let expected = r#"Sort: #value DESC NULLS FIRST + Projection: #partsupp.ps_partkey, #SUM(partsupp.ps_supplycost * partsupp.ps_availqty) AS value + Filter: #SUM(partsupp.ps_supplycost * partsupp.ps_availqty) > #__sq_1.__value + CrossJoin: + Aggregate: groupBy=[[#partsupp.ps_partkey]], aggr=[[SUM(#partsupp.ps_supplycost * #partsupp.ps_availqty)]] + Inner Join: #supplier.s_nationkey = #nation.n_nationkey + Inner Join: #partsupp.ps_suppkey = #supplier.s_suppkey + TableScan: partsupp projection=[ps_partkey, ps_suppkey, ps_availqty, ps_supplycost] + TableScan: supplier projection=[s_suppkey, s_nationkey] + Filter: #nation.n_name = Utf8("GERMANY") + TableScan: nation projection=[n_nationkey, n_name], partial_filters=[#nation.n_name = Utf8("GERMANY")] + Projection: #SUM(partsupp.ps_supplycost * partsupp.ps_availqty) * Float64(0.0001) AS __value, alias=__sq_1 + Aggregate: groupBy=[[]], aggr=[[SUM(#partsupp.ps_supplycost * #partsupp.ps_availqty)]] + Inner Join: #supplier.s_nationkey = #nation.n_nationkey + Inner Join: #partsupp.ps_suppkey = #supplier.s_suppkey + TableScan: partsupp projection=[ps_partkey, ps_suppkey, ps_availqty, ps_supplycost] + TableScan: supplier projection=[s_suppkey, s_nationkey] + Filter: #nation.n_name = Utf8("GERMANY") + TableScan: nation projection=[n_nationkey, n_name], partial_filters=[#nation.n_name = Utf8("GERMANY")]"# + .to_string(); + assert_eq!(actual, expected); + + // assert data + let results = execute_to_batches(&ctx, sql).await; + let expected = vec!["++", "++"]; + assert_batches_eq!(expected, &results); + + Ok(()) +} diff --git a/datafusion/core/tests/tpch-csv/lineitem.csv b/datafusion/core/tests/tpch-csv/lineitem.csv index 47f08711da07..797a891805df 100644 --- a/datafusion/core/tests/tpch-csv/lineitem.csv +++ b/datafusion/core/tests/tpch-csv/lineitem.csv @@ -1,5 +1,5 @@ l_orderkey,l_partkey,l_suppkey,l_linenumber,l_quantity,l_extendedprice,l_discount,l_tax,l_returnflag,l_linestatus,l_shipdate,l_commitdate,l_receiptdate,l_shipinstruct,l_shipmode,l_comment -1,67310,7311,2,36.0,45983.16,0.09,0.06,N,O,1996-04-12,1996-02-28,1996-04-20,TAKE BACK RETURN,MAIL,ly final dependencies: slyly bold +1,67310,7311,2,36.0,45983.16,0.09,0.06,N,O,1996-04-12,1996-02-28,1996-04-20,TAKE BACK RETURN,MAIL,ly final dependencies: slyly bold 1,63700,3701,3,8.0,13309.6,0.1,0.02,N,O,1996-01-29,1996-03-05,1996-01-31,TAKE BACK RETURN,REG AIR,"riously. regular, express dep" 1,2132,4633,4,28.0,28955.64,0.09,0.06,N,O,1996-04-21,1996-03-30,1996-05-16,NONE,AIR,lites. fluffily even de 1,24027,1534,5,24.0,22824.48,0.1,0.04,N,O,1996-03-30,1996-03-14,1996-04-01,NONE,FOB, pending foxes. slyly re @@ -7,4 +7,4 @@ l_orderkey,l_partkey,l_suppkey,l_linenumber,l_quantity,l_extendedprice,l_discoun 2,106170,1191,1,38.0,44694.46,0.0,0.05,N,O,1997-01-28,1997-01-14,1997-02-02,TAKE BACK RETURN,RAIL,ven requests. deposits breach a 3,4297,1798,1,45.0,54058.05,0.06,0.0,R,F,1994-02-02,1994-01-04,1994-02-23,NONE,AIR,ongside of the furiously brave acco 3,19036,6540,2,49.0,46796.47,0.1,0.0,R,F,1993-11-09,1993-12-20,1993-11-24,TAKE BACK RETURN,RAIL, unusual accounts. eve -3,128449,3474,3,27.0,39890.88,0.06,0.07,A,F,1994-01-16,1993-11-22,1994-01-23,DELIVER IN PERSON,SHIP,nal foxes wake. +3,128449,3474,3,27.0,39890.88,0.06,0.07,A,F,1994-01-16,1993-11-22,1994-01-23,DELIVER IN PERSON,SHIP,nal foxes wake. diff --git a/datafusion/core/tests/tpch-csv/nation.csv b/datafusion/core/tests/tpch-csv/nation.csv index e37130f4a57a..4b301059631e 100644 --- a/datafusion/core/tests/tpch-csv/nation.csv +++ b/datafusion/core/tests/tpch-csv/nation.csv @@ -8,4 +8,4 @@ n_nationkey,n_name,n_regionkey,n_comment 7,GERMANY,3,"l platelets. regular accounts x-ray: unusual, regular acco" 8,INDIA,2,ss excuses cajole slyly across the packages. deposits print aroun 9,INDONESIA,2, slyly express asymptotes. regular deposits haggle slyly. carefully ironic hockey players sleep blithely. carefull -10,IRAN,4,efully alongside of the slyly final dependencies. +10,IRAN,4,efully alongside of the slyly final dependencies. diff --git a/datafusion/core/tests/tpch-csv/part.csv b/datafusion/core/tests/tpch-csv/part.csv new file mode 100644 index 000000000000..b505100ff160 --- /dev/null +++ b/datafusion/core/tests/tpch-csv/part.csv @@ -0,0 +1,2 @@ +p_partkey,p_name,p_mfgr,p_brand,p_type,p_size,p_container,p_retailprice,p_comment +63700,goldenrod lavender spring chocolate lace,Manufacturer#1,Brand#23,PROMO BURNISHED COPPER,7,MED BOX,901.00,ly. slyly ironi diff --git a/datafusion/core/tests/tpch-csv/partsupp.csv b/datafusion/core/tests/tpch-csv/partsupp.csv new file mode 100644 index 000000000000..d7db83d03042 --- /dev/null +++ b/datafusion/core/tests/tpch-csv/partsupp.csv @@ -0,0 +1,2 @@ +ps_partkey,ps_suppkey,ps_availqty,ps_supplycost,ps_comment +67310,7311,100,993.49,ven ideas. quickly even packages print. pending multipliers must have to are fluff diff --git a/datafusion/core/tests/tpch-csv/region.csv b/datafusion/core/tests/tpch-csv/region.csv new file mode 100644 index 000000000000..269c0915648b --- /dev/null +++ b/datafusion/core/tests/tpch-csv/region.csv @@ -0,0 +1,2 @@ +r_regionkey,r_name,r_comment +4,MIDDLE EAST,uickly special accounts cajole carefully blithely close requests. carefully final asymptotes haggle furiousl diff --git a/datafusion/core/tests/tpch-csv/supplier.csv b/datafusion/core/tests/tpch-csv/supplier.csv new file mode 100644 index 000000000000..85f9aaefbedb --- /dev/null +++ b/datafusion/core/tests/tpch-csv/supplier.csv @@ -0,0 +1,3 @@ +s_suppkey,s_name,s_address,s_nationkey,s_phone,s_acctbal,s_comment +1,Supplier#000000001," N kD4on9OM Ipw3,gf0JBoQDd7tgrzrddZ",17,27-918-335-1736,5755.94,each slyly above the careful +8136,Supplier#000008136,kXATyaEZOWdQC7fE43IquuR1HkKV8qx,20,30-268-895-2611,8383.6,er the carefully regular depths. pinto beans detect quickly p diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index ad0b58fac582..ba6f7a96c29d 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -27,7 +27,7 @@ use crate::AggregateUDF; use crate::Operator; use crate::ScalarUDF; use arrow::datatypes::DataType; -use datafusion_common::Column; +use datafusion_common::{plan_err, Column}; use datafusion_common::{DFSchema, Result}; use datafusion_common::{DataFusionError, ScalarValue}; use std::fmt; @@ -452,6 +452,13 @@ impl Expr { nulls_first, } } + + pub fn try_into_col(&self) -> Result { + match self { + Expr::Column(it) => Ok(it.clone()), + _ => plan_err!(format!("Could not coerce '{}' into Column!", self)), + } + } } impl Not for Expr { diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 93c18f4b96e3..d42109788f9a 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -20,7 +20,7 @@ use crate::logical_plan::extension::UserDefinedLogicalNode; use crate::utils::exprlist_to_fields; use crate::{Expr, TableProviderFilterPushDown, TableSource}; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; -use datafusion_common::{Column, DFSchema, DFSchemaRef, DataFusionError}; +use datafusion_common::{plan_err, Column, DFSchema, DFSchemaRef, DataFusionError}; use std::collections::HashSet; ///! Logical plan types use std::fmt::{self, Debug, Display, Formatter}; @@ -1074,6 +1074,13 @@ impl Projection { alias, }) } + + pub fn try_from_plan(plan: &LogicalPlan) -> datafusion_common::Result<&Projection> { + match plan { + LogicalPlan::Projection(it) => Ok(it), + _ => plan_err!("Could not coerce into Projection!"), + } + } } /// Aliased subquery @@ -1103,6 +1110,15 @@ pub struct Filter { pub input: Arc, } +impl Filter { + pub fn try_from_plan(plan: &LogicalPlan) -> datafusion_common::Result<&Filter> { + match plan { + LogicalPlan::Filter(it) => Ok(it), + _ => plan_err!("Could not coerce into Filter!"), + } + } +} + /// Window its input based on a set of window spec and window function (e.g. SUM or RANK) #[derive(Clone)] pub struct Window { @@ -1287,6 +1303,15 @@ pub struct Aggregate { pub schema: DFSchemaRef, } +impl Aggregate { + pub fn try_from_plan(plan: &LogicalPlan) -> datafusion_common::Result<&Aggregate> { + match plan { + LogicalPlan::Aggregate(it) => Ok(it), + _ => plan_err!("Could not coerce into Aggregate!"), + } + } +} + /// Sorts its input according to a list of sort expressions. #[derive(Clone)] pub struct Sort { @@ -1324,6 +1349,15 @@ pub struct Subquery { pub subquery: Arc, } +impl Subquery { + pub fn try_from_expr(plan: &Expr) -> datafusion_common::Result<&Subquery> { + match plan { + Expr::ScalarSubquery(it) => Ok(it), + _ => plan_err!("Could not coerce into ScalarSubquery!"), + } + } +} + impl Debug for Subquery { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { write!(f, "") diff --git a/datafusion/optimizer/Cargo.toml b/datafusion/optimizer/Cargo.toml index ae493b2b01d3..24d2f1812d5b 100644 --- a/datafusion/optimizer/Cargo.toml +++ b/datafusion/optimizer/Cargo.toml @@ -45,3 +45,8 @@ datafusion-expr = { path = "../expr", version = "10.0.0" } datafusion-physical-expr = { path = "../physical-expr", version = "10.0.0" } hashbrown = { version = "0.12", features = ["raw"] } log = "^0.4" + +[dev-dependencies] +ctor = "0.1.22" +env_logger = "0.9.0" + diff --git a/datafusion/optimizer/src/decorrelate_scalar_subquery.rs b/datafusion/optimizer/src/decorrelate_scalar_subquery.rs new file mode 100644 index 000000000000..d4f8372bd326 --- /dev/null +++ b/datafusion/optimizer/src/decorrelate_scalar_subquery.rs @@ -0,0 +1,705 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::utils::{ + exprs_to_join_cols, find_join_exprs, only_or_err, split_conjunction, + verify_not_disjunction, +}; +use crate::{utils, OptimizerConfig, OptimizerRule}; +use datafusion_common::{context, plan_err, Column, Result}; +use datafusion_expr::logical_plan::{Aggregate, Filter, JoinType, Projection, Subquery}; +use datafusion_expr::{combine_filters, Expr, LogicalPlan, LogicalPlanBuilder, Operator}; +use log::debug; +use std::sync::Arc; + +/// Optimizer rule for rewriting subquery filters to joins +#[derive(Default)] +pub struct DecorrelateScalarSubquery {} + +impl DecorrelateScalarSubquery { + #[allow(missing_docs)] + pub fn new() -> Self { + Self {} + } + + /// Finds expressions that have a scalar subquery in them (and recurses when found) + /// + /// # Arguments + /// * `predicate` - A conjunction to split and search + /// * `optimizer_config` - For generating unique subquery aliases + /// + /// Returns a tuple (subqueries, non-subquery expressions) + fn extract_subquery_exprs( + &self, + predicate: &Expr, + optimizer_config: &mut OptimizerConfig, + ) -> Result<(Vec, Vec)> { + let mut filters = vec![]; + split_conjunction(predicate, &mut filters); // TODO: disjunctions + + let mut subqueries = vec![]; + let mut others = vec![]; + for it in filters.iter() { + match it { + Expr::BinaryExpr { left, op, right } => { + let l_query = Subquery::try_from_expr(left); + let r_query = Subquery::try_from_expr(right); + if l_query.is_err() && r_query.is_err() { + others.push((*it).clone()); + continue; + } + let mut recurse = + |q: Result<&Subquery>, expr: Expr, lhs: bool| -> Result<()> { + let subquery = match q { + Ok(subquery) => subquery, + _ => return Ok(()), + }; + let subquery = + self.optimize(&*subquery.subquery, optimizer_config)?; + let subquery = Arc::new(subquery); + let subquery = Subquery { subquery }; + let res = SubqueryInfo::new(subquery, expr, *op, lhs); + subqueries.push(res); + Ok(()) + }; + recurse(l_query, (**right).clone(), false)?; + recurse(r_query, (**left).clone(), true)?; + // TODO: if subquery doesn't get optimized, optimized children are lost + } + _ => others.push((*it).clone()), + } + } + + Ok((subqueries, others)) + } +} + +impl OptimizerRule for DecorrelateScalarSubquery { + fn optimize( + &self, + plan: &LogicalPlan, + optimizer_config: &mut OptimizerConfig, + ) -> Result { + match plan { + LogicalPlan::Filter(Filter { predicate, input }) => { + // Apply optimizer rule to current input + let optimized_input = self.optimize(input, optimizer_config)?; + + let (subqueries, other_exprs) = + self.extract_subquery_exprs(predicate, optimizer_config)?; + let optimized_plan = LogicalPlan::Filter(Filter { + predicate: predicate.clone(), + input: Arc::new(optimized_input), + }); + if subqueries.is_empty() { + // regular filter, no subquery exists clause here + return Ok(optimized_plan); + } + + // iterate through all exists clauses in predicate, turning each into a join + let mut cur_input = (**input).clone(); + for subquery in subqueries { + cur_input = optimize_scalar( + &subquery, + &cur_input, + &other_exprs, + optimizer_config, + )?; + } + Ok(cur_input) + } + _ => { + // Apply the optimization to all inputs of the plan + utils::optimize_children(self, plan, optimizer_config) + } + } + } + + fn name(&self) -> &str { + "decorrelate_scalar_subquery" + } +} + +/// Takes a query like: +/// +/// ```select id from customers where balance > +/// (select avg(total) from orders where orders.c_id = customers.id) +/// ``` +/// +/// and optimizes it into: +/// +/// ```select c.id from customers c +/// inner join (select c_id, avg(total) as val from orders group by c_id) o on o.c_id = c.c_id +/// where c.balance > o.val``` +/// +/// # Arguments +/// +/// * `subqry` - The subquery portion of the `where exists` (select * from orders) +/// * `negated` - True if the subquery is a `where not exists` +/// * `filter_input` - The non-subquery portion (from customers) +/// * `other_filter_exprs` - Any additional parts to the `where` expression (and c.x = y) +/// * `optimizer_config` - Used to generate unique subquery aliases +fn optimize_scalar( + query_info: &SubqueryInfo, + filter_input: &LogicalPlan, + outer_others: &[Expr], + optimizer_config: &mut OptimizerConfig, +) -> Result { + debug!( + "optimizing:\n{}", + query_info.query.subquery.display_indent() + ); + let proj = Projection::try_from_plan(&*query_info.query.subquery) + .map_err(|e| context!("scalar subqueries must have a projection", e))?; + let proj = only_or_err(proj.expr.as_slice()) + .map_err(|e| context!("exactly one expression should be projected", e))?; + let proj = Expr::Alias(Box::new(proj.clone()), "__value".to_string()); + let sub_inputs = query_info.query.subquery.inputs(); + let sub_input = only_or_err(sub_inputs.as_slice()) + .map_err(|e| context!("Exactly one input is expected. Is this a join?", e))?; + let aggr = Aggregate::try_from_plan(sub_input) + .map_err(|e| context!("scalar subqueries must aggregate a value", e))?; + let filter = Filter::try_from_plan(&*aggr.input).map_err(|e| { + context!("scalar subqueries must have a filter to be correlated", e) + })?; + + // split into filters + let mut subqry_filter_exprs = vec![]; + split_conjunction(&filter.predicate, &mut subqry_filter_exprs); + verify_not_disjunction(&subqry_filter_exprs)?; + + // Grab column names to join on + let (col_exprs, other_subqry_exprs) = + find_join_exprs(subqry_filter_exprs, filter.input.schema())?; + let (outer_cols, subqry_cols, join_filters) = + exprs_to_join_cols(&col_exprs, filter.input.schema(), false)?; + if join_filters.is_some() { + plan_err!("only joins on column equality are presently supported")?; + } + + // Only operate if one column is present and the other closed upon from outside scope + let subqry_alias = format!("__sq_{}", optimizer_config.next_id()); + let group_by: Vec<_> = subqry_cols + .iter() + .map(|it| Expr::Column(it.clone())) + .collect(); + + // build subquery side of join - the thing the subquery was querying + let mut subqry_plan = LogicalPlanBuilder::from((*filter.input).clone()); + if let Some(expr) = combine_filters(&other_subqry_exprs) { + subqry_plan = subqry_plan.filter(expr)? // if the subquery had additional expressions, restore them + } + + // project the prior projection + any correlated (and now grouped) columns + let proj: Vec<_> = group_by + .iter() + .cloned() + .chain(vec![proj].iter().cloned()) + .collect(); + let subqry_plan = subqry_plan + .aggregate(group_by, aggr.aggr_expr.clone())? + .project_with_alias(proj, Some(subqry_alias.clone()))? + .build()?; + + // qualify the join columns for outside the subquery + let subqry_cols: Vec<_> = subqry_cols + .iter() + .map(|it| Column { + relation: Some(subqry_alias.clone()), + name: it.name.clone(), + }) + .collect(); + let join_keys = (outer_cols, subqry_cols); + + // join our sub query into the main plan + let new_plan = LogicalPlanBuilder::from(filter_input.clone()); + let mut new_plan = if join_keys.0.is_empty() { + // if not correlated, group down to 1 row and cross join on that (preserving row count) + new_plan.cross_join(&subqry_plan)? + } else { + // inner join if correlated, grouping by the join keys so we don't change row count + new_plan.join(&subqry_plan, JoinType::Inner, join_keys, None)? + }; + + // restore where in condition + let qry_expr = Box::new(Expr::Column(Column { + relation: Some(subqry_alias), + name: "__value".to_string(), + })); + let filter_expr = if query_info.expr_on_left { + Expr::BinaryExpr { + left: Box::new(query_info.expr.clone()), + op: query_info.op, + right: qry_expr, + } + } else { + Expr::BinaryExpr { + left: qry_expr, + op: query_info.op, + right: Box::new(query_info.expr.clone()), + } + }; + new_plan = new_plan.filter(filter_expr)?; + + // if the main query had additional expressions, restore them + if let Some(expr) = combine_filters(outer_others) { + new_plan = new_plan.filter(expr)? + } + let new_plan = new_plan.build()?; + + Ok(new_plan) +} + +struct SubqueryInfo { + query: Subquery, + expr: Expr, + op: Operator, + expr_on_left: bool, +} + +impl SubqueryInfo { + pub fn new(query: Subquery, expr: Expr, op: Operator, expr_on_left: bool) -> Self { + Self { + query, + expr, + op, + expr_on_left, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::test::*; + use datafusion_common::Result; + use datafusion_expr::{ + col, lit, logical_plan::LogicalPlanBuilder, max, min, scalar_subquery, sum, + }; + use std::ops::Add; + + #[cfg(test)] + #[ctor::ctor] + fn init() { + let _ = env_logger::try_init(); + } + + /// Test multiple correlated subqueries + #[test] + fn multiple_subqueries() -> Result<()> { + let orders = Arc::new( + LogicalPlanBuilder::from(scan_tpch_table("orders")) + .filter(col("orders.o_custkey").eq(col("customer.c_custkey")))? + .aggregate(Vec::::new(), vec![max(col("orders.o_custkey"))])? + .project(vec![max(col("orders.o_custkey"))])? + .build()?, + ); + + let plan = LogicalPlanBuilder::from(scan_tpch_table("customer")) + .filter( + lit(1) + .lt(scalar_subquery(orders.clone())) + .and(lit(1).lt(scalar_subquery(orders))), + )? + .project(vec![col("customer.c_custkey")])? + .build()?; + + let expected = r#"Projection: #customer.c_custkey [c_custkey:Int64] + Filter: Int32(1) < #__sq_2.__value [c_custkey:Int64, c_name:Utf8, o_custkey:Int64, __value:Int64;N, o_custkey:Int64, __value:Int64;N] + Inner Join: #customer.c_custkey = #__sq_2.o_custkey [c_custkey:Int64, c_name:Utf8, o_custkey:Int64, __value:Int64;N, o_custkey:Int64, __value:Int64;N] + Filter: Int32(1) < #__sq_1.__value [c_custkey:Int64, c_name:Utf8, o_custkey:Int64, __value:Int64;N] + Inner Join: #customer.c_custkey = #__sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, o_custkey:Int64, __value:Int64;N] + TableScan: customer [c_custkey:Int64, c_name:Utf8] + Projection: #orders.o_custkey, #MAX(orders.o_custkey) AS __value, alias=__sq_1 [o_custkey:Int64, __value:Int64;N] + Aggregate: groupBy=[[#orders.o_custkey]], aggr=[[MAX(#orders.o_custkey)]] [o_custkey:Int64, MAX(orders.o_custkey):Int64;N] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + Projection: #orders.o_custkey, #MAX(orders.o_custkey) AS __value, alias=__sq_2 [o_custkey:Int64, __value:Int64;N] + Aggregate: groupBy=[[#orders.o_custkey]], aggr=[[MAX(#orders.o_custkey)]] [o_custkey:Int64, MAX(orders.o_custkey):Int64;N] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"#; + assert_optimized_plan_eq(&DecorrelateScalarSubquery::new(), &plan, expected); + Ok(()) + } + + /// Test recursive correlated subqueries + #[test] + fn recursive_subqueries() -> Result<()> { + let lineitem = Arc::new( + LogicalPlanBuilder::from(scan_tpch_table("lineitem")) + .filter(col("lineitem.l_orderkey").eq(col("orders.o_orderkey")))? + .aggregate( + Vec::::new(), + vec![sum(col("lineitem.l_extendedprice"))], + )? + .project(vec![sum(col("lineitem.l_extendedprice"))])? + .build()?, + ); + + let orders = Arc::new( + LogicalPlanBuilder::from(scan_tpch_table("orders")) + .filter( + col("orders.o_custkey") + .eq(col("customer.c_custkey")) + .and(col("orders.o_totalprice").lt(scalar_subquery(lineitem))), + )? + .aggregate(Vec::::new(), vec![sum(col("orders.o_totalprice"))])? + .project(vec![sum(col("orders.o_totalprice"))])? + .build()?, + ); + + let plan = LogicalPlanBuilder::from(scan_tpch_table("customer")) + .filter(col("customer.c_acctbal").lt(scalar_subquery(orders)))? + .project(vec![col("customer.c_custkey")])? + .build()?; + + let expected = r#"Projection: #customer.c_custkey [c_custkey:Int64] + Filter: #customer.c_acctbal < #__sq_2.__value [c_custkey:Int64, c_name:Utf8, o_custkey:Int64, __value:Float64;N] + Inner Join: #customer.c_custkey = #__sq_2.o_custkey [c_custkey:Int64, c_name:Utf8, o_custkey:Int64, __value:Float64;N] + TableScan: customer [c_custkey:Int64, c_name:Utf8] + Projection: #orders.o_custkey, #SUM(orders.o_totalprice) AS __value, alias=__sq_2 [o_custkey:Int64, __value:Float64;N] + Aggregate: groupBy=[[#orders.o_custkey]], aggr=[[SUM(#orders.o_totalprice)]] [o_custkey:Int64, SUM(orders.o_totalprice):Float64;N] + Filter: #orders.o_totalprice < #__sq_1.__value [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N, l_orderkey:Int64, __value:Float64;N] + Inner Join: #orders.o_orderkey = #__sq_1.l_orderkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N, l_orderkey:Int64, __value:Float64;N] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + Projection: #lineitem.l_orderkey, #SUM(lineitem.l_extendedprice) AS __value, alias=__sq_1 [l_orderkey:Int64, __value:Float64;N] + Aggregate: groupBy=[[#lineitem.l_orderkey]], aggr=[[SUM(#lineitem.l_extendedprice)]] [l_orderkey:Int64, SUM(lineitem.l_extendedprice):Float64;N] + TableScan: lineitem [l_orderkey:Int64, l_partkey:Int64, l_suppkey:Int64, l_linenumber:Int32, l_quantity:Float64, l_extendedprice:Float64]"#; + assert_optimized_plan_eq(&DecorrelateScalarSubquery::new(), &plan, expected); + Ok(()) + } + + /// Test for correlated scalar subquery filter with additional subquery filters + #[test] + fn scalar_subquery_with_subquery_filters() -> Result<()> { + let sq = Arc::new( + LogicalPlanBuilder::from(scan_tpch_table("orders")) + .filter( + col("customer.c_custkey") + .eq(col("orders.o_custkey")) + .and(col("o_orderkey").eq(lit(1))), + )? + .aggregate(Vec::::new(), vec![max(col("orders.o_custkey"))])? + .project(vec![max(col("orders.o_custkey"))])? + .build()?, + ); + + let plan = LogicalPlanBuilder::from(scan_tpch_table("customer")) + .filter(col("customer.c_custkey").eq(scalar_subquery(sq)))? + .project(vec![col("customer.c_custkey")])? + .build()?; + + let expected = r#"Projection: #customer.c_custkey [c_custkey:Int64] + Filter: #customer.c_custkey = #__sq_1.__value [c_custkey:Int64, c_name:Utf8, o_custkey:Int64, __value:Int64;N] + Inner Join: #customer.c_custkey = #__sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, o_custkey:Int64, __value:Int64;N] + TableScan: customer [c_custkey:Int64, c_name:Utf8] + Projection: #orders.o_custkey, #MAX(orders.o_custkey) AS __value, alias=__sq_1 [o_custkey:Int64, __value:Int64;N] + Aggregate: groupBy=[[#orders.o_custkey]], aggr=[[MAX(#orders.o_custkey)]] [o_custkey:Int64, MAX(orders.o_custkey):Int64;N] + Filter: #orders.o_orderkey = Int32(1) [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"#; + + assert_optimized_plan_eq(&DecorrelateScalarSubquery::new(), &plan, expected); + Ok(()) + } + + /// Test for correlated scalar subquery with no columns in schema + #[test] + fn scalar_subquery_no_cols() -> Result<()> { + let sq = Arc::new( + LogicalPlanBuilder::from(scan_tpch_table("orders")) + .filter(col("customer.c_custkey").eq(col("customer.c_custkey")))? + .aggregate(Vec::::new(), vec![max(col("orders.o_custkey"))])? + .project(vec![max(col("orders.o_custkey"))])? + .build()?, + ); + + let plan = LogicalPlanBuilder::from(scan_tpch_table("customer")) + .filter(col("customer.c_custkey").eq(scalar_subquery(sq)))? + .project(vec![col("customer.c_custkey")])? + .build()?; + + // it will optimize, but fail for the same reason the unoptimized query would + let expected = r#"Projection: #customer.c_custkey [c_custkey:Int64] + Filter: #customer.c_custkey = #__sq_1.__value [c_custkey:Int64, c_name:Utf8, __value:Int64;N] + CrossJoin: [c_custkey:Int64, c_name:Utf8, __value:Int64;N] + TableScan: customer [c_custkey:Int64, c_name:Utf8] + Projection: #MAX(orders.o_custkey) AS __value, alias=__sq_1 [__value:Int64;N] + Aggregate: groupBy=[[]], aggr=[[MAX(#orders.o_custkey)]] [MAX(orders.o_custkey):Int64;N] + Filter: #customer.c_custkey = #customer.c_custkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"#; + assert_optimized_plan_eq(&DecorrelateScalarSubquery::new(), &plan, expected); + Ok(()) + } + + /// Test for scalar subquery with both columns in schema + #[test] + fn scalar_subquery_with_no_correlated_cols() -> Result<()> { + let sq = Arc::new( + LogicalPlanBuilder::from(scan_tpch_table("orders")) + .filter(col("orders.o_custkey").eq(col("orders.o_custkey")))? + .aggregate(Vec::::new(), vec![max(col("orders.o_custkey"))])? + .project(vec![max(col("orders.o_custkey"))])? + .build()?, + ); + + let plan = LogicalPlanBuilder::from(scan_tpch_table("customer")) + .filter(col("customer.c_custkey").eq(scalar_subquery(sq)))? + .project(vec![col("customer.c_custkey")])? + .build()?; + + let expected = r#"Projection: #customer.c_custkey [c_custkey:Int64] + Filter: #customer.c_custkey = #__sq_1.__value [c_custkey:Int64, c_name:Utf8, __value:Int64;N] + CrossJoin: [c_custkey:Int64, c_name:Utf8, __value:Int64;N] + TableScan: customer [c_custkey:Int64, c_name:Utf8] + Projection: #MAX(orders.o_custkey) AS __value, alias=__sq_1 [__value:Int64;N] + Aggregate: groupBy=[[]], aggr=[[MAX(#orders.o_custkey)]] [MAX(orders.o_custkey):Int64;N] + Filter: #orders.o_custkey = #orders.o_custkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"#; + + assert_optimized_plan_eq(&DecorrelateScalarSubquery::new(), &plan, expected); + Ok(()) + } + + /// Test for correlated scalar subquery not equal + #[test] + fn scalar_subquery_where_not_eq() -> Result<()> { + let sq = Arc::new( + LogicalPlanBuilder::from(scan_tpch_table("orders")) + .filter(col("customer.c_custkey").not_eq(col("orders.o_custkey")))? + .aggregate(Vec::::new(), vec![max(col("orders.o_custkey"))])? + .project(vec![max(col("orders.o_custkey"))])? + .build()?, + ); + + let plan = LogicalPlanBuilder::from(scan_tpch_table("customer")) + .filter(col("customer.c_custkey").eq(scalar_subquery(sq)))? + .project(vec![col("customer.c_custkey")])? + .build()?; + + let expected = r#"only joins on column equality are presently supported"#; + + assert_optimizer_err(&DecorrelateScalarSubquery::new(), &plan, expected); + Ok(()) + } + + /// Test for correlated scalar subquery less than + #[test] + fn scalar_subquery_where_less_than() -> Result<()> { + let sq = Arc::new( + LogicalPlanBuilder::from(scan_tpch_table("orders")) + .filter(col("customer.c_custkey").lt(col("orders.o_custkey")))? + .aggregate(Vec::::new(), vec![max(col("orders.o_custkey"))])? + .project(vec![max(col("orders.o_custkey"))])? + .build()?, + ); + + let plan = LogicalPlanBuilder::from(scan_tpch_table("customer")) + .filter(col("customer.c_custkey").eq(scalar_subquery(sq)))? + .project(vec![col("customer.c_custkey")])? + .build()?; + + let expected = r#"can't optimize < column comparison"#; + assert_optimizer_err(&DecorrelateScalarSubquery::new(), &plan, expected); + Ok(()) + } + + /// Test for correlated scalar subquery filter with subquery disjunction + #[test] + fn scalar_subquery_with_subquery_disjunction() -> Result<()> { + let sq = Arc::new( + LogicalPlanBuilder::from(scan_tpch_table("orders")) + .filter( + col("customer.c_custkey") + .eq(col("orders.o_custkey")) + .or(col("o_orderkey").eq(lit(1))), + )? + .aggregate(Vec::::new(), vec![max(col("orders.o_custkey"))])? + .project(vec![max(col("orders.o_custkey"))])? + .build()?, + ); + + let plan = LogicalPlanBuilder::from(scan_tpch_table("customer")) + .filter(col("customer.c_custkey").eq(scalar_subquery(sq)))? + .project(vec![col("customer.c_custkey")])? + .build()?; + + let expected = r#"Optimizing disjunctions not supported!"#; + assert_optimizer_err(&DecorrelateScalarSubquery::new(), &plan, expected); + Ok(()) + } + + /// Test for correlated scalar without projection + #[test] + fn scalar_subquery_no_projection() -> Result<()> { + let sq = Arc::new( + LogicalPlanBuilder::from(scan_tpch_table("orders")) + .filter(col("customer.c_custkey").eq(col("orders.o_custkey")))? + .build()?, + ); + + let plan = LogicalPlanBuilder::from(scan_tpch_table("customer")) + .filter(col("customer.c_custkey").eq(scalar_subquery(sq)))? + .project(vec![col("customer.c_custkey")])? + .build()?; + + let expected = r#"scalar subqueries must have a projection"#; + assert_optimizer_err(&DecorrelateScalarSubquery::new(), &plan, expected); + Ok(()) + } + + /// Test for correlated scalar expressions + #[test] + #[ignore] + fn scalar_subquery_project_expr() -> Result<()> { + let sq = Arc::new( + LogicalPlanBuilder::from(scan_tpch_table("orders")) + .filter(col("customer.c_custkey").eq(col("orders.o_custkey")))? + .aggregate(Vec::::new(), vec![max(col("orders.o_custkey"))])? + .project(vec![max(col("orders.o_custkey")).add(lit(1))])? + .build()?, + ); + /* + Error: SchemaError(FieldNotFound { qualifier: Some("orders"), name: "o_custkey", valid_fields: Some(["MAX(orders.o_custkey)"]) }) + */ + + let plan = LogicalPlanBuilder::from(scan_tpch_table("customer")) + .filter(col("customer.c_custkey").eq(scalar_subquery(sq)))? + .project(vec![col("customer.c_custkey")])? + .build()?; + + let expected = r#""#; + + assert_optimized_plan_eq(&DecorrelateScalarSubquery::new(), &plan, expected); + Ok(()) + } + + /// Test for correlated scalar subquery multiple projected columns + #[test] + fn scalar_subquery_multi_col() -> Result<()> { + let sq = Arc::new( + LogicalPlanBuilder::from(scan_tpch_table("orders")) + .filter(col("customer.c_custkey").eq(col("orders.o_custkey")))? + .project(vec![col("orders.o_custkey"), col("orders.o_orderkey")])? + .build()?, + ); + + let plan = LogicalPlanBuilder::from(scan_tpch_table("customer")) + .filter( + col("customer.c_custkey") + .eq(scalar_subquery(sq)) + .and(col("c_custkey").eq(lit(1))), + )? + .project(vec![col("customer.c_custkey")])? + .build()?; + + let expected = r#"exactly one expression should be projected"#; + assert_optimizer_err(&DecorrelateScalarSubquery::new(), &plan, expected); + Ok(()) + } + + /// Test for correlated scalar subquery filter with additional filters + #[test] + fn scalar_subquery_additional_filters() -> Result<()> { + let sq = Arc::new( + LogicalPlanBuilder::from(scan_tpch_table("orders")) + .filter(col("customer.c_custkey").eq(col("orders.o_custkey")))? + .aggregate(Vec::::new(), vec![max(col("orders.o_custkey"))])? + .project(vec![max(col("orders.o_custkey"))])? + .build()?, + ); + + let plan = LogicalPlanBuilder::from(scan_tpch_table("customer")) + .filter( + col("customer.c_custkey") + .eq(scalar_subquery(sq)) + .and(col("c_custkey").eq(lit(1))), + )? + .project(vec![col("customer.c_custkey")])? + .build()?; + + let expected = r#"Projection: #customer.c_custkey [c_custkey:Int64] + Filter: #customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8, o_custkey:Int64, __value:Int64;N] + Filter: #customer.c_custkey = #__sq_1.__value [c_custkey:Int64, c_name:Utf8, o_custkey:Int64, __value:Int64;N] + Inner Join: #customer.c_custkey = #__sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, o_custkey:Int64, __value:Int64;N] + TableScan: customer [c_custkey:Int64, c_name:Utf8] + Projection: #orders.o_custkey, #MAX(orders.o_custkey) AS __value, alias=__sq_1 [o_custkey:Int64, __value:Int64;N] + Aggregate: groupBy=[[#orders.o_custkey]], aggr=[[MAX(#orders.o_custkey)]] [o_custkey:Int64, MAX(orders.o_custkey):Int64;N] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"#; + + assert_optimized_plan_eq(&DecorrelateScalarSubquery::new(), &plan, expected); + Ok(()) + } + + /// Test for correlated scalar subquery filter with disjustions + #[test] + fn scalar_subquery_disjunction() -> Result<()> { + let sq = Arc::new( + LogicalPlanBuilder::from(scan_tpch_table("orders")) + .filter(col("customer.c_custkey").eq(col("orders.o_custkey")))? + .aggregate(Vec::::new(), vec![max(col("orders.o_custkey"))])? + .project(vec![max(col("orders.o_custkey"))])? + .build()?, + ); + + let plan = LogicalPlanBuilder::from(scan_tpch_table("customer")) + .filter( + col("customer.c_custkey") + .eq(scalar_subquery(sq)) + .or(col("customer.c_custkey").eq(lit(1))), + )? + .project(vec![col("customer.c_custkey")])? + .build()?; + + // unoptimized plan because we don't support disjunctions yet + let expected = r#"Projection: #customer.c_custkey [c_custkey:Int64] + Filter: #customer.c_custkey = () OR #customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8] + Subquery: [MAX(orders.o_custkey):Int64;N] + Projection: #MAX(orders.o_custkey) [MAX(orders.o_custkey):Int64;N] + Aggregate: groupBy=[[]], aggr=[[MAX(#orders.o_custkey)]] [MAX(orders.o_custkey):Int64;N] + Filter: #customer.c_custkey = #orders.o_custkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + TableScan: customer [c_custkey:Int64, c_name:Utf8]"#; + assert_optimized_plan_eq(&DecorrelateScalarSubquery::new(), &plan, expected); + Ok(()) + } + + /// Test for correlated scalar subquery filter + #[test] + fn exists_subquery_correlated() -> Result<()> { + let sq = Arc::new( + LogicalPlanBuilder::from(test_table_scan_with_name("sq")?) + .filter(col("test.a").eq(col("sq.a")))? + .aggregate(Vec::::new(), vec![min(col("c"))])? + .project(vec![min(col("c"))])? + .build()?, + ); + + let plan = LogicalPlanBuilder::from(test_table_scan_with_name("test")?) + .filter(col("test.c").lt(scalar_subquery(sq)))? + .project(vec![col("test.c")])? + .build()?; + + let expected = r#"Projection: #test.c [c:UInt32] + Filter: #test.c < #__sq_1.__value [a:UInt32, b:UInt32, c:UInt32, a:UInt32, __value:UInt32;N] + Inner Join: #test.a = #__sq_1.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, __value:UInt32;N] + TableScan: test [a:UInt32, b:UInt32, c:UInt32] + Projection: #sq.a, #MIN(sq.c) AS __value, alias=__sq_1 [a:UInt32, __value:UInt32;N] + Aggregate: groupBy=[[#sq.a]], aggr=[[MIN(#sq.c)]] [a:UInt32, MIN(sq.c):UInt32;N] + TableScan: sq [a:UInt32, b:UInt32, c:UInt32]"#; + + assert_optimized_plan_eq(&DecorrelateScalarSubquery::new(), &plan, expected); + Ok(()) + } +} diff --git a/datafusion/optimizer/src/decorrelate_where_exists.rs b/datafusion/optimizer/src/decorrelate_where_exists.rs new file mode 100644 index 000000000000..2c25bcbb28e7 --- /dev/null +++ b/datafusion/optimizer/src/decorrelate_where_exists.rs @@ -0,0 +1,557 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::utils::{ + exprs_to_join_cols, find_join_exprs, only_or_err, split_conjunction, + verify_not_disjunction, +}; +use crate::{utils, OptimizerConfig, OptimizerRule}; +use datafusion_common::{context, plan_err}; +use datafusion_expr::logical_plan::{Filter, JoinType, Subquery}; +use datafusion_expr::{combine_filters, Expr, LogicalPlan, LogicalPlanBuilder}; +use std::sync::Arc; + +/// Optimizer rule for rewriting subquery filters to joins +#[derive(Default)] +pub struct DecorrelateWhereExists {} + +impl DecorrelateWhereExists { + #[allow(missing_docs)] + pub fn new() -> Self { + Self {} + } + + /// Finds expressions that have a where in subquery (and recurses when found) + /// + /// # Arguments + /// + /// * `predicate` - A conjunction to split and search + /// * `optimizer_config` - For generating unique subquery aliases + /// + /// Returns a tuple (subqueries, non-subquery expressions) + fn extract_subquery_exprs( + &self, + predicate: &Expr, + optimizer_config: &mut OptimizerConfig, + ) -> datafusion_common::Result<(Vec, Vec)> { + let mut filters = vec![]; + split_conjunction(predicate, &mut filters); + + let mut subqueries = vec![]; + let mut others = vec![]; + for it in filters.iter() { + match it { + Expr::Exists { subquery, negated } => { + let subquery = + self.optimize(&*subquery.subquery, optimizer_config)?; + let subquery = Arc::new(subquery); + let subquery = Subquery { subquery }; + let subquery = SubqueryInfo::new(subquery.clone(), *negated); + subqueries.push(subquery); + } + _ => others.push((*it).clone()), + } + } + + Ok((subqueries, others)) + } +} + +impl OptimizerRule for DecorrelateWhereExists { + fn optimize( + &self, + plan: &LogicalPlan, + optimizer_config: &mut OptimizerConfig, + ) -> datafusion_common::Result { + match plan { + LogicalPlan::Filter(Filter { + predicate, + input: filter_input, + }) => { + // Apply optimizer rule to current input + let optimized_input = self.optimize(filter_input, optimizer_config)?; + + let (subqueries, other_exprs) = + self.extract_subquery_exprs(predicate, optimizer_config)?; + let optimized_plan = LogicalPlan::Filter(Filter { + predicate: predicate.clone(), + input: Arc::new(optimized_input), + }); + if subqueries.is_empty() { + // regular filter, no subquery exists clause here + return Ok(optimized_plan); + } + + // iterate through all exists clauses in predicate, turning each into a join + let mut cur_input = (**filter_input).clone(); + for subquery in subqueries { + cur_input = optimize_exists(&subquery, &cur_input, &other_exprs)?; + } + Ok(cur_input) + } + _ => { + // Apply the optimization to all inputs of the plan + utils::optimize_children(self, plan, optimizer_config) + } + } + } + + fn name(&self) -> &str { + "decorrelate_where_exists" + } +} + +/// Takes a query like: +/// +/// ```select c.id from customers c where exists (select * from orders o where o.c_id = c.id)``` +/// +/// and optimizes it into: +/// +/// ```select c.id from customers c +/// inner join (select o.c_id from orders o group by o.c_id) o on o.c_id = c.c_id``` +/// +/// # Arguments +/// +/// * subqry - The subquery portion of the `where exists` (select * from orders) +/// * negated - True if the subquery is a `where not exists` +/// * filter_input - The non-subquery portion (from customers) +/// * outer_exprs - Any additional parts to the `where` expression (and c.x = y) +fn optimize_exists( + query_info: &SubqueryInfo, + outer_input: &LogicalPlan, + outer_other_exprs: &[Expr], +) -> datafusion_common::Result { + let subqry_inputs = query_info.query.subquery.inputs(); + let subqry_input = only_or_err(subqry_inputs.as_slice()) + .map_err(|e| context!("single expression projection required", e))?; + let subqry_filter = Filter::try_from_plan(subqry_input) + .map_err(|e| context!("cannot optimize non-correlated subquery", e))?; + + // split into filters + let mut subqry_filter_exprs = vec![]; + split_conjunction(&subqry_filter.predicate, &mut subqry_filter_exprs); + verify_not_disjunction(&subqry_filter_exprs)?; + + // Grab column names to join on + let (col_exprs, other_subqry_exprs) = + find_join_exprs(subqry_filter_exprs, subqry_filter.input.schema())?; + let (outer_cols, subqry_cols, join_filters) = + exprs_to_join_cols(&col_exprs, subqry_filter.input.schema(), false)?; + if subqry_cols.is_empty() || outer_cols.is_empty() { + plan_err!("cannot optimize non-correlated subquery")?; + } + + // build subquery side of join - the thing the subquery was querying + let mut subqry_plan = LogicalPlanBuilder::from((*subqry_filter.input).clone()); + if let Some(expr) = combine_filters(&other_subqry_exprs) { + subqry_plan = subqry_plan.filter(expr)? // if the subquery had additional expressions, restore them + } + let subqry_plan = subqry_plan.build()?; + + let join_keys = (subqry_cols, outer_cols); + + // join our sub query into the main plan + let join_type = match query_info.negated { + true => JoinType::Anti, + false => JoinType::Semi, + }; + let mut new_plan = LogicalPlanBuilder::from(outer_input.clone()).join( + &subqry_plan, + join_type, + join_keys, + join_filters, + )?; + if let Some(expr) = combine_filters(outer_other_exprs) { + new_plan = new_plan.filter(expr)? // if the main query had additional expressions, restore them + } + + let result = new_plan.build()?; + Ok(result) +} + +struct SubqueryInfo { + query: Subquery, + negated: bool, +} + +impl SubqueryInfo { + pub fn new(query: Subquery, negated: bool) -> Self { + Self { query, negated } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::test::*; + use datafusion_common::Result; + use datafusion_expr::{ + col, exists, lit, logical_plan::LogicalPlanBuilder, not_exists, + }; + use std::ops::Add; + + /// Test for multiple exists subqueries in the same filter expression + #[test] + fn multiple_subqueries() -> Result<()> { + let orders = Arc::new( + LogicalPlanBuilder::from(scan_tpch_table("orders")) + .filter(col("orders.o_custkey").eq(col("customer.c_custkey")))? + .project(vec![col("orders.o_custkey")])? + .build()?, + ); + + let plan = LogicalPlanBuilder::from(scan_tpch_table("customer")) + .filter(exists(orders.clone()).and(exists(orders)))? + .project(vec![col("customer.c_custkey")])? + .build()?; + + let expected = r#"Projection: #customer.c_custkey [c_custkey:Int64] + Semi Join: #customer.c_custkey = #orders.o_custkey [c_custkey:Int64, c_name:Utf8] + Semi Join: #customer.c_custkey = #orders.o_custkey [c_custkey:Int64, c_name:Utf8] + TableScan: customer [c_custkey:Int64, c_name:Utf8] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"#; + + assert_optimized_plan_eq(&DecorrelateWhereExists::new(), &plan, expected); + Ok(()) + } + + /// Test recursive correlated subqueries + #[test] + fn recursive_subqueries() -> Result<()> { + let lineitem = Arc::new( + LogicalPlanBuilder::from(scan_tpch_table("lineitem")) + .filter(col("lineitem.l_orderkey").eq(col("orders.o_orderkey")))? + .project(vec![col("lineitem.l_orderkey")])? + .build()?, + ); + + let orders = Arc::new( + LogicalPlanBuilder::from(scan_tpch_table("orders")) + .filter( + exists(lineitem) + .and(col("orders.o_custkey").eq(col("customer.c_custkey"))), + )? + .project(vec![col("orders.o_custkey")])? + .build()?, + ); + + let plan = LogicalPlanBuilder::from(scan_tpch_table("customer")) + .filter(exists(orders))? + .project(vec![col("customer.c_custkey")])? + .build()?; + + let expected = r#"Projection: #customer.c_custkey [c_custkey:Int64] + Semi Join: #customer.c_custkey = #orders.o_custkey [c_custkey:Int64, c_name:Utf8] + TableScan: customer [c_custkey:Int64, c_name:Utf8] + Semi Join: #orders.o_orderkey = #lineitem.l_orderkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + TableScan: lineitem [l_orderkey:Int64, l_partkey:Int64, l_suppkey:Int64, l_linenumber:Int32, l_quantity:Float64, l_extendedprice:Float64]"#; + + assert_optimized_plan_eq(&DecorrelateWhereExists::new(), &plan, expected); + Ok(()) + } + + /// Test for correlated exists subquery filter with additional subquery filters + #[test] + fn exists_subquery_with_subquery_filters() -> Result<()> { + let sq = Arc::new( + LogicalPlanBuilder::from(scan_tpch_table("orders")) + .filter( + col("customer.c_custkey") + .eq(col("orders.o_custkey")) + .and(col("o_orderkey").eq(lit(1))), + )? + .project(vec![col("orders.o_custkey")])? + .build()?, + ); + + let plan = LogicalPlanBuilder::from(scan_tpch_table("customer")) + .filter(exists(sq))? + .project(vec![col("customer.c_custkey")])? + .build()?; + + let expected = r#"Projection: #customer.c_custkey [c_custkey:Int64] + Semi Join: #customer.c_custkey = #orders.o_custkey [c_custkey:Int64, c_name:Utf8] + TableScan: customer [c_custkey:Int64, c_name:Utf8] + Filter: #orders.o_orderkey = Int32(1) [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"#; + + assert_optimized_plan_eq(&DecorrelateWhereExists::new(), &plan, expected); + Ok(()) + } + + /// Test for correlated exists subquery with no columns in schema + #[test] + fn exists_subquery_no_cols() -> Result<()> { + let sq = Arc::new( + LogicalPlanBuilder::from(scan_tpch_table("orders")) + .filter(col("customer.c_custkey").eq(col("customer.c_custkey")))? + .project(vec![col("orders.o_custkey")])? + .build()?, + ); + + let plan = LogicalPlanBuilder::from(scan_tpch_table("customer")) + .filter(exists(sq))? + .project(vec![col("customer.c_custkey")])? + .build()?; + + let expected = r#"cannot optimize non-correlated subquery"#; + + assert_optimizer_err(&DecorrelateWhereExists::new(), &plan, expected); + Ok(()) + } + + /// Test for exists subquery with both columns in schema + #[test] + fn exists_subquery_with_no_correlated_cols() -> Result<()> { + let sq = Arc::new( + LogicalPlanBuilder::from(scan_tpch_table("orders")) + .filter(col("orders.o_custkey").eq(col("orders.o_custkey")))? + .project(vec![col("orders.o_custkey")])? + .build()?, + ); + + let plan = LogicalPlanBuilder::from(scan_tpch_table("customer")) + .filter(exists(sq))? + .project(vec![col("customer.c_custkey")])? + .build()?; + + let expected = r#"cannot optimize non-correlated subquery"#; + + assert_optimizer_err(&DecorrelateWhereExists::new(), &plan, expected); + Ok(()) + } + + /// Test for correlated exists subquery not equal + #[test] + fn exists_subquery_where_not_eq() -> Result<()> { + let sq = Arc::new( + LogicalPlanBuilder::from(scan_tpch_table("orders")) + .filter(col("customer.c_custkey").not_eq(col("orders.o_custkey")))? + .project(vec![col("orders.o_custkey")])? + .build()?, + ); + + let plan = LogicalPlanBuilder::from(scan_tpch_table("customer")) + .filter(exists(sq))? + .project(vec![col("customer.c_custkey")])? + .build()?; + + let expected = r#"cannot optimize non-correlated subquery"#; + + assert_optimizer_err(&DecorrelateWhereExists::new(), &plan, expected); + Ok(()) + } + + /// Test for correlated exists subquery less than + #[test] + fn exists_subquery_where_less_than() -> Result<()> { + let sq = Arc::new( + LogicalPlanBuilder::from(scan_tpch_table("orders")) + .filter(col("customer.c_custkey").lt(col("orders.o_custkey")))? + .project(vec![col("orders.o_custkey")])? + .build()?, + ); + + let plan = LogicalPlanBuilder::from(scan_tpch_table("customer")) + .filter(exists(sq))? + .project(vec![col("customer.c_custkey")])? + .build()?; + + let expected = r#"can't optimize < column comparison"#; + + assert_optimizer_err(&DecorrelateWhereExists::new(), &plan, expected); + Ok(()) + } + + /// Test for correlated exists subquery filter with subquery disjunction + #[test] + fn exists_subquery_with_subquery_disjunction() -> Result<()> { + let sq = Arc::new( + LogicalPlanBuilder::from(scan_tpch_table("orders")) + .filter( + col("customer.c_custkey") + .eq(col("orders.o_custkey")) + .or(col("o_orderkey").eq(lit(1))), + )? + .project(vec![col("orders.o_custkey")])? + .build()?, + ); + + let plan = LogicalPlanBuilder::from(scan_tpch_table("customer")) + .filter(exists(sq))? + .project(vec![col("customer.c_custkey")])? + .build()?; + + let expected = r#"Optimizing disjunctions not supported!"#; + + assert_optimizer_err(&DecorrelateWhereExists::new(), &plan, expected); + Ok(()) + } + + /// Test for correlated exists without projection + #[test] + fn exists_subquery_no_projection() -> Result<()> { + let sq = Arc::new( + LogicalPlanBuilder::from(scan_tpch_table("orders")) + .filter(col("customer.c_custkey").eq(col("orders.o_custkey")))? + .build()?, + ); + + let plan = LogicalPlanBuilder::from(scan_tpch_table("customer")) + .filter(exists(sq))? + .project(vec![col("customer.c_custkey")])? + .build()?; + + let expected = r#"cannot optimize non-correlated subquery"#; + + assert_optimizer_err(&DecorrelateWhereExists::new(), &plan, expected); + Ok(()) + } + + /// Test for correlated exists expressions + #[test] + fn exists_subquery_project_expr() -> Result<()> { + let sq = Arc::new( + LogicalPlanBuilder::from(scan_tpch_table("orders")) + .filter(col("customer.c_custkey").eq(col("orders.o_custkey")))? + .project(vec![col("orders.o_custkey").add(lit(1))])? + .build()?, + ); + + let plan = LogicalPlanBuilder::from(scan_tpch_table("customer")) + .filter(exists(sq))? + .project(vec![col("customer.c_custkey")])? + .build()?; + + // Doesn't matter we projected an expression, just that we returned a result + let expected = r#"Projection: #customer.c_custkey [c_custkey:Int64] + Semi Join: #customer.c_custkey = #orders.o_custkey [c_custkey:Int64, c_name:Utf8] + TableScan: customer [c_custkey:Int64, c_name:Utf8] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"#; + + assert_optimized_plan_eq(&DecorrelateWhereExists::new(), &plan, expected); + Ok(()) + } + + /// Test for correlated exists subquery filter with additional filters + #[test] + fn should_support_additional_filters() -> Result<()> { + let sq = Arc::new( + LogicalPlanBuilder::from(scan_tpch_table("orders")) + .filter(col("customer.c_custkey").eq(col("orders.o_custkey")))? + .project(vec![col("orders.o_custkey")])? + .build()?, + ); + let plan = LogicalPlanBuilder::from(scan_tpch_table("customer")) + .filter(exists(sq).and(col("c_custkey").eq(lit(1))))? + .project(vec![col("customer.c_custkey")])? + .build()?; + + let expected = r#"Projection: #customer.c_custkey [c_custkey:Int64] + Filter: #customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8] + Semi Join: #customer.c_custkey = #orders.o_custkey [c_custkey:Int64, c_name:Utf8] + TableScan: customer [c_custkey:Int64, c_name:Utf8] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"#; + + assert_optimized_plan_eq(&DecorrelateWhereExists::new(), &plan, expected); + Ok(()) + } + + /// Test for correlated exists subquery filter with disjustions + #[test] + fn exists_subquery_disjunction() -> Result<()> { + let sq = Arc::new( + LogicalPlanBuilder::from(scan_tpch_table("orders")) + .filter(col("customer.c_custkey").eq(col("orders.o_custkey")))? + .project(vec![col("orders.o_custkey")])? + .build()?, + ); + + let plan = LogicalPlanBuilder::from(scan_tpch_table("customer")) + .filter(exists(sq).or(col("customer.c_custkey").eq(lit(1))))? + .project(vec![col("customer.c_custkey")])? + .build()?; + + // not optimized + let expected = r#"Projection: #customer.c_custkey [c_custkey:Int64] + Filter: EXISTS () OR #customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8] + Subquery: [o_custkey:Int64] + Projection: #orders.o_custkey [o_custkey:Int64] + Filter: #customer.c_custkey = #orders.o_custkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + TableScan: customer [c_custkey:Int64, c_name:Utf8]"#; + + assert_optimized_plan_eq(&DecorrelateWhereExists::new(), &plan, expected); + Ok(()) + } + + /// Test for correlated EXISTS subquery filter + #[test] + fn exists_subquery_correlated() -> Result<()> { + let sq = Arc::new( + LogicalPlanBuilder::from(test_table_scan_with_name("sq")?) + .filter(col("test.a").eq(col("sq.a")))? + .project(vec![col("c")])? + .build()?, + ); + + let plan = LogicalPlanBuilder::from(test_table_scan_with_name("test")?) + .filter(exists(sq))? + .project(vec![col("test.c")])? + .build()?; + + let expected = r#"Projection: #test.c [c:UInt32] + Semi Join: #test.a = #sq.a [a:UInt32, b:UInt32, c:UInt32] + TableScan: test [a:UInt32, b:UInt32, c:UInt32] + TableScan: sq [a:UInt32, b:UInt32, c:UInt32]"#; + + assert_optimized_plan_eq(&DecorrelateWhereExists::new(), &plan, expected); + Ok(()) + } + + /// Test for single exists subquery filter + #[test] + fn exists_subquery_simple() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan) + .filter(exists(test_subquery_with_name("sq")?))? + .project(vec![col("test.b")])? + .build()?; + + let expected = "cannot optimize non-correlated subquery"; + + assert_optimizer_err(&DecorrelateWhereExists::new(), &plan, expected); + Ok(()) + } + + /// Test for single NOT exists subquery filter + #[test] + fn not_exists_subquery_simple() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan) + .filter(not_exists(test_subquery_with_name("sq")?))? + .project(vec![col("test.b")])? + .build()?; + + let expected = "cannot optimize non-correlated subquery"; + + assert_optimizer_err(&DecorrelateWhereExists::new(), &plan, expected); + Ok(()) + } +} diff --git a/datafusion/optimizer/src/decorrelate_where_in.rs b/datafusion/optimizer/src/decorrelate_where_in.rs new file mode 100644 index 000000000000..f90d94d8c16f --- /dev/null +++ b/datafusion/optimizer/src/decorrelate_where_in.rs @@ -0,0 +1,693 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::utils::{ + alias_cols, exprs_to_join_cols, find_join_exprs, merge_cols, only_or_err, + split_conjunction, swap_table, verify_not_disjunction, +}; +use crate::{utils, OptimizerConfig, OptimizerRule}; +use datafusion_common::context; +use datafusion_expr::logical_plan::{Filter, JoinType, Projection, Subquery}; +use datafusion_expr::{combine_filters, Expr, LogicalPlan, LogicalPlanBuilder}; +use log::debug; +use std::sync::Arc; + +#[derive(Default)] +pub struct DecorrelateWhereIn {} + +impl DecorrelateWhereIn { + #[allow(missing_docs)] + pub fn new() -> Self { + Self {} + } + + /// Finds expressions that have a where in subquery (and recurses when found) + /// + /// # Arguments + /// + /// * `predicate` - A conjunction to split and search + /// * `optimizer_config` - For generating unique subquery aliases + /// + /// Returns a tuple (subqueries, non-subquery expressions) + fn extract_subquery_exprs( + &self, + predicate: &Expr, + optimizer_config: &mut OptimizerConfig, + ) -> datafusion_common::Result<(Vec, Vec)> { + let mut filters = vec![]; + split_conjunction(predicate, &mut filters); // TODO: disjunctions + + let mut subqueries = vec![]; + let mut others = vec![]; + for it in filters.iter() { + match it { + Expr::InSubquery { + expr, + subquery, + negated, + } => { + let subquery = + self.optimize(&*subquery.subquery, optimizer_config)?; + let subquery = Arc::new(subquery); + let subquery = Subquery { subquery }; + let subquery = + SubqueryInfo::new(subquery.clone(), (**expr).clone(), *negated); + subqueries.push(subquery); + // TODO: if subquery doesn't get optimized, optimized children are lost + } + _ => others.push((*it).clone()), + } + } + + Ok((subqueries, others)) + } +} + +impl OptimizerRule for DecorrelateWhereIn { + fn optimize( + &self, + plan: &LogicalPlan, + optimizer_config: &mut OptimizerConfig, + ) -> datafusion_common::Result { + match plan { + LogicalPlan::Filter(Filter { + predicate, + input: filter_input, + }) => { + // Apply optimizer rule to current input + let optimized_input = self.optimize(filter_input, optimizer_config)?; + + let (subqueries, other_exprs) = + self.extract_subquery_exprs(predicate, optimizer_config)?; + let optimized_plan = LogicalPlan::Filter(Filter { + predicate: predicate.clone(), + input: Arc::new(optimized_input), + }); + if subqueries.is_empty() { + // regular filter, no subquery exists clause here + return Ok(optimized_plan); + } + + // iterate through all exists clauses in predicate, turning each into a join + let mut cur_input = (**filter_input).clone(); + for subquery in subqueries { + cur_input = optimize_where_in( + &subquery, + &cur_input, + &other_exprs, + optimizer_config, + )?; + } + Ok(cur_input) + } + _ => { + // Apply the optimization to all inputs of the plan + utils::optimize_children(self, plan, optimizer_config) + } + } + } + + fn name(&self) -> &str { + "decorrelate_where_in" + } +} + +fn optimize_where_in( + query_info: &SubqueryInfo, + outer_input: &LogicalPlan, + outer_other_exprs: &[Expr], + optimizer_config: &mut OptimizerConfig, +) -> datafusion_common::Result { + let proj = Projection::try_from_plan(&*query_info.query.subquery) + .map_err(|e| context!("a projection is required", e))?; + let mut subqry_input = proj.input.clone(); + let proj = only_or_err(proj.expr.as_slice()) + .map_err(|e| context!("single expression projection required", e))?; + let subquery_col = proj + .try_into_col() + .map_err(|e| context!("single column projection required", e))?; + let outer_col = query_info + .where_in_expr + .try_into_col() + .map_err(|e| context!("column comparison required", e))?; + + // If subquery is correlated, grab necessary information + let mut subqry_cols = vec![]; + let mut outer_cols = vec![]; + let mut join_filters = None; + let mut other_subqry_exprs = vec![]; + if let LogicalPlan::Filter(subqry_filter) = (*subqry_input).clone() { + // split into filters + let mut subqry_filter_exprs = vec![]; + split_conjunction(&subqry_filter.predicate, &mut subqry_filter_exprs); + verify_not_disjunction(&subqry_filter_exprs)?; + + // Grab column names to join on + let (col_exprs, other_exprs) = + find_join_exprs(subqry_filter_exprs, subqry_filter.input.schema()) + .map_err(|e| context!("column correlation not found", e))?; + if !col_exprs.is_empty() { + // it's correlated + subqry_input = subqry_filter.input.clone(); + (outer_cols, subqry_cols, join_filters) = + exprs_to_join_cols(&col_exprs, subqry_filter.input.schema(), false) + .map_err(|e| context!("column correlation not found", e))?; + other_subqry_exprs = other_exprs; + } + } + + let (subqry_cols, outer_cols) = + merge_cols((&[subquery_col], &subqry_cols), (&[outer_col], &outer_cols)); + + // build subquery side of join - the thing the subquery was querying + let subqry_alias = format!("__sq_{}", optimizer_config.next_id()); + let mut subqry_plan = LogicalPlanBuilder::from((*subqry_input).clone()); + if let Some(expr) = combine_filters(&other_subqry_exprs) { + // if the subquery had additional expressions, restore them + subqry_plan = subqry_plan.filter(expr)? + } + let projection = alias_cols(&subqry_cols); + let subqry_plan = subqry_plan + .project_with_alias(projection, Some(subqry_alias.clone()))? + .build()?; + debug!("subquery plan:\n{}", subqry_plan.display_indent()); + + // qualify the join columns for outside the subquery + let subqry_cols = swap_table(&subqry_alias, &subqry_cols); + let join_keys = (outer_cols, subqry_cols); + + // join our sub query into the main plan + let join_type = match query_info.negated { + true => JoinType::Anti, + false => JoinType::Semi, + }; + let mut new_plan = LogicalPlanBuilder::from(outer_input.clone()).join( + &subqry_plan, + join_type, + join_keys, + join_filters, + )?; + if let Some(expr) = combine_filters(outer_other_exprs) { + new_plan = new_plan.filter(expr)? // if the main query had additional expressions, restore them + } + let new_plan = new_plan.build()?; + + debug!("where in optimized:\n{}", new_plan.display_indent()); + Ok(new_plan) +} + +struct SubqueryInfo { + query: Subquery, + where_in_expr: Expr, + negated: bool, +} + +impl SubqueryInfo { + pub fn new(query: Subquery, expr: Expr, negated: bool) -> Self { + Self { + query, + where_in_expr: expr, + negated, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::test::*; + use datafusion_common::Result; + use datafusion_expr::{ + col, in_subquery, lit, logical_plan::LogicalPlanBuilder, not_in_subquery, + }; + use std::ops::Add; + + #[cfg(test)] + #[ctor::ctor] + fn init() { + let _ = env_logger::try_init(); + } + + /// Test multiple correlated subqueries + /// See subqueries.rs where_in_multiple() + #[test] + fn multiple_subqueries() -> Result<()> { + let orders = Arc::new( + LogicalPlanBuilder::from(scan_tpch_table("orders")) + .filter(col("orders.o_custkey").eq(col("customer.c_custkey")))? + .project(vec![col("orders.o_custkey")])? + .build()?, + ); + let plan = LogicalPlanBuilder::from(scan_tpch_table("customer")) + .filter( + in_subquery(col("customer.c_custkey"), orders.clone()) + .and(in_subquery(col("customer.c_custkey"), orders)), + )? + .project(vec![col("customer.c_custkey")])? + .build()?; + debug!("plan to optimize:\n{}", plan.display_indent()); + + let expected = r#"Projection: #customer.c_custkey [c_custkey:Int64] + Semi Join: #customer.c_custkey = #__sq_2.o_custkey [c_custkey:Int64, c_name:Utf8] + Semi Join: #customer.c_custkey = #__sq_1.o_custkey [c_custkey:Int64, c_name:Utf8] + TableScan: customer [c_custkey:Int64, c_name:Utf8] + Projection: #orders.o_custkey AS o_custkey, alias=__sq_1 [o_custkey:Int64] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + Projection: #orders.o_custkey AS o_custkey, alias=__sq_2 [o_custkey:Int64] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"#; + assert_optimized_plan_eq(&DecorrelateWhereIn::new(), &plan, expected); + Ok(()) + } + + /// Test recursive correlated subqueries + /// See subqueries.rs where_in_recursive() + #[test] + fn recursive_subqueries() -> Result<()> { + let lineitem = Arc::new( + LogicalPlanBuilder::from(scan_tpch_table("lineitem")) + .filter(col("lineitem.l_orderkey").eq(col("orders.o_orderkey")))? + .project(vec![col("lineitem.l_orderkey")])? + .build()?, + ); + + let orders = Arc::new( + LogicalPlanBuilder::from(scan_tpch_table("orders")) + .filter( + in_subquery(col("orders.o_orderkey"), lineitem) + .and(col("orders.o_custkey").eq(col("customer.c_custkey"))), + )? + .project(vec![col("orders.o_custkey")])? + .build()?, + ); + + let plan = LogicalPlanBuilder::from(scan_tpch_table("customer")) + .filter(in_subquery(col("customer.c_custkey"), orders))? + .project(vec![col("customer.c_custkey")])? + .build()?; + + let expected = r#"Projection: #customer.c_custkey [c_custkey:Int64] + Semi Join: #customer.c_custkey = #__sq_2.o_custkey [c_custkey:Int64, c_name:Utf8] + TableScan: customer [c_custkey:Int64, c_name:Utf8] + Projection: #orders.o_custkey AS o_custkey, alias=__sq_2 [o_custkey:Int64] + Semi Join: #orders.o_orderkey = #__sq_1.l_orderkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + Projection: #lineitem.l_orderkey AS l_orderkey, alias=__sq_1 [l_orderkey:Int64] + TableScan: lineitem [l_orderkey:Int64, l_partkey:Int64, l_suppkey:Int64, l_linenumber:Int32, l_quantity:Float64, l_extendedprice:Float64]"#; + + assert_optimized_plan_eq(&DecorrelateWhereIn::new(), &plan, expected); + Ok(()) + } + + /// Test for correlated IN subquery filter with additional subquery filters + #[test] + fn in_subquery_with_subquery_filters() -> Result<()> { + let sq = Arc::new( + LogicalPlanBuilder::from(scan_tpch_table("orders")) + .filter( + col("customer.c_custkey") + .eq(col("orders.o_custkey")) + .and(col("o_orderkey").eq(lit(1))), + )? + .project(vec![col("orders.o_custkey")])? + .build()?, + ); + + let plan = LogicalPlanBuilder::from(scan_tpch_table("customer")) + .filter(in_subquery(col("customer.c_custkey"), sq))? + .project(vec![col("customer.c_custkey")])? + .build()?; + + let expected = r#"Projection: #customer.c_custkey [c_custkey:Int64] + Semi Join: #customer.c_custkey = #__sq_1.o_custkey [c_custkey:Int64, c_name:Utf8] + TableScan: customer [c_custkey:Int64, c_name:Utf8] + Projection: #orders.o_custkey AS o_custkey, alias=__sq_1 [o_custkey:Int64] + Filter: #orders.o_orderkey = Int32(1) [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"#; + + assert_optimized_plan_eq(&DecorrelateWhereIn::new(), &plan, expected); + Ok(()) + } + + /// Test for correlated IN subquery with no columns in schema + #[test] + fn in_subquery_no_cols() -> Result<()> { + let sq = Arc::new( + LogicalPlanBuilder::from(scan_tpch_table("orders")) + .filter(col("customer.c_custkey").eq(col("customer.c_custkey")))? + .project(vec![col("orders.o_custkey")])? + .build()?, + ); + + let plan = LogicalPlanBuilder::from(scan_tpch_table("customer")) + .filter(in_subquery(col("customer.c_custkey"), sq))? + .project(vec![col("customer.c_custkey")])? + .build()?; + + // Query will fail, but we can still transform the plan + let expected = r#"Projection: #customer.c_custkey [c_custkey:Int64] + Semi Join: #customer.c_custkey = #__sq_1.o_custkey [c_custkey:Int64, c_name:Utf8] + TableScan: customer [c_custkey:Int64, c_name:Utf8] + Projection: #orders.o_custkey AS o_custkey, alias=__sq_1 [o_custkey:Int64] + Filter: #customer.c_custkey = #customer.c_custkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"#; + + assert_optimized_plan_eq(&DecorrelateWhereIn::new(), &plan, expected); + Ok(()) + } + + /// Test for IN subquery with both columns in schema + #[test] + fn in_subquery_with_no_correlated_cols() -> Result<()> { + let sq = Arc::new( + LogicalPlanBuilder::from(scan_tpch_table("orders")) + .filter(col("orders.o_custkey").eq(col("orders.o_custkey")))? + .project(vec![col("orders.o_custkey")])? + .build()?, + ); + + let plan = LogicalPlanBuilder::from(scan_tpch_table("customer")) + .filter(in_subquery(col("customer.c_custkey"), sq))? + .project(vec![col("customer.c_custkey")])? + .build()?; + + let expected = r#"Projection: #customer.c_custkey [c_custkey:Int64] + Semi Join: #customer.c_custkey = #__sq_1.o_custkey [c_custkey:Int64, c_name:Utf8] + TableScan: customer [c_custkey:Int64, c_name:Utf8] + Projection: #orders.o_custkey AS o_custkey, alias=__sq_1 [o_custkey:Int64] + Filter: #orders.o_custkey = #orders.o_custkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"#; + + assert_optimized_plan_eq(&DecorrelateWhereIn::new(), &plan, expected); + Ok(()) + } + + /// Test for correlated IN subquery not equal + #[test] + fn in_subquery_where_not_eq() -> Result<()> { + let sq = Arc::new( + LogicalPlanBuilder::from(scan_tpch_table("orders")) + .filter(col("customer.c_custkey").not_eq(col("orders.o_custkey")))? + .project(vec![col("orders.o_custkey")])? + .build()?, + ); + + let plan = LogicalPlanBuilder::from(scan_tpch_table("customer")) + .filter(in_subquery(col("customer.c_custkey"), sq))? + .project(vec![col("customer.c_custkey")])? + .build()?; + + let expected = r#"Projection: #customer.c_custkey [c_custkey:Int64] + Semi Join: #customer.c_custkey = #__sq_1.o_custkey Filter: #customer.c_custkey != #orders.o_custkey [c_custkey:Int64, c_name:Utf8] + TableScan: customer [c_custkey:Int64, c_name:Utf8] + Projection: #orders.o_custkey AS o_custkey, alias=__sq_1 [o_custkey:Int64] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"#; + + assert_optimized_plan_eq(&DecorrelateWhereIn::new(), &plan, expected); + Ok(()) + } + + /// Test for correlated IN subquery less than + #[test] + fn in_subquery_where_less_than() -> Result<()> { + let sq = Arc::new( + LogicalPlanBuilder::from(scan_tpch_table("orders")) + .filter(col("customer.c_custkey").lt(col("orders.o_custkey")))? + .project(vec![col("orders.o_custkey")])? + .build()?, + ); + + let plan = LogicalPlanBuilder::from(scan_tpch_table("customer")) + .filter(in_subquery(col("customer.c_custkey"), sq))? + .project(vec![col("customer.c_custkey")])? + .build()?; + + // can't optimize on arbitrary expressions (yet) + assert_optimizer_err( + &DecorrelateWhereIn::new(), + &plan, + "column correlation not found", + ); + Ok(()) + } + + /// Test for correlated IN subquery filter with subquery disjunction + #[test] + fn in_subquery_with_subquery_disjunction() -> Result<()> { + let sq = Arc::new( + LogicalPlanBuilder::from(scan_tpch_table("orders")) + .filter( + col("customer.c_custkey") + .eq(col("orders.o_custkey")) + .or(col("o_orderkey").eq(lit(1))), + )? + .project(vec![col("orders.o_custkey")])? + .build()?, + ); + + let plan = LogicalPlanBuilder::from(scan_tpch_table("customer")) + .filter(in_subquery(col("customer.c_custkey"), sq))? + .project(vec![col("customer.c_custkey")])? + .build()?; + + assert_optimizer_err( + &DecorrelateWhereIn::new(), + &plan, + "Optimizing disjunctions not supported!", + ); + Ok(()) + } + + /// Test for correlated IN without projection + #[test] + fn in_subquery_no_projection() -> Result<()> { + let sq = Arc::new( + LogicalPlanBuilder::from(scan_tpch_table("orders")) + .filter(col("customer.c_custkey").eq(col("orders.o_custkey")))? + .build()?, + ); + + let plan = LogicalPlanBuilder::from(scan_tpch_table("customer")) + .filter(in_subquery(col("customer.c_custkey"), sq))? + .project(vec![col("customer.c_custkey")])? + .build()?; + + // Maybe okay if the table only has a single column? + assert_optimizer_err( + &DecorrelateWhereIn::new(), + &plan, + "a projection is required", + ); + Ok(()) + } + + /// Test for correlated IN subquery join on expression + #[test] + fn in_subquery_join_expr() -> Result<()> { + let sq = Arc::new( + LogicalPlanBuilder::from(scan_tpch_table("orders")) + .filter(col("customer.c_custkey").eq(col("orders.o_custkey")))? + .project(vec![col("orders.o_custkey")])? + .build()?, + ); + + let plan = LogicalPlanBuilder::from(scan_tpch_table("customer")) + .filter(in_subquery(col("customer.c_custkey").add(lit(1)), sq))? + .project(vec![col("customer.c_custkey")])? + .build()?; + + // TODO: support join on expression + assert_optimizer_err( + &DecorrelateWhereIn::new(), + &plan, + "column comparison required", + ); + Ok(()) + } + + /// Test for correlated IN expressions + #[test] + fn in_subquery_project_expr() -> Result<()> { + let sq = Arc::new( + LogicalPlanBuilder::from(scan_tpch_table("orders")) + .filter(col("customer.c_custkey").eq(col("orders.o_custkey")))? + .project(vec![col("orders.o_custkey").add(lit(1))])? + .build()?, + ); + + let plan = LogicalPlanBuilder::from(scan_tpch_table("customer")) + .filter(in_subquery(col("customer.c_custkey"), sq))? + .project(vec![col("customer.c_custkey")])? + .build()?; + + // TODO: support join on expressions? + assert_optimizer_err( + &DecorrelateWhereIn::new(), + &plan, + "single column projection required", + ); + Ok(()) + } + + /// Test for correlated IN subquery multiple projected columns + #[test] + fn in_subquery_multi_col() -> Result<()> { + let sq = Arc::new( + LogicalPlanBuilder::from(scan_tpch_table("orders")) + .filter(col("customer.c_custkey").eq(col("orders.o_custkey")))? + .project(vec![col("orders.o_custkey"), col("orders.o_orderkey")])? + .build()?, + ); + + let plan = LogicalPlanBuilder::from(scan_tpch_table("customer")) + .filter( + in_subquery(col("customer.c_custkey"), sq) + .and(col("c_custkey").eq(lit(1))), + )? + .project(vec![col("customer.c_custkey")])? + .build()?; + + assert_optimizer_err( + &DecorrelateWhereIn::new(), + &plan, + "single expression projection required", + ); + Ok(()) + } + + /// Test for correlated IN subquery filter with additional filters + #[test] + fn should_support_additional_filters() -> Result<()> { + let sq = Arc::new( + LogicalPlanBuilder::from(scan_tpch_table("orders")) + .filter(col("customer.c_custkey").eq(col("orders.o_custkey")))? + .project(vec![col("orders.o_custkey")])? + .build()?, + ); + + let plan = LogicalPlanBuilder::from(scan_tpch_table("customer")) + .filter( + in_subquery(col("customer.c_custkey"), sq) + .and(col("c_custkey").eq(lit(1))), + )? + .project(vec![col("customer.c_custkey")])? + .build()?; + + let expected = r#"Projection: #customer.c_custkey [c_custkey:Int64] + Filter: #customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8] + Semi Join: #customer.c_custkey = #__sq_1.o_custkey [c_custkey:Int64, c_name:Utf8] + TableScan: customer [c_custkey:Int64, c_name:Utf8] + Projection: #orders.o_custkey AS o_custkey, alias=__sq_1 [o_custkey:Int64] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"#; + + assert_optimized_plan_eq(&DecorrelateWhereIn::new(), &plan, expected); + Ok(()) + } + + /// Test for correlated IN subquery filter with disjustions + #[test] + fn in_subquery_disjunction() -> Result<()> { + let sq = Arc::new( + LogicalPlanBuilder::from(scan_tpch_table("orders")) + .filter(col("customer.c_custkey").eq(col("orders.o_custkey")))? + .project(vec![col("orders.o_custkey")])? + .build()?, + ); + + let plan = LogicalPlanBuilder::from(scan_tpch_table("customer")) + .filter( + in_subquery(col("customer.c_custkey"), sq) + .or(col("customer.c_custkey").eq(lit(1))), + )? + .project(vec![col("customer.c_custkey")])? + .build()?; + + // TODO: support disjunction - for now expect unaltered plan + let expected = r#"Projection: #customer.c_custkey [c_custkey:Int64] + Filter: #customer.c_custkey IN () OR #customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8] + Subquery: [o_custkey:Int64] + Projection: #orders.o_custkey [o_custkey:Int64] + Filter: #customer.c_custkey = #orders.o_custkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + TableScan: customer [c_custkey:Int64, c_name:Utf8]"#; + + assert_optimized_plan_eq(&DecorrelateWhereIn::new(), &plan, expected); + Ok(()) + } + + /// Test for correlated IN subquery filter + #[test] + fn in_subquery_correlated() -> Result<()> { + let sq = Arc::new( + LogicalPlanBuilder::from(test_table_scan_with_name("sq")?) + .filter(col("test.a").eq(col("sq.a")))? + .project(vec![col("c")])? + .build()?, + ); + + let plan = LogicalPlanBuilder::from(test_table_scan_with_name("test")?) + .filter(in_subquery(col("c"), sq))? + .project(vec![col("test.b")])? + .build()?; + + let expected = r#"Projection: #test.b [b:UInt32] + Semi Join: #test.c = #__sq_1.c, #test.a = #__sq_1.a [a:UInt32, b:UInt32, c:UInt32] + TableScan: test [a:UInt32, b:UInt32, c:UInt32] + Projection: #sq.c AS c, #sq.a AS a, alias=__sq_1 [c:UInt32, a:UInt32] + TableScan: sq [a:UInt32, b:UInt32, c:UInt32]"#; + + assert_optimized_plan_eq(&DecorrelateWhereIn::new(), &plan, expected); + Ok(()) + } + + /// Test for single IN subquery filter + #[test] + fn in_subquery_simple() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan) + .filter(in_subquery(col("c"), test_subquery_with_name("sq")?))? + .project(vec![col("test.b")])? + .build()?; + + let expected = r#"Projection: #test.b [b:UInt32] + Semi Join: #test.c = #__sq_1.c [a:UInt32, b:UInt32, c:UInt32] + TableScan: test [a:UInt32, b:UInt32, c:UInt32] + Projection: #sq.c AS c, alias=__sq_1 [c:UInt32] + TableScan: sq [a:UInt32, b:UInt32, c:UInt32]"#; + + assert_optimized_plan_eq(&DecorrelateWhereIn::new(), &plan, expected); + Ok(()) + } + + /// Test for single NOT IN subquery filter + #[test] + fn not_in_subquery_simple() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan) + .filter(not_in_subquery(col("c"), test_subquery_with_name("sq")?))? + .project(vec![col("test.b")])? + .build()?; + + let expected = r#"Projection: #test.b [b:UInt32] + Anti Join: #test.c = #__sq_1.c [a:UInt32, b:UInt32, c:UInt32] + TableScan: test [a:UInt32, b:UInt32, c:UInt32] + Projection: #sq.c AS c, alias=__sq_1 [c:UInt32] + TableScan: sq [a:UInt32, b:UInt32, c:UInt32]"#; + + assert_optimized_plan_eq(&DecorrelateWhereIn::new(), &plan, expected); + Ok(()) + } +} diff --git a/datafusion/optimizer/src/lib.rs b/datafusion/optimizer/src/lib.rs index a6b7cfcbb8fb..588903ad08e2 100644 --- a/datafusion/optimizer/src/lib.rs +++ b/datafusion/optimizer/src/lib.rs @@ -16,6 +16,9 @@ // under the License. pub mod common_subexpr_eliminate; +pub mod decorrelate_scalar_subquery; +pub mod decorrelate_where_exists; +pub mod decorrelate_where_in; pub mod eliminate_filter; pub mod eliminate_limit; pub mod expr_simplifier; diff --git a/datafusion/optimizer/src/test/mod.rs b/datafusion/optimizer/src/test/mod.rs index 86e12bc30c09..fc7d0bd8a1d5 100644 --- a/datafusion/optimizer/src/test/mod.rs +++ b/datafusion/optimizer/src/test/mod.rs @@ -15,9 +15,11 @@ // specific language governing permissions and limitations // under the License. +use crate::{OptimizerConfig, OptimizerRule}; use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::Result; -use datafusion_expr::{logical_plan::table_scan, LogicalPlan, LogicalPlanBuilder}; +use datafusion_expr::{col, logical_plan::table_scan, LogicalPlan, LogicalPlanBuilder}; +use std::sync::Arc; pub mod user_defined; @@ -54,3 +56,76 @@ pub fn assert_fields_eq(plan: &LogicalPlan, expected: Vec<&str>) { .collect(); assert_eq!(actual, expected); } + +pub fn test_subquery_with_name(name: &str) -> Result> { + let table_scan = test_table_scan_with_name(name)?; + Ok(Arc::new( + LogicalPlanBuilder::from(table_scan) + .project(vec![col("c")])? + .build()?, + )) +} + +pub fn scan_tpch_table(table: &str) -> LogicalPlan { + let schema = Arc::new(get_tpch_table_schema(table)); + table_scan(Some(table), &schema, None) + .unwrap() + .build() + .unwrap() +} + +pub fn get_tpch_table_schema(table: &str) -> Schema { + match table { + "customer" => Schema::new(vec![ + Field::new("c_custkey", DataType::Int64, false), + Field::new("c_name", DataType::Utf8, false), + ]), + + "orders" => Schema::new(vec![ + Field::new("o_orderkey", DataType::Int64, false), + Field::new("o_custkey", DataType::Int64, false), + Field::new("o_orderstatus", DataType::Utf8, false), + Field::new("o_totalprice", DataType::Float64, true), + ]), + + "lineitem" => Schema::new(vec![ + Field::new("l_orderkey", DataType::Int64, false), + Field::new("l_partkey", DataType::Int64, false), + Field::new("l_suppkey", DataType::Int64, false), + Field::new("l_linenumber", DataType::Int32, false), + Field::new("l_quantity", DataType::Float64, false), + Field::new("l_extendedprice", DataType::Float64, false), + ]), + + _ => unimplemented!("Table: {}", table), + } +} + +pub fn assert_optimized_plan_eq( + rule: &dyn OptimizerRule, + plan: &LogicalPlan, + expected: &str, +) { + let optimized_plan = rule + .optimize(plan, &mut OptimizerConfig::new()) + .expect("failed to optimize plan"); + let formatted_plan = format!("{}", optimized_plan.display_indent_schema()); + assert_eq!(formatted_plan, expected); +} + +pub fn assert_optimizer_err( + rule: &dyn OptimizerRule, + plan: &LogicalPlan, + expected: &str, +) { + let res = rule.optimize(plan, &mut OptimizerConfig::new()); + match res { + Ok(plan) => assert_eq!(format!("{}", plan.display_indent()), "An error"), + Err(ref e) => { + let actual = format!("{}", e); + if expected.is_empty() || !actual.contains(expected) { + assert_eq!(actual, expected) + } + } + } +} diff --git a/datafusion/optimizer/src/utils.rs b/datafusion/optimizer/src/utils.rs index cd70c50913a4..41c75d689f5d 100644 --- a/datafusion/optimizer/src/utils.rs +++ b/datafusion/optimizer/src/utils.rs @@ -19,12 +19,15 @@ use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::Result; +use datafusion_common::{plan_err, Column, DFSchemaRef}; +use datafusion_expr::expr_visitor::{ExprVisitable, ExpressionVisitor, Recursion}; use datafusion_expr::{ - and, + and, col, combine_filters, logical_plan::{Filter, LogicalPlan}, utils::from_plan, Expr, Operator, }; +use std::collections::HashSet; use std::sync::Arc; /// Convenience rule for writing optimizers: recursively invoke @@ -65,6 +68,40 @@ pub fn split_conjunction<'a>(predicate: &'a Expr, predicates: &mut Vec<&'a Expr> } } +/// Recursively scans a slice of expressions for any `Or` operators +/// +/// # Arguments +/// +/// * `predicates` - the expressions to scan +/// +/// # Return value +/// +/// A PlanError if a disjunction is found +pub fn verify_not_disjunction(predicates: &[&Expr]) -> Result<()> { + struct DisjunctionVisitor {} + + impl ExpressionVisitor for DisjunctionVisitor { + fn pre_visit(self, expr: &Expr) -> Result> { + match expr { + Expr::BinaryExpr { + left: _, + op: Operator::Or, + right: _, + } => { + plan_err!("Optimizing disjunctions not supported!") + } + _ => Ok(Recursion::Continue(self)), + } + } + } + + for predicate in predicates.iter() { + predicate.accept(DisjunctionVisitor {})?; + } + + Ok(()) +} + /// returns a new [LogicalPlan] that wraps `plan` in a [LogicalPlan::Filter] with /// its predicate be all `predicates` ANDed. pub fn add_filter(plan: LogicalPlan, predicates: &[&Expr]) -> LogicalPlan { @@ -82,6 +119,202 @@ pub fn add_filter(plan: LogicalPlan, predicates: &[&Expr]) -> LogicalPlan { }) } +/// Looks for correlating expressions: equality expressions with one field from the subquery, and +/// one not in the subquery (closed upon from outer scope) +/// +/// # Arguments +/// +/// * `exprs` - List of expressions that may or may not be joins +/// * `fields` - HashSet of fully qualified (table.col) fields in subquery schema +/// +/// # Return value +/// +/// Tuple of (expressions containing joins, remaining non-join expressions) +pub fn find_join_exprs( + exprs: Vec<&Expr>, + schema: &DFSchemaRef, +) -> Result<(Vec, Vec)> { + let fields: HashSet<_> = schema + .fields() + .iter() + .map(|it| it.qualified_name()) + .collect(); + + let mut joins = vec![]; + let mut others = vec![]; + for filter in exprs.iter() { + let (left, op, right) = match filter { + Expr::BinaryExpr { left, op, right } => (*left.clone(), *op, *right.clone()), + _ => { + others.push((*filter).clone()); + continue; + } + }; + let left = match left { + Expr::Column(c) => c, + _ => { + others.push((*filter).clone()); + continue; + } + }; + let right = match right { + Expr::Column(c) => c, + _ => { + others.push((*filter).clone()); + continue; + } + }; + if fields.contains(&left.flat_name()) && fields.contains(&right.flat_name()) { + others.push((*filter).clone()); + continue; // both columns present (none closed-upon) + } + if !fields.contains(&left.flat_name()) && !fields.contains(&right.flat_name()) { + others.push((*filter).clone()); + continue; // neither column present (syntax error?) + } + match op { + Operator::Eq => {} + Operator::NotEq => {} + _ => { + plan_err!(format!("can't optimize {} column comparison", op))?; + } + } + + joins.push((*filter).clone()) + } + + Ok((joins, others)) +} + +/// Extracts correlating columns from expressions +/// +/// # Arguments +/// +/// * `exprs` - List of expressions that correlate a subquery to an outer scope +/// * `fields` - HashSet of fully qualified (table.col) fields in subquery schema +/// * `include_negated` - true if `NotEq` counts as a join operator +/// +/// # Return value +/// +/// Tuple of (outer-scope cols, subquery cols, non-correlation expressions) +pub fn exprs_to_join_cols( + exprs: &[Expr], + schema: &DFSchemaRef, + include_negated: bool, +) -> Result<(Vec, Vec, Option)> { + let fields: HashSet<_> = schema + .fields() + .iter() + .map(|it| it.qualified_name()) + .collect(); + + let mut joins: Vec<(String, String)> = vec![]; + let mut others: Vec = vec![]; + for filter in exprs.iter() { + let (left, op, right) = match filter { + Expr::BinaryExpr { left, op, right } => (*left.clone(), *op, *right.clone()), + _ => plan_err!("Invalid correlation expression!")?, + }; + match op { + Operator::Eq => {} + Operator::NotEq => { + if !include_negated { + others.push((*filter).clone()); + continue; + } + } + _ => plan_err!(format!("Correlation operator unsupported: {}", op))?, + } + let left = left.try_into_col()?; + let right = right.try_into_col()?; + let sorted = if fields.contains(&left.flat_name()) { + (right.flat_name(), left.flat_name()) + } else { + (left.flat_name(), right.flat_name()) + }; + joins.push(sorted); + } + + let (left_cols, right_cols): (Vec<_>, Vec<_>) = joins + .into_iter() + .map(|(l, r)| (Column::from(l.as_str()), Column::from(r.as_str()))) + .unzip(); + let pred = combine_filters(&others); + + Ok((left_cols, right_cols, pred)) +} + +/// Returns the first (and only) element in a slice, or an error +/// +/// # Arguments +/// +/// * `slice` - The slice to extract from +/// +/// # Return value +/// +/// The first element, or an error +pub fn only_or_err(slice: &[T]) -> Result<&T> { + match slice { + [it] => Ok(it), + [] => plan_err!("No items found!"), + _ => plan_err!("More than one item found!"), + } +} + +/// Merge and deduplicate two sets Column slices +/// +/// # Arguments +/// +/// * `a` - A tuple of slices of Columns +/// * `b` - A tuple of slices of Columns +/// +/// # Return value +/// +/// The deduplicated union of the two slices +pub fn merge_cols( + a: (&[Column], &[Column]), + b: (&[Column], &[Column]), +) -> (Vec, Vec) { + let e = + a.0.iter() + .map(|it| it.flat_name()) + .chain(a.1.iter().map(|it| it.flat_name())) + .map(|it| Column::from(it.as_str())); + let f = + b.0.iter() + .map(|it| it.flat_name()) + .chain(b.1.iter().map(|it| it.flat_name())) + .map(|it| Column::from(it.as_str())); + let mut g = e.zip(f).collect::>(); + g.dedup(); + g.into_iter().unzip() +} + +/// Change the relation on a slice of Columns +/// +/// # Arguments +/// +/// * `new_table` - The table/relation for the new columns +/// * `cols` - A slice of Columns +/// +/// # Return value +/// +/// A new slice of columns, now belonging to the new table +pub fn swap_table(new_table: &str, cols: &[Column]) -> Vec { + cols.iter() + .map(|it| Column { + relation: Some(new_table.to_string()), + name: it.name.clone(), + }) + .collect() +} + +pub fn alias_cols(cols: &[Column]) -> Vec { + cols.iter() + .map(|it| col(it.flat_name().as_str()).alias(it.name.as_str())) + .collect() +} + #[cfg(test)] mod tests { use super::*;