diff --git a/datafusion/core/src/physical_plan/sorts/sort.rs b/datafusion/core/src/physical_plan/sorts/sort.rs index 69779ac905b..f48beb34b61 100644 --- a/datafusion/core/src/physical_plan/sorts/sort.rs +++ b/datafusion/core/src/physical_plan/sorts/sort.rs @@ -34,28 +34,30 @@ use crate::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeStrea use crate::physical_plan::sorts::SortedStream; use crate::physical_plan::stream::RecordBatchReceiverStream; use crate::physical_plan::{ - common, DisplayFormatType, Distribution, EmptyRecordBatchStream, ExecutionPlan, - Partitioning, SendableRecordBatchStream, Statistics, + DisplayFormatType, Distribution, EmptyRecordBatchStream, ExecutionPlan, Partitioning, + RecordBatchStream, SendableRecordBatchStream, Statistics, }; use crate::prelude::SessionConfig; -use arrow::array::ArrayRef; +use arrow::array::{make_array, Array, ArrayRef, MutableArrayData, UInt32Array}; pub use arrow::compute::SortOptions; -use arrow::compute::{lexsort_to_indices, take, SortColumn, TakeOptions}; +use arrow::compute::{concat, lexsort_to_indices, take, SortColumn, TakeOptions}; use arrow::datatypes::SchemaRef; use arrow::error::Result as ArrowResult; use arrow::ipc::reader::FileReader; use arrow::record_batch::RecordBatch; use async_trait::async_trait; use futures::lock::Mutex; -use futures::StreamExt; +use futures::{Stream, StreamExt}; use log::{debug, error}; use std::any::Any; +use std::cmp::min; use std::fmt; use std::fmt::{Debug, Formatter}; use std::fs::File; use std::io::BufReader; use std::path::{Path, PathBuf}; use std::sync::Arc; +use std::task::{Context, Poll}; use tempfile::NamedTempFile; use tokio::sync::mpsc::{Receiver, Sender}; use tokio::task; @@ -72,7 +74,7 @@ use tokio::task; struct ExternalSorter { id: MemoryConsumerId, schema: SchemaRef, - in_mem_batches: Mutex>, + in_mem_batches: Mutex>, spills: Mutex>, /// Sort expressions expr: Vec, @@ -105,13 +107,21 @@ impl ExternalSorter { } } - async fn insert_batch(&self, input: RecordBatch) -> Result<()> { + async fn insert_batch( + &self, + input: RecordBatch, + tracking_metrics: &MemTrackingMetrics, + ) -> Result<()> { if input.num_rows() > 0 { let size = batch_byte_size(&input); self.try_grow(size).await?; self.metrics.mem_used().add(size); let mut in_mem_batches = self.in_mem_batches.lock().await; - in_mem_batches.push(input); + // NB timer records time taken on drop, so there are no + // calls to `timer.done()` below. + let _timer = tracking_metrics.elapsed_compute().timer(); + let partial = sort_batch(input, self.schema.clone(), &self.expr)?; + in_mem_batches.push(partial); } Ok(()) } @@ -124,6 +134,7 @@ impl ExternalSorter { /// MergeSort in mem batches as well as spills into total order with `SortPreservingMergeStream`. async fn sort(&self) -> Result { let partition = self.partition_id(); + let batch_size = self.session_config.batch_size; let mut in_mem_batches = self.in_mem_batches.lock().await; if self.spilled_before().await { @@ -136,6 +147,7 @@ impl ExternalSorter { &mut *in_mem_batches, self.schema.clone(), &self.expr, + batch_size, tracking_metrics, )?; let prev_used = self.free_all_memory(); @@ -166,6 +178,7 @@ impl ExternalSorter { &mut *in_mem_batches, self.schema.clone(), &self.expr, + batch_size, tracking_metrics, ); // Report to the memory manager we are no longer using memory @@ -255,6 +268,7 @@ impl MemoryConsumer for ExternalSorter { &mut *in_mem_batches, self.schema.clone(), &*self.expr, + self.session_config.batch_size, tracking_metrics, ); @@ -274,36 +288,268 @@ impl MemoryConsumer for ExternalSorter { /// consume the non-empty `sorted_bathes` and do in_mem_sort fn in_mem_partial_sort( - buffered_batches: &mut Vec, + buffered_batches: &mut Vec, schema: SchemaRef, expressions: &[PhysicalSortExpr], + batch_size: usize, tracking_metrics: MemTrackingMetrics, ) -> Result { assert_ne!(buffered_batches.len(), 0); + if buffered_batches.len() == 1 { + let result = buffered_batches.pop(); + Ok(Box::pin(SizedRecordBatchStream::new( + schema, + vec![Arc::new(result.unwrap().sorted_batch)], + tracking_metrics, + ))) + } else { + let (sorted_arrays, batches): (Vec>, Vec) = + buffered_batches + .drain(..) + .into_iter() + .map(|b| { + let BatchWithSortArray { + sort_arrays, + sorted_batch: batch, + } = b; + (sort_arrays, batch) + }) + .unzip(); + + let sorted_iter = { + // NB timer records time taken on drop, so there are no + // calls to `timer.done()` below. + let _timer = tracking_metrics.elapsed_compute().timer(); + get_sorted_iter(&sorted_arrays, expressions, batch_size)? + }; + Ok(Box::pin(SortedSizedRecordBatchStream::new( + schema, + batches, + sorted_iter, + tracking_metrics, + ))) + } +} - let result = { - // NB timer records time taken on drop, so there are no - // calls to `timer.done()` below. - let _timer = tracking_metrics.elapsed_compute().timer(); +#[derive(Debug, Copy, Clone)] +struct CompositeIndex { + batch_idx: u32, + row_idx: u32, +} - let pre_sort = if buffered_batches.len() == 1 { - buffered_batches.pop() - } else { - let batches = buffered_batches.drain(..).collect::>(); - // combine all record batches into one for each column - common::combine_batches(&batches, schema.clone())? - }; +/// Get sorted iterator by sort concatenated `SortColumn`s +fn get_sorted_iter( + sort_arrays: &[Vec], + expr: &[PhysicalSortExpr], + batch_size: usize, +) -> Result { + let row_indices = sort_arrays + .iter() + .enumerate() + .flat_map(|(i, arrays)| { + (0..arrays[0].len()).map(move |r| CompositeIndex { + // since we original use UInt32Array to index the combined mono batch, + // component record batches won't overflow as well, + // use u32 here for space efficiency. + batch_idx: i as u32, + row_idx: r as u32, + }) + }) + .collect::>(); + + let sort_columns = expr + .iter() + .enumerate() + .map(|(i, expr)| { + let columns_i = sort_arrays + .iter() + .map(|cs| cs[i].as_ref()) + .collect::>(); + Ok(SortColumn { + values: concat(columns_i.as_slice())?, + options: Some(expr.options), + }) + }) + .collect::>>()?; + let indices = lexsort_to_indices(&sort_columns, None)?; - pre_sort - .map(|batch| sort_batch(batch, schema.clone(), expressions)) - .transpose()? - }; + Ok(SortedIterator::new(indices, row_indices, batch_size)) +} - Ok(Box::pin(SizedRecordBatchStream::new( - schema, - vec![Arc::new(result.unwrap())], - tracking_metrics, - ))) +struct SortedIterator { + /// Current logical position in the iterator + pos: usize, + /// Indexes into the input representing the correctly sorted total output + indices: UInt32Array, + /// Map each each logical input index to where it can be found in the sorted input batches + composite: Vec, + /// Maximum batch size to produce + batch_size: usize, + /// total length of the iterator + length: usize, +} + +impl SortedIterator { + fn new( + indices: UInt32Array, + composite: Vec, + batch_size: usize, + ) -> Self { + let length = composite.len(); + Self { + pos: 0, + indices, + composite, + batch_size, + length, + } + } + + fn memory_size(&self) -> usize { + std::mem::size_of_val(self) + + self.indices.get_array_memory_size() + + std::mem::size_of_val(&self.composite[..]) + } +} + +impl Iterator for SortedIterator { + type Item = Vec; + + /// Emit a max of `batch_size` positions each time + fn next(&mut self) -> Option { + if self.pos >= self.length { + return None; + } + + let current_size = min(self.batch_size, self.length - self.pos); + + // Combine adjacent indexes from the same batch to make a slice, + // for more efficient `extend` later. + let mut last_batch_idx = 0; + let mut start_row_idx = 0; + let mut len = 0; + + let mut slices = vec![]; + for i in 0..current_size { + let p = self.pos + i; + let c_index = self.indices.value(p) as usize; + let ci = self.composite[c_index]; + + if len == 0 { + last_batch_idx = ci.batch_idx; + start_row_idx = ci.row_idx; + len = 1; + } else if ci.batch_idx == last_batch_idx { + len += 1; + // since we have pre-sort each of the incoming batches, + // so if we witnessed a wrong order of indexes from the same batch, + // it must be of the same key with the row pointed by start_row_index. + start_row_idx = min(start_row_idx, ci.row_idx); + } else { + slices.push(CompositeSlice { + batch_idx: last_batch_idx, + start_row_idx, + len, + }); + last_batch_idx = ci.batch_idx; + start_row_idx = ci.row_idx; + len = 1; + } + } + + assert!( + len > 0, + "There should have at least one record in a sort output slice." + ); + slices.push(CompositeSlice { + batch_idx: last_batch_idx, + start_row_idx, + len, + }); + + self.pos += current_size; + Some(slices) + } +} + +/// Stream of sorted record batches +struct SortedSizedRecordBatchStream { + schema: SchemaRef, + batches: Vec, + sorted_iter: SortedIterator, + num_cols: usize, + metrics: MemTrackingMetrics, +} + +impl SortedSizedRecordBatchStream { + /// new + pub fn new( + schema: SchemaRef, + batches: Vec, + sorted_iter: SortedIterator, + metrics: MemTrackingMetrics, + ) -> Self { + let size = batches.iter().map(batch_byte_size).sum::() + + sorted_iter.memory_size(); + metrics.init_mem_used(size); + let num_cols = batches[0].num_columns(); + SortedSizedRecordBatchStream { + schema, + batches, + sorted_iter, + num_cols, + metrics, + } + } +} + +impl Stream for SortedSizedRecordBatchStream { + type Item = ArrowResult; + + fn poll_next( + mut self: std::pin::Pin<&mut Self>, + _: &mut Context<'_>, + ) -> Poll> { + match self.sorted_iter.next() { + None => Poll::Ready(None), + Some(slices) => { + let num_rows = slices.iter().map(|s| s.len).sum(); + let output = (0..self.num_cols) + .map(|i| { + let arrays = self + .batches + .iter() + .map(|b| b.column(i).data()) + .collect::>(); + let mut mutable = MutableArrayData::new(arrays, false, num_rows); + for x in slices.iter() { + mutable.extend( + x.batch_idx as usize, + x.start_row_idx as usize, + x.start_row_idx as usize + x.len, + ); + } + make_array(mutable.freeze()) + }) + .collect::>(); + let batch = RecordBatch::try_new(self.schema.clone(), output); + let poll = Poll::Ready(Some(batch)); + self.metrics.record_poll(poll) + } + } + } +} + +struct CompositeSlice { + batch_idx: u32, + start_row_idx: u32, + len: usize, +} + +impl RecordBatchStream for SortedSizedRecordBatchStream { + fn schema(&self) -> SchemaRef { + self.schema.clone() + } } async fn spill_partial_sorted_stream( @@ -539,22 +785,26 @@ impl ExecutionPlan for SortExec { } } +struct BatchWithSortArray { + sort_arrays: Vec, + sorted_batch: RecordBatch, +} + fn sort_batch( batch: RecordBatch, schema: SchemaRef, expr: &[PhysicalSortExpr], -) -> ArrowResult { +) -> ArrowResult { // TODO: pushup the limit expression to sort - let indices = lexsort_to_indices( - &expr - .iter() - .map(|e| e.evaluate_to_sort_column(&batch)) - .collect::>>()?, - None, - )?; + let sort_columns = expr + .iter() + .map(|e| e.evaluate_to_sort_column(&batch)) + .collect::>>()?; + + let indices = lexsort_to_indices(&sort_columns, None)?; // reorder all rows based on sorted indices - RecordBatch::try_new( + let sorted_batch = RecordBatch::try_new( schema, batch .columns() @@ -571,7 +821,25 @@ fn sort_batch( ) }) .collect::>>()?, - ) + )?; + + let sort_arrays = sort_columns + .into_iter() + .map(|sc| { + Ok(take( + sc.values.as_ref(), + &indices, + Some(TakeOptions { + check_bounds: false, + }), + )?) + }) + .collect::>>()?; + + Ok(BatchWithSortArray { + sort_arrays, + sorted_batch, + }) } async fn do_sort( @@ -582,6 +850,8 @@ async fn do_sort( context: Arc, ) -> Result { let schema = input.schema(); + let tracking_metrics = + metrics_set.new_intermediate_tracking(partition_id, context.runtime_env()); let sorter = ExternalSorter::new( partition_id, schema.clone(), @@ -593,7 +863,7 @@ async fn do_sort( context.runtime_env().register_requester(sorter.id()); while let Some(batch) = input.next().await { let batch = batch?; - sorter.insert_batch(batch).await?; + sorter.insert_batch(batch, &tracking_metrics).await?; } sorter.sort().await }