Skip to content

Commit

Permalink
feat(stream,over-window): implement AggregateState to allow aggrega…
Browse files Browse the repository at this point in the history
…te function over window (#9248)

Signed-off-by: Richard Chien <stdrc@outlook.com>
  • Loading branch information
stdrc committed Apr 21, 2023
1 parent 14cf958 commit 64aa26d
Show file tree
Hide file tree
Showing 9 changed files with 789 additions and 42 deletions.
82 changes: 76 additions & 6 deletions src/expr/src/function/window/call.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,88 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use std::cmp::Ordering;

use risingwave_common::types::DataType;

use super::WindowFuncKind;
use crate::function::aggregate::AggArgs;

#[derive(Clone)]
#[derive(Clone, Eq, PartialEq, Hash)]
pub enum Frame {
/// Frame by row offset, for `lag` and `lead`.
Offset(isize),
// Rows(Bound<usize>, Bound<usize>),
// Groups(Bound<usize>, Bound<usize>),
// Range(Bound<ScalarImpl>, Bound<ScalarImpl>),
Rows(FrameBound<usize>, FrameBound<usize>),
// Groups(FrameBound<usize>, FrameBound<usize>),
// Range(FrameBound<ScalarImpl>, FrameBound<ScalarImpl>),
}

impl Frame {
pub fn is_valid(&self) -> bool {
match self {
Frame::Rows(start, end) => start.partial_cmp(end).map(|o| o.is_le()).unwrap_or(false),
}
}

pub fn start_is_unbounded(&self) -> bool {
match self {
Frame::Rows(start, _) => matches!(start, FrameBound::UnboundedPreceding),
}
}

pub fn end_is_unbounded(&self) -> bool {
match self {
Frame::Rows(_, end) => matches!(end, FrameBound::UnboundedFollowing),
}
}
}

#[derive(Clone, Eq, PartialEq, Hash)]
pub enum FrameBound<T> {
UnboundedPreceding,
Preceding(T),
CurrentRow,
Following(T),
UnboundedFollowing,
}

impl<T: Ord> PartialOrd for FrameBound<T> {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
use FrameBound::*;
match (self, other) {
(UnboundedPreceding, UnboundedPreceding) => None,
(UnboundedPreceding, _) => Some(Ordering::Less),
(_, UnboundedPreceding) => Some(Ordering::Greater),

(UnboundedFollowing, UnboundedFollowing) => None,
(UnboundedFollowing, _) => Some(Ordering::Greater),
(_, UnboundedFollowing) => Some(Ordering::Less),

(CurrentRow, CurrentRow) => Some(Ordering::Equal),

// it's ok to think preceding(0) < current row here
(Preceding(_), CurrentRow) => Some(Ordering::Less),
(CurrentRow, Preceding(_)) => Some(Ordering::Greater),

// it's ok to think following(0) > current row here
(Following(_), CurrentRow) => Some(Ordering::Greater),
(CurrentRow, Following(_)) => Some(Ordering::Less),

(Preceding(n1), Preceding(n2)) => n2.partial_cmp(n1),
(Following(n1), Following(n2)) => n1.partial_cmp(n2),
(Preceding(_), Following(_)) => Some(Ordering::Less),
(Following(_), Preceding(_)) => Some(Ordering::Greater),
}
}
}

impl FrameBound<usize> {
pub fn to_offset(&self) -> Option<isize> {
match self {
FrameBound::UnboundedPreceding | FrameBound::UnboundedFollowing => None,
FrameBound::CurrentRow => Some(0),
FrameBound::Preceding(n) => Some(-(*n as isize)),
FrameBound::Following(n) => Some(*n as isize),
}
}
}

#[derive(Clone)]
Expand Down
6 changes: 4 additions & 2 deletions src/expr/src/function/window/kind.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

use parse_display::{Display, FromStr};

use crate::function::aggregate::AggKind;

/// Kind of window functions.
#[derive(Debug, Display, FromStr, Copy, Clone, PartialEq, Eq, Hash)]
#[display(style = "snake_case")]
Expand All @@ -26,6 +28,6 @@ pub enum WindowFuncKind {
// NthValue,

// Aggregate functions that are used with `OVER`.
// #[display("{0}")]
// Aggregate(AggKind),
#[display("{0}")]
Aggregate(AggKind),
}
64 changes: 48 additions & 16 deletions src/stream/src/executor/over_window/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ use std::marker::PhantomData;

use futures::StreamExt;
use futures_async_stream::{for_await, try_stream};
use itertools::Itertools;
use risingwave_common::array::column::Column;
use risingwave_common::array::stream_record::Record;
use risingwave_common::array::{Op, StreamChunk};
Expand Down Expand Up @@ -221,7 +220,7 @@ impl<S: StateStore> OverWindowExecutor<S> {
return Ok(());
}

let mut partition = Partition::new(&this.calls);
let mut partition = Partition::new(&this.calls)?;

