Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Relax join keys constraint from Column to any physical expression for physical join operators #8991

Merged
merged 14 commits into from
Jan 29, 2024
291 changes: 170 additions & 121 deletions datafusion/core/src/physical_optimizer/enforce_distribution.rs

Large diffs are not rendered by default.

19 changes: 11 additions & 8 deletions datafusion/core/src/physical_optimizer/enforce_sorting.rs
Original file line number Diff line number Diff line change
Expand Up @@ -985,8 +985,8 @@ mod tests {
let right_input = parquet_exec_sorted(&right_schema, parquet_sort_exprs);

let on = vec![(
Column::new_with_schema("col_a", &left_schema)?,
Column::new_with_schema("c", &right_schema)?,
Arc::new(Column::new_with_schema("col_a", &left_schema)?) as _,
Arc::new(Column::new_with_schema("c", &right_schema)?) as _,
)];
let join = hash_join_exec(left_input, right_input, on, None, &JoinType::Inner)?;
let physical_plan = sort_exec(vec![sort_expr("a", &join.schema())], join);
Expand Down Expand Up @@ -1639,8 +1639,9 @@ mod tests {

// Join on (nullable_col == col_a)
let join_on = vec![(
Column::new_with_schema("nullable_col", &left.schema()).unwrap(),
Column::new_with_schema("col_a", &right.schema()).unwrap(),
Arc::new(Column::new_with_schema("nullable_col", &left.schema()).unwrap())
as _,
Arc::new(Column::new_with_schema("col_a", &right.schema()).unwrap()) as _,
)];

let join_types = vec![
Expand Down Expand Up @@ -1711,8 +1712,9 @@ mod tests {

// Join on (nullable_col == col_a)
let join_on = vec![(
Column::new_with_schema("nullable_col", &left.schema()).unwrap(),
Column::new_with_schema("col_a", &right.schema()).unwrap(),
Arc::new(Column::new_with_schema("nullable_col", &left.schema()).unwrap())
as _,
Arc::new(Column::new_with_schema("col_a", &right.schema()).unwrap()) as _,
)];

let join_types = vec![
Expand Down Expand Up @@ -1785,8 +1787,9 @@ mod tests {

// Join on (nullable_col == col_a)
let join_on = vec![(
Column::new_with_schema("nullable_col", &left.schema()).unwrap(),
Column::new_with_schema("col_a", &right.schema()).unwrap(),
Arc::new(Column::new_with_schema("nullable_col", &left.schema()).unwrap())
as _,
Arc::new(Column::new_with_schema("col_a", &right.schema()).unwrap()) as _,
)];

let join = sort_merge_join_exec(left, right, &join_on, &JoinType::Inner);
Expand Down
79 changes: 47 additions & 32 deletions datafusion/core/src/physical_optimizer/join_selection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -690,7 +690,7 @@ mod tests_statistical {
use arrow::datatypes::{DataType, Field, Schema};
use datafusion_common::{stats::Precision, JoinType, ScalarValue};
use datafusion_physical_expr::expressions::Column;
use datafusion_physical_expr::PhysicalExpr;
use datafusion_physical_expr::{PhysicalExpr, PhysicalExprRef};

/// Return statistcs for empty table
fn empty_statistics() -> Statistics {
Expand Down Expand Up @@ -860,8 +860,10 @@ mod tests_statistical {
Arc::clone(&big),
Arc::clone(&small),
vec![(
Column::new_with_schema("big_col", &big.schema()).unwrap(),
Column::new_with_schema("small_col", &small.schema()).unwrap(),
Arc::new(Column::new_with_schema("big_col", &big.schema()).unwrap()),
Arc::new(
Column::new_with_schema("small_col", &small.schema()).unwrap(),
),
)],
None,
&JoinType::Left,
Expand Down Expand Up @@ -914,8 +916,10 @@ mod tests_statistical {
Arc::clone(&small),
Arc::clone(&big),
vec![(
Column::new_with_schema("small_col", &small.schema()).unwrap(),
Column::new_with_schema("big_col", &big.schema()).unwrap(),
Arc::new(
Column::new_with_schema("small_col", &small.schema()).unwrap(),
),
Arc::new(Column::new_with_schema("big_col", &big.schema()).unwrap()),
)],
None,
&JoinType::Left,
Expand Down Expand Up @@ -970,8 +974,13 @@ mod tests_statistical {
Arc::clone(&big),
Arc::clone(&small),
vec![(
Column::new_with_schema("big_col", &big.schema()).unwrap(),
Column::new_with_schema("small_col", &small.schema()).unwrap(),
Arc::new(
Column::new_with_schema("big_col", &big.schema()).unwrap(),
),
Arc::new(
Column::new_with_schema("small_col", &small.schema())
.unwrap(),
),
)],
None,
&join_type,
Expand Down Expand Up @@ -1040,8 +1049,8 @@ mod tests_statistical {
Arc::clone(&big),
Arc::clone(&small),
vec![(
Column::new_with_schema("big_col", &big.schema()).unwrap(),
Column::new_with_schema("small_col", &small.schema()).unwrap(),
Arc::new(Column::new_with_schema("big_col", &big.schema()).unwrap()),
Arc::new(Column::new_with_schema("small_col", &small.schema()).unwrap()),
)],
None,
&JoinType::Inner,
Expand All @@ -1056,8 +1065,10 @@ mod tests_statistical {
Arc::clone(&medium),
Arc::new(child_join),
vec![(
Column::new_with_schema("medium_col", &medium.schema()).unwrap(),
Column::new_with_schema("small_col", &child_schema).unwrap(),
Arc::new(
Column::new_with_schema("medium_col", &medium.schema()).unwrap(),
),
Arc::new(Column::new_with_schema("small_col", &child_schema).unwrap()),
)],
None,
&JoinType::Left,
Expand Down Expand Up @@ -1094,8 +1105,10 @@ mod tests_statistical {
Arc::clone(&small),
Arc::clone(&big),
vec![(
Column::new_with_schema("small_col", &small.schema()).unwrap(),
Column::new_with_schema("big_col", &big.schema()).unwrap(),
Arc::new(
Column::new_with_schema("small_col", &small.schema()).unwrap(),
),
Arc::new(Column::new_with_schema("big_col", &big.schema()).unwrap()),
)],
None,
&JoinType::Inner,
Expand Down Expand Up @@ -1178,8 +1191,8 @@ mod tests_statistical {
));

let join_on = vec![(
Column::new_with_schema("small_col", &small.schema()).unwrap(),
Column::new_with_schema("big_col", &big.schema()).unwrap(),
Arc::new(Column::new_with_schema("small_col", &small.schema()).unwrap()) as _,
Arc::new(Column::new_with_schema("big_col", &big.schema()).unwrap()) as _,
)];
check_join_partition_mode(
small.clone(),
Expand All @@ -1190,8 +1203,8 @@ mod tests_statistical {
);

let join_on = vec![(
Column::new_with_schema("big_col", &big.schema()).unwrap(),
Column::new_with_schema("small_col", &small.schema()).unwrap(),
Arc::new(Column::new_with_schema("big_col", &big.schema()).unwrap()) as _,
Arc::new(Column::new_with_schema("small_col", &small.schema()).unwrap()) as _,
)];
check_join_partition_mode(
big.clone(),
Expand All @@ -1202,8 +1215,8 @@ mod tests_statistical {
);

let join_on = vec![(
Column::new_with_schema("small_col", &small.schema()).unwrap(),
Column::new_with_schema("empty_col", &empty.schema()).unwrap(),
Arc::new(Column::new_with_schema("small_col", &small.schema()).unwrap()) as _,
Arc::new(Column::new_with_schema("empty_col", &empty.schema()).unwrap()) as _,
)];
check_join_partition_mode(
small.clone(),
Expand All @@ -1214,8 +1227,8 @@ mod tests_statistical {
);

let join_on = vec![(
Column::new_with_schema("empty_col", &empty.schema()).unwrap(),
Column::new_with_schema("small_col", &small.schema()).unwrap(),
Arc::new(Column::new_with_schema("empty_col", &empty.schema()).unwrap()) as _,
Arc::new(Column::new_with_schema("small_col", &small.schema()).unwrap()) as _,
)];
check_join_partition_mode(
empty.clone(),
Expand Down Expand Up @@ -1244,8 +1257,9 @@ mod tests_statistical {
));

let join_on = vec![(
Column::new_with_schema("big_col", &big.schema()).unwrap(),
Column::new_with_schema("bigger_col", &bigger.schema()).unwrap(),
Arc::new(Column::new_with_schema("big_col", &big.schema()).unwrap()) as _,
Arc::new(Column::new_with_schema("bigger_col", &bigger.schema()).unwrap())
as _,
)];
check_join_partition_mode(
big.clone(),
Expand All @@ -1256,8 +1270,9 @@ mod tests_statistical {
);

let join_on = vec![(
Column::new_with_schema("bigger_col", &bigger.schema()).unwrap(),
Column::new_with_schema("big_col", &big.schema()).unwrap(),
Arc::new(Column::new_with_schema("bigger_col", &bigger.schema()).unwrap())
as _,
Arc::new(Column::new_with_schema("big_col", &big.schema()).unwrap()) as _,
)];
check_join_partition_mode(
bigger.clone(),
Expand All @@ -1268,8 +1283,8 @@ mod tests_statistical {
);

let join_on = vec![(
Column::new_with_schema("empty_col", &empty.schema()).unwrap(),
Column::new_with_schema("big_col", &big.schema()).unwrap(),
Arc::new(Column::new_with_schema("empty_col", &empty.schema()).unwrap()) as _,
Arc::new(Column::new_with_schema("big_col", &big.schema()).unwrap()) as _,
)];
check_join_partition_mode(
empty.clone(),
Expand All @@ -1280,16 +1295,16 @@ mod tests_statistical {
);

let join_on = vec![(
Column::new_with_schema("big_col", &big.schema()).unwrap(),
Column::new_with_schema("empty_col", &empty.schema()).unwrap(),
Arc::new(Column::new_with_schema("big_col", &big.schema()).unwrap()) as _,
Arc::new(Column::new_with_schema("empty_col", &empty.schema()).unwrap()) as _,
)];
check_join_partition_mode(big, empty, join_on, false, PartitionMode::Partitioned);
}

fn check_join_partition_mode(
left: Arc<StatisticsExec>,
right: Arc<StatisticsExec>,
on: Vec<(Column, Column)>,
on: Vec<(PhysicalExprRef, PhysicalExprRef)>,
is_swapped: bool,
expected_mode: PartitionMode,
) {
Expand Down Expand Up @@ -1748,8 +1763,8 @@ mod hash_join_tests {
Arc::clone(&left_exec),
Arc::clone(&right_exec),
vec![(
Column::new_with_schema("a", &left_exec.schema())?,
Column::new_with_schema("b", &right_exec.schema())?,
Arc::new(Column::new_with_schema("a", &left_exec.schema())?),
Arc::new(Column::new_with_schema("b", &right_exec.schema())?),
)],
None,
&t.initial_join_type,
Expand Down
49 changes: 37 additions & 12 deletions datafusion/core/src/physical_optimizer/projection_pushdown.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,11 @@ use crate::physical_plan::{Distribution, ExecutionPlan};
use arrow_schema::SchemaRef;
use datafusion_common::config::ConfigOptions;
use datafusion_common::tree_node::{Transformed, TreeNode, VisitRecursion};
use datafusion_common::JoinSide;
use datafusion_common::{DataFusionError, JoinSide};
use datafusion_physical_expr::expressions::{Column, Literal};
use datafusion_physical_expr::{
Partitioning, PhysicalExpr, PhysicalSortExpr, PhysicalSortRequirement,
Partitioning, PhysicalExpr, PhysicalExprRef, PhysicalSortExpr,
PhysicalSortRequirement,
};
use datafusion_physical_plan::streaming::StreamingTableExec;
use datafusion_physical_plan::union::UnionExec;
Expand Down Expand Up @@ -1000,8 +1001,8 @@ fn join_table_borders(
fn update_join_on(
proj_left_exprs: &[(Column, String)],
proj_right_exprs: &[(Column, String)],
hash_join_on: &[(Column, Column)],
) -> Option<Vec<(Column, Column)>> {
hash_join_on: &[(PhysicalExprRef, PhysicalExprRef)],
) -> Option<Vec<(PhysicalExprRef, PhysicalExprRef)>> {
// TODO: Clippy wants the "map" call removed, but doing so generates
// a compilation error. Remove the clippy directive once this
// issue is fixed.
Expand All @@ -1024,17 +1025,41 @@ fn update_join_on(
/// operation based on a set of equi-join conditions (`hash_join_on`) and a
/// list of projection expressions (`projection_exprs`).
fn new_columns_for_join_on(
hash_join_on: &[&Column],
hash_join_on: &[&PhysicalExprRef],
projection_exprs: &[(Column, String)],
) -> Option<Vec<Column>> {
) -> Option<Vec<PhysicalExprRef>> {
let new_columns = hash_join_on
.iter()
.filter_map(|on| {
projection_exprs
.iter()
.enumerate()
.find(|(_, (proj_column, _))| on.name() == proj_column.name())
.map(|(index, (_, alias))| Column::new(alias, index))
// Rewrite all columns in `on`
(*on)
.clone()
.transform(&|expr| {
if let Some(column) = expr.as_any().downcast_ref::<Column>() {
// Find the column in the projection expressions
let new_column = projection_exprs
.iter()
.enumerate()
.find(|(_, (proj_column, _))| {
column.name() == proj_column.name()
})
.map(|(index, (_, alias))| Column::new(alias, index));
if let Some(new_column) = new_column {
Ok(Transformed::Yes(Arc::new(new_column)))
} else {
// If the column is not found in the projection expressions,
// it means that the column is not projected. In this case,
// we cannot push the projection down.
Err(DataFusionError::Internal(format!(
"Column {:?} not found in projection expressions",
column
)))
}
} else {
Ok(Transformed::No(expr))
}
})
.ok()
})
.collect::<Vec<_>>();
(new_columns.len() == hash_join_on.len()).then_some(new_columns)
Expand Down Expand Up @@ -2018,7 +2043,7 @@ mod tests {
let join: Arc<dyn ExecutionPlan> = Arc::new(SymmetricHashJoinExec::try_new(
left_csv,
right_csv,
vec![(Column::new("b", 1), Column::new("c", 2))],
vec![(Arc::new(Column::new("b", 1)), Arc::new(Column::new("c", 2)))],
// b_left-(1+a_right)<=a_right+c_left
Some(JoinFilter::new(
Arc::new(BinaryExpr::new(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1440,7 +1440,7 @@ mod tests {
HashJoinExec::try_new(
left,
right,
vec![(left_col.clone(), right_col.clone())],
vec![(Arc::new(left_col.clone()), Arc::new(right_col.clone()))],
None,
&JoinType::Inner,
PartitionMode::Partitioned,
Expand Down
18 changes: 12 additions & 6 deletions datafusion/core/src/physical_planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1036,15 +1036,21 @@ impl DefaultPhysicalPlanner {
let [physical_left, physical_right]: [Arc<dyn ExecutionPlan>; 2] = left_right.try_into().map_err(|_| DataFusionError::Internal("`create_initial_plan_multi` is broken".to_string()))?;
let left_df_schema = left.schema();
let right_df_schema = right.schema();
let execution_props = session_state.execution_props();
let join_on = keys
.iter()
.map(|(l, r)| {
let l = l.try_into_col()?;
let r = r.try_into_col()?;
Ok((
Column::new(&l.name, left_df_schema.index_of_column(&l)?),
Column::new(&r.name, right_df_schema.index_of_column(&r)?),
))
let l = create_physical_expr(
l,
left_df_schema,
execution_props
)?;
let r = create_physical_expr(
r,
right_df_schema,
execution_props
)?;
Ok((l, r))
})
.collect::<Result<join_utils::JoinOn>>()?;

Expand Down
8 changes: 4 additions & 4 deletions datafusion/core/tests/fuzz_cases/join_fuzz.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,12 +109,12 @@ async fn run_join_test(
let schema2 = input2[0].schema();
let on_columns = vec![
(
Column::new_with_schema("a", &schema1).unwrap(),
Column::new_with_schema("a", &schema2).unwrap(),
Arc::new(Column::new_with_schema("a", &schema1).unwrap()) as _,
Arc::new(Column::new_with_schema("a", &schema2).unwrap()) as _,
),
(
Column::new_with_schema("b", &schema1).unwrap(),
Column::new_with_schema("b", &schema2).unwrap(),
Arc::new(Column::new_with_schema("b", &schema1).unwrap()) as _,
Arc::new(Column::new_with_schema("b", &schema2).unwrap()) as _,
),
];

Expand Down