Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(stream,over-window): implement AggregateState to allow aggregate function over window #9248

Merged
merged 16 commits into from
Apr 21, 2023
82 changes: 76 additions & 6 deletions src/expr/src/function/window/call.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,88 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use std::cmp::Ordering;

use risingwave_common::types::DataType;

use super::WindowFuncKind;
use crate::function::aggregate::AggArgs;

#[derive(Clone)]
#[derive(Clone, Eq, PartialEq, Hash)]
pub enum Frame {
/// Frame by row offset, for `lag` and `lead`.
Offset(isize),
// Rows(Bound<usize>, Bound<usize>),
// Groups(Bound<usize>, Bound<usize>),
// Range(Bound<ScalarImpl>, Bound<ScalarImpl>),
Rows(FrameBound<usize>, FrameBound<usize>),
// Groups(FrameBound<usize>, FrameBound<usize>),
// Range(FrameBound<ScalarImpl>, FrameBound<ScalarImpl>),
}

impl Frame {
pub fn is_valid(&self) -> bool {
match self {
Frame::Rows(start, end) => start.partial_cmp(end).map(|o| o.is_le()).unwrap_or(false),
}
}

pub fn start_is_unbounded(&self) -> bool {
match self {
Frame::Rows(start, _) => matches!(start, FrameBound::UnboundedPreceding),
}
}

pub fn end_is_unbounded(&self) -> bool {
match self {
Frame::Rows(_, end) => matches!(end, FrameBound::UnboundedFollowing),
}
}
}

#[derive(Clone, Eq, PartialEq, Hash)]
pub enum FrameBound<T> {
UnboundedPreceding,
Preceding(T),
CurrentRow,
Following(T),
UnboundedFollowing,
}

impl<T: Ord> PartialOrd for FrameBound<T> {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
use FrameBound::*;
match (self, other) {
(UnboundedPreceding, UnboundedPreceding) => None,
(UnboundedPreceding, _) => Some(Ordering::Less),
(_, UnboundedPreceding) => Some(Ordering::Greater),

(UnboundedFollowing, UnboundedFollowing) => None,
(UnboundedFollowing, _) => Some(Ordering::Greater),
(_, UnboundedFollowing) => Some(Ordering::Less),

(CurrentRow, CurrentRow) => Some(Ordering::Equal),

// it's ok to think preceding(0) < current row here
(Preceding(_), CurrentRow) => Some(Ordering::Less),
(CurrentRow, Preceding(_)) => Some(Ordering::Greater),

// it's ok to think following(0) > current row here
(Following(_), CurrentRow) => Some(Ordering::Greater),
(CurrentRow, Following(_)) => Some(Ordering::Less),

(Preceding(n1), Preceding(n2)) => n2.partial_cmp(n1),
(Following(n1), Following(n2)) => n1.partial_cmp(n2),
(Preceding(_), Following(_)) => Some(Ordering::Less),
(Following(_), Preceding(_)) => Some(Ordering::Greater),
}
}
}

impl FrameBound<usize> {
pub fn to_offset(&self) -> Option<isize> {
match self {
FrameBound::UnboundedPreceding | FrameBound::UnboundedFollowing => None,
FrameBound::CurrentRow => Some(0),
FrameBound::Preceding(n) => Some(-(*n as isize)),
FrameBound::Following(n) => Some(*n as isize),
}
}
}

#[derive(Clone)]
Expand Down
6 changes: 4 additions & 2 deletions src/expr/src/function/window/kind.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

use parse_display::{Display, FromStr};

use crate::function::aggregate::AggKind;

/// Kind of window functions.
#[derive(Debug, Display, FromStr, Copy, Clone, PartialEq, Eq, Hash)]
#[display(style = "snake_case")]
Expand All @@ -26,6 +28,6 @@ pub enum WindowFuncKind {
// NthValue,

// Aggregate functions that are used with `OVER`.
// #[display("{0}")]
// Aggregate(AggKind),
#[display("{0}")]
Aggregate(AggKind),
}
64 changes: 48 additions & 16 deletions src/stream/src/executor/over_window/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ use std::marker::PhantomData;

use futures::StreamExt;
use futures_async_stream::{for_await, try_stream};
use itertools::Itertools;
use risingwave_common::array::column::Column;
use risingwave_common::array::stream_record::Record;
use risingwave_common::array::{Op, StreamChunk};
Expand Down Expand Up @@ -221,7 +220,7 @@ impl<S: StateStore> OverWindowExecutor<S> {
return Ok(());
}

let mut partition = Partition::new(&this.calls);
let mut partition = Partition::new(&this.calls)?;