// Recover states from state table.
let table_iter = this
Expand All @@ -246,7 +245,7 @@ impl<S: StateStore> OverWindowExecutor<S> {
.input_pk_indices
.iter()
.map(|idx| this.col_mapping.upstream_to_state_table(*idx).unwrap())
.collect_vec(),
.collect::<Vec<_>>(),
),
&vec![OrderType::ascending(); this.input_pk_indices.len()],
)?
Expand All @@ -265,7 +264,7 @@ impl<S: StateStore> OverWindowExecutor<S> {
.val_indices()
.iter()
.map(|idx| this.col_mapping.upstream_to_state_table(*idx).unwrap())
.collect_vec(),
.collect::<Vec<_>>(),
)
.into_owned_row()
.into_inner()
Expand All @@ -279,9 +278,9 @@ impl<S: StateStore> OverWindowExecutor<S> {

// Ignore ready windows (all ready windows were outputted before).
while partition.is_ready() {
partition.states.iter_mut().for_each(|state| {
state.slide();
});
for state in &mut partition.states {
state.output()?;
}
}

cache.put(encoded_partition_key.clone(), partition);
Expand Down Expand Up @@ -363,7 +362,9 @@ impl<S: StateStore> OverWindowExecutor<S> {
let (ret_values, evict_hints): (Vec<_>, Vec<_>) = partition
.states
.iter_mut()
.map(|state| state.slide())
.map(|state| state.output())
.try_collect::<Vec<_>>()?
.into_iter()
.map(|o| (o.return_value, o.evict_hint))
.unzip();

Expand Down Expand Up @@ -402,10 +403,7 @@ impl<S: StateStore> OverWindowExecutor<S> {
}
}

let columns: Vec<Column> = builders
.into_iter()
.map(|b| b.finish().into())
.collect_vec();
let columns: Vec<Column> = builders.into_iter().map(|b| b.finish().into()).collect();
let chunk_size = columns[0].len();
Ok(if chunk_size > 0 {
Some(StreamChunk::new(
Expand Down Expand Up @@ -501,8 +499,8 @@ mod tests {
use risingwave_common::test_prelude::StreamChunkTestExt;
use risingwave_common::types::DataType;
use risingwave_common::util::sort_util::OrderType;
use risingwave_expr::function::aggregate::AggArgs;
use risingwave_expr::function::window::{Frame, WindowFuncCall, WindowFuncKind};
use risingwave_expr::function::aggregate::{AggArgs, AggKind};
use risingwave_expr::function::window::{Frame, FrameBound, WindowFuncCall, WindowFuncKind};
use risingwave_storage::memory::MemoryStateStore;
use risingwave_storage::StateStore;

Expand Down Expand Up @@ -575,13 +573,13 @@ mod tests {
kind: WindowFuncKind::Lag,
args: AggArgs::Unary(DataType::Int32, 3),
return_type: DataType::Int32,
frame: Frame::Offset(-1),
frame: Frame::Rows(FrameBound::Preceding(1), FrameBound::CurrentRow),
},
WindowFuncCall {
kind: WindowFuncKind::Lead,
args: AggArgs::Unary(DataType::Int32, 3),
return_type: DataType::Int32,
frame: Frame::Offset(1),
frame: Frame::Rows(FrameBound::CurrentRow, FrameBound::Following(1)),
},
];

Expand Down Expand Up @@ -655,4 +653,38 @@ mod tests {
over_window.expect_barrier().await;
}
}

#[tokio::test]
async fn test_over_window_aggregate() {
let store = MemoryStateStore::new();
let calls = vec![WindowFuncCall {
kind: WindowFuncKind::Aggregate(AggKind::Sum),
args: AggArgs::Unary(DataType::Int32, 3),
return_type: DataType::Int64,
frame: Frame::Rows(FrameBound::Preceding(1), FrameBound::Following(1)),
}];

let (mut tx, mut over_window) = create_executor(calls.clone(), store.clone()).await;

tx.push_barrier(1, false);
over_window.expect_barrier().await;

tx.push_chunk(StreamChunk::from_pretty(
" I T I i
+ 1 p1 100 10
+ 1 p1 101 16
+ 4 p1 102 20",
));
assert_eq!(1, over_window.expect_watermark().await.val.into_int64());
let chunk = over_window.expect_chunk().await;
println!("{}", chunk.to_pretty_string());
assert_eq!(
chunk,
StreamChunk::from_pretty(
" T I I I
+ p1 1 100 26
+ p1 1 101 46"
)
);
}
}
7 changes: 4 additions & 3 deletions src/stream/src/executor/over_window/partition.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,16 @@ use risingwave_common::estimate_size::EstimateSize;
use risingwave_expr::function::window::WindowFuncCall;

use super::state::{create_window_state, StateKey, WindowState};
use crate::executor::StreamExecutorResult;

pub(super) struct Partition {
pub states: Vec<Box<dyn WindowState + Send>>,
}

impl Partition {
pub fn new(calls: &[WindowFuncCall]) -> Self {
let states = calls.iter().map(create_window_state).collect();
Self { states }
pub fn new(calls: &[WindowFuncCall]) -> StreamExecutorResult<Self> {
let states = calls.iter().map(create_window_state).try_collect()?;
Ok(Self { states })
}

pub fn is_aligned(&self) -> bool {
Expand Down

0 comments on commit 64aa26d

Please sign in to comment.