Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix aggregate type coercion bug #3710

Merged
merged 7 commits into from Oct 5, 2022
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
16 changes: 10 additions & 6 deletions datafusion/optimizer/src/optimizer.rs
Expand Up @@ -178,16 +178,15 @@ impl Optimizer {
F: FnMut(&LogicalPlan, &dyn OptimizerRule),
{
let mut new_plan = plan.clone();
debug!("Input logical plan:\n{}\n", plan.display_indent());
trace!("Full input logical plan:\n{:?}", plan);
log_plan("Optimizer input", plan);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a driveby cleanup to improve logging (specifically, also add trace! to log schema)


for rule in &self.rules {
let result = rule.optimize(&new_plan, optimizer_config);
match result {
Ok(plan) => {
new_plan = plan;
observer(&new_plan, rule.as_ref());
debug!("After apply {} rule:\n", rule.name());
debug!("Optimized logical plan:\n{}\n", new_plan.display_indent());
log_plan(rule.name(), &new_plan);
}
Err(ref e) => {
if optimizer_config.skip_failing_rules {
Expand All @@ -209,12 +208,17 @@ impl Optimizer {
}
}
}
debug!("Optimized logical plan:\n{}\n", new_plan.display_indent());
trace!("Full Optimized logical plan:\n {:?}", new_plan);
log_plan("Optimized plan", &new_plan);
Ok(new_plan)
}
}

/// Log the plan in debug/tracing mode after some part of the optimizer runs
fn log_plan(description: &str, plan: &LogicalPlan) {
debug!("{description}:\n{}\n", plan.display_indent());
trace!("{description}::\n{}\n", plan.display_indent_schema());
}

#[cfg(test)]
mod tests {
use crate::optimizer::Optimizer;
Expand Down
37 changes: 36 additions & 1 deletion datafusion/optimizer/src/unwrap_cast_in_comparison.rs
Expand Up @@ -97,12 +97,47 @@ fn optimize(plan: &LogicalPlan) -> Result<LogicalPlan> {
let new_exprs = plan
.expressions()
.into_iter()
.map(|expr| expr.rewrite(&mut expr_rewriter))
.map(|expr| {
let original_name = name_for_alias(&expr)?;
let expr = expr.rewrite(&mut expr_rewriter)?;
add_alias_if_changed(&original_name, expr)
})
.collect::<Result<Vec<_>>>()?;

from_plan(plan, new_exprs.as_slice(), new_inputs.as_slice())
}

fn name_for_alias(expr: &Expr) -> Result<String> {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I plan to make this easier on the eyes as a follow on PR

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

follow on #3727

match expr {
Expr::Sort { expr, .. } => name_for_alias(expr),
expr => expr.name(),
}
}

fn add_alias_if_changed(original_name: &str, expr: Expr) -> Result<Expr> {
let new_name = name_for_alias(&expr)?;

if new_name == original_name {
return Ok(expr);
}

Ok(match expr {
Expr::Sort {
expr,
asc,
nulls_first,
} => {
let expr = add_alias_if_changed(original_name, *expr)?;
Expr::Sort {
expr: Box::new(expr),
asc,
nulls_first,
}
}
expr => expr.alias(&original_name),
})
}

struct UnwrapCastExprRewriter {
schema: DFSchemaRef,
}
Expand Down
21 changes: 19 additions & 2 deletions datafusion/optimizer/tests/integration-test.rs
Expand Up @@ -29,13 +29,19 @@ use std::any::Any;
use std::collections::HashMap;
use std::sync::Arc;

#[cfg(test)]
#[ctor::ctor]
fn init() {
let _ = env_logger::try_init();
}

#[test]
fn case_when() -> Result<()> {
let sql = "SELECT CASE WHEN col_int32 > 0 THEN 1 ELSE 0 END FROM test";
let plan = test_sql(sql)?;
let expected =
"Projection: CASE WHEN test.col_int32 > Int32(0) THEN Int64(1) ELSE Int64(0) END\
\n TableScan: test projection=[col_int32]";
"Projection: CASE WHEN test.col_int32 > Int32(0) THEN Int64(1) ELSE Int64(0) END AS CASE WHEN test.col_int32 > Int64(0) THEN Int64(1) ELSE Int64(0) END\
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI @andygrove the alias was added to this as well

\n TableScan: test projection=[col_int32]";
assert_eq!(expected, format!("{:?}", plan));

let sql = "SELECT CASE WHEN col_uint32 > 0 THEN 1 ELSE 0 END FROM test";
Expand All @@ -46,6 +52,17 @@ fn case_when() -> Result<()> {
Ok(())
}

#[test]
fn case_when_aggregate() -> Result<()> {
let sql = "SELECT col_utf8, SUM(CASE WHEN col_int32 > 0 THEN 1 ELSE 0 END) AS n FROM test GROUP BY col_utf8";
let plan = test_sql(sql)?;
let expected = "Projection: test.col_utf8, SUM(CASE WHEN test.col_int32 > Int64(0) THEN Int64(1) ELSE Int64(0) END) AS n\
\n Aggregate: groupBy=[[test.col_utf8]], aggr=[[SUM(CASE WHEN test.col_int32 > Int32(0) THEN Int64(1) ELSE Int64(0) END) AS SUM(CASE WHEN test.col_int32 > Int64(0) THEN Int64(1) ELSE Int64(0) END)]]\
\n TableScan: test projection=[col_int32, col_utf8]";
assert_eq!(expected, format!("{:?}", plan));
Ok(())
}

#[test]
fn unsigned_target_type() -> Result<()> {
let sql = "SELECT * FROM test WHERE col_uint32 > 0";
Expand Down