// Recover states from state table.
let table_iter = this
Expand All @@ -246,7 +245,7 @@ impl<S: StateStore> OverWindowExecutor<S> {
.input_pk_indices
.iter()
.map(|idx| this.col_mapping.upstream_to_state_table(*idx).unwrap())
.collect_vec(),
.collect::<Vec<_>>(),
),
&vec![OrderType::ascending(); this.input_pk_indices.len()],
)?
Expand All @@ -265,7 +264,7 @@ impl<S: StateStore> OverWindowExecutor<S> {
.val_indices()
.iter()
.map(|idx| this.col_mapping.upstream_to_state_table(*idx).unwrap())
.collect_vec(),
.collect::<Vec<_>>(),
)
.into_owned_row()
.into_inner()
Expand All @@ -279,9 +278,9 @@ impl<S: StateStore> OverWindowExecutor<S> {

// Ignore ready windows (all ready windows were outputted before).
while partition.is_ready() {
partition.states.iter_mut().for_each(|state| {
state.slide();
});
for state in &mut partition.states {
state.output()?;
}
}

cache.put(encoded_partition_key.clone(), partition);
Expand Down Expand Up @@ -363,7 +362,9 @@ impl<S: StateStore> OverWindowExecutor<S> {
let (ret_values, evict_hints): (Vec<_>, Vec<_>) = partition
.states
.iter_mut()
.map(|state| state.slide())
.map(|state| state.output())
.try_collect::<Vec<_>>()?
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the point of returning an error early here when .map(|o| (o.return_value, o.evict_hint)) has no side effects?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because batch aggregator will possibly throw an error

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree, but the collect is unnecessary if we have rust-itertools/itertools#402

We can resolve it later :)

.into_iter()
.map(|o| (o.return_value, o.evict_hint))
.unzip();

Expand Down Expand Up @@ -402,10 +403,7 @@ impl<S: StateStore> OverWindowExecutor<S> {
}
}

let columns: Vec<Column> = builders
.into_iter()
.map(|b| b.finish().into())
.collect_vec();
let columns: Vec<Column> = builders.into_iter().map(|b| b.finish().into()).collect();
let chunk_size = columns[0].len();
Ok(if chunk_size > 0 {
Some(StreamChunk::new(
Expand Down Expand Up @@ -501,8 +499,8 @@ mod tests {
use risingwave_common::test_prelude::StreamChunkTestExt;
use risingwave_common::types::DataType;
use risingwave_common::util::sort_util::OrderType;
use risingwave_expr::function::aggregate::AggArgs;
use risingwave_expr::function::window::{Frame, WindowFuncCall, WindowFuncKind};
use risingwave_expr::function::aggregate::{AggArgs, AggKind};
use risingwave_expr::function::window::{Frame, FrameBound, WindowFuncCall, WindowFuncKind};
use risingwave_storage::memory::MemoryStateStore;
use risingwave_storage::StateStore;

Expand Down Expand Up @@ -575,13 +573,13 @@ mod tests {
kind: WindowFuncKind::Lag,
args: AggArgs::Unary(DataType::Int32, 3),
return_type: DataType::Int32,
frame: Frame::Offset(-1),
frame: Frame::Rows(FrameBound::Preceding(1), FrameBound::CurrentRow),
},
WindowFuncCall {
kind: WindowFuncKind::Lead,
args: AggArgs::Unary(DataType::Int32, 3),
return_type: DataType::Int32,
frame: Frame::Offset(1),
frame: Frame::Rows(FrameBound::CurrentRow, FrameBound::Following(1)),
},
];

Expand Down Expand Up @@ -655,4 +653,38 @@ mod tests {
over_window.expect_barrier().await;
}
}

#[tokio::test]
async fn test_over_window_aggregate() {
let store = MemoryStateStore::new();
let calls = vec![WindowFuncCall {
kind: WindowFuncKind::Aggregate(AggKind::Sum),
args: AggArgs::Unary(DataType::Int32, 3),
return_type: DataType::Int64,
frame: Frame::Rows(FrameBound::Preceding(1), FrameBound::Following(1)),
}];

let (mut tx, mut over_window) = create_executor(calls.clone(), store.clone()).await;

tx.push_barrier(1, false);
over_window.expect_barrier().await;

tx.push_chunk(StreamChunk::from_pretty(
" I T I i
+ 1 p1 100 10
+ 1 p1 101 16
+ 4 p1 102 20",
));
assert_eq!(1, over_window.expect_watermark().await.val.into_int64());
let chunk = over_window.expect_chunk().await;
println!("{}", chunk.to_pretty_string());
assert_eq!(
chunk,
StreamChunk::from_pretty(
" T I I I
+ p1 1 100 26
+ p1 1 101 46"
)
);
}
}
7 changes: 4 additions & 3 deletions src/stream/src/executor/over_window/partition.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,16 @@ use itertools::Itertools;
use risingwave_expr::function::window::WindowFuncCall;

use super::state::{create_window_state, StateKey, WindowState};
use crate::executor::StreamExecutorResult;

pub(super) struct Partition {
pub states: Vec<Box<dyn WindowState + Send>>,
}

impl Partition {
pub fn new(calls: &[WindowFuncCall]) -> Self {
let states = calls.iter().map(create_window_state).collect();
Self { states }
pub fn new(calls: &[WindowFuncCall]) -> StreamExecutorResult<Self> {
let states = calls.iter().map(create_window_state).try_collect()?;
Ok(Self { states })
}

pub fn is_aligned(&self) -> bool {
Expand Down