Skip to content

Commit

Permalink
fix partition record batch
Browse files Browse the repository at this point in the history
  • Loading branch information
jiacai2050 committed Nov 8, 2022
1 parent 9badf7e commit 851f98e
Showing 1 changed file with 126 additions and 14 deletions.
140 changes: 126 additions & 14 deletions analytic_engine/src/sst/parquet/builder.rs
Expand Up @@ -2,9 +2,12 @@

//! Sst builder implementation based on parquet.

use std::sync::{
atomic::{AtomicUsize, Ordering},
Arc,
use std::{
collections::VecDeque,
sync::{
atomic::{AtomicUsize, Ordering},
Arc,
},
};

use async_trait::async_trait;
Expand Down Expand Up @@ -55,17 +58,19 @@ struct RecordBytesReader {
compression: Compression,
meta_data: SstMetaData,
total_row_num: Arc<AtomicUsize>,
// Record batch partitioned by given `num_rows_per_row_group`
// Record batch partitioned by exactly given `num_rows_per_row_group`
// There may be more than one `RecordBatchWithKey` inside each partition
partitioned_record_batch: Vec<Vec<RecordBatchWithKey>>,
}

impl RecordBytesReader {
// Partition record batch stream into batch vector with given
// Partition record batch stream into batch vector with exactly given
// `num_rows_per_row_group`
async fn partition_record_batch(&mut self) -> Result<()> {
let mut fetched_row_num = 0;
let mut pending_record_batch: Vec<RecordBatchWithKey> = Default::default();
let mut pending_record_batch: VecDeque<RecordBatchWithKey> = Default::default();
let mut current_batch = Vec::new();
let mut remaining = self.num_rows_per_row_group; // how many records are left for current_batch

while let Some(record_batch) = self.record_stream.next().await {
let record_batch = record_batch.context(PollRecordBatch)?;
Expand All @@ -77,18 +82,50 @@ impl RecordBytesReader {
);

fetched_row_num += record_batch.num_rows();
pending_record_batch.push(record_batch);
pending_record_batch.push_back(record_batch);

// reach batch limit, append to self and reset counter and pending batch
if fetched_row_num >= self.num_rows_per_row_group {
fetched_row_num = 0;
self.partitioned_record_batch
.push(std::mem::take(&mut pending_record_batch));
// Note: pending_record_batch may contains multiple batches
while fetched_row_num >= self.num_rows_per_row_group {
match pending_record_batch.pop_front() {
// accumulated records is enough for one batch
Some(next) if next.num_rows() >= remaining => {
current_batch.push(next.slice(0, remaining));
pending_record_batch
.push_front(next.slice(remaining, next.num_rows() - remaining));

self.partitioned_record_batch
.push(std::mem::take(&mut current_batch));
fetched_row_num -= remaining;
remaining = self.num_rows_per_row_group;
}
// not enough for one batch
Some(next) => {
remaining -= next.num_rows();
fetched_row_num -= next.num_rows();

current_batch.push(next);
}
// nothing left, put back to pending_record_batch
_ => {
for records in std::mem::take(&mut current_batch) {
fetched_row_num += records.num_rows();
pending_record_batch.push_front(records);
}

break;
}
}
}
}

if !pending_record_batch.is_empty() {
self.partitioned_record_batch.push(pending_record_batch);
// collect remaining records into one batch
let mut remaining = Vec::with_capacity(pending_record_batch.len());
while let Some(batch) = pending_record_batch.pop_front() {
remaining.push(batch);
}
if !remaining.is_empty() {
self.partitioned_record_batch.push(remaining);
}

Ok(())
Expand Down Expand Up @@ -206,7 +243,10 @@ mod tests {
tests::{build_row, build_schema},
time::{TimeRange, Timestamp},
};
use common_util::runtime::{self, Runtime};
use common_util::{
runtime::{self, Runtime},
tests::init_log_for_test,
};
use futures::stream;
use object_store::LocalFileSystem;
use table_engine::predicate::Predicate;
Expand Down Expand Up @@ -387,4 +427,76 @@ mod tests {
check_stream(&mut stream, expect_rows).await;
});
}

#[tokio::test]
async fn test_partition_record_batch() {
// row group size: 10
let testcases = vec![
// input, expected
(vec![10, 10], vec![10, 10]),
(vec![10, 10, 1], vec![10, 10, 1]),
(vec![10, 10, 21], vec![10, 10, 10, 10, 1]),
(vec![5, 6, 10], vec![10, 10, 1]),
(vec![5, 4, 4, 30], vec![10, 10, 10, 10, 3]),
];

for (input, expected) in testcases {
test_partition_record_batch_inner(input, expected).await;
}
}

async fn test_partition_record_batch_inner(
input_row_nums: Vec<usize>,
expected_row_nums: Vec<usize>,
) {
init_log_for_test();
let schema = build_schema();
let mut poll_cnt = 0;
let schema_clone = schema.clone();
let record_batch_stream = Box::new(stream::poll_fn(move |_ctx| -> Poll<Option<_>> {
if poll_cnt == input_row_nums.len() {
return Poll::Ready(None);
}

let rows = (0..input_row_nums[poll_cnt])
.map(|_| build_row(b"a", 100, 10.0, "v4"))
.collect::<Vec<_>>();

let batch = build_record_batch_with_key(schema_clone.clone(), rows);
let ret = Poll::Ready(Some(Ok(batch)));
poll_cnt += 1;

ret
}));

let mut reader = RecordBytesReader {
request_id: RequestId::next_id(),
record_stream: record_batch_stream,
num_rows_per_row_group: 10,
compression: Compression::UNCOMPRESSED,
meta_data: SstMetaData {
min_key: Default::default(),
max_key: Default::default(),
time_range: Default::default(),
max_sequence: 1,
schema,
size: 0,
row_num: 0,
storage_format_opts: Default::default(),
bloom_filter: Default::default(),
},
total_row_num: Arc::new(AtomicUsize::new(0)),
partitioned_record_batch: Vec::new(),
};

reader.partition_record_batch().await.unwrap();

for (i, expected_row_num) in expected_row_nums.into_iter().enumerate() {
let actual: usize = reader.partitioned_record_batch[i]
.iter()
.map(|b| b.num_rows())
.sum();
assert_eq!(expected_row_num, actual);
}
}
}

0 comments on commit 851f98e

Please sign in to comment.