From 64669e997bce2f90b400614e97a87a60c5a25f3c Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Wed, 5 Oct 2022 14:28:57 -0400 Subject: [PATCH] Fix aggregate type coercion bug (#3710) * Do not change output expr name in `UnwrapCastInComparison` * Update * Update test * Fix regression * Update tests * clippy --- datafusion/optimizer/src/optimizer.rs | 16 +++++--- .../src/unwrap_cast_in_comparison.rs | 37 ++++++++++++++++++- .../optimizer/tests/integration-test.rs | 21 ++++++++++- 3 files changed, 65 insertions(+), 9 deletions(-) diff --git a/datafusion/optimizer/src/optimizer.rs b/datafusion/optimizer/src/optimizer.rs index 5ef5cfdd5975..aa10cd8a7dc2 100644 --- a/datafusion/optimizer/src/optimizer.rs +++ b/datafusion/optimizer/src/optimizer.rs @@ -178,16 +178,15 @@ impl Optimizer { F: FnMut(&LogicalPlan, &dyn OptimizerRule), { let mut new_plan = plan.clone(); - debug!("Input logical plan:\n{}\n", plan.display_indent()); - trace!("Full input logical plan:\n{:?}", plan); + log_plan("Optimizer input", plan); + for rule in &self.rules { let result = rule.optimize(&new_plan, optimizer_config); match result { Ok(plan) => { new_plan = plan; observer(&new_plan, rule.as_ref()); - debug!("After apply {} rule:\n", rule.name()); - debug!("Optimized logical plan:\n{}\n", new_plan.display_indent()); + log_plan(rule.name(), &new_plan); } Err(ref e) => { if optimizer_config.skip_failing_rules { @@ -209,12 +208,17 @@ impl Optimizer { } } } - debug!("Optimized logical plan:\n{}\n", new_plan.display_indent()); - trace!("Full Optimized logical plan:\n {:?}", new_plan); + log_plan("Optimized plan", &new_plan); Ok(new_plan) } } +/// Log the plan in debug/tracing mode after some part of the optimizer runs +fn log_plan(description: &str, plan: &LogicalPlan) { + debug!("{description}:\n{}\n", plan.display_indent()); + trace!("{description}::\n{}\n", plan.display_indent_schema()); +} + #[cfg(test)] mod tests { use crate::optimizer::Optimizer; diff --git a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs index 7d6858362cad..542c29bd7767 100644 --- a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs +++ b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs @@ -97,12 +97,47 @@ fn optimize(plan: &LogicalPlan) -> Result { let new_exprs = plan .expressions() .into_iter() - .map(|expr| expr.rewrite(&mut expr_rewriter)) + .map(|expr| { + let original_name = name_for_alias(&expr)?; + let expr = expr.rewrite(&mut expr_rewriter)?; + add_alias_if_changed(&original_name, expr) + }) .collect::>>()?; from_plan(plan, new_exprs.as_slice(), new_inputs.as_slice()) } +fn name_for_alias(expr: &Expr) -> Result { + match expr { + Expr::Sort { expr, .. } => name_for_alias(expr), + expr => expr.name(), + } +} + +fn add_alias_if_changed(original_name: &str, expr: Expr) -> Result { + let new_name = name_for_alias(&expr)?; + + if new_name == original_name { + return Ok(expr); + } + + Ok(match expr { + Expr::Sort { + expr, + asc, + nulls_first, + } => { + let expr = add_alias_if_changed(original_name, *expr)?; + Expr::Sort { + expr: Box::new(expr), + asc, + nulls_first, + } + } + expr => expr.alias(original_name), + }) +} + struct UnwrapCastExprRewriter { schema: DFSchemaRef, } diff --git a/datafusion/optimizer/tests/integration-test.rs b/datafusion/optimizer/tests/integration-test.rs index 2d9546f13e51..dc452af3be0a 100644 --- a/datafusion/optimizer/tests/integration-test.rs +++ b/datafusion/optimizer/tests/integration-test.rs @@ -29,13 +29,19 @@ use std::any::Any; use std::collections::HashMap; use std::sync::Arc; +#[cfg(test)] +#[ctor::ctor] +fn init() { + let _ = env_logger::try_init(); +} + #[test] fn case_when() -> Result<()> { let sql = "SELECT CASE WHEN col_int32 > 0 THEN 1 ELSE 0 END FROM test"; let plan = test_sql(sql)?; let expected = - "Projection: CASE WHEN test.col_int32 > Int32(0) THEN Int64(1) ELSE Int64(0) END\ - \n TableScan: test projection=[col_int32]"; + "Projection: CASE WHEN test.col_int32 > Int32(0) THEN Int64(1) ELSE Int64(0) END AS CASE WHEN test.col_int32 > Int64(0) THEN Int64(1) ELSE Int64(0) END\ + \n TableScan: test projection=[col_int32]"; assert_eq!(expected, format!("{:?}", plan)); let sql = "SELECT CASE WHEN col_uint32 > 0 THEN 1 ELSE 0 END FROM test"; @@ -46,6 +52,17 @@ fn case_when() -> Result<()> { Ok(()) } +#[test] +fn case_when_aggregate() -> Result<()> { + let sql = "SELECT col_utf8, SUM(CASE WHEN col_int32 > 0 THEN 1 ELSE 0 END) AS n FROM test GROUP BY col_utf8"; + let plan = test_sql(sql)?; + let expected = "Projection: test.col_utf8, SUM(CASE WHEN test.col_int32 > Int64(0) THEN Int64(1) ELSE Int64(0) END) AS n\ + \n Aggregate: groupBy=[[test.col_utf8]], aggr=[[SUM(CASE WHEN test.col_int32 > Int32(0) THEN Int64(1) ELSE Int64(0) END) AS SUM(CASE WHEN test.col_int32 > Int64(0) THEN Int64(1) ELSE Int64(0) END)]]\ + \n TableScan: test projection=[col_int32, col_utf8]"; + assert_eq!(expected, format!("{:?}", plan)); + Ok(()) +} + #[test] fn unsigned_target_type() -> Result<()> { let sql = "SELECT * FROM test WHERE col_uint32 > 0";