Skip to content

Commit

Permalink
Push down limit to sort (apache#3530)
Browse files Browse the repository at this point in the history
* Push down limit to sort

Support skip, fix test

Fmt

Add limit directly after sort

Update comment

Simplify parallel sort by using new pushdown

Clippy

* Update datafusion/core/src/physical_plan/sorts/sort.rs

Co-authored-by: Andrew Lamb <andrew@nerdnetworks.org>

Co-authored-by: Andrew Lamb <andrew@nerdnetworks.org>
  • Loading branch information
Dandandan and alamb committed Sep 20, 2022
1 parent c7f3a70 commit 81b5794
Show file tree
Hide file tree
Showing 15 changed files with 132 additions and 30 deletions.
2 changes: 1 addition & 1 deletion datafusion/core/src/dataframe.rs
Expand Up @@ -1296,7 +1296,7 @@ mod tests {
assert_eq!("\
Projection: #t1.c1 AS AAA, #t1.c2, #t1.c3, #t2.c1, #t2.c2, #t2.c3\
\n Limit: skip=0, fetch=1\
\n Sort: #t1.c1 ASC NULLS FIRST, #t1.c2 ASC NULLS FIRST, #t1.c3 ASC NULLS FIRST, #t2.c1 ASC NULLS FIRST, #t2.c2 ASC NULLS FIRST, #t2.c3 ASC NULLS FIRST\
\n Sort: #t1.c1 ASC NULLS FIRST, #t1.c2 ASC NULLS FIRST, #t1.c3 ASC NULLS FIRST, #t2.c1 ASC NULLS FIRST, #t2.c2 ASC NULLS FIRST, #t2.c3 ASC NULLS FIRST, fetch=1\
\n Inner Join: #t1.c1 = #t2.c1\
\n TableScan: t1 projection=[c1, c2, c3]\
\n TableScan: t2 projection=[c1, c2, c3]",
Expand Down
23 changes: 12 additions & 11 deletions datafusion/core/src/physical_optimizer/parallel_sort.rs
Expand Up @@ -20,7 +20,6 @@ use crate::{
error::Result,
physical_optimizer::PhysicalOptimizerRule,
physical_plan::{
limit::GlobalLimitExec,
sorts::{sort::SortExec, sort_preserving_merge::SortPreservingMergeExec},
with_new_children_if_necessary,
},
Expand Down Expand Up @@ -55,31 +54,33 @@ impl PhysicalOptimizerRule for ParallelSort {
.map(|child| self.optimize(child.clone(), config))
.collect::<Result<Vec<_>>>()?;
let plan = with_new_children_if_necessary(plan, children)?;
let children = plan.children();
let plan_any = plan.as_any();
// GlobalLimitExec (SortExec preserve_partitioning=False)
// -> GlobalLimitExec (SortExec preserve_partitioning=True)
let parallel_sort = plan_any.downcast_ref::<GlobalLimitExec>().is_some()
&& children.len() == 1
&& children[0].as_any().downcast_ref::<SortExec>().is_some()
&& !children[0]
.as_any()
// SortExec preserve_partitioning=False, fetch=Some(n))
// -> SortPreservingMergeExec (SortExec preserve_partitioning=True, fetch=Some(n))
let parallel_sort = plan_any.downcast_ref::<SortExec>().is_some()
&& plan_any
.downcast_ref::<SortExec>()
.unwrap()
.fetch()
.is_some()
&& !plan_any
.downcast_ref::<SortExec>()
.unwrap()
.preserve_partitioning();

Ok(if parallel_sort {
let sort = children[0].as_any().downcast_ref::<SortExec>().unwrap();
let sort = plan_any.downcast_ref::<SortExec>().unwrap();
let new_sort = SortExec::new_with_partitioning(
sort.expr().to_vec(),
sort.input().clone(),
true,
sort.fetch(),
);
let merge = SortPreservingMergeExec::new(
sort.expr().to_vec(),
Arc::new(new_sort),
);
with_new_children_if_necessary(plan, vec![Arc::new(merge)])?
Arc::new(merge)
} else {
plan.clone()
})
Expand Down
2 changes: 1 addition & 1 deletion datafusion/core/src/physical_optimizer/repartition.rs
Expand Up @@ -295,7 +295,7 @@ mod tests {
expr: col("c1", &schema()).unwrap(),
options: SortOptions::default(),
}];
Arc::new(SortExec::try_new(sort_exprs, input).unwrap())
Arc::new(SortExec::try_new(sort_exprs, input, None).unwrap())
}

fn projection_exec(input: Arc<dyn ExecutionPlan>) -> Arc<dyn ExecutionPlan> {
Expand Down
8 changes: 4 additions & 4 deletions datafusion/core/src/physical_plan/planner.rs
Expand Up @@ -590,9 +590,9 @@ impl DefaultPhysicalPlanner {
})
.collect::<Result<Vec<_>>>()?;
Arc::new(if can_repartition {
SortExec::new_with_partitioning(sort_keys, input_exec, true)
SortExec::new_with_partitioning(sort_keys, input_exec, true, None)
} else {
SortExec::try_new(sort_keys, input_exec)?
SortExec::try_new(sort_keys, input_exec, None)?
})
};

Expand Down Expand Up @@ -815,7 +815,7 @@ impl DefaultPhysicalPlanner {
physical_partitioning,
)?) )
}
LogicalPlan::Sort(Sort { expr, input, .. }) => {
LogicalPlan::Sort(Sort { expr, input, fetch, .. }) => {
let physical_input = self.create_initial_plan(input, session_state).await?;
let input_schema = physical_input.as_ref().schema();
let input_dfschema = input.as_ref().schema();
Expand All @@ -841,7 +841,7 @@ impl DefaultPhysicalPlanner {
)),
})
.collect::<Result<Vec<_>>>()?;
Ok(Arc::new(SortExec::try_new(sort_expr, physical_input)?))
Ok(Arc::new(SortExec::try_new(sort_expr, physical_input, *fetch)?))
}
LogicalPlan::Join(Join {
left,
Expand Down
30 changes: 26 additions & 4 deletions datafusion/core/src/physical_plan/sorts/sort.rs
Expand Up @@ -82,6 +82,7 @@ struct ExternalSorter {
runtime: Arc<RuntimeEnv>,
metrics_set: CompositeMetricsSet,
metrics: BaselineMetrics,
fetch: Option<usize>,
}

impl ExternalSorter {
Expand All @@ -92,6 +93,7 @@ impl ExternalSorter {
metrics_set: CompositeMetricsSet,
session_config: Arc<SessionConfig>,
runtime: Arc<RuntimeEnv>,
fetch: Option<usize>,
) -> Self {
let metrics = metrics_set.new_intermediate_baseline(partition_id);
Self {
Expand All @@ -104,6 +106,7 @@ impl ExternalSorter {
runtime,
metrics_set,
metrics,
fetch,
}
}

Expand All @@ -120,7 +123,7 @@ impl ExternalSorter {
// NB timer records time taken on drop, so there are no
// calls to `timer.done()` below.
let _timer = tracking_metrics.elapsed_compute().timer();
let partial = sort_batch(input, self.schema.clone(), &self.expr)?;
let partial = sort_batch(input, self.schema.clone(), &self.expr, self.fetch)?;
in_mem_batches.push(partial);
}
Ok(())
Expand Down Expand Up @@ -657,15 +660,18 @@ pub struct SortExec {
metrics_set: CompositeMetricsSet,
/// Preserve partitions of input plan
preserve_partitioning: bool,
/// Fetch highest/lowest n results
fetch: Option<usize>,
}

impl SortExec {
/// Create a new sort execution plan
pub fn try_new(
expr: Vec<PhysicalSortExpr>,
input: Arc<dyn ExecutionPlan>,
fetch: Option<usize>,
) -> Result<Self> {
Ok(Self::new_with_partitioning(expr, input, false))
Ok(Self::new_with_partitioning(expr, input, false, fetch))
}

/// Whether this `SortExec` preserves partitioning of the children
Expand All @@ -679,12 +685,14 @@ impl SortExec {
expr: Vec<PhysicalSortExpr>,
input: Arc<dyn ExecutionPlan>,
preserve_partitioning: bool,
fetch: Option<usize>,
) -> Self {
Self {
expr,
input,
metrics_set: CompositeMetricsSet::new(),
preserve_partitioning,
fetch,
}
}

Expand All @@ -697,6 +705,11 @@ impl SortExec {
pub fn expr(&self) -> &[PhysicalSortExpr] {
&self.expr
}

/// If `Some(fetch)`, limits output to only the first "fetch" items
pub fn fetch(&self) -> Option<usize> {
self.fetch
}
}

impl ExecutionPlan for SortExec {
Expand Down Expand Up @@ -750,6 +763,7 @@ impl ExecutionPlan for SortExec {
self.expr.clone(),
children[0].clone(),
self.preserve_partitioning,
self.fetch,
)))
}

Expand Down Expand Up @@ -778,6 +792,7 @@ impl ExecutionPlan for SortExec {
self.expr.clone(),
self.metrics_set.clone(),
context,
self.fetch(),
)
.map_err(|e| ArrowError::ExternalError(Box::new(e))),
)
Expand Down Expand Up @@ -816,14 +831,14 @@ fn sort_batch(
batch: RecordBatch,
schema: SchemaRef,
expr: &[PhysicalSortExpr],
fetch: Option<usize>,
) -> ArrowResult<BatchWithSortArray> {
// TODO: pushup the limit expression to sort
let sort_columns = expr
.iter()
.map(|e| e.evaluate_to_sort_column(&batch))
.collect::<Result<Vec<SortColumn>>>()?;

let indices = lexsort_to_indices(&sort_columns, None)?;
let indices = lexsort_to_indices(&sort_columns, fetch)?;

// reorder all rows based on sorted indices
let sorted_batch = RecordBatch::try_new(
Expand Down Expand Up @@ -870,6 +885,7 @@ async fn do_sort(
expr: Vec<PhysicalSortExpr>,
metrics_set: CompositeMetricsSet,
context: Arc<TaskContext>,
fetch: Option<usize>,
) -> Result<SendableRecordBatchStream> {
debug!(
"Start do_sort for partition {} of context session_id {} and task_id {:?}",
Expand All @@ -887,6 +903,7 @@ async fn do_sort(
metrics_set,
Arc::new(context.session_config()),
context.runtime_env(),
fetch,
);
context.runtime_env().register_requester(sorter.id());
while let Some(batch) = input.next().await {
Expand Down Expand Up @@ -949,6 +966,7 @@ mod tests {
},
],
Arc::new(CoalescePartitionsExec::new(csv)),
None,
)?);

let result = collect(sort_exec, task_ctx).await?;
Expand Down Expand Up @@ -1011,6 +1029,7 @@ mod tests {
},
],
Arc::new(CoalescePartitionsExec::new(csv)),
None,
)?);

let task_ctx = session_ctx.task_ctx();
Expand Down Expand Up @@ -1083,6 +1102,7 @@ mod tests {
options: SortOptions::default(),
}],
input,
None,
)?);

