diff --git a/springql-core/src/pipeline/pump_model/window_operation_parameter.rs b/springql-core/src/pipeline/pump_model/window_operation_parameter.rs index ae93da38..1369e482 100644 --- a/springql-core/src/pipeline/pump_model/window_operation_parameter.rs +++ b/springql-core/src/pipeline/pump_model/window_operation_parameter.rs @@ -3,11 +3,11 @@ pub(crate) mod aggregate; pub(crate) mod join_parameter; -use self::{aggregate::GroupAggregateParameter, join_parameter::JoinParameter}; +use self::{aggregate::AggregateParameter, join_parameter::JoinParameter}; /// Window operation parameters #[derive(Clone, PartialEq, Debug)] pub(crate) enum WindowOperationParameter { - GroupAggregation(GroupAggregateParameter), + Aggregate(AggregateParameter), Join(JoinParameter), } diff --git a/springql-core/src/pipeline/pump_model/window_operation_parameter/aggregate.rs b/springql-core/src/pipeline/pump_model/window_operation_parameter/aggregate.rs index a7796bc5..5c67519e 100644 --- a/springql-core/src/pipeline/pump_model/window_operation_parameter/aggregate.rs +++ b/springql-core/src/pipeline/pump_model/window_operation_parameter/aggregate.rs @@ -2,21 +2,32 @@ use crate::expr_resolver::expr_label::{AggrExprLabel, ValueExprLabel}; +/// [GROUP BY c1, c2, c3...] +#[derive(Clone, PartialEq, Debug, Default, new)] +pub(crate) struct GroupByLabels( + /// Empty when GROUP BY clause is not supplied. + Vec, +); +impl GroupByLabels { + pub(crate) fn as_labels(&self) -> &[ValueExprLabel] { + &self.0 + } +} + /// TODO [support complex expression with aggregations](https://gh01.base.toyota-tokyo.tech/SpringQL-internal/SpringQL/issues/152) /// /// ```sql /// SELECT group_by, aggr_expr.func(aggr_expr.aggregated) /// FROM s -/// GROUP BY group_by +/// [GROUP BY group_by] /// SLIDING WINDOW ...; /// ``` -#[derive(Copy, Clone, PartialEq, Debug, new)] -pub(crate) struct GroupAggregateParameter { +#[derive(Clone, PartialEq, Debug, new)] +pub(crate) struct AggregateParameter { // TODO multiple aggr_expr pub(crate) aggr_func: AggregateFunctionParameter, pub(crate) aggr_expr: AggrExprLabel, - - pub(crate) group_by: ValueExprLabel, // TODO multiple group by expression + pub(crate) group_by: GroupByLabels, } #[derive(Copy, Clone, Eq, PartialEq, Debug)] diff --git a/springql-core/src/sql_processor/query_planner.rs b/springql-core/src/sql_processor/query_planner.rs index b7712a16..26d07090 100644 --- a/springql-core/src/sql_processor/query_planner.rs +++ b/springql-core/src/sql_processor/query_planner.rs @@ -63,7 +63,8 @@ use crate::{ pipeline::{ pump_model::{ window_operation_parameter::{ - aggregate::GroupAggregateParameter, WindowOperationParameter, + aggregate::{AggregateParameter, GroupByLabels}, + WindowOperationParameter, }, window_parameter::WindowParameter, }, @@ -123,7 +124,7 @@ impl QueryPlanner { match (window_param, group_aggr_param) { (Some(window_param), Some(group_aggr_param)) => Ok(Some(GroupAggregateWindowOp { window_param, - op_param: WindowOperationParameter::GroupAggregation(group_aggr_param), + op_param: WindowOperationParameter::Aggregate(group_aggr_param), })), _ => Ok(None), } @@ -137,8 +138,8 @@ impl QueryPlanner { &self, expr_resolver: &mut ExprResolver, projection_op: &ProjectionOp, - ) -> Result> { - let opt_grouping_elem = self.analyzer.grouping_element(); + ) -> Result> { + let grouping_elements = self.analyzer.grouping_elements(); let aggr_labels = projection_op .expr_labels .iter() @@ -151,28 +152,31 @@ impl QueryPlanner { }) .collect::>(); - match (opt_grouping_elem, aggr_labels.len()) { - (Some(grouping_elem), 1) => { + match aggr_labels.len() { + 1 => { let aggr_label = aggr_labels.get(0).expect("len checked"); let aggr_func = expr_resolver.resolve_aggr_expr(*aggr_label).func; - let group_by_label = match grouping_elem { - GroupingElementSyntax::ValueExpr(expr) => { - expr_resolver.register_value_expr(expr) - } - GroupingElementSyntax::ValueAlias(alias) => { - expr_resolver.resolve_value_alias(alias)? - } - }; - - Ok(Some(GroupAggregateParameter::new( + let group_by_labels = grouping_elements + .iter() + .map(|grouping_elem| match grouping_elem { + GroupingElementSyntax::ValueExpr(expr) => { + Ok(expr_resolver.register_value_expr(expr.clone())) + } + GroupingElementSyntax::ValueAlias(alias) => { + expr_resolver.resolve_value_alias(alias.clone()) + } + }) + .collect::>>()?; + + Ok(Some(AggregateParameter::new( aggr_func, *aggr_label, - group_by_label, + GroupByLabels::new(group_by_labels), ))) } - (None, 0) => Ok(None), - _ => unimplemented!(), + 0 => Ok(None), + _ => unimplemented!("2 or more aggregate expressions"), } } diff --git a/springql-core/src/sql_processor/query_planner/select_syntax_analyzer/group_aggregate.rs b/springql-core/src/sql_processor/query_planner/select_syntax_analyzer/group_aggregate.rs index 3ddafb55..89e6d28e 100644 --- a/springql-core/src/sql_processor/query_planner/select_syntax_analyzer/group_aggregate.rs +++ b/springql-core/src/sql_processor/query_planner/select_syntax_analyzer/group_aggregate.rs @@ -4,8 +4,7 @@ use super::SelectSyntaxAnalyzer; use crate::sql_processor::sql_parser::syntax::GroupingElementSyntax; impl SelectSyntaxAnalyzer { - /// TODO multiple GROUP BY - pub(in super::super) fn grouping_element(&self) -> Option { - self.select_syntax.grouping_element.clone() + pub(in super::super) fn grouping_elements(&self) -> Vec { + self.select_syntax.grouping_elements.clone() } } diff --git a/springql-core/src/sql_processor/sql_parser/pest_grammar/springql.pest b/springql-core/src/sql_processor/sql_parser/pest_grammar/springql.pest index 5832ab67..8c5c7f46 100644 --- a/springql-core/src/sql_processor/sql_parser/pest_grammar/springql.pest +++ b/springql-core/src/sql_processor/sql_parser/pest_grammar/springql.pest @@ -479,7 +479,7 @@ select_stream_command = { ^"SELECT" ~ "STREAM" ~ select_field ~ ("," ~ select_field)* ~ (^"FROM" ~ from_item) - ~ (^"GROUP" ~ "BY" ~ grouping_element)? // TODO multiple grouping elements + ~ group_by_clause? ~ window_clause? } @@ -500,6 +500,10 @@ join_type = { ^"LEFT" ~ ^"OUTER" ~ ^"JOIN" } +group_by_clause = { + ^"GROUP" ~ "BY" ~ grouping_element ~ ("," ~ grouping_element)* +} + grouping_element = { value_expr | value_alias diff --git a/springql-core/src/sql_processor/sql_parser/pest_parser_impl.rs b/springql-core/src/sql_processor/sql_parser/pest_parser_impl.rs index 93a4861a..448d3b25 100644 --- a/springql-core/src/sql_processor/sql_parser/pest_parser_impl.rs +++ b/springql-core/src/sql_processor/sql_parser/pest_parser_impl.rs @@ -210,12 +210,12 @@ impl PestParserImpl { )?; let event_duration = match duration_function { - DurationFunction::Millis => { - Ok(SpringEventDuration::from_millis(integer_constant.to_i64()? as u64)) - } - DurationFunction::Secs => { - Ok(SpringEventDuration::from_secs(integer_constant.to_i64()? as u64)) - } + DurationFunction::Millis => Ok(SpringEventDuration::from_millis( + integer_constant.to_i64()? as u64, + )), + DurationFunction::Secs => Ok(SpringEventDuration::from_secs( + integer_constant.to_i64()? as u64, + )), }?; Ok(SqlValue::NotNull(NnSqlValue::Duration(event_duration))) @@ -549,10 +549,10 @@ impl PestParserImpl { Self::parse_from_item, identity, )?; - let grouping_element = try_parse_child( + let grouping_elements = try_parse_child( &mut params, - Rule::grouping_element, - Self::parse_grouping_element, + Rule::group_by_clause, + Self::parse_group_by_clause, identity, )?; let window_clause = try_parse_child( @@ -565,7 +565,7 @@ impl PestParserImpl { Ok(SelectStreamSyntax { fields, from_item, - grouping_element, + grouping_elements: grouping_elements.unwrap_or_default(), window_clause, }) } @@ -680,6 +680,15 @@ impl PestParserImpl { } } + fn parse_group_by_clause(mut params: FnParseParams) -> Result> { + parse_child_seq( + &mut params, + Rule::grouping_element, + &Self::parse_grouping_element, + &identity, + ) + } + fn parse_grouping_element(mut params: FnParseParams) -> Result { try_parse_child( &mut params, diff --git a/springql-core/src/sql_processor/sql_parser/syntax.rs b/springql-core/src/sql_processor/sql_parser/syntax.rs index b8fa7469..fdb7c558 100644 --- a/springql-core/src/sql_processor/sql_parser/syntax.rs +++ b/springql-core/src/sql_processor/sql_parser/syntax.rs @@ -26,7 +26,10 @@ pub(in crate::sql_processor) struct OptionSyntax { pub(in crate::sql_processor) struct SelectStreamSyntax { pub(in crate::sql_processor) fields: Vec, pub(in crate::sql_processor) from_item: FromItemSyntax, - pub(in crate::sql_processor) grouping_element: Option, + + /// Empty when no GROUP BY clause is supplied. + pub(in crate::sql_processor) grouping_elements: Vec, + pub(in crate::sql_processor) window_clause: Option, } diff --git a/springql-core/src/stream_engine/autonomous_executor/task/window/aggregate.rs b/springql-core/src/stream_engine/autonomous_executor/task/window/aggregate.rs index 84f3cbc5..d9b49c92 100644 --- a/springql-core/src/stream_engine/autonomous_executor/task/window/aggregate.rs +++ b/springql-core/src/stream_engine/autonomous_executor/task/window/aggregate.rs @@ -123,7 +123,7 @@ mod tests { pipeline::{ name::{AggrAlias, ColumnName, StreamName}, pump_model::window_operation_parameter::aggregate::{ - AggregateFunctionParameter, GroupAggregateParameter, + AggregateFunctionParameter, AggregateParameter, GroupByLabels, }, }, sql_processor::sql_parser::syntax::SelectFieldSyntax, @@ -218,10 +218,10 @@ mod tests { period: SpringEventDuration::from_secs(5), allowed_delay: SpringEventDuration::from_secs(1), }, - WindowOperationParameter::GroupAggregation(GroupAggregateParameter { + WindowOperationParameter::Aggregate(AggregateParameter { aggr_func: AggregateFunctionParameter::Avg, aggr_expr: aggr_label, - group_by: group_by_label, + group_by: GroupByLabels::new(vec![group_by_label]), }), ); @@ -449,10 +449,10 @@ mod tests { length: SpringEventDuration::from_secs(10), allowed_delay: SpringEventDuration::from_secs(1), }, - WindowOperationParameter::GroupAggregation(GroupAggregateParameter { + WindowOperationParameter::Aggregate(AggregateParameter { aggr_func: AggregateFunctionParameter::Avg, aggr_expr: aggr_label, - group_by: group_by_label, + group_by: GroupByLabels::new(vec![group_by_label]), }), ); diff --git a/springql-core/src/stream_engine/autonomous_executor/task/window/panes.rs b/springql-core/src/stream_engine/autonomous_executor/task/window/panes.rs index c302aaee..daf9f0fa 100644 --- a/springql-core/src/stream_engine/autonomous_executor/task/window/panes.rs +++ b/springql-core/src/stream_engine/autonomous_executor/task/window/panes.rs @@ -143,7 +143,7 @@ mod tests { expr_resolver::{expr_label::ExprLabel, ExprResolver}, expression::{AggrExpr, ValueExpr}, pipeline::pump_model::window_operation_parameter::aggregate::{ - AggregateFunctionParameter, GroupAggregateParameter, + AggregateFunctionParameter, AggregateParameter, GroupByLabels, }, sql_processor::sql_parser::syntax::SelectFieldSyntax, stream_engine::{ @@ -168,16 +168,17 @@ mod tests { }]; let (mut expr_resolver, labels) = ExprResolver::new(select_list); - let group_by_label = expr_resolver.register_value_expr(group_by_expr); + let group_by_labels = + GroupByLabels::new(vec![expr_resolver.register_value_expr(group_by_expr)]); - WindowOperationParameter::GroupAggregation(GroupAggregateParameter { + WindowOperationParameter::Aggregate(AggregateParameter { aggr_func: AggregateFunctionParameter::Avg, aggr_expr: if let ExprLabel::Aggr(l) = labels[0] { l } else { unreachable!() }, - group_by: group_by_label, + group_by: group_by_labels, }) } diff --git a/springql-core/src/stream_engine/autonomous_executor/task/window/panes/pane/aggregate_pane.rs b/springql-core/src/stream_engine/autonomous_executor/task/window/panes/pane/aggregate_pane.rs index df96cffe..1ce7f913 100644 --- a/springql-core/src/stream_engine/autonomous_executor/task/window/panes/pane/aggregate_pane.rs +++ b/springql-core/src/stream_engine/autonomous_executor/task/window/panes/pane/aggregate_pane.rs @@ -7,9 +7,10 @@ use std::collections::HashMap; use ordered_float::OrderedFloat; use crate::{ + error::Result, expr_resolver::ExprResolver, pipeline::pump_model::window_operation_parameter::{ - aggregate::{AggregateFunctionParameter, GroupAggregateParameter}, + aggregate::{AggregateFunctionParameter, AggregateParameter, GroupByLabels}, WindowOperationParameter, }, stream_engine::{ @@ -31,7 +32,7 @@ pub(in crate::stream_engine::autonomous_executor) struct AggrPane { open_at: SpringTimestamp, close_at: SpringTimestamp, - group_aggregation_parameter: GroupAggregateParameter, + aggregate_parameter: AggregateParameter, inner: AggrPaneInner, } @@ -48,8 +49,8 @@ impl Pane for AggrPane { close_at: SpringTimestamp, op_param: WindowOperationParameter, ) -> Self { - if let WindowOperationParameter::GroupAggregation(group_aggregation_parameter) = op_param { - let inner = match group_aggregation_parameter.aggr_func { + if let WindowOperationParameter::Aggregate(aggregate_parameter) = op_param { + let inner = match aggregate_parameter.aggr_func { AggregateFunctionParameter::Avg => AggrPaneInner::Avg { states: HashMap::new(), }, @@ -58,7 +59,7 @@ impl Pane for AggrPane { Self { open_at, close_at, - group_aggregation_parameter, + aggregate_parameter, inner, } } else { @@ -80,17 +81,15 @@ impl Pane for AggrPane { tuple: &Tuple, _arg: (), ) -> WindowInFlowByWindowTask { - let group_by_value = expr_resolver - .eval_value_expr(self.group_aggregation_parameter.group_by, tuple) - .expect("TODO Result"); - let group_by_value = if let SqlValue::NotNull(v) = group_by_value { - v - } else { - unimplemented!("group by NULL is not supported ") - }; + let group_by_values = GroupByValues::from_group_by_labels( + self.aggregate_parameter.group_by.clone(), + expr_resolver, + tuple, + ) + .expect("TODO handle Result"); let aggregated_value = expr_resolver - .eval_aggr_expr_inner(self.group_aggregation_parameter.aggr_expr, tuple) + .eval_aggr_expr_inner(self.aggregate_parameter.aggr_expr, tuple) .expect("TODO Result"); let aggregated_value = if let SqlValue::NotNull(v) = aggregated_value { v @@ -101,7 +100,7 @@ impl Pane for AggrPane { match &mut self.inner { AggrPaneInner::Avg { states } => { let state = states - .entry(group_by_value) + .entry(group_by_values) .or_insert_with(AvgState::default); state.next( @@ -119,20 +118,25 @@ impl Pane for AggrPane { self, _expr_resolver: &ExprResolver, ) -> (Vec, WindowInFlowByWindowTask) { - let aggr_label = self.group_aggregation_parameter.aggr_expr; - let group_by_label = self.group_aggregation_parameter.group_by; + let aggr_label = self.aggregate_parameter.aggr_expr; + let group_by_labels = self.aggregate_parameter.group_by; match self.inner { AggrPaneInner::Avg { states } => { let aggregated_and_grouping_values_seq = states .into_iter() - .map(|(group_by, state)| { + .map(|(group_by_values, state)| { let aggr_value = SqlValue::NotNull(NnSqlValue::Float(OrderedFloat(state.finalize()))); - AggregatedAndGroupingValues::new( - vec![(aggr_label, aggr_value)], - vec![(group_by_label, SqlValue::NotNull(group_by))], - ) + + let group_bys = group_by_labels + .as_labels() + .iter() + .cloned() + .zip(group_by_values.into_sql_values()) + .collect(); + + AggregatedAndGroupingValues::new(vec![(aggr_label, aggr_value)], group_bys) }) .collect(); @@ -148,6 +152,41 @@ impl Pane for AggrPane { #[derive(Debug)] pub(in crate::stream_engine::autonomous_executor) enum AggrPaneInner { Avg { - states: HashMap, + states: HashMap, }, } + +#[derive(Clone, PartialEq, Eq, Hash, Debug)] +pub(in crate::stream_engine::autonomous_executor) struct GroupByValues( + /// TODO support NULL in GROUP BY elements + Vec, +); + +impl GroupByValues { + /// Order of elements in GROUP BY clause is preserved. + fn from_group_by_labels( + group_by_labels: GroupByLabels, + expr_resolver: &ExprResolver, + tuple: &Tuple, + ) -> Result { + let values = group_by_labels + .as_labels() + .iter() + .map(|group_by_label| { + let group_by_value = expr_resolver.eval_value_expr(*group_by_label, tuple)?; + + if let SqlValue::NotNull(v) = group_by_value { + Ok(v) + } else { + unimplemented!("group by NULL is not supported ") + } + }) + .collect::>>()?; + + Ok(Self(values)) + } + + pub(in crate::stream_engine::autonomous_executor) fn into_sql_values(self) -> Vec { + self.0.into_iter().map(SqlValue::NotNull).collect() + } +} diff --git a/springql-core/tests/feat_aggregation.rs b/springql-core/tests/feat_aggregation.rs new file mode 100644 index 00000000..1a80d94b --- /dev/null +++ b/springql-core/tests/feat_aggregation.rs @@ -0,0 +1,132 @@ +// This file is part of https://github.com/SpringQL/SpringQL which is licensed under MIT OR Apache-2.0. See file LICENSE-MIT or LICENSE-APACHE for full license details. + +mod test_support; + +use pretty_assertions::assert_eq; +use serde_json::json; +use springql_core::error::Result; +use springql_core::low_level_rs::*; +use springql_foreign_service::sink::ForeignSink; +use springql_foreign_service::source::source_input::ForeignSourceInput; +use springql_foreign_service::source::ForeignSource; +use springql_test_logger::setup_test_logger; + +use crate::test_support::*; + +fn gen_source_input() -> Vec { + let json_00_1 = json!({ + "ts": "2020-01-01 00:00:00.000000000", + "ticker": "ORCL", + "amount": 10, + }); + let json_00_2 = json!({ + "ts": "2020-01-01 00:00:09.9999999999", + "ticker": "GOOGL", + "amount": 30, + }); + let json_10_1 = json!({ + "ts": "2020-01-01 00:00:10.0000000000", + "ticker": "IBM", + "amount": 50, + }); + let json_20_1 = json!({ + "ts": "2020-01-01 00:00:20.0000000000", + "ticker": "IBM", + "amount": 70, + }); + + vec![json_00_1, json_00_2, json_10_1, json_20_1] +} + +fn run_and_drain( + ddls: &[String], + source_input: ForeignSourceInput, + test_source: ForeignSource, + test_sink: &ForeignSink, +) -> Vec { + let _pipeline = apply_ddls(ddls, spring_config_default()); + test_source.start(source_input); + drain_from_sink(test_sink) +} + +/// See: +#[test] +fn test_feat_aggregation_without_group_by() -> Result<()> { + setup_test_logger(); + + let source_input = gen_source_input(); + + let test_source = ForeignSource::new().unwrap(); + let test_sink = ForeignSink::start().unwrap(); + + let ddls = vec![ + " + CREATE SOURCE STREAM source_trade ( + ts TIMESTAMP NOT NULL ROWTIME, + ticker TEXT NOT NULL, + amount INTEGER NOT NULL + ); + " + .to_string(), + " + CREATE SINK STREAM sink_avg_all ( + avg_amount FLOAT NOT NULL + ); + " + .to_string(), + " + CREATE PUMP avg_all AS + INSERT INTO sink_avg_all (avg_amount) + SELECT STREAM + AVG(source_trade.amount) AS avg_amount + FROM source_trade + FIXED WINDOW DURATION_SECS(10), DURATION_SECS(0); + " + .to_string(), + format!( + " + CREATE SINK WRITER tcp_sink_trade FOR sink_avg_all + TYPE NET_CLIENT OPTIONS ( + PROTOCOL 'TCP', + REMOTE_HOST '{remote_host}', + REMOTE_PORT '{remote_port}' + ); + ", + remote_host = test_sink.host_ip(), + remote_port = test_sink.port() + ), + format!( + " + CREATE SOURCE READER tcp_trade FOR source_trade + TYPE NET_CLIENT OPTIONS ( + PROTOCOL 'TCP', + REMOTE_HOST '{remote_host}', + REMOTE_PORT '{remote_port}' + ); + ", + remote_host = test_source.host_ip(), + remote_port = test_source.port() + ), + ]; + + let sink_received = run_and_drain( + &ddls, + ForeignSourceInput::new_fifo_batch(source_input), + test_source, + &test_sink, + ); + + assert_eq!(sink_received.len(), 2); + + assert_eq!( + sink_received[0]["avg_amount"].as_f64().unwrap().round() as i32, + 20, + ); + + assert_eq!( + sink_received[1]["avg_amount"].as_f64().unwrap().round() as i32, + 50, + ); + + Ok(()) +}