diff --git a/analytic_engine/src/sst/parquet/builder.rs b/analytic_engine/src/sst/parquet/builder.rs index bf12064b0e..ece3d1ee66 100644 --- a/analytic_engine/src/sst/parquet/builder.rs +++ b/analytic_engine/src/sst/parquet/builder.rs @@ -2,9 +2,12 @@ //! Sst builder implementation based on parquet. -use std::sync::{ - atomic::{AtomicUsize, Ordering}, - Arc, +use std::{ + collections::VecDeque, + sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, + }, }; use async_trait::async_trait; @@ -55,17 +58,19 @@ struct RecordBytesReader { compression: Compression, meta_data: SstMetaData, total_row_num: Arc, - // Record batch partitioned by given `num_rows_per_row_group` + // Record batch partitioned by exactly given `num_rows_per_row_group` // There may be more than one `RecordBatchWithKey` inside each partition partitioned_record_batch: Vec>, } impl RecordBytesReader { - // Partition record batch stream into batch vector with given + // Partition record batch stream into batch vector with exactly given // `num_rows_per_row_group` async fn partition_record_batch(&mut self) -> Result<()> { let mut fetched_row_num = 0; - let mut pending_record_batch: Vec = Default::default(); + let mut pending_record_batch: VecDeque = Default::default(); + let mut current_batch = Vec::new(); + let mut remaining = self.num_rows_per_row_group; // how many records are left for current_batch while let Some(record_batch) = self.record_stream.next().await { let record_batch = record_batch.context(PollRecordBatch)?; @@ -77,18 +82,50 @@ impl RecordBytesReader { ); fetched_row_num += record_batch.num_rows(); - pending_record_batch.push(record_batch); + pending_record_batch.push_back(record_batch); // reach batch limit, append to self and reset counter and pending batch - if fetched_row_num >= self.num_rows_per_row_group { - fetched_row_num = 0; - self.partitioned_record_batch - .push(std::mem::take(&mut pending_record_batch)); + // Note: pending_record_batch may contains multiple batches + while fetched_row_num >= self.num_rows_per_row_group { + match pending_record_batch.pop_front() { + // accumulated records is enough for one batch + Some(next) if next.num_rows() >= remaining => { + current_batch.push(next.slice(0, remaining)); + pending_record_batch + .push_front(next.slice(remaining, next.num_rows() - remaining)); + + self.partitioned_record_batch + .push(std::mem::take(&mut current_batch)); + fetched_row_num -= remaining; + remaining = self.num_rows_per_row_group; + } + // not enough for one batch + Some(next) => { + remaining -= next.num_rows(); + fetched_row_num -= next.num_rows(); + + current_batch.push(next); + } + // nothing left, put back to pending_record_batch + _ => { + for records in std::mem::take(&mut current_batch) { + fetched_row_num += records.num_rows(); + pending_record_batch.push_front(records); + } + + break; + } + } } } - if !pending_record_batch.is_empty() { - self.partitioned_record_batch.push(pending_record_batch); + // collect remaining records into one batch + let mut remaining = Vec::with_capacity(pending_record_batch.len()); + while let Some(batch) = pending_record_batch.pop_front() { + remaining.push(batch); + } + if !remaining.is_empty() { + self.partitioned_record_batch.push(remaining); } Ok(()) @@ -206,7 +243,10 @@ mod tests { tests::{build_row, build_schema}, time::{TimeRange, Timestamp}, }; - use common_util::runtime::{self, Runtime}; + use common_util::{ + runtime::{self, Runtime}, + tests::init_log_for_test, + }; use futures::stream; use object_store::LocalFileSystem; use table_engine::predicate::Predicate; @@ -387,4 +427,76 @@ mod tests { check_stream(&mut stream, expect_rows).await; }); } + + #[tokio::test] + async fn test_partition_record_batch() { + // row group size: 10 + let testcases = vec![ + // input, expected + (vec![10, 10], vec![10, 10]), + (vec![10, 10, 1], vec![10, 10, 1]), + (vec![10, 10, 21], vec![10, 10, 10, 10, 1]), + (vec![5, 6, 10], vec![10, 10, 1]), + (vec![5, 4, 4, 30], vec![10, 10, 10, 10, 3]), + ]; + + for (input, expected) in testcases { + test_partition_record_batch_inner(input, expected).await; + } + } + + async fn test_partition_record_batch_inner( + input_row_nums: Vec, + expected_row_nums: Vec, + ) { + init_log_for_test(); + let schema = build_schema(); + let mut poll_cnt = 0; + let schema_clone = schema.clone(); + let record_batch_stream = Box::new(stream::poll_fn(move |_ctx| -> Poll> { + if poll_cnt == input_row_nums.len() { + return Poll::Ready(None); + } + + let rows = (0..input_row_nums[poll_cnt]) + .map(|_| build_row(b"a", 100, 10.0, "v4")) + .collect::>(); + + let batch = build_record_batch_with_key(schema_clone.clone(), rows); + let ret = Poll::Ready(Some(Ok(batch))); + poll_cnt += 1; + + ret + })); + + let mut reader = RecordBytesReader { + request_id: RequestId::next_id(), + record_stream: record_batch_stream, + num_rows_per_row_group: 10, + compression: Compression::UNCOMPRESSED, + meta_data: SstMetaData { + min_key: Default::default(), + max_key: Default::default(), + time_range: Default::default(), + max_sequence: 1, + schema, + size: 0, + row_num: 0, + storage_format_opts: Default::default(), + bloom_filter: Default::default(), + }, + total_row_num: Arc::new(AtomicUsize::new(0)), + partitioned_record_batch: Vec::new(), + }; + + reader.partition_record_batch().await.unwrap(); + + for (i, expected_row_num) in expected_row_nums.into_iter().enumerate() { + let actual: usize = reader.partitioned_record_batch[i] + .iter() + .map(|b| b.num_rows()) + .sum(); + assert_eq!(expected_row_num, actual); + } + } }