Skip to content

Commit

Permalink
Relax join keys constraint from Column to any physical expression for…
Browse files Browse the repository at this point in the history
… physical join operators (#8991)

* Relex SortMergeJoin join keys

* More

* More

* More

* More

* Fix clippy

* Fix more clippy

* More

* More

* Fix

* Fix

* Use collect_columns

---------

Co-authored-by: Andrew Lamb <andrew@nerdnetworks.org>
  • Loading branch information
viirya and alamb committed Jan 29, 2024
1 parent 92104a5 commit d594e62
Show file tree
Hide file tree
Showing 18 changed files with 691 additions and 511 deletions.
291 changes: 170 additions & 121 deletions datafusion/core/src/physical_optimizer/enforce_distribution.rs

Large diffs are not rendered by default.

19 changes: 11 additions & 8 deletions datafusion/core/src/physical_optimizer/enforce_sorting.rs
Original file line number Diff line number Diff line change
Expand Up @@ -985,8 +985,8 @@ mod tests {
let right_input = parquet_exec_sorted(&right_schema, parquet_sort_exprs);

let on = vec![(
Column::new_with_schema("col_a", &left_schema)?,
Column::new_with_schema("c", &right_schema)?,
Arc::new(Column::new_with_schema("col_a", &left_schema)?) as _,
Arc::new(Column::new_with_schema("c", &right_schema)?) as _,
)];
let join = hash_join_exec(left_input, right_input, on, None, &JoinType::Inner)?;
let physical_plan = sort_exec(vec![sort_expr("a", &join.schema())], join);
Expand Down Expand Up @@ -1639,8 +1639,9 @@ mod tests {

// Join on (nullable_col == col_a)
let join_on = vec![(
Column::new_with_schema("nullable_col", &left.schema()).unwrap(),
Column::new_with_schema("col_a", &right.schema()).unwrap(),
Arc::new(Column::new_with_schema("nullable_col", &left.schema()).unwrap())
as _,
Arc::new(Column::new_with_schema("col_a", &right.schema()).unwrap()) as _,
)];

let join_types = vec![
Expand Down Expand Up @@ -1711,8 +1712,9 @@ mod tests {

// Join on (nullable_col == col_a)
let join_on = vec![(
Column::new_with_schema("nullable_col", &left.schema()).unwrap(),
Column::new_with_schema("col_a", &right.schema()).unwrap(),
Arc::new(Column::new_with_schema("nullable_col", &left.schema()).unwrap())
as _,
Arc::new(Column::new_with_schema("col_a", &right.schema()).unwrap()) as _,
)];

let join_types = vec![
Expand Down Expand Up @@ -1785,8 +1787,9 @@ mod tests {

// Join on (nullable_col == col_a)
let join_on = vec![(
Column::new_with_schema("nullable_col", &left.schema()).unwrap(),
Column::new_with_schema("col_a", &right.schema()).unwrap(),
Arc::new(Column::new_with_schema("nullable_col", &left.schema()).unwrap())
as _,
Arc::new(Column::new_with_schema("col_a", &right.schema()).unwrap()) as _,
)];

let join = sort_merge_join_exec(left, right, &join_on, &JoinType::Inner);
Expand Down
79 changes: 47 additions & 32 deletions datafusion/core/src/physical_optimizer/join_selection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -690,7 +690,7 @@ mod tests_statistical {
use arrow::datatypes::{DataType, Field, Schema};
use datafusion_common::{stats::Precision, JoinType, ScalarValue};
use datafusion_physical_expr::expressions::Column;
use datafusion_physical_expr::PhysicalExpr;
use datafusion_physical_expr::{PhysicalExpr, PhysicalExprRef};

/// Return statistcs for empty table
fn empty_statistics() -> Statistics {
Expand Down Expand Up @@ -860,8 +860,10 @@ mod tests_statistical {
Arc::clone(&big),
Arc::clone(&small),
vec![(
Column::new_with_schema("big_col", &big.schema()).unwrap(),
Column::new_with_schema("small_col", &small.schema()).unwrap(),
Arc::new(Column::new_with_schema("big_col", &big.schema()).unwrap()),
Arc::new(
Column::new_with_schema("small_col", &small.schema()).unwrap(),
),
)],
None,
&JoinType::Left,
Expand Down Expand Up @@ -914,8 +916,10 @@ mod tests_statistical {
Arc::clone(&small),
Arc::clone(&big),
vec![(
Column::new_with_schema("small_col", &small.schema()).unwrap(),
Column::new_with_schema("big_col", &big.schema()).unwrap(),
Arc::new(
Column::new_with_schema("small_col", &small.schema()).unwrap(),
),
Arc::new(Column::new_with_schema("big_col", &big.schema()).unwrap()),
)],
None,
&JoinType::Left,
Expand Down Expand Up @@ -970,8 +974,13 @@ mod tests_statistical {
Arc::clone(&big),
Arc::clone(&small),
vec![(
Column::new_with_schema("big_col", &big.schema()).unwrap(),
Column::new_with_schema("small_col", &small.schema()).unwrap(),
Arc::new(
Column::new_with_schema("big_col", &big.schema()).unwrap(),
),
Arc::new(
Column::new_with_schema("small_col", &small.schema())
.unwrap(),
),
)],
None,
&join_type,
Expand Down Expand Up @@ -1040,8 +1049,8 @@ mod tests_statistical {
Arc::clone(&big),
Arc::clone(&small),
vec![(
Column::new_with_schema("big_col", &big.schema()).unwrap(),
Column::new_with_schema("small_col", &small.schema()).unwrap(),
Arc::new(Column::new_with_schema("big_col", &big.schema()).unwrap()),
Arc::new(Column::new_with_schema("small_col", &small.schema()).unwrap()),
)],
None,
&JoinType::Inner,
Expand All @@ -1056,8 +1065,10 @@ mod tests_statistical {
Arc::clone(&medium),
Arc::new(child_join),
vec![(
Column::new_with_schema("medium_col", &medium.schema()).unwrap(),
Column::new_with_schema("small_col", &child_schema).unwrap(),
Arc::new(
Column::new_with_schema("medium_col", &medium.schema()).unwrap(),
),
Arc::new(Column::new_with_schema("small_col", &child_schema).unwrap()),
)],
None,
&JoinType::Left,
Expand Down Expand Up @@ -1094,8 +1105,10 @@ mod tests_statistical {
Arc::clone(&small),
Arc::clone(&big),
vec![(
Column::new_with_schema("small_col", &small.schema()).unwrap(),
Column::new_with_schema("big_col", &big.schema()).unwrap(),
Arc::new(
Column::new_with_schema("small_col", &small.schema()).unwrap(),
),
Arc::new(Column::new_with_schema("big_col", &big.schema()).unwrap()),
)],
None,
&JoinType::Inner,
Expand Down Expand Up @@ -1178,8 +1191,8 @@ mod tests_statistical {
));

let join_on = vec![(
Column::new_with_schema("small_col", &small.schema()).unwrap(),
Column::new_with_schema("big_col", &big.schema()).unwrap(),
Arc::new(Column::new_with_schema("small_col", &small.schema()).unwrap()) as _,
Arc::new(Column::new_with_schema("big_col", &big.schema()).unwrap()) as _,
)];
check_join_partition_mode(
small.clone(),
Expand All @@ -1190,8 +1203,8 @@ mod tests_statistical {
);

let join_on = vec![(
Column::new_with_schema("big_col", &big.schema()).unwrap(),
Column::new_with_schema("small_col", &small.schema()).unwrap(),
Arc::new(Column::new_with_schema("big_col", &big.schema()).unwrap()) as _,
Arc::new(Column::new_with_schema("small_col", &small.schema()).unwrap()) as _,
)];
check_join_partition_mode(
big.clone(),
Expand All @@ -1202,8 +1215,8 @@ mod tests_statistical {
);

let join_on = vec![(
Column::new_with_schema("small_col", &small.schema()).unwrap(),
Column::new_with_schema("empty_col", &empty.schema()).unwrap(),
Arc::new(Column::new_with_schema("small_col", &small.schema()).unwrap()) as _,
Arc::new(Column::new_with_schema("empty_col", &empty.schema()).unwrap()) as _,
)];
check_join_partition_mode(
small.clone(),
Expand All @@ -1214,8 +1227,8 @@ mod tests_statistical {
);

let join_on = vec![(
Column::new_with_schema("empty_col", &empty.schema()).unwrap(),
Column::new_with_schema("small_col", &small.schema()).unwrap(),
Arc::new(Column::new_with_schema("empty_col", &empty.schema()).unwrap()) as _,
Arc::new(Column::new_with_schema("small_col", &small.schema()).unwrap()) as _,
)];
check_join_partition_mode(
empty.clone(),
Expand Down Expand Up @@ -1244,8 +1257,9 @@ mod tests_statistical {
));

let join_on = vec![(
Column::new_with_schema("big_col", &big.schema()).unwrap(),
Column::new_with_schema("bigger_col", &bigger.schema()).unwrap(),
Arc::new(Column::new_with_schema("big_col", &big.schema()).unwrap()) as _,
Arc::new(Column::new_with_schema("bigger_col", &bigger.schema()).unwrap())
as _,
)];
check_join_partition_mode(
big.clone(),
Expand All @@ -1256,8 +1270,9 @@ mod tests_statistical {
);

let join_on = vec![(
Column::new_with_schema("bigger_col", &bigger.schema()).unwrap(),
Column::new_with_schema("big_col", &big.schema()).unwrap(),
Arc::new(Column::new_with_schema("bigger_col", &bigger.schema()).unwrap())
as _,
Arc::new(Column::new_with_schema("big_col", &big.schema()).unwrap()) as _,
)];
check_join_partition_mode(
bigger.clone(),
Expand All @@ -1268,8 +1283,8 @@ mod tests_statistical {
);

let join_on = vec![(
Column::new_with_schema("empty_col", &empty.schema()).unwrap(),
Column::new_with_schema("big_col", &big.schema()).unwrap(),
Arc::new(Column::new_with_schema("empty_col", &empty.schema()).unwrap()) as _,
Arc::new(Column::new_with_schema("big_col", &big.schema()).unwrap()) as _,
)];
check_join_partition_mode(
empty.clone(),
Expand All @@ -1280,16 +1295,16 @@ mod tests_statistical {
);

let join_on = vec![(
Column::new_with_schema("big_col", &big.schema()).unwrap(),
Column::new_with_schema("empty_col", &empty.schema()).unwrap(),
Arc::new(Column::new_with_schema("big_col", &big.schema()).unwrap()) as _,
Arc::new(Column::new_with_schema("empty_col", &empty.schema()).unwrap()) as _,
)];
check_join_partition_mode(big, empty, join_on, false, PartitionMode::Partitioned);
}

fn check_join_partition_mode(
left: Arc<StatisticsExec>,
right: Arc<StatisticsExec>,
on: Vec<(Column, Column)>,
on: Vec<(PhysicalExprRef, PhysicalExprRef)>,
is_swapped: bool,
expected_mode: PartitionMode,
) {
Expand Down Expand Up @@ -1748,8 +1763,8 @@ mod hash_join_tests {
Arc::clone(&left_exec),
Arc::clone(&right_exec),
vec![(
Column::new_with_schema("a", &left_exec.schema())?,
Column::new_with_schema("b", &right_exec.schema())?,
Arc::new(Column::new_with_schema("a", &left_exec.schema())?),
Arc::new(Column::new_with_schema("b", &right_exec.schema())?),
)],
None,
&t.initial_join_type,
Expand Down
49 changes: 37 additions & 12 deletions datafusion/core/src/physical_optimizer/projection_pushdown.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,11 @@ use crate::physical_plan::{Distribution, ExecutionPlan};
use arrow_schema::SchemaRef;
use datafusion_common::config::ConfigOptions;
use datafusion_common::tree_node::{Transformed, TreeNode, VisitRecursion};
use datafusion_common::JoinSide;
use datafusion_common::{DataFusionError, JoinSide};
use datafusion_physical_expr::expressions::{Column, Literal};
use datafusion_physical_expr::{
Partitioning, PhysicalExpr, PhysicalSortExpr, PhysicalSortRequirement,
Partitioning, PhysicalExpr, PhysicalExprRef, PhysicalSortExpr,
PhysicalSortRequirement,
};
use datafusion_physical_plan::streaming::StreamingTableExec;
use datafusion_physical_plan::union::UnionExec;
Expand Down Expand Up @@ -1000,8 +1001,8 @@ fn join_table_borders(
fn update_join_on(
proj_left_exprs: &[(Column, String)],
proj_right_exprs: &[(Column, String)],
hash_join_on: &[(Column, Column)],
) -> Option<Vec<(Column, Column)>> {
hash_join_on: &[(PhysicalExprRef, PhysicalExprRef)],
) -> Option<Vec<(PhysicalExprRef, PhysicalExprRef)>> {
// TODO: Clippy wants the "map" call removed, but doing so generates
// a compilation error. Remove the clippy directive once this
// issue is fixed.
Expand All @@ -1024,17 +1025,41 @@ fn update_join_on(
/// operation based on a set of equi-join conditions (`hash_join_on`) and a
/// list of projection expressions (`projection_exprs`).
fn new_columns_for_join_on(
hash_join_on: &[&Column],
hash_join_on: &[&PhysicalExprRef],
projection_exprs: &[(Column, String)],
) -> Option<Vec<Column>> {
) -> Option<Vec<PhysicalExprRef>> {
let new_columns = hash_join_on
.iter()
.filter_map(|on| {
projection_exprs
.iter()
.enumerate()
.find(|(_, (proj_column, _))| on.name() == proj_column.name())
.map(|(index, (_, alias))| Column::new(alias, index))
// Rewrite all columns in `on`
(*on)
.clone()
.transform(&|expr| {
if let Some(column) = expr.as_any().downcast_ref::<Column>() {
// Find the column in the projection expressions
let new_column = projection_exprs
.iter()
.enumerate()
.find(|(_, (proj_column, _))| {
column.name() == proj_column.name()
})
.map(|(index, (_, alias))| Column::new(alias, index));
if let Some(new_column) = new_column {
Ok(Transformed::Yes(Arc::new(new_column)))
} else {
// If the column is not found in the projection expressions,
// it means that the column is not projected. In this case,
// we cannot push the projection down.
Err(DataFusionError::Internal(format!(
"Column {:?} not found in projection expressions",
column
)))
}
} else {
Ok(Transformed::No(expr))
}
})
.ok()
})
.collect::<Vec<_>>();
(new_columns.len() == hash_join_on.len()).then_some(new_columns)
Expand Down Expand Up @@ -2018,7 +2043,7 @@ mod tests {
let join: Arc<dyn ExecutionPlan> = Arc::new(SymmetricHashJoinExec::try_new(
left_csv,
right_csv,
vec![(Column::new("b", 1), Column::new("c", 2))],
vec![(Arc::new(Column::new("b", 1)), Arc::new(Column::new("c", 2)))],
// b_left-(1+a_right)<=a_right+c_left
Some(JoinFilter::new(
Arc::new(BinaryExpr::new(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1440,7 +1440,7 @@ mod tests {
HashJoinExec::try_new(
left,
right,
vec![(left_col.clone(), right_col.clone())],
vec![(Arc::new(left_col.clone()), Arc::new(right_col.clone()))],
None,
&JoinType::Inner,
PartitionMode::Partitioned,
Expand Down
18 changes: 12 additions & 6 deletions datafusion/core/src/physical_planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1036,15 +1036,21 @@ impl DefaultPhysicalPlanner {
let [physical_left, physical_right]: [Arc<dyn ExecutionPlan>; 2] = left_right.try_into().map_err(|_| DataFusionError::Internal("`create_initial_plan_multi` is broken".to_string()))?;
let left_df_schema = left.schema();
let right_df_schema = right.schema();
let execution_props = session_state.execution_props();
let join_on = keys
.iter()
.map(|(l, r)| {
let l = l.try_into_col()?;
let r = r.try_into_col()?;
Ok((
Column::new(&l.name, left_df_schema.index_of_column(&l)?),
Column::new(&r.name, right_df_schema.index_of_column(&r)?),
))
let l = create_physical_expr(
l,
left_df_schema,
execution_props
)?;
let r = create_physical_expr(
r,
right_df_schema,
execution_props
)?;
Ok((l, r))
})
.collect::<Result<join_utils::JoinOn>>()?;

Expand Down
8 changes: 4 additions & 4 deletions datafusion/core/tests/fuzz_cases/join_fuzz.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,12 +109,12 @@ async fn run_join_test(
let schema2 = input2[0].schema();
let on_columns = vec![
(
Column::new_with_schema("a", &schema1).unwrap(),
Column::new_with_schema("a", &schema2).unwrap(),
Arc::new(Column::new_with_schema("a", &schema1).unwrap()) as _,
Arc::new(Column::new_with_schema("a", &schema2).unwrap()) as _,
),
(
Column::new_with_schema("b", &schema1).unwrap(),
Column::new_with_schema("b", &schema2).unwrap(),
Arc::new(Column::new_with_schema("b", &schema1).unwrap()) as _,
Arc::new(Column::new_with_schema("b", &schema2).unwrap()) as _,
),
];

Expand Down

0 comments on commit d594e62

Please sign in to comment.