Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(stream,over-window): implement AggregateState to allow aggregate function over window #9248

Merged
merged 16 commits into from
Apr 21, 2023
27 changes: 22 additions & 5 deletions src/expr/src/function/window/call.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,28 @@ use crate::function::aggregate::AggArgs;

#[derive(Clone)]
pub enum Frame {
/// Frame by row offset, for `lag` and `lead`.
Offset(isize),
// Rows(Bound<usize>, Bound<usize>),
// Groups(Bound<usize>, Bound<usize>),
// Range(Bound<ScalarImpl>, Bound<ScalarImpl>),
Rows(FrameBound<usize>, FrameBound<usize>),
// Groups(FrameBound<usize>, FrameBound<usize>),
// Range(FrameBound<ScalarImpl>, FrameBound<ScalarImpl>),
}

#[derive(Clone)]
pub enum FrameBound<T> {
Unbounded,
CurrentRow,
Preceding(T),
Following(T),
}

impl FrameBound<usize> {
pub fn to_offset(&self) -> Option<isize> {
match self {
FrameBound::Unbounded => None,
FrameBound::CurrentRow => Some(0),
FrameBound::Preceding(n) => Some(-(*n as isize)),
FrameBound::Following(n) => Some(*n as isize),
}
}
}

#[derive(Clone)]
Expand Down
6 changes: 4 additions & 2 deletions src/expr/src/function/window/kind.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

use parse_display::{Display, FromStr};

use crate::function::aggregate::AggKind;

/// Kind of window functions.
#[derive(Debug, Display, FromStr, Copy, Clone, PartialEq, Eq, Hash)]
#[display(style = "snake_case")]
Expand All @@ -26,6 +28,6 @@ pub enum WindowFuncKind {
// NthValue,

// Aggregate functions that are used with `OVER`.
// #[display("{0}")]
// Aggregate(AggKind),
#[display("{0}")]
Aggregate(AggKind),
}
64 changes: 48 additions & 16 deletions src/stream/src/executor/over_window/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ use std::marker::PhantomData;

use futures::StreamExt;
use futures_async_stream::{for_await, try_stream};
use itertools::Itertools;
use risingwave_common::array::column::Column;
use risingwave_common::array::stream_record::Record;
use risingwave_common::array::{Op, StreamChunk};
Expand Down Expand Up @@ -221,7 +220,7 @@ impl<S: StateStore> OverWindowExecutor<S> {
return Ok(());
}

let mut partition = Partition::new(&this.calls);
let mut partition = Partition::new(&this.calls)?;

// Recover states from state table.
let table_iter = this
Expand All @@ -246,7 +245,7 @@ impl<S: StateStore> OverWindowExecutor<S> {
.input_pk_indices
.iter()
.map(|idx| this.col_mapping.upstream_to_state_table(*idx).unwrap())
.collect_vec(),
.collect::<Vec<_>>(),
),
&vec![OrderType::ascending(); this.input_pk_indices.len()],
)?
Expand All @@ -265,7 +264,7 @@ impl<S: StateStore> OverWindowExecutor<S> {
.val_indices()
.iter()
.map(|idx| this.col_mapping.upstream_to_state_table(*idx).unwrap())
.collect_vec(),
.collect::<Vec<_>>(),
)
.into_owned_row()
.into_inner()
Expand All @@ -279,9 +278,9 @@ impl<S: StateStore> OverWindowExecutor<S> {

// Ignore ready windows (all ready windows were outputted before).
while partition.is_ready() {
partition.states.iter_mut().for_each(|state| {
state.slide();
});
for state in &mut partition.states {
state.output()?;
}
}

cache.put(encoded_partition_key.clone(), partition);
Expand Down Expand Up @@ -363,7 +362,9 @@ impl<S: StateStore> OverWindowExecutor<S> {
let (ret_values, evict_hints): (Vec<_>, Vec<_>) = partition
.states
.iter_mut()
.map(|state| state.slide())
.map(|state| state.output())
.try_collect::<Vec<_>>()?
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the point of returning an error early here when .map(|o| (o.return_value, o.evict_hint)) has no side effects?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because batch aggregator will possibly throw an error

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree, but the collect is unnecessary if we have rust-itertools/itertools#402

We can resolve it later :)

.into_iter()
.map(|o| (o.return_value, o.evict_hint))
.unzip();

Expand Down Expand Up @@ -402,10 +403,7 @@ impl<S: StateStore> OverWindowExecutor<S> {
}
}

