Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix aggregate type coercion bug #3710

Merged
merged 7 commits into from Oct 5, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 2 additions & 2 deletions datafusion/core/tests/sql/subqueries.rs
Expand Up @@ -147,8 +147,8 @@ order by s_acctbal desc, n_name, s_name, p_partkey;"#;
Inner Join: #supplier.s_nationkey = #nation.n_nationkey
Inner Join: #partsupp.ps_suppkey = #supplier.s_suppkey
Inner Join: #part.p_partkey = #partsupp.ps_partkey
Filter: #part.p_size = Int32(15) AND #part.p_type LIKE Utf8("%BRASS")
TableScan: part projection=[p_partkey, p_mfgr, p_type, p_size], partial_filters=[#part.p_size = Int32(15), #part.p_type LIKE Utf8("%BRASS")]
Filter: CAST(#part.p_size AS Int64) = Int64(15) AND #part.p_type LIKE Utf8("%BRASS")
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to be honest, I am not sure why this has changed (aka the filters are no longer simplified). I will look into that in the morning

TableScan: part projection=[p_partkey, p_mfgr, p_type, p_size], partial_filters=[CAST(#part.p_size AS Int64) = Int64(15), #part.p_type LIKE Utf8("%BRASS")]
TableScan: partsupp projection=[ps_partkey, ps_suppkey, ps_supplycost]
TableScan: supplier projection=[s_suppkey, s_name, s_address, s_nationkey, s_phone, s_acctbal, s_comment]
TableScan: nation projection=[n_nationkey, n_name, n_regionkey]
Expand Down
16 changes: 10 additions & 6 deletions datafusion/optimizer/src/optimizer.rs
Expand Up @@ -178,16 +178,15 @@ impl Optimizer {
F: FnMut(&LogicalPlan, &dyn OptimizerRule),
{
let mut new_plan = plan.clone();
debug!("Input logical plan:\n{}\n", plan.display_indent());
trace!("Full input logical plan:\n{:?}", plan);
log_plan("Optimizer input", plan);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a driveby cleanup to improve logging (specifically, also add trace! to log schema)


for rule in &self.rules {
let result = rule.optimize(&new_plan, optimizer_config);
match result {
Ok(plan) => {
new_plan = plan;
observer(&new_plan, rule.as_ref());
debug!("After apply {} rule:\n", rule.name());
debug!("Optimized logical plan:\n{}\n", new_plan.display_indent());
log_plan(rule.name(), &new_plan);
}
Err(ref e) => {
if optimizer_config.skip_failing_rules {
Expand All @@ -209,12 +208,17 @@ impl Optimizer {
}
}
}
debug!("Optimized logical plan:\n{}\n", new_plan.display_indent());
trace!("Full Optimized logical plan:\n {:?}", new_plan);
log_plan("Optimized plan", &new_plan);
Ok(new_plan)
}
}

/// Log the plan in debug/tracing mode after some part of the optimizer runs
fn log_plan(description: &str, plan: &LogicalPlan) {
debug!("{description}:\n{}\n", plan.display_indent());
trace!("{description}::\n{}\n", plan.display_indent_schema());
}

#[cfg(test)]
mod tests {
use crate::optimizer::Optimizer;
Expand Down
13 changes: 12 additions & 1 deletion datafusion/optimizer/src/unwrap_cast_in_comparison.rs
Expand Up @@ -97,7 +97,18 @@ fn optimize(plan: &LogicalPlan) -> Result<LogicalPlan> {
let new_exprs = plan
.expressions()
.into_iter()
.map(|expr| expr.rewrite(&mut expr_rewriter))
.map(|expr| {
let original_name = expr.name()?;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the root cause issue is that the UnwrapCastInComparison can add potentially change the expression but it doesn't add an alias so the output name changes

Bigger picture there are at least three places that we have rediscovered this same problem when rewriting expressions --#3555 and https://github.com/apache/arrow-datafusion/blob/master/datafusion/optimizer/src/simplify_expressions.rs#L316 I will try and make a follow on PR to clean them all up. In particular, I think this is something from_plan could potentially handle

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

expr.name() here hides the casts that are added by this optimization rule, so expr.name() is the same as the original name (even though the expression is now different), and the alias does not get added.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this optimizer rule may actually change the name (e.g. from Int64(0) to Int32(0)) which i think is the root cause of the issue in this bug

let expr = expr.rewrite(&mut expr_rewriter)?;

// Ensure this rewrite doesn't change the name
// https://github.com/apache/arrow-datafusion/issues/3704
if expr.name()? != original_name {
Ok(expr.alias(&original_name))
} else {
Ok(expr)
}
})
.collect::<Result<Vec<_>>>()?;

from_plan(plan, new_exprs.as_slice(), new_inputs.as_slice())
Expand Down
17 changes: 17 additions & 0 deletions datafusion/optimizer/tests/integration-test.rs
Expand Up @@ -29,6 +29,12 @@ use std::any::Any;
use std::collections::HashMap;
use std::sync::Arc;

#[cfg(test)]
#[ctor::ctor]
fn init() {
let _ = env_logger::try_init();
}

#[test]
fn case_when() -> Result<()> {
let sql = "SELECT CASE WHEN col_int32 > 0 THEN 1 ELSE 0 END FROM test";
Expand All @@ -45,6 +51,17 @@ fn case_when() -> Result<()> {
Ok(())
}

#[test]
fn case_when_aggregate() -> Result<()> {
let sql = "SELECT col_utf8, SUM(CASE WHEN col_int32 > 0 THEN 1 ELSE 0 END) AS n FROM test GROUP BY col_utf8";
let plan = test_sql(sql)?;
let expected = "Projection: #test.col_utf8, #SUM(CASE WHEN test.col_int32 > Int64(0) THEN Int64(1) ELSE Int64(0) END) AS n\
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you merge latest from master - we should not include # before column names now

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will do

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in e3830a7

\n Aggregate: groupBy=[[#test.col_utf8]], aggr=[[SUM(CASE WHEN #test.col_int32 > Int32(0) THEN Int64(1) ELSE Int64(0) END) AS SUM(CASE WHEN test.col_int32 > Int64(0) THEN Int64(1) ELSE Int64(0) END)]]\
\n TableScan: test projection=[col_int32, col_utf8]";
assert_eq!(expected, format!("{:?}", plan));
Ok(())
}

#[test]
fn unsigned_target_type() -> Result<()> {
let sql = "SELECT * FROM test WHERE col_uint32 > 0";
Expand Down