From af3322a14e4e790832d444a6f7f8fe1014130f29 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Tue, 4 Oct 2022 14:48:01 -0400 Subject: [PATCH 1/6] Do not change output expr name in `UnwrapCastInComparison` --- datafusion/optimizer/src/optimizer.rs | 16 ++++++++++------ .../optimizer/src/unwrap_cast_in_comparison.rs | 13 ++++++++++++- datafusion/optimizer/tests/integration-test.rs | 17 +++++++++++++++++ 3 files changed, 39 insertions(+), 7 deletions(-) diff --git a/datafusion/optimizer/src/optimizer.rs b/datafusion/optimizer/src/optimizer.rs index 5ef5cfdd5975..aa10cd8a7dc2 100644 --- a/datafusion/optimizer/src/optimizer.rs +++ b/datafusion/optimizer/src/optimizer.rs @@ -178,16 +178,15 @@ impl Optimizer { F: FnMut(&LogicalPlan, &dyn OptimizerRule), { let mut new_plan = plan.clone(); - debug!("Input logical plan:\n{}\n", plan.display_indent()); - trace!("Full input logical plan:\n{:?}", plan); + log_plan("Optimizer input", plan); + for rule in &self.rules { let result = rule.optimize(&new_plan, optimizer_config); match result { Ok(plan) => { new_plan = plan; observer(&new_plan, rule.as_ref()); - debug!("After apply {} rule:\n", rule.name()); - debug!("Optimized logical plan:\n{}\n", new_plan.display_indent()); + log_plan(rule.name(), &new_plan); } Err(ref e) => { if optimizer_config.skip_failing_rules { @@ -209,12 +208,17 @@ impl Optimizer { } } } - debug!("Optimized logical plan:\n{}\n", new_plan.display_indent()); - trace!("Full Optimized logical plan:\n {:?}", new_plan); + log_plan("Optimized plan", &new_plan); Ok(new_plan) } } +/// Log the plan in debug/tracing mode after some part of the optimizer runs +fn log_plan(description: &str, plan: &LogicalPlan) { + debug!("{description}:\n{}\n", plan.display_indent()); + trace!("{description}::\n{}\n", plan.display_indent_schema()); +} + #[cfg(test)] mod tests { use crate::optimizer::Optimizer; diff --git a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs index 7d6858362cad..e0ce44e928d5 100644 --- a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs +++ b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs @@ -97,7 +97,18 @@ fn optimize(plan: &LogicalPlan) -> Result { let new_exprs = plan .expressions() .into_iter() - .map(|expr| expr.rewrite(&mut expr_rewriter)) + .map(|expr| { + let original_name = expr.name()?; + let expr = expr.rewrite(&mut expr_rewriter)?; + + // Ensure this rewrite doesn't change the name + // https://github.com/apache/arrow-datafusion/issues/3704 + if expr.name()? != original_name { + Ok(expr.alias(&original_name)) + } else { + Ok(expr) + } + }) .collect::>>()?; from_plan(plan, new_exprs.as_slice(), new_inputs.as_slice()) diff --git a/datafusion/optimizer/tests/integration-test.rs b/datafusion/optimizer/tests/integration-test.rs index e7245c06c102..83a4edbb11a6 100644 --- a/datafusion/optimizer/tests/integration-test.rs +++ b/datafusion/optimizer/tests/integration-test.rs @@ -29,6 +29,12 @@ use std::any::Any; use std::collections::HashMap; use std::sync::Arc; +#[cfg(test)] +#[ctor::ctor] +fn init() { + let _ = env_logger::try_init(); +} + #[test] fn case_when() -> Result<()> { let sql = "SELECT CASE WHEN col_int32 > 0 THEN 1 ELSE 0 END FROM test"; @@ -45,6 +51,17 @@ fn case_when() -> Result<()> { Ok(()) } +#[test] +fn case_when_aggregate() -> Result<()> { + let sql = "SELECT col_utf8, SUM(CASE WHEN col_int32 > 0 THEN 1 ELSE 0 END) AS n FROM test GROUP BY col_utf8"; + let plan = test_sql(sql)?; + let expected = "Projection: #test.col_utf8, #SUM(CASE WHEN test.col_int32 > Int64(0) THEN Int64(1) ELSE Int64(0) END) AS n\ + \n Aggregate: groupBy=[[#test.col_utf8]], aggr=[[SUM(CASE WHEN #test.col_int32 > Int32(0) THEN Int64(1) ELSE Int64(0) END) AS SUM(CASE WHEN test.col_int32 > Int64(0) THEN Int64(1) ELSE Int64(0) END)]]\ + \n TableScan: test projection=[col_int32, col_utf8]"; + assert_eq!(expected, format!("{:?}", plan)); + Ok(()) +} + #[test] fn unsigned_target_type() -> Result<()> { let sql = "SELECT * FROM test WHERE col_uint32 > 0"; From b486da750209a311792448734177e000f5d9175a Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Tue, 4 Oct 2022 16:38:17 -0400 Subject: [PATCH 2/6] Update --- datafusion/core/tests/sql/subqueries.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/datafusion/core/tests/sql/subqueries.rs b/datafusion/core/tests/sql/subqueries.rs index 4b4f23e13bfa..00a14b36a910 100644 --- a/datafusion/core/tests/sql/subqueries.rs +++ b/datafusion/core/tests/sql/subqueries.rs @@ -147,8 +147,8 @@ order by s_acctbal desc, n_name, s_name, p_partkey;"#; Inner Join: #supplier.s_nationkey = #nation.n_nationkey Inner Join: #partsupp.ps_suppkey = #supplier.s_suppkey Inner Join: #part.p_partkey = #partsupp.ps_partkey - Filter: #part.p_size = Int32(15) AND #part.p_type LIKE Utf8("%BRASS") - TableScan: part projection=[p_partkey, p_mfgr, p_type, p_size], partial_filters=[#part.p_size = Int32(15), #part.p_type LIKE Utf8("%BRASS")] + Filter: CAST(#part.p_size AS Int64) = Int64(15) AND #part.p_type LIKE Utf8("%BRASS") + TableScan: part projection=[p_partkey, p_mfgr, p_type, p_size], partial_filters=[CAST(#part.p_size AS Int64) = Int64(15), #part.p_type LIKE Utf8("%BRASS")] TableScan: partsupp projection=[ps_partkey, ps_suppkey, ps_supplycost] TableScan: supplier projection=[s_suppkey, s_name, s_address, s_nationkey, s_phone, s_acctbal, s_comment] TableScan: nation projection=[n_nationkey, n_name, n_regionkey] From edc25618b454b2dce34eeb01902a01874a3aa8e2 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Wed, 5 Oct 2022 12:12:35 -0400 Subject: [PATCH 3/6] Update test --- datafusion/core/tests/sql/subqueries.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/datafusion/core/tests/sql/subqueries.rs b/datafusion/core/tests/sql/subqueries.rs index f91018d8bf64..add808ec8751 100644 --- a/datafusion/core/tests/sql/subqueries.rs +++ b/datafusion/core/tests/sql/subqueries.rs @@ -147,8 +147,8 @@ order by s_acctbal desc, n_name, s_name, p_partkey;"#; Inner Join: supplier.s_nationkey = nation.n_nationkey Inner Join: partsupp.ps_suppkey = supplier.s_suppkey Inner Join: part.p_partkey = partsupp.ps_partkey - Filter: part.p_size = Int32(15) AND part.p_type LIKE Utf8("%BRASS") - TableScan: part projection=[p_partkey, p_mfgr, p_type, p_size], partial_filters=[part.p_size = Int32(15), part.p_type LIKE Utf8("%BRASS")] + Filter: CAST(part.p_size AS Int64) = Int64(15) AND part.p_type LIKE Utf8("%BRASS") + TableScan: part projection=[p_partkey, p_mfgr, p_type, p_size], partial_filters=[CAST(part.p_size AS Int64) = Int64(15), part.p_type LIKE Utf8("%BRASS")] TableScan: partsupp projection=[ps_partkey, ps_suppkey, ps_supplycost] TableScan: supplier projection=[s_suppkey, s_name, s_address, s_nationkey, s_phone, s_acctbal, s_comment] TableScan: nation projection=[n_nationkey, n_name, n_regionkey] From a9d6f8c204e5d31873a770e0c05d78a770b7af2a Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Wed, 5 Oct 2022 12:38:13 -0400 Subject: [PATCH 4/6] Fix regression --- datafusion/core/tests/sql/subqueries.rs | 4 +- .../src/unwrap_cast_in_comparison.rs | 42 +++++++++++++++---- 2 files changed, 35 insertions(+), 11 deletions(-) diff --git a/datafusion/core/tests/sql/subqueries.rs b/datafusion/core/tests/sql/subqueries.rs index add808ec8751..f91018d8bf64 100644 --- a/datafusion/core/tests/sql/subqueries.rs +++ b/datafusion/core/tests/sql/subqueries.rs @@ -147,8 +147,8 @@ order by s_acctbal desc, n_name, s_name, p_partkey;"#; Inner Join: supplier.s_nationkey = nation.n_nationkey Inner Join: partsupp.ps_suppkey = supplier.s_suppkey Inner Join: part.p_partkey = partsupp.ps_partkey - Filter: CAST(part.p_size AS Int64) = Int64(15) AND part.p_type LIKE Utf8("%BRASS") - TableScan: part projection=[p_partkey, p_mfgr, p_type, p_size], partial_filters=[CAST(part.p_size AS Int64) = Int64(15), part.p_type LIKE Utf8("%BRASS")] + Filter: part.p_size = Int32(15) AND part.p_type LIKE Utf8("%BRASS") + TableScan: part projection=[p_partkey, p_mfgr, p_type, p_size], partial_filters=[part.p_size = Int32(15), part.p_type LIKE Utf8("%BRASS")] TableScan: partsupp projection=[ps_partkey, ps_suppkey, ps_supplycost] TableScan: supplier projection=[s_suppkey, s_name, s_address, s_nationkey, s_phone, s_acctbal, s_comment] TableScan: nation projection=[n_nationkey, n_name, n_regionkey] diff --git a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs index e0ce44e928d5..ec83cbab061b 100644 --- a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs +++ b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs @@ -98,22 +98,46 @@ fn optimize(plan: &LogicalPlan) -> Result { .expressions() .into_iter() .map(|expr| { - let original_name = expr.name()?; + let original_name = name_for_alias(&expr)?; let expr = expr.rewrite(&mut expr_rewriter)?; - - // Ensure this rewrite doesn't change the name - // https://github.com/apache/arrow-datafusion/issues/3704 - if expr.name()? != original_name { - Ok(expr.alias(&original_name)) - } else { - Ok(expr) - } + add_alias_if_changed(&original_name, expr) }) .collect::>>()?; from_plan(plan, new_exprs.as_slice(), new_inputs.as_slice()) } +fn name_for_alias(expr: &Expr) -> Result { + match expr { + Expr::Sort { expr, .. } => name_for_alias(expr), + expr => expr.name(), + } +} + +fn add_alias_if_changed(original_name: &str, expr: Expr) -> Result { + let new_name = name_for_alias(&expr)?; + + if new_name == original_name { + return Ok(expr); + } + + Ok(match expr { + Expr::Sort { + expr, + asc, + nulls_first, + } => { + let expr = add_alias_if_changed(original_name, *expr)?; + Expr::Sort { + expr: Box::new(expr), + asc, + nulls_first, + } + } + expr => expr.alias(&original_name), + }) +} + struct UnwrapCastExprRewriter { schema: DFSchemaRef, } From e3830a7bdf2c574a601fd9e35ccc1a290c71c035 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Wed, 5 Oct 2022 13:03:07 -0400 Subject: [PATCH 5/6] Update tests --- datafusion/optimizer/tests/integration-test.rs | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/datafusion/optimizer/tests/integration-test.rs b/datafusion/optimizer/tests/integration-test.rs index 1990ebbf6b40..dc452af3be0a 100644 --- a/datafusion/optimizer/tests/integration-test.rs +++ b/datafusion/optimizer/tests/integration-test.rs @@ -40,8 +40,8 @@ fn case_when() -> Result<()> { let sql = "SELECT CASE WHEN col_int32 > 0 THEN 1 ELSE 0 END FROM test"; let plan = test_sql(sql)?; let expected = - "Projection: CASE WHEN test.col_int32 > Int32(0) THEN Int64(1) ELSE Int64(0) END\ - \n TableScan: test projection=[col_int32]"; + "Projection: CASE WHEN test.col_int32 > Int32(0) THEN Int64(1) ELSE Int64(0) END AS CASE WHEN test.col_int32 > Int64(0) THEN Int64(1) ELSE Int64(0) END\ + \n TableScan: test projection=[col_int32]"; assert_eq!(expected, format!("{:?}", plan)); let sql = "SELECT CASE WHEN col_uint32 > 0 THEN 1 ELSE 0 END FROM test"; @@ -56,9 +56,9 @@ fn case_when() -> Result<()> { fn case_when_aggregate() -> Result<()> { let sql = "SELECT col_utf8, SUM(CASE WHEN col_int32 > 0 THEN 1 ELSE 0 END) AS n FROM test GROUP BY col_utf8"; let plan = test_sql(sql)?; - let expected = "Projection: #test.col_utf8, #SUM(CASE WHEN test.col_int32 > Int64(0) THEN Int64(1) ELSE Int64(0) END) AS n\ - \n Aggregate: groupBy=[[#test.col_utf8]], aggr=[[SUM(CASE WHEN #test.col_int32 > Int32(0) THEN Int64(1) ELSE Int64(0) END) AS SUM(CASE WHEN test.col_int32 > Int64(0) THEN Int64(1) ELSE Int64(0) END)]]\ - \n TableScan: test projection=[col_int32, col_utf8]"; + let expected = "Projection: test.col_utf8, SUM(CASE WHEN test.col_int32 > Int64(0) THEN Int64(1) ELSE Int64(0) END) AS n\ + \n Aggregate: groupBy=[[test.col_utf8]], aggr=[[SUM(CASE WHEN test.col_int32 > Int32(0) THEN Int64(1) ELSE Int64(0) END) AS SUM(CASE WHEN test.col_int32 > Int64(0) THEN Int64(1) ELSE Int64(0) END)]]\ + \n TableScan: test projection=[col_int32, col_utf8]"; assert_eq!(expected, format!("{:?}", plan)); Ok(()) } From f6e8ffa74d34b8d97a22a0c4c919d696254211f4 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Wed, 5 Oct 2022 13:42:13 -0400 Subject: [PATCH 6/6] clippy --- datafusion/optimizer/src/unwrap_cast_in_comparison.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs index ec83cbab061b..542c29bd7767 100644 --- a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs +++ b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs @@ -134,7 +134,7 @@ fn add_alias_if_changed(original_name: &str, expr: Expr) -> Result { nulls_first, } } - expr => expr.alias(&original_name), + expr => expr.alias(original_name), }) }