let result: Vec<RecordBatch> = collect(sort_exec, task_ctx).await?;
Expand Down Expand Up @@ -1159,6 +1179,7 @@ mod tests {
},
],
Arc::new(MemoryExec::try_new(&[vec![batch]], schema, None)?),
None,
)?);

assert_eq!(DataType::Float32, *sort_exec.schema().field(0).data_type());
Expand Down Expand Up @@ -1226,6 +1247,7 @@ mod tests {
options: SortOptions::default(),
}],
blocking_exec,
None,
)?);

let fut = collect(sort_exec, task_ctx);
Expand Down
10 changes: 7 additions & 3 deletions datafusion/core/src/physical_plan/sorts/sort_preserving_merge.rs
Expand Up @@ -874,8 +874,12 @@ mod tests {
sort: Vec<PhysicalSortExpr>,
context: Arc<TaskContext>,
) -> RecordBatch {
let sort_exec =
Arc::new(SortExec::new_with_partitioning(sort.clone(), input, true));
let sort_exec = Arc::new(SortExec::new_with_partitioning(
sort.clone(),
input,
true,
None,
));
sorted_merge(sort_exec, sort, context).await
}

Expand All @@ -885,7 +889,7 @@ mod tests {
context: Arc<TaskContext>,
) -> RecordBatch {
let merge = Arc::new(CoalescePartitionsExec::new(src));
let sort_exec = Arc::new(SortExec::try_new(sort, merge).unwrap());
let sort_exec = Arc::new(SortExec::try_new(sort, merge, None).unwrap());
let mut result = collect(sort_exec, context).await.unwrap();
assert_eq!(result.len(), 1);
result.remove(0)
Expand Down
2 changes: 1 addition & 1 deletion datafusion/core/tests/order_spill_fuzz.rs
Expand Up @@ -75,7 +75,7 @@ async fn run_sort(pool_size: usize, size_spill: Vec<(usize, bool)>) {
}];