let columns: Vec<Column> = builders
.into_iter()
.map(|b| b.finish().into())
.collect_vec();
let columns: Vec<Column> = builders.into_iter().map(|b| b.finish().into()).collect();
let chunk_size = columns[0].len();
Ok(if chunk_size > 0 {
Some(StreamChunk::new(
Expand Down Expand Up @@ -501,8 +499,8 @@ mod tests {
use risingwave_common::test_prelude::StreamChunkTestExt;
use risingwave_common::types::DataType;
use risingwave_common::util::sort_util::OrderType;
use risingwave_expr::function::aggregate::AggArgs;
use risingwave_expr::function::window::{Frame, WindowFuncCall, WindowFuncKind};
use risingwave_expr::function::aggregate::{AggArgs, AggKind};
use risingwave_expr::function::window::{Frame, FrameBound, WindowFuncCall, WindowFuncKind};
use risingwave_storage::memory::MemoryStateStore;
use risingwave_storage::StateStore;

Expand Down Expand Up @@ -575,13 +573,13 @@ mod tests {
kind: WindowFuncKind::Lag,
args: AggArgs::Unary(DataType::Int32, 3),
return_type: DataType::Int32,
frame: Frame::Offset(-1),
frame: Frame::Rows(FrameBound::Preceding(1), FrameBound::CurrentRow),
},
WindowFuncCall {
kind: WindowFuncKind::Lead,
args: AggArgs::Unary(DataType::Int32, 3),
return_type: DataType::Int32,
frame: Frame::Offset(1),
frame: Frame::Rows(FrameBound::CurrentRow, FrameBound::Following(1)),
},
];

Expand Down Expand Up @@ -655,4 +653,38 @@ mod tests {
over_window.expect_barrier().await;
}
}

#[tokio::test]
async fn test_over_window_aggregate() {
let store = MemoryStateStore::new();
let calls = vec![WindowFuncCall {
kind: WindowFuncKind::Aggregate(AggKind::Sum),
args: AggArgs::Unary(DataType::Int32, 3),
return_type: DataType::Int64,
frame: Frame::Rows(FrameBound::Preceding(1), FrameBound::Following(1)),
}];

let (mut tx, mut over_window) = create_executor(calls.clone(), store.clone()).await;

tx.push_barrier(1, false);
over_window.expect_barrier().await;

tx.push_chunk(StreamChunk::from_pretty(
" I T I i
+ 1 p1 100 10
+ 1 p1 101 16
+ 4 p1 102 20",
));
assert_eq!(1, over_window.expect_watermark().await.val.into_int64());
let chunk = over_window.expect_chunk().await;
println!("{}", chunk.to_pretty_string());
assert_eq!(
chunk,
StreamChunk::from_pretty(
" T I I I
+ p1 1 100 26
+ p1 1 101 46"
)
);
}
}
7 changes: 4 additions & 3 deletions src/stream/src/executor/over_window/partition.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,16 @@ use itertools::Itertools;
use risingwave_expr::function::window::WindowFuncCall;

use super::state::{create_window_state, StateKey, WindowState};
use crate::executor::StreamExecutorResult;

pub(super) struct Partition {
pub states: Vec<Box<dyn WindowState + Send>>,
}

impl Partition {
pub fn new(calls: &[WindowFuncCall]) -> Self {
let states = calls.iter().map(create_window_state).collect();
Self { states }
pub fn new(calls: &[WindowFuncCall]) -> StreamExecutorResult<Self> {
let states = calls.iter().map(create_window_state).try_collect()?;
Ok(Self { states })
}

pub fn is_aligned(&self) -> bool {
Expand Down
148 changes: 148 additions & 0 deletions src/stream/src/executor/over_window/state/aggregate.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
// Copyright 2023 RisingWave Labs
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

use std::collections::BTreeSet;

use futures::FutureExt;
use risingwave_common::array::{DataChunk, Vis};
use risingwave_common::must_match;
use risingwave_common::types::{DataType, Datum};
use risingwave_common::util::iter_util::ZipEqFast;
use risingwave_expr::function::aggregate::{AggArgs, AggCall};
use risingwave_expr::function::window::{WindowFuncCall, WindowFuncKind};
use risingwave_expr::vector_op::agg::AggStateFactory;
use smallvec::SmallVec;

use super::buffer::WindowBuffer;
use super::{StateEvictHint, StateKey, StateOutput, StatePos, WindowState};
use crate::executor::StreamExecutorResult;

pub(super) struct AggregateState {
factory: AggStateFactory,
arg_data_types: Vec<DataType>,
buffer: WindowBuffer<StateKey, SmallVec<[Datum; 2]>>,
}

impl AggregateState {
pub fn new(call: &WindowFuncCall) -> StreamExecutorResult<Self> {
let agg_kind = must_match!(call.kind, WindowFuncKind::Aggregate(agg_kind) => agg_kind);
let arg_data_types = call.args.arg_types().to_vec();
let agg_call = AggCall {
kind: agg_kind,
args: match &call.args {
// convert args to [0] or [0, 1]
AggArgs::None => AggArgs::None,
AggArgs::Unary(data_type, _) => AggArgs::Unary(data_type.to_owned(), 0),
AggArgs::Binary(data_types, _) => AggArgs::Binary(data_types.to_owned(), [0, 1]),
},
return_type: call.return_type.clone(),
column_orders: Vec::new(), // the input is already sorted
// TODO(rc): support filter on window function call
filter: None,
// TODO(rc): support distinct on window function call? PG doesn't support it either.
distinct: false,
};
Ok(Self {
factory: AggStateFactory::new(agg_call)?,
arg_data_types,
buffer: WindowBuffer::new(call.frame.clone()),
})
}
}

impl WindowState for AggregateState {
fn append(&mut self, key: StateKey, args: SmallVec<[Datum; 2]>) {
self.buffer.append(key, args);
}

fn curr_window(&self) -> StatePos<'_> {
let window = self.buffer.curr_window();
StatePos {
key: window.key,
is_ready: window.is_ready(),
}
}

fn output(&mut self) -> StreamExecutorResult<StateOutput> {
assert!(self.buffer.curr_window().is_ready());
let wrapper = BatchAggregatorWrapper {
factory: &self.factory,
arg_data_types: &self.arg_data_types,
};
let return_value =
wrapper.aggregate(self.buffer.curr_window_values().map(SmallVec::as_slice))?;
let removed_keys: BTreeSet<_> = self.buffer.slide().collect();
Ok(StateOutput {
return_value,
evict_hint: if removed_keys.is_empty() {
StateEvictHint::CannotEvict(
self.buffer
.curr_window_left()
.expect("sliding without removing, must have window left")
.0
.clone(),
)
} else {
StateEvictHint::CanEvict(removed_keys)
},
})
}
}

struct BatchAggregatorWrapper<'a> {
factory: &'a AggStateFactory,
arg_data_types: &'a [DataType],
}

impl BatchAggregatorWrapper<'_> {
fn aggregate<'a>(
&'a self,
values: impl ExactSizeIterator<Item = &'a [Datum]>,
) -> StreamExecutorResult<Datum> {
// TODO(rc): switch to a better general version of aggregator implementation

let n_values = values.len();

let mut args_builders = self
.arg_data_types
.iter()
.map(|data_type| data_type.create_array_builder(n_values))
.collect::<Vec<_>>();
for value in values {
for (builder, datum) in args_builders.iter_mut().zip_eq_fast(value.iter()) {
builder.append_datum(datum);
}
}
let columns = args_builders
.into_iter()
.map(|builder| builder.finish().into())
.collect::<Vec<_>>();
let chunk = DataChunk::new(columns, Vis::Compact(n_values));

let mut aggregator = self.factory.create_agg_state();
aggregator
.update_multi(&chunk, 0, n_values)
.now_or_never()
.expect("we don't support UDAF currently, so the function should return immediately")?;

let mut ret_value_builder = aggregator.return_type().create_array_builder(1);
aggregator.output(&mut ret_value_builder)?;
Ok(ret_value_builder.finish().to_datum())
}
}

#[cfg(test)]
mod tests {
// TODO(rc): need to add some unit tests
}