let exec = MemoryExec::try_new(&input, schema, None).unwrap();
let sort = Arc::new(SortExec::try_new(sort, Arc::new(exec)).unwrap());
let sort = Arc::new(SortExec::try_new(sort, Arc::new(exec), None).unwrap());

let runtime_config = RuntimeConfig::new().with_memory_manager(
MemoryManagerConfig::try_new_limit(pool_size, 1.0).unwrap(),
Expand Down
1 change: 1 addition & 0 deletions datafusion/core/tests/user_defined_plan.rs
Expand Up @@ -299,6 +299,7 @@ impl OptimizerRule for TopKOptimizerRule {
if let LogicalPlan::Sort(Sort {
ref expr,
ref input,
..
}) = **input
{
if expr.len() == 1 {
Expand Down
2 changes: 2 additions & 0 deletions datafusion/expr/src/logical_plan/builder.rs
Expand Up @@ -396,13 +396,15 @@ impl LogicalPlanBuilder {
return Ok(Self::from(LogicalPlan::Sort(Sort {
expr: normalize_cols(exprs, &self.plan)?,
input: Arc::new(self.plan.clone()),
fetch: None,
})));
}

let plan = self.add_missing_columns(self.plan.clone(), &missing_cols)?;
let sort_plan = LogicalPlan::Sort(Sort {
expr: normalize_cols(exprs, &plan)?,
input: Arc::new(plan.clone()),
fetch: None,
});
// remove pushed down sort columns
let new_expr = schema
Expand Down
8 changes: 7 additions & 1 deletion datafusion/expr/src/logical_plan/plan.rs
Expand Up @@ -806,14 +806,18 @@ impl LogicalPlan {
"Aggregate: groupBy=[{:?}], aggr=[{:?}]",
group_expr, aggr_expr
),
LogicalPlan::Sort(Sort { expr, .. }) => {
LogicalPlan::Sort(Sort { expr, fetch, .. }) => {
write!(f, "Sort: ")?;
for (i, expr_item) in expr.iter().enumerate() {
if i > 0 {
write!(f, ", ")?;
}
write!(f, "{:?}", expr_item)?;
}
if let Some(a) = fetch {
write!(f, ", fetch={}", a)?;
}

Ok(())
}
LogicalPlan::Join(Join {
Expand Down Expand Up @@ -1373,6 +1377,8 @@ pub struct Sort {
pub expr: Vec<Expr>,
/// The incoming logical plan
pub input: Arc<LogicalPlan>,
/// Optional fetch limit
pub fetch: Option<usize>,
}

/// Join two logical plans on one or more join columns
Expand Down
3 changes: 2 additions & 1 deletion datafusion/expr/src/utils.rs
Expand Up @@ -420,9 +420,10 @@ pub fn from_plan(
expr[group_expr.len()..].to_vec(),
schema.clone(),
)?)),
LogicalPlan::Sort(Sort { .. }) => Ok(LogicalPlan::Sort(Sort {
LogicalPlan::Sort(Sort { fetch, .. }) => Ok(LogicalPlan::Sort(Sort {
expr: expr.to_vec(),
input: Arc::new(inputs[0].clone()),
fetch: *fetch,
})),
LogicalPlan::Join(Join {
join_type,
Expand Down
3 changes: 2 additions & 1 deletion datafusion/optimizer/src/common_subexpr_eliminate.rs
Expand Up @@ -196,7 +196,7 @@ fn optimize(
schema.clone(),
)?))
}
LogicalPlan::Sort(Sort { expr, input }) => {
LogicalPlan::Sort(Sort { expr, input, fetch }) => {
let arrays = to_arrays(expr, input, &mut expr_set)?;

let (mut new_expr, new_input) = rewrite_expr(
Expand All @@ -210,6 +210,7 @@ fn optimize(
Ok(LogicalPlan::Sort(Sort {
expr: pop_expr(&mut new_expr)?,
input: Arc::new(new_input),
fetch: *fetch,
}))
}
LogicalPlan::Join { .. }
Expand Down

0 comments on commit 81b5794

Please sign in to